1+ import tensorflow as tf
2+
3+ from tensor_train_base import TensorTrainBase
4+ from tensor_train_batch import TensorTrainBatch
5+ import ops
6+
7+
8+ def concat_along_batch_dim (tt_list ):
9+ """Concat all TensorTrainBatch objects along batch dimension.
10+
11+ Args:
12+ tt_list: a list of TensorTrainBatch objects.
13+
14+ Returns:
15+ TensorTrainBatch
16+ """
17+ ndims = tt_list [0 ].ndims ()
18+
19+ if isinstance (tt_list , TensorTrainBase ):
20+ # Not a list but just one element, nothing to concat.
21+ return tt_list
22+
23+ for batch_idx in range (len (tt_list )):
24+ if not isinstance (tt_list [batch_idx ], TensorTrainBatch ):
25+ raise ValueError ('All objects in the list should be TTBatch objects, got '
26+ '%s' % tt_list [batch_idx ])
27+ for batch_idx in range (1 , len (tt_list )):
28+ if tt_list [batch_idx ].get_raw_shape () != tt_list [0 ].get_raw_shape ():
29+ raise ValueError ('Shapes of all TT-batch objects should coincide, got %s '
30+ 'and %s' % (tt_list [0 ].get_raw_shape (),
31+ tt_list [batch_idx ].get_raw_shape ()))
32+ if tt_list [batch_idx ].get_tt_ranks () != tt_list [0 ].get_tt_ranks ():
33+ raise ValueError ('TT-ranks of all TT-batch objects should coincide, got '
34+ '%s and %s' % (tt_list [0 ].get_tt_ranks (),
35+ tt_list [batch_idx ].get_tt_ranks ()))
36+
37+ res_cores = []
38+ for core_idx in range (ndims ):
39+ curr_core = tf .concat ([tt .tt_cores [core_idx ] for tt in tt_list ], axis = 0 )
40+ res_cores .append (curr_core )
41+
42+ batch_size = sum ([tt .batch_size for tt in tt_list ])
43+
44+ return TensorTrainBatch (res_cores , tt_list [0 ].get_raw_shape (),
45+ tt_list [0 ].get_tt_ranks (), batch_size )
46+
47+
48+ def gram_matrix (tt_vectors , matrix = None ):
49+ """Computes Gramian matrix of a batch of TT-vecors.
50+
51+ If matrix is None, computes
52+ res[i, j] = t3f.flat_inner(tt_vectors[i], tt_vectors[j]).
53+ If matrix is present, computes
54+ res[i, j] = t3f.flat_inner(tt_vectors[i], t3f.matmul(matrix, tt_vectors[j]))
55+ or more shorly
56+ res[i, j] = tt_vectors[i]^T * matrix * tt_vectors[j]
57+
58+ Args:
59+ tt_vectors: TensorTrainBatch.
60+ matrix: None, or TensorTrain matrix.
61+
62+ Returns:
63+ tf.tensor with the Gram matrix.
64+ """
65+ ndims = tt_vectors .ndims ()
66+ if matrix is None :
67+ curr_core = tt_vectors .tt_cores [0 ]
68+ res = tf .einsum ('paijb,qcijd->pqbd' , curr_core , curr_core )
69+ for core_idx in range (1 , ndims ):
70+ curr_core = tt_vectors .tt_cores [core_idx ]
71+ res = tf .einsum ('pqac,paijb,qcijd->pqbd' , res , curr_core , curr_core )
72+ else :
73+ # res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
74+ vectors_shape = tt_vectors .get_shape ()
75+ if vectors_shape [2 ] == 1 and vectors_shape [1 ] != 1 :
76+ # TODO: not very efficient, better to use different order in einsum.
77+ tt_vectors = ops .transpose (tt_vectors )
78+ vectors_shape = tt_vectors .get_shape ()
79+ if vectors_shape [1 ] != 1 :
80+ # TODO: do something so that in case the shape is undefined on compilation
81+ # it still works.
82+ raise ValueError ('The tt_vectors argument should be vectors (not '
83+ 'matrices) with shape defined on compilation.' )
84+ curr_core = tt_vectors .tt_cores [0 ]
85+ curr_matrix_core = matrix .tt_cores [0 ]
86+ # We enumerate the dummy dimension (that takes 1 value) with `k`.
87+ res = tf .einsum ('pakib,cijd,qekjf->pqbdf' , curr_core , curr_matrix_core ,
88+ curr_core )
89+ for core_idx in range (1 , ndims ):
90+ curr_core = tt_vectors .tt_cores [core_idx ]
91+ curr_matrix_core = matrix .tt_cores [core_idx ]
92+ res = tf .einsum ('pqace,pakib,cijd,qekjf->pqbdf' , res , curr_core ,
93+ curr_matrix_core , curr_core )
94+
95+ # Squeeze to make the result of size batch_size x batch_size instead of
96+ # batch_size x batch_size x 1 x 1.
97+ return tf .squeeze (res )
0 commit comments