Skip to content

Commit a01b54b

Browse files
committed
Initial open source release
0 parents  commit a01b54b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+5531
-0
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
.ipynb_checkpoints
2+
__pycache__
3+
*.pyc
4+
build
5+
dist
6+
*.egg-info

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 Oleg Smirnov <[email protected]>
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# TensorFlow ManOpt
2+
3+
[![License](https://img.shields.io/:license-mit-blue.svg)](https://badges.mit-license.org)
4+
5+
A library for manifold-constrained optimization in TensorFlow.
6+
7+
## Installation
8+
9+
To install the latest development version from GitHub:
10+
11+
```bash
12+
pip install git+https://github.com/master/tensorflow-manopt.git
13+
```
14+
15+
To install a package from PyPI:
16+
17+
```bash
18+
pip install tensorflow-manopt
19+
```
20+
21+
## Features
22+
23+
The core package implements concepts in differential geometry, such as
24+
manifolds and Riemannian metrics with associated exponential and logarithmic
25+
maps, geodesics, retractions, and transports. For manifolds, where closed-form
26+
expressions are not available, the library provides numerical approximations.
27+
28+
<img align="right" width="400" src="https://github.com/master/tensorflow-manopt/blob/master/examples/usage.png?raw=true">
29+
30+
```python
31+
S = manopt.manifolds.Sphere()
32+
33+
x = S.projx(tf.constant([0.1, -0.1, 0.1]))
34+
u = S.proju(x, tf.constant([1., 1., 1.]))
35+
v = S.proju(x, tf.constant([-0.7, -1.4, 1.4]))
36+
37+
y = S.exp(x, v)
38+
39+
u_ = S.transp(x, y, u)
40+
v_ = S.transp(x, y, v)
41+
```
42+
43+
### Manifolds
44+
45+
- `manifolds.Cholesky` - manifold of lower triangular matrices with positive diagonal elements
46+
- `manifolds.Euclidian` - unconstrained manifold with the Euclidean metric
47+
- `manifolds.Grassmannian` - manifold of `p`-dimensional linear subspaces of the `n`-dimensional space
48+
- `manifolds.Hyperboloid` - manifold of `n`-dimensional hyperbolic space embedded in the `n+1`-dimensional Minkowski space
49+
- `manifolds.Poincare` - the Poincaré ball model of the hyperbolic space
50+
- `manifolds.Product` - Cartesian product of manifolds
51+
- `manifolds.SPDAffineInvariant` - manifold of symmetric positive definite (SPD) matrices endowed with the affine-invariant metric
52+
- `manifolds.SPDLogCholesky` - SPD manifold with the Log-Cholesky metric
53+
- `manifolds.SPDLogEuclidean` - SPD manifold with the Log-Euclidean metric
54+
- `manifolds.SpecialOrthogonal` - manifold of rotation matrices
55+
- `manifolds.Sphere` - manifold of unit-normalized points
56+
- `manifolds.StiefelEuclidean` - manifold of orthonormal `p`-frames in the `n`-dimensional space endowed with the Euclidean metric
57+
- `manifolds.StiefelCanonical` - Stiefel manifold with the canonical metric
58+
- `manifolds.StiefelCayley` - Stiefel manifold the retraction map via an iterative Cayley transform
59+
60+
61+
### Optimizers
62+
63+
Constrained optimization algorithms work as drop-in replacements for Keras
64+
optimizers for sparse and dense updates in both Eager and Graph modes.
65+
66+
- `optimizers.RiemannianSGD` - Riemannian Gradient Descent
67+
- `optimizers.RiemannianAdam` - Riemannian Adam and AMSGrad
68+
- `optimizers.ConstrainedRMSProp` - Constrained RMSProp
69+
70+
### Layers
71+
72+
- `layers.ManifoldEmbedding` - constrained `keras.layers.Embedding` layer
73+
74+
## Examples
75+
76+
- [SPDNet](examples/spdnet/) - Huang, Zhiwu, and Luc Van Gool. "A Riemannian network for SPD matrix learning." Proceedings of the Thirty-First AAAI Conference on Artificial Intelligence. AAAI Press, 2017.
77+
- [LieNet](examples/lienet/) - Huang, Zhiwu, et al. "Deep learning on Lie groups for skeleton-based action recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
78+
- [GrNet](examples/grnet/) - Huang, Zhiwu, Jiqing Wu, and Luc Van Gool. "Building Deep Networks on Grassmann Manifolds." AAAI. AAAI Press, 2018.
79+
- [Hyperbolic Neural Network](examples/hyperbolic_nn/) - Ganea, Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic neural networks." Advances in neural information processing systems. 2018.
80+
- [Poincaré GloVe](examples/poincare_glove/) - Tifrea, Alexandru, Gary Becigneul, and Octavian-Eugen Ganea. "Poincaré Glove: Hyperbolic Word Embeddings." International Conference on Learning Representations. 2018.
81+
82+
## Acknowledgment
83+
84+
TensorFlow ManOpt was inspired by many similar projects:
85+
86+
- [Manopt](https://www.manopt.org/), a matlab toolbox for optimization on manifolds
87+
- [Pymanopt](https://www.pymanopt.org/), a Python toolbox for optimization on manifolds
88+
- [Geoopt](https://geoopt.readthedocs.io/): Riemannian Optimization in PyTorch
89+
- [Geomstats](https://geomstats.github.io/), an open-source Python package for computations and statistics on nonlinear manifolds
90+
91+
## License
92+
93+
The code is MIT-licensed.

examples/grnet/README.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# GrNet in TensorFlow
2+
3+
Implementation of GrNet [1], a deep network on Grassmann manifolds.
4+
5+
<img align="center" width="800" src="https://github.com/master/tensorflow-manopt/blob/master/examples/grnet/grnet.png?raw=true">
6+
7+
## Requirements
8+
9+
* Python 3.5+
10+
* SciPy
11+
* NumPy
12+
* TensorFlow 2.3+
13+
* TensorFlow ManOpt
14+
15+
## Training
16+
17+
Configure `gcloud` to use Python 3:
18+
19+
```bash
20+
gcloud config set ml_engine/local_python /usr/bin/python3
21+
```
22+
23+
Train GrNet locally on the Acted Facial Expression in Wild [2] dataset:
24+
25+
```bash
26+
gcloud ai-platform local train \
27+
--module-name grnet.task \
28+
--package-path . \
29+
-- \
30+
--data-dir data
31+
--job-dir ckpt
32+
```
33+
34+
## References
35+
36+
1. Huang, Zhiwu, Jiqing Wu, and Luc Van Gool. "Building Deep Networks on
37+
Grassmann Manifolds." AAAI. AAAI Press, 2018.
38+
2. Dhall, Abhinav, et al. "Acted facial expressions in the wild database."
39+
Australian National University, Canberra, Australia, Technical Report
40+
TR-CS-11 2 (2011): 1.

examples/grnet/__init__.py

Whitespace-only changes.

examples/grnet/grnet.png

51.7 KB
Loading

examples/grnet/model.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import tensorflow as tf
2+
3+
from tensorflow_manopt.variable import assign_to_manifold
4+
from tensorflow_manopt.manifolds import Grassmannian
5+
from tensorflow_manopt.manifolds import utils
6+
from tensorflow_manopt.optimizers import RiemannianSGD
7+
8+
9+
@tf.keras.utils.register_keras_serializable(name="FRMap")
10+
class FRMap(tf.keras.layers.Layer):
11+
"""Full Rank Mapping layer."""
12+
13+
def __init__(self, output_dim, num_proj=8, *args, **kwargs):
14+
"""Instantiate the FRMap layer.
15+
16+
Args:
17+
output_dim: projection output dimension
18+
num_proj: number of projections to compute
19+
"""
20+
super().__init__(*args, **kwargs)
21+
self.output_dim = output_dim
22+
self.num_proj = num_proj
23+
24+
def build(self, input_shape):
25+
grassmannian = Grassmannian()
26+
self.w = self.add_weight(
27+
"w",
28+
shape=[self.num_proj, input_shape[-2], self.output_dim],
29+
initializer=grassmannian.random,
30+
)
31+
assign_to_manifold(self.w, grassmannian)
32+
self._expand = len(input_shape) == 3
33+
34+
def call(self, inputs):
35+
if self._expand:
36+
inputs = tf.expand_dims(inputs, -3)
37+
return utils.transposem(self.w) @ inputs
38+
39+
def get_config(self):
40+
config = {"output_dim": self.output_dim, "num_proj": self.num_proj}
41+
return dict(list(super().get_config().items()) + list(config.items()))
42+
43+
44+
@tf.keras.utils.register_keras_serializable(name="ReOrth")
45+
class ReOrth(tf.keras.layers.Layer):
46+
"""Re-Orthonormalization layer."""
47+
48+
def call(self, inputs):
49+
q, _r = tf.linalg.qr(inputs)
50+
return q
51+
52+
53+
@tf.keras.utils.register_keras_serializable(name="ProjMap")
54+
class ProjMap(tf.keras.layers.Layer):
55+
"""Projection Mapping layer."""
56+
57+
def call(self, inputs):
58+
return inputs @ utils.transposem(inputs)
59+
60+
61+
@tf.keras.utils.register_keras_serializable(name="ProjPooling")
62+
class ProjPooling(tf.keras.layers.Layer):
63+
"""Projection Pooling layer."""
64+
65+
def __init__(self, stride=2, *args, **kwargs):
66+
"""Instantiate the ProjPooling layer.
67+
68+
Args:
69+
stride: factor by which to downscale
70+
"""
71+
super().__init__(*args, **kwargs)
72+
self.stride = stride
73+
74+
def call(self, inputs):
75+
shape = tf.shape(inputs)
76+
new_shape = [
77+
shape[0],
78+
shape[1],
79+
self.stride,
80+
shape[2] // self.stride,
81+
shape[3],
82+
]
83+
return tf.reduce_mean(
84+
tf.reshape(inputs, new_shape), axis=-3, keepdims=False
85+
)
86+
87+
def get_config(self):
88+
config = {"stride": self.stride}
89+
return dict(list(super().get_config().items()) + list(config.items()))
90+
91+
92+
@tf.keras.utils.register_keras_serializable(name="OrthMap")
93+
class OrthMap(tf.keras.layers.Layer):
94+
"""Orthonormal Mapping layer."""
95+
96+
def __init__(self, top_eigen, *args, **kwargs):
97+
"""Instantiate the OrthMap layer.
98+
99+
Args:
100+
num_eigen: number of top eigenvectors to retain
101+
"""
102+
super().__init__(*args, **kwargs)
103+
self.top_eigen = top_eigen
104+
105+
def call(self, inputs):
106+
_s, u, _vt = tf.linalg.svd(inputs)
107+
return u[..., : self.top_eigen]
108+
109+
def get_config(self):
110+
config = {"top_eigen": self.top_eigen}
111+
return dict(list(super().get_config().items()) + list(config.items()))
112+
113+
114+
def create_model(
115+
learning_rate,
116+
num_classes,
117+
frmap_dims=[300, 100],
118+
pool_stride=2,
119+
top_eigen=10,
120+
):
121+
"""Instantiate the GrNet architecture.
122+
123+
Huang, Zhiwu, Jiqing Wu, and Luc Van Gool. "Building Deep Networks on
124+
Grassmann Manifolds." AAAI. AAAI Press, 2018.
125+
126+
Args:
127+
learning_rate: model learning rate
128+
num_classes: number of output classes
129+
frmap_dims: dimensions of FrMap layers
130+
pool_stride: pooling stride
131+
top_eigen: number of eigenvectors to retain in OrthMap
132+
"""
133+
model = tf.keras.Sequential()
134+
for output_dim in frmap_dims:
135+
model.add(FRMap(output_dim))
136+
model.add(ReOrth())
137+
model.add(ProjMap())
138+
model.add(ProjPooling(pool_stride))
139+
model.add(OrthMap(top_eigen))
140+
model.add(ProjMap())
141+
model.add(tf.keras.layers.Flatten())
142+
model.add(tf.keras.layers.Dense(num_classes, use_bias=False))
143+
model.compile(
144+
optimizer=RiemannianSGD(learning_rate),
145+
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
146+
metrics=[tf.metrics.SparseCategoricalAccuracy()],
147+
)
148+
return model

examples/grnet/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
scipy
2+
tensorflow-manopt

examples/grnet/task.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env python3
2+
import argparse
3+
import os
4+
import tensorflow as tf
5+
6+
from . import model
7+
from shared import utils
8+
9+
DATA_URL = "https://data.vision.ee.ethz.ch/zzhiwu/ManifoldNetData/GrData/AFEW_Gr_data.zip"
10+
DATA_FOLDER = "grface_400_inter_histeq"
11+
AFEW_CLASSES = 7
12+
13+
14+
def get_args():
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument(
17+
'--job-dir', type=str, required=True, help='checkpoint dir'
18+
)
19+
parser.add_argument('--data-dir', type=str, required=True, help='data dir')
20+
parser.add_argument(
21+
'--num-epochs',
22+
type=float,
23+
default=50,
24+
help='number of training epochs (default 50)',
25+
)
26+
parser.add_argument(
27+
'--batch-size',
28+
default=30,
29+
type=int,
30+
help='number of examples per batch (default 30)',
31+
)
32+
parser.add_argument(
33+
'--shuffle-buffer',
34+
default=100,
35+
type=int,
36+
help='shuffle buffer size (default 100)',
37+
)
38+
parser.add_argument(
39+
'--learning-rate',
40+
default=0.01,
41+
type=float,
42+
help='learning rate (default .01)',
43+
)
44+
return parser.parse_args()
45+
46+
47+
def train_and_evaluate(args):
48+
utils.download_data(args.data_dir, DATA_URL, unpack=True)
49+
train = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "train")
50+
val = utils.load_matlab_data("Y1", args.data_dir, DATA_FOLDER, "val")
51+
train_dataset = (
52+
tf.data.Dataset.from_tensor_slices(train)
53+
.repeat(args.num_epochs)
54+
.shuffle(args.shuffle_buffer)
55+
.batch(args.batch_size, drop_remainder=True)
56+
)
57+
val_dataset = tf.data.Dataset.from_tensor_slices(val).batch(
58+
args.batch_size, drop_remainder=True
59+
)
60+
61+
grnet = model.create_model(args.learning_rate, num_classes=AFEW_CLASSES)
62+
checkpoint_path = os.path.join(args.job_dir, "afew-grnet.ckpt")
63+
cp_callback = tf.keras.callbacks.ModelCheckpoint(
64+
filepath=checkpoint_path, save_weights_only=True, verbose=1
65+
)
66+
grnet.fit(
67+
train_dataset,
68+
epochs=args.num_epochs,
69+
validation_data=val_dataset,
70+
callbacks=[cp_callback],
71+
)
72+
_, acc = grnet.evaluate(val_dataset, verbose=2)
73+
print("Final accuracy: {}%".format(acc * 100))
74+
75+
76+
if __name__ == "__main__":
77+
tf.get_logger().setLevel("INFO")
78+
train_and_evaluate(get_args())

examples/hyperbolic_nn/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)