Skip to content

Commit f84e520

Browse files
authored
Merge pull request #6 from master/upgrade
Upgrade to TensorFlow 2.11
2 parents 78d49b7 + 2c643c2 commit f84e520

File tree

14 files changed

+78
-59
lines changed

14 files changed

+78
-59
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-latest
88
strategy:
99
matrix:
10-
python-version: [3.7, 3.8]
10+
python-version: ["3.8", "3.9", "3.10"]
1111

1212
steps:
1313
- uses: actions/checkout@v2

examples/shared/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def download_data(data_dir, url, unpack=True, block_size=10 * 1024):
1616
print("{} already exists. Skipping download".format(filename))
1717
return
1818

19-
print("Downloading {0} to {1}".format(url, filename))
19+
print("Downloading {} to {}".format(url, filename))
2020
response = requests.get(url, stream=True)
2121
total = int(response.headers.get("content-length", 0))
2222
progress_bar = tqdm.tqdm(total=total, unit="iB", unit_scale=True)
@@ -33,7 +33,7 @@ def download_data(data_dir, url, unpack=True, block_size=10 * 1024):
3333
with open(filename, "rb") as f:
3434
with zipfile.ZipFile(f) as zip_ref:
3535
zip_ref.extractall(data_dir)
36-
print("Unzipped {0} to {1}".format(filename, data_dir))
36+
print("Unzipped {} to {}".format(filename, data_dir))
3737

3838

3939
def load_matlab_data(key, data_dir, *folders):

examples/tutorial.ipynb

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"%matplotlib inline\n",
1414
"\n",
1515
"import sys\n",
16+
"\n",
1617
"sys.path.append(\"../\")\n",
1718
"import tensorflow_riemopt as riemopt"
1819
]
@@ -119,7 +120,7 @@
119120
"\n",
120121
"opt = riemopt.optimizers.RiemannianAdam(learning_rate=0.2)\n",
121122
"\n",
122-
"npole = tf.constant([0., 1.])\n",
123+
"npole = tf.constant([0.0, 1.0])\n",
123124
"phi = np.linspace(-np.pi, np.pi, 100)\n",
124125
"\n",
125126
"for _ in range(STEPS):\n",
@@ -136,8 +137,14 @@
136137
" plt.plot(np.cos(phi), np.sin(phi))\n",
137138
" plt.plot(var_np[:, 0], var_np[:, 1], '+', color='black')\n",
138139
" for i in range(len(egrad_np)):\n",
139-
" plt.arrow(var_np[i][0], var_np[i][1], -egrad_np[i][0], -egrad_np[i][1],\n",
140-
" width=0.01, color='green')\n",
140+
" plt.arrow(\n",
141+
" var_np[i][0],\n",
142+
" var_np[i][1],\n",
143+
" -egrad_np[i][0],\n",
144+
" -egrad_np[i][1],\n",
145+
" width=0.01,\n",
146+
" color='green',\n",
147+
" )\n",
141148
" plt.plot(var_t_np[:, 0], var_t_np[:, 1], '+', color='red')\n",
142149
" plt.show()"
143150
]
@@ -152,7 +159,7 @@
152159
],
153160
"metadata": {
154161
"kernelspec": {
155-
"display_name": "Python 3",
162+
"display_name": "Python 3 (ipykernel)",
156163
"language": "python",
157164
"name": "python3"
158165
},
@@ -166,7 +173,7 @@
166173
"name": "python",
167174
"nbconvert_exporter": "python",
168175
"pygments_lexer": "ipython3",
169-
"version": "3.7.3"
176+
"version": "3.10.6"
170177
}
171178
},
172179
"nbformat": 4,

examples/usage.ipynb

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"from mpl_toolkits.mplot3d import Axes3D\n",
1515
"\n",
1616
"import sys\n",
17+
"\n",
1718
"sys.path.append(\"../\")\n",
1819
"import tensorflow_riemopt as riemopt"
1920
]
@@ -32,15 +33,17 @@
3233
" s_z = np.outer(np.abs(np.cos(phi)), np.ones_like(psi))\n",
3334
" return ax.plot_wireframe(s_x, s_y, s_z, color=color, alpha=0.3)\n",
3435
"\n",
36+
"\n",
3537
"def plot_vector(ax, x, u, color=\"darkorange\"):\n",
3638
" return ax.quiver(*x, *u, length=0.6, normalize=True, color=color)\n",
3739
"\n",
40+
"\n",
3841
"def plot_hyperplane(ax, p, u, v, color=\"limegreen\"):\n",
39-
" xx = np.linspace(-0.05, 1., 10)\n",
40-
" yy = np.linspace(-1., 0.1, 10)\n",
42+
" xx = np.linspace(-0.05, 1.0, 10)\n",
43+
" yy = np.linspace(-1.0, 0.1, 10)\n",
4144
" x, y = np.meshgrid(xx, yy)\n",
4245
" n = np.cross(u, v)\n",
43-
" z = (- n[0] * x - n[1] * y + p.dot(n)) / n[2]\n",
46+
" z = (-n[0] * x - n[1] * y + p.dot(n)) / n[2]\n",
4447
" return ax.plot_wireframe(x, y, z, color=color, alpha=0.4)"
4548
]
4649
},
@@ -53,7 +56,7 @@
5356
"S = riemopt.manifolds.Sphere()\n",
5457
"\n",
5558
"x = S.projx(tf.constant([0.1, -0.1, 0.1]))\n",
56-
"u = S.proju(x, tf.constant([1., 1., 1.]))\n",
59+
"u = S.proju(x, tf.constant([1.0, 1.0, 1.0]))\n",
5760
"v = S.proju(x, tf.constant([-0.7, -1.4, 1.4]))\n",
5861
"y = S.exp(x, v)\n",
5962
"u_ = S.transp(x, y, u)\n",
@@ -866,9 +869,9 @@
866869
"fig = plt.figure()\n",
867870
"ax = fig.gca(projection=\"3d\")\n",
868871
"ax.axis(\"off\")\n",
869-
"ax.set_zlim3d(-1.5, 1.5) \n",
870-
"ax.set_ylim3d(-1.5, 1.5) \n",
871-
"ax.set_xlim3d(-1.5, 1.5) \n",
872+
"ax.set_zlim3d(-1.5, 1.5)\n",
873+
"ax.set_ylim3d(-1.5, 1.5)\n",
874+
"ax.set_xlim3d(-1.5, 1.5)\n",
872875
"\n",
873876
"plot_halfsphere(ax)\n",
874877
"plot_hyperplane(ax, x, u, v)\n",
@@ -885,7 +888,7 @@
885888
],
886889
"metadata": {
887890
"kernelspec": {
888-
"display_name": "Python 3",
891+
"display_name": "Python 3 (ipykernel)",
889892
"language": "python",
890893
"name": "python3"
891894
},
@@ -899,7 +902,7 @@
899902
"name": "python",
900903
"nbconvert_exporter": "python",
901904
"pygments_lexer": "ipython3",
902-
"version": "3.7.3"
905+
"version": "3.10.6"
903906
}
904907
},
905908
"nbformat": 4,

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
tensorflow==2.6.3
2-
protobuf==3.19.6
1+
tensorflow<2.12.0
2+
keras<2.12.0
3+
protobuf<3.20,>=3.9.2

setup.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
setup(
55
name="tensorflow-riemopt",
6-
version="0.1.2",
6+
version="0.2.0",
77
description="a library for optimization on Riemannian manifolds",
8-
long_description=open("README.md", "r").read(),
8+
long_description=open("README.md").read(),
99
long_description_content_type="text/markdown",
1010
author="Oleg Smirnov",
1111
author_email="[email protected]",
1212
packages=find_packages(),
13-
install_requires=["tensorflow==2.6.3", "protobuf==3.19.6"],
13+
install_requires=["tensorflow<2.12.0", "keras<2.12.0", "protobuf<3.20,>=3.9.2"],
1414
python_requires=">=3.6.0",
1515
url="https://github.com/master/tensorflow-riemopt",
1616
zip_safe=True,
@@ -22,6 +22,8 @@
2222
"Programming Language :: Python :: 3.6",
2323
"Programming Language :: Python :: 3.7",
2424
"Programming Language :: Python :: 3.8",
25+
"Programming Language :: Python :: 3.9",
26+
"Programming Language :: Python :: 3.10",
2527
"Topic :: Scientific/Engineering :: Mathematics",
2628
"Topic :: Software Development :: Libraries :: Python Modules",
2729
"Topic :: Software Development :: Libraries",

tensorflow_riemopt/manifolds/hyperboloid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, k=1.0):
2424
super().__init__()
2525

2626
def __repr__(self):
27-
return "{0} (k={1}, ndims={2}) manifold".format(
27+
return "{} (k={}, ndims={}) manifold".format(
2828
self.name, self.k, self.ndims
2929
)
3030

tensorflow_riemopt/manifolds/manifold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class Manifold(metaclass=abc.ABCMeta):
88

99
def __repr__(self):
1010
"""Returns a string representation of the particular manifold."""
11-
return "{0} (ndims={1}) manifold".format(self.name, self.ndims)
11+
return "{} (ndims={}) manifold".format(self.name, self.ndims)
1212

1313
def check_shape(self, shape_or_tensor):
1414
"""Check if given shape is compatible with the manifold."""

tensorflow_riemopt/manifolds/poincare.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, k=1.0):
3131
super().__init__()
3232

3333
def __repr__(self):
34-
return "{0} (k={1}, ndims={2}) manifold".format(
34+
return "{} (k={}, ndims={}) manifold".format(
3535
self.name, self.k, self.ndims
3636
)
3737

@@ -45,10 +45,10 @@ def _check_vector_on_tangent(self, x, u, atol, rtol):
4545

4646
def _mobius_add(self, x, y):
4747
"""Compute the Möbius addition of :math:`x` and :math:`y` in
48-
:math:`\mathcal{D}^{n}_{k}`
48+
:math:`\\mathcal{D}^{n}_{k}`
4949
50-
:math:`x \oplus y = \frac{(1 + 2k\langle x, y\rangle + k||y||^2)x + (1
51-
- k||x||^2)y}{1 + 2k\langle x,y\rangle + k^2||x||^2||y||^2}`
50+
:math:`x \\oplus y = \frac{(1 + 2k\\langle x, y\rangle + k||y||^2)x + (1
51+
- k||x||^2)y}{1 + 2k\\langle x,y\rangle + k^2||x||^2||y||^2}`
5252
"""
5353
x_2 = tf.reduce_sum(tf.math.square(x), axis=-1, keepdims=True)
5454
y_2 = tf.reduce_sum(tf.math.square(y), axis=-1, keepdims=True)
@@ -59,11 +59,11 @@ def _mobius_add(self, x, y):
5959
)
6060

6161
def _mobius_scal_mul(self, x, r):
62-
"""Compute the Möbius scalar multiplication of :math:`x \in
63-
\mathcal{D}^{n}_{k} \ {0}` by :math:`r`
62+
"""Compute the Möbius scalar multiplication of :math:`x \\in
63+
\\mathcal{D}^{n}_{k} \\ {0}` by :math:`r`
6464
65-
:math:`x \otimes r = (1/\sqrt{k})\tanh(r
66-
\atanh(\sqrt{k}||x||))\frac{x}{||x||}`
65+
:math:`x \\otimes r = (1/\\sqrt{k})\tanh(r
66+
\atanh(\\sqrt{k}||x||))\frac{x}{||x||}`
6767
6868
"""
6969
sqrt_k = tf.math.sqrt(tf.cast(self.k, x.dtype))
@@ -73,7 +73,7 @@ def _mobius_scal_mul(self, x, r):
7373
return (1 / sqrt_k) * tf.math.tanh(r * tf.math.atanh(tan)) * x / norm_x
7474

7575
def _gyration(self, u, v, w):
76-
"""Compute the gyration of :math:`u`, :math:`v`, :math:`w`:
76+
r"""Compute the gyration of :math:`u`, :math:`v`, :math:`w`:
7777
7878
:math:`\operatorname{gyr}[u, v]w =
7979
\ominus (u \oplus_\kappa v) \oplus (u \oplus_\kappa (v \oplus_\kappa w))`

tensorflow_riemopt/optimizers/constrained_rmsprop.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
from tensorflow.python.ops import math_ops
1616
from tensorflow.python.ops import state_ops
1717
from tensorflow.python.training import gen_training_ops
18-
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
18+
19+
try:
20+
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
21+
except ImportError:
22+
from tensorflow.keras.optimizers.legacy import Optimizer as OptimizerV2
1923

2024
from tensorflow_riemopt.variable import get_manifold
2125

@@ -66,7 +70,7 @@ def __init__(
6670
allow time inverse decay of learning rate. `lr` is included for backward
6771
compatibility, recommended to use `learning_rate` instead.
6872
"""
69-
super(ConstrainedRMSprop, self).__init__(name, **kwargs)
73+
super().__init__(name, **kwargs)
7074
self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
7175
self._set_hyper("decay", self._initial_decay)
7276
self._set_hyper("rho", rho)
@@ -83,9 +87,7 @@ def _create_slots(self, var_list):
8387
self.add_slot(var, "mg")
8488

8589
def _prepare_local(self, var_device, var_dtype, apply_state):
86-
super(ConstrainedRMSprop, self)._prepare_local(
87-
var_device, var_dtype, apply_state
88-
)
90+
super()._prepare_local(var_device, var_dtype, apply_state)
8991

9092
rho = array_ops.identity(self._get_hyper("rho", var_dtype))
9193
apply_state[(var_device, var_dtype)].update(
@@ -197,10 +199,10 @@ def set_weights(self, weights):
197199
params = self.weights
198200
if len(params) == len(weights) + 1:
199201
weights = [np.array(0)] + weights
200-
super(ConstrainedRMSprop, self).set_weights(weights)
202+
super().set_weights(weights)
201203

202204
def get_config(self):
203-
config = super(ConstrainedRMSprop, self).get_config()
205+
config = super().get_config()
204206
config.update(
205207
{
206208
"learning_rate": self._serialize_hyperparameter(

0 commit comments

Comments
 (0)