Skip to content

Commit 77bf22e

Browse files
committed
Fixed syntax errors
1 parent 30cd60c commit 77bf22e

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

scikit_tt/data_driven/tgedmd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from typing import List
2+
from typing import List, Union
33
from scikit_tt.data_driven.transform import Function
44
from scikit_tt.tensor_train import TT
55

@@ -440,7 +440,7 @@ def _reduced_matrix_tgedmd(u: List['TT'],
440440
def _contraction_step_LPsi_u(psi_k: list,
441441
x: np.ndarray, bx: np.ndarray, sig_x: np.ndarray,
442442
u_k: np.ndarray, position: str="middle",
443-
v: np.ndarray=None) np.ndarray:
443+
v: np.ndarray=None) -> np.ndarray:
444444
"""
445445
Helper function for the tensor network contraction Lpsi(x)^T \times \hat{U}, for
446446
one data point x, within non-reversible tgEDMD. The contraction is performed using

scikit_tt/tensor_train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,15 +1723,15 @@ def squeeze(self) -> 'TT':
17231723

17241724
# find cores with row and column dimension equal to 1
17251725
no_mode_list = []
1726-
for i in range(t.order):
1726+
for i in range(self.order):
17271727
if self.row_dims[i] == 1 and self.col_dims[i] == 1:
17281728
no_mode_list.append(i)
17291729

17301730
# cores with row or column dimension (or both) larger than 1
1731-
mode_list = list(np.setdiff1d(np.arange(t.order), no_mode_list))
1731+
mode_list = list(np.setdiff1d(np.arange(self.order), no_mode_list))
17321732

17331733
# append t.order for later loop
1734-
mode_list += [t.order]
1734+
mode_list += [self.order]
17351735

17361736
# define core list
17371737
cores = []
@@ -1741,7 +1741,7 @@ def squeeze(self) -> 'TT':
17411741
if mode_list[0]>0:
17421742
core_tmp = self.cores[0][0,0,0,:][None,:]
17431743
for i in range(1,mode_list[0]):
1744-
core_tmp = core_tmp@t.cores[i][:,0,0,:]
1744+
core_tmp = core_tmp@self.cores[i][:,0,0,:]
17451745
self.cores[mode_list[0]] = np.tensordot(core_tmp, self.cores[mode_list[0]], axes=(1,0))
17461746

17471747
# contract cores with row and column dimension with relevant cores from the right

0 commit comments

Comments
 (0)