Skip to content

Commit 810a6a2

Browse files
committed
Upgrade to the latest Keras & TensorFlow versions
1 parent 68680ed commit 810a6a2

File tree

7 files changed

+27
-13
lines changed

7 files changed

+27
-13
lines changed

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
tensorflow==2.6.3
2-
protobuf==3.19.6
1+
tensorflow
2+
keras

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
setup(
55
name="tensorflow-riemopt",
6-
version="0.1.2",
6+
version="0.2.0",
77
description="a library for optimization on Riemannian manifolds",
88
long_description=open("README.md").read(),
99
long_description_content_type="text/markdown",
1010
author="Oleg Smirnov",
1111
author_email="[email protected]",
1212
packages=find_packages(),
13-
install_requires=["tensorflow==2.6.3", "protobuf==3.19.6"],
13+
install_requires=["tensorflow", "keras"],
1414
python_requires=">=3.6.0",
1515
url="https://github.com/master/tensorflow-riemopt",
1616
zip_safe=True,
@@ -22,6 +22,8 @@
2222
"Programming Language :: Python :: 3.6",
2323
"Programming Language :: Python :: 3.7",
2424
"Programming Language :: Python :: 3.8",
25+
"Programming Language :: Python :: 3.9",
26+
"Programming Language :: Python :: 3.10",
2527
"Topic :: Scientific/Engineering :: Mathematics",
2628
"Topic :: Software Development :: Libraries :: Python Modules",
2729
"Topic :: Software Development :: Libraries",

tensorflow_riemopt/optimizers/constrained_rmsprop.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
from tensorflow.python.ops import math_ops
1616
from tensorflow.python.ops import state_ops
1717
from tensorflow.python.training import gen_training_ops
18-
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
18+
19+
try:
20+
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
21+
except ImportError:
22+
from tensorflow.keras.optimizers.legacy import Optimizer as OptimizerV2
1923

2024
from tensorflow_riemopt.variable import get_manifold
2125

tensorflow_riemopt/optimizers/constrained_rmsprop_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,14 @@ def testBasic(self):
139139
self.assertAllCloseAccordingToType(
140140
self.evaluate(var0_ref),
141141
self.evaluate(var0),
142-
rtol=1e-4,
143-
atol=1e-4,
142+
rtol=1e-3,
143+
atol=1e-3,
144144
)
145145
self.assertAllCloseAccordingToType(
146146
self.evaluate(var1_ref),
147147
self.evaluate(var1),
148-
rtol=1e-4,
149-
atol=1e-4,
148+
rtol=1e-2,
149+
atol=1e-2,
150150
)
151151

152152

tensorflow_riemopt/optimizers/riemannian_adam.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from tensorflow.python.ops import math_ops
1313
from tensorflow.python.ops import state_ops
1414
from tensorflow.python.training import gen_training_ops
15-
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
15+
16+
try:
17+
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
18+
except ImportError:
19+
from tensorflow.keras.optimizers.legacy import Optimizer as OptimizerV2
1620

1721
from tensorflow_riemopt.variable import get_manifold
1822

tensorflow_riemopt/optimizers/riemannian_gradient_descent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
from tensorflow.python.ops import math_ops
1313
from tensorflow.python.ops import state_ops
1414
from tensorflow.python.training import gen_training_ops
15-
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
15+
16+
try:
17+
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
18+
except ImportError:
19+
from tensorflow.keras.optimizers.legacy import Optimizer as OptimizerV2
1620

1721
from tensorflow_riemopt.variable import get_manifold
1822

tensorflow_riemopt/optimizers/riemannian_gradient_descent_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def testBasic(self):
156156
self.assertAllCloseAccordingToType(
157157
self.evaluate(var1_ref),
158158
self.evaluate(var1),
159-
rtol=1e-4,
160-
atol=1e-4,
159+
rtol=1e-3,
160+
atol=1e-3,
161161
)
162162

163163

0 commit comments

Comments
 (0)