Skip to content

Commit 52f04b6

Browse files
author
Alexander Novikov
authored
Merge pull request #36 from Bihaqo/develop
0.2.0
2 parents 53e6345 + de08139 commit 52f04b6

20 files changed

+2379
-423
lines changed

CHANGELOG.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Change Log
2+
All notable changes to this project will be documented in this file.
3+
4+
The format is based on [Keep a Changelog](http://keepachangelog.com/)
5+
and this project adheres to [Semantic Versioning](http://semver.org/).
6+
7+
## [Unreleased]
8+
9+
## [0.2.0] - 2017-03-23
10+
### Added
11+
- (Partial) support for batches of TT-tensors.
12+
- Riemannian module (projection on the tangent space).
13+
- op property and str method for TensorTrain
14+
- concat_along_batch_dim
15+
- expand_batch_dim
16+
- gram_matrix
17+
- Multiplication by a number
18+
19+
### Changed
20+
- Fix add function for dtypes not equal tf.float32
21+
- flat_inner and quadratic_form now return numbers (instead of 1 x 1 tensors)
22+
23+
## [0.1.0] - 2017-03-12
24+
### Added
25+
- Indexing (e.g. TensorTrain[:, 3, 2:4])
26+
- Full (converting TT to dense)
27+
- TT-SVD and rounding
28+
- Basic arithmetic (add, multiply, matmul, flat_inner)
29+
- Variables support
30+
- Kronecker module (functions for TT-rank 1 TT-matrices)
31+
- quadratic_form
32+
- frobenius_norm
33+
34+
[Unreleased]: https://github.com/Bihaqo/t3f/compare/master...develop
35+
[0.2.0]: https://github.com/Bihaqo/t3f/compare/0.1.0...0.2.0
36+
[0.1.0]: https://github.com/Bihaqo/t3f/compare/f24409508...0.1.0

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
TensorFlow implementation of the Tensor Train (TT) -Toolbox.
22

33
# Installation
4-
First, [install TensorFlow](https://www.tensorflow.org/install/). Then simply run
4+
First, [install TensorFlow](https://www.tensorflow.org/install/) v1 or higher. Then simply run
55
```bash
66
pip install t3f
77
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22

33
setup(name='t3f',
4-
version='0.1.0',
4+
version='0.2.0',
55
description='Tensor Train decomposition on TensorFlow',
66
url='https://github.com/Bihaqo/t3f',
77
author='Alexander Novikov',

t3f/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from tensor_train import *
1+
from tensor_train_base import TensorTrainBase
2+
from tensor_train import TensorTrain
3+
from tensor_train_batch import TensorTrainBatch
24
from variables import *
35
from ops import *
6+
from batch_ops import *
47
from initializers import *
58
from regularizers import *
9+
from riemannian import *
610
from shapes import *
711
from decompositions import *

t3f/batch_ops.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import tensorflow as tf
2+
3+
from tensor_train_base import TensorTrainBase
4+
from tensor_train_batch import TensorTrainBatch
5+
import ops
6+
7+
8+
def concat_along_batch_dim(tt_list):
9+
"""Concat all TensorTrainBatch objects along batch dimension.
10+
11+
Args:
12+
tt_list: a list of TensorTrainBatch objects.
13+
14+
Returns:
15+
TensorTrainBatch
16+
"""
17+
ndims = tt_list[0].ndims()
18+
19+
if isinstance(tt_list, TensorTrainBase):
20+
# Not a list but just one element, nothing to concat.
21+
return tt_list
22+
23+
for batch_idx in range(len(tt_list)):
24+
if not isinstance(tt_list[batch_idx], TensorTrainBatch):
25+
raise ValueError('All objects in the list should be TTBatch objects, got '
26+
'%s' % tt_list[batch_idx])
27+
for batch_idx in range(1, len(tt_list)):
28+
if tt_list[batch_idx].get_raw_shape() != tt_list[0].get_raw_shape():
29+
raise ValueError('Shapes of all TT-batch objects should coincide, got %s '
30+
'and %s' % (tt_list[0].get_raw_shape(),
31+
tt_list[batch_idx].get_raw_shape()))
32+
if tt_list[batch_idx].get_tt_ranks() != tt_list[0].get_tt_ranks():
33+
raise ValueError('TT-ranks of all TT-batch objects should coincide, got '
34+
'%s and %s' % (tt_list[0].get_tt_ranks(),
35+
tt_list[batch_idx].get_tt_ranks()))
36+
37+
res_cores = []
38+
for core_idx in range(ndims):
39+
curr_core = tf.concat([tt.tt_cores[core_idx] for tt in tt_list], axis=0)
40+
res_cores.append(curr_core)
41+
42+
batch_size = sum([tt.batch_size for tt in tt_list])
43+
44+
return TensorTrainBatch(res_cores, tt_list[0].get_raw_shape(),
45+
tt_list[0].get_tt_ranks(), batch_size)
46+
47+
48+
def gram_matrix(tt_vectors, matrix=None):
49+
"""Computes Gramian matrix of a batch of TT-vecors.
50+
51+
If matrix is None, computes
52+
res[i, j] = t3f.flat_inner(tt_vectors[i], tt_vectors[j]).
53+
If matrix is present, computes
54+
res[i, j] = t3f.flat_inner(tt_vectors[i], t3f.matmul(matrix, tt_vectors[j]))
55+
or more shorly
56+
res[i, j] = tt_vectors[i]^T * matrix * tt_vectors[j]
57+
58+
Args:
59+
tt_vectors: TensorTrainBatch.
60+
matrix: None, or TensorTrain matrix.
61+
62+
Returns:
63+
tf.tensor with the Gram matrix.
64+
"""
65+
ndims = tt_vectors.ndims()
66+
if matrix is None:
67+
curr_core = tt_vectors.tt_cores[0]
68+
res = tf.einsum('paijb,qcijd->pqbd', curr_core, curr_core)
69+
for core_idx in range(1, ndims):
70+
curr_core = tt_vectors.tt_cores[core_idx]
71+
res = tf.einsum('pqac,paijb,qcijd->pqbd', res, curr_core, curr_core)
72+
else:
73+
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
74+
vectors_shape = tt_vectors.get_shape()
75+
if vectors_shape[2] == 1 and vectors_shape[1] != 1:
76+
# TODO: not very efficient, better to use different order in einsum.
77+
tt_vectors = ops.transpose(tt_vectors)
78+
vectors_shape = tt_vectors.get_shape()
79+
if vectors_shape[1] != 1:
80+
# TODO: do something so that in case the shape is undefined on compilation
81+
# it still works.
82+
raise ValueError('The tt_vectors argument should be vectors (not '
83+
'matrices) with shape defined on compilation.')
84+
curr_core = tt_vectors.tt_cores[0]
85+
curr_matrix_core = matrix.tt_cores[0]
86+
# We enumerate the dummy dimension (that takes 1 value) with `k`.
87+
res = tf.einsum('pakib,cijd,qekjf->pqbdf', curr_core, curr_matrix_core,
88+
curr_core)
89+
for core_idx in range(1, ndims):
90+
curr_core = tt_vectors.tt_cores[core_idx]
91+
curr_matrix_core = matrix.tt_cores[core_idx]
92+
res = tf.einsum('pqace,pakib,cijd,qekjf->pqbdf', res, curr_core,
93+
curr_matrix_core, curr_core)
94+
95+
# Squeeze to make the result of size batch_size x batch_size instead of
96+
# batch_size x batch_size x 1 x 1.
97+
return tf.squeeze(res)

t3f/batch_ops_test.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
4+
from tensor_train import TensorTrain
5+
from tensor_train_batch import TensorTrainBatch
6+
import ops
7+
import batch_ops
8+
import initializers
9+
10+
11+
class BatchOpsTest(tf.test.TestCase):
12+
13+
def testConcatMatrix(self):
14+
# Test concating TTMatrix batches along batch dimension.
15+
first = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=1)
16+
second = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=4)
17+
third = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3)
18+
first_res = batch_ops.concat_along_batch_dim((first))
19+
first_res = ops.full(first_res)
20+
first_second_res = batch_ops.concat_along_batch_dim((first, second))
21+
first_second_res = ops.full(first_second_res)
22+
first_second_third_res = batch_ops.concat_along_batch_dim((first, second,
23+
third))
24+
first_second_third_res = ops.full(first_second_third_res)
25+
26+
first_full = ops.full(first)
27+
second_full = ops.full(second)
28+
third_full = ops.full(third)
29+
first_desired = first_full
30+
first_second_desired = tf.concat((first_full, second_full), axis=0)
31+
first_second_third_desired = tf.concat((first_full, second_full, third_full),
32+
axis=0)
33+
with self.test_session() as sess:
34+
res = sess.run((first_res, first_second_res, first_second_third_res,
35+
first_desired, first_second_desired,
36+
first_second_third_desired))
37+
first_res_val = res[0]
38+
first_second_res_val = res[1]
39+
first_second_third_res_val = res[2]
40+
first_desired_val = res[3]
41+
first_second_desired_val = res[4]
42+
first_second_third_desired_val = res[5]
43+
self.assertAllClose(first_res_val, first_desired_val)
44+
self.assertAllClose(first_second_res_val, first_second_desired_val)
45+
self.assertAllClose(first_second_third_res_val, first_second_third_desired_val)
46+
47+
def testGramMatrix(self):
48+
# Test Gram Matrix of a batch of TT vectors.
49+
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=5)
50+
res_actual = batch_ops.gram_matrix(tt_vectors)
51+
full_vectors = tf.reshape(ops.full(tt_vectors), (5, 6))
52+
res_desired = tf.matmul(full_vectors, tf.transpose(full_vectors))
53+
res_desired = tf.squeeze(res_desired)
54+
with self.test_session() as sess:
55+
res_actual_val, res_desired_val = sess.run((res_actual, res_desired))
56+
self.assertAllClose(res_desired_val, res_actual_val)
57+
58+
def testGramMatrixWithMatrix(self):
59+
# Test Gram Matrix of a batch of TT vectors with providing a matrix, so we
60+
# should compute
61+
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
62+
tt_vectors = initializers.random_matrix_batch((None, (2, 3)), batch_size=4)
63+
matrix = initializers.random_matrix(((2, 3), (2, 3)))
64+
res_actual = batch_ops.gram_matrix(tt_vectors, matrix)
65+
full_vectors = tf.reshape(ops.full(tt_vectors), (4, 6))
66+
with self.test_session() as sess:
67+
res = sess.run((res_actual, full_vectors, ops.full(matrix)))
68+
res_actual_val, vectors_val, matrix_val = res
69+
res_desired_val = np.zeros((4, 4))
70+
for i in range(4):
71+
for j in range(4):
72+
curr_val = np.dot(vectors_val[i], matrix_val)
73+
curr_val = np.dot(curr_val, vectors_val[j])
74+
res_desired_val[i, j] = curr_val
75+
self.assertAllClose(res_desired_val, res_actual_val, atol=1e-5, rtol=1e-5)
76+
77+
if __name__ == "__main__":
78+
tf.test.main()
79+

0 commit comments

Comments
 (0)