Skip to content

Commit db387f0

Browse files
committed
Fix TF/Keras compatibility
1 parent f706e2f commit db387f0

File tree

9 files changed

+50
-63
lines changed

9 files changed

+50
-63
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-latest
88
strategy:
99
matrix:
10-
python-version: ["3.8", "3.9", "3.10"]
10+
python-version: ["3.10", "3.11", "3.12"]
1111

1212
steps:
1313
- uses: actions/checkout@v2

requirements.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
tensorflow<2.12.0
2-
keras<2.12.0
3-
protobuf<3.20,>=3.9.2
1+
tensorflow

tensorflow_riemopt/optimizers/constrained_rmsprop.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,9 @@
1414
from tensorflow.python.ops import control_flow_ops
1515
from tensorflow.python.ops import math_ops
1616
from tensorflow.python.ops import state_ops
17-
from tensorflow.python.training import gen_training_ops
18-
19-
try:
20-
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
21-
except ImportError:
22-
from tensorflow.keras.optimizers.legacy import Optimizer as OptimizerV2
17+
from tensorflow.python.ops import gen_training_ops
18+
from tensorflow.python.framework.indexed_slices import IndexedSlices
19+
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
2320

2421
from tensorflow_riemopt.variable import get_manifold
2522

@@ -93,7 +90,7 @@ def _prepare_local(self, var_device, var_dtype, apply_state):
9390
apply_state[(var_device, var_dtype)].update(
9491
dict(
9592
neg_lr_t=-apply_state[(var_device, var_dtype)]["lr_t"],
96-
epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
93+
epsilon=ops.convert_to_tensor(self.epsilon, var_dtype),
9794
rho=rho,
9895
one_minus_rho=1.0 - rho,
9996
)
@@ -131,10 +128,10 @@ def _resource_apply_dense(self, grad, var, apply_state=None):
131128
rms.assign(manifold.transp(var, var_t, rms_t))
132129
if self.centered:
133130
mg.assign(manifold.transp(var, var_t, mg_t))
134-
var.assign(var_t)
135-
131+
var_update = var.assign(var_t)
136132
if self.stabilize is not None:
137133
self._stabilize(var)
134+
return var_update
138135

139136
@def_function.function(experimental_compile=True)
140137
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
@@ -173,16 +170,16 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
173170
)
174171

175172
rms_t_transp = manifold.transp(var_values, var_t_values, rms_t_values)
176-
rms.scatter_update(ops.IndexedSlices(rms_t_transp, indices))
173+
rms.scatter_update(IndexedSlices(rms_t_transp, indices))
177174

178175
if self.centered:
179176
mg_t_transp = manifold.transp(var_values, var_t_values, mg_t_values)
180-
mg.scatter_update(ops.IndexedSlices(mg_t_transp, indices))
181-
182-
var.scatter_update(ops.IndexedSlices(var_t_values, indices))
177+
mg.scatter_update(IndexedSlices(mg_t_transp, indices))
183178

179+
var_update = var.scatter_update(IndexedSlices(var_t_values, indices))
184180
if self.stabilize is not None:
185181
self._stabilize(var)
182+
return var_update
186183

187184
@def_function.function(experimental_compile=True)
188185
def _stabilize(self, var):

tensorflow_riemopt/optimizers/constrained_rmsprop_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorflow.python.ops import math_ops
1313
from tensorflow.python.ops import variables
1414
from tensorflow.python.platform import test
15+
from tensorflow.python.framework.indexed_slices import IndexedSlices
1516

1617
from tensorflow_riemopt.optimizers.constrained_rmsprop import (
1718
ConstrainedRMSprop,
@@ -43,13 +44,13 @@ def testSparse(self):
4344
var0_ref = variables.Variable(var0_np)
4445
var1_ref = variables.Variable(var1_np)
4546
grads0_np_indices = np.array([0, 2], dtype=np.int32)
46-
grads0 = ops.IndexedSlices(
47+
grads0 = IndexedSlices(
4748
constant_op.constant(grads0_np[grads0_np_indices]),
4849
constant_op.constant(grads0_np_indices),
4950
constant_op.constant([3]),
5051
)
5152
grads1_np_indices = np.array([0, 2], dtype=np.int32)
52-
grads1 = ops.IndexedSlices(
53+
grads1 = IndexedSlices(
5354
constant_op.constant(grads1_np[grads1_np_indices]),
5455
constant_op.constant(grads1_np_indices),
5556
constant_op.constant([3]),

tensorflow_riemopt/optimizers/riemannian_adam.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
from tensorflow.python.ops import control_flow_ops
1212
from tensorflow.python.ops import math_ops
1313
from tensorflow.python.ops import state_ops
14-
from tensorflow.python.training import gen_training_ops
15-
16-
try:
17-
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
18-
except ImportError:
19-
from tensorflow.keras.optimizers.legacy import Optimizer as OptimizerV2
14+
from tensorflow.python.ops import gen_training_ops
15+
from tensorflow.python.framework.indexed_slices import IndexedSlices
16+
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
2017

2118
from tensorflow_riemopt.variable import get_manifold
2219

@@ -105,7 +102,7 @@ def _prepare_local(self, var_device, var_dtype, apply_state):
105102
apply_state[(var_device, var_dtype)].update(
106103
dict(
107104
lr=lr,
108-
epsilon=ops.convert_to_tensor_v2(self.epsilon, var_dtype),
105+
epsilon=ops.convert_to_tensor(self.epsilon, var_dtype),
109106
beta_1_t=beta_1_t,
110107
beta_1_power=beta_1_power,
111108
one_minus_beta_1_t=1 - beta_1_t,
@@ -154,10 +151,10 @@ def _resource_apply_dense(self, grad, var, apply_state=None):
154151
var, -(m * alpha) / (math_ops.sqrt(v) + coefficients["epsilon"])
155152
)
156153
m.assign(manifold.transp(var, var_t, m))
157-
var.assign(var_t)
158-
154+
var_update = var.assign(var_t)
159155
if self.stabilize is not None:
160156
self._stabilize(var)
157+
return var_update
161158

162159
@def_function.function(experimental_compile=True)
163160
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
@@ -190,7 +187,7 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
190187

191188
if self.amsgrad:
192189
vhat = self.get_slot(var, "vhat")
193-
vhat.scatter_max(ops.IndexedSlices(v_t_values, indices))
190+
vhat.scatter_max(IndexedSlices(v_t_values, indices))
194191
v_t_values = array_ops.gather(vhat, indices)
195192

196193
var_values = array_ops.gather(var, indices)
@@ -201,12 +198,12 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
201198
)
202199
m_t_transp = manifold.transp(var_values, var_t_values, m_t_values)
203200

204-
m.scatter_update(ops.IndexedSlices(m_t_transp, indices))
205-
v.scatter_update(ops.IndexedSlices(v_t_values, indices))
206-
var.scatter_update(ops.IndexedSlices(var_t_values, indices))
207-
201+
m.scatter_update(IndexedSlices(m_t_transp, indices))
202+
v.scatter_update(IndexedSlices(v_t_values, indices))
203+
var_update = var.scatter_update(IndexedSlices(var_t_values, indices))
208204
if self.stabilize is not None:
209205
self._stabilize(var)
206+
return var_update
210207

211208
@def_function.function(experimental_compile=True)
212209
def _stabilize(self, var):

tensorflow_riemopt/optimizers/riemannian_adam_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from tensorflow.python.ops import math_ops
1414
from tensorflow.python.ops import variables
1515
from tensorflow.python.platform import test
16+
from tensorflow.python.framework.indexed_slices import IndexedSlices
1617

1718
from tensorflow_riemopt.optimizers.riemannian_adam import RiemannianAdam
1819

@@ -141,13 +142,13 @@ def testSparse(self):
141142
var0_ref = variables.Variable(var0_np)
142143
var1_ref = variables.Variable(var1_np)
143144
grads0_np_indices = np.array([0, 2], dtype=np.int32)
144-
grads0 = ops.IndexedSlices(
145+
grads0 = IndexedSlices(
145146
constant_op.constant(grads0_np[grads0_np_indices]),
146147
constant_op.constant(grads0_np_indices),
147148
constant_op.constant([3]),
148149
)
149150
grads1_np_indices = np.array([0, 2], dtype=np.int32)
150-
grads1 = ops.IndexedSlices(
151+
grads1 = IndexedSlices(
151152
constant_op.constant(grads1_np[grads1_np_indices]),
152153
constant_op.constant(grads1_np_indices),
153154
constant_op.constant([3]),

tensorflow_riemopt/optimizers/riemannian_gradient_descent.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
from tensorflow.python.ops import control_flow_ops
1212
from tensorflow.python.ops import math_ops
1313
from tensorflow.python.ops import state_ops
14-
from tensorflow.python.training import gen_training_ops
15-
16-
try:
17-
from keras.optimizer_v2.optimizer_v2 import OptimizerV2
18-
except ImportError:
19-
from tensorflow.keras.optimizers.legacy import Optimizer as OptimizerV2
14+
from tensorflow.python.ops import gen_training_ops
15+
from tensorflow.python.framework.tensor import Tensor
16+
from tensorflow.python.framework.indexed_slices import IndexedSlices
17+
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
2018

2119
from tensorflow_riemopt.variable import get_manifold
2220

@@ -66,7 +64,7 @@ def __init__(
6664
self._set_hyper("decay", self._initial_decay)
6765
self._momentum = False
6866
if (
69-
isinstance(momentum, ops.Tensor)
67+
isinstance(momentum, Tensor)
7068
or callable(momentum)
7169
or momentum > 0
7270
):
@@ -111,12 +109,13 @@ def _resource_apply_dense(self, grad, var, apply_state=None):
111109
else:
112110
var_t = manifold.retr(var, momentum_t)
113111
momentum.assign(manifold.transp(var, var_t, momentum_t))
114-
var.assign(var_t)
112+
var_update = var.assign(var_t)
115113
else:
116-
var.assign(manifold.retr(var, -grad * coefficients["lr_t"]))
114+
var_update = var.assign(manifold.retr(var, -grad * coefficients["lr_t"]))
117115

118116
if self.stabilize is not None:
119117
self._stabilize(var)
118+
return var_update
120119

121120
@def_function.function(experimental_compile=True)
122121
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
@@ -147,18 +146,15 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
147146
momentum_transp_values = manifold.transp(
148147
var_values, var_t_values, momentum_t_values
149148
)
150-
momentum.scatter_update(
151-
ops.IndexedSlices(momentum_transp_values, indices)
152-
)
149+
momentum.scatter_update(IndexedSlices(momentum_transp_values, indices))
150+
var_update = var.scatter_update(IndexedSlices(var_t_values, indices))
153151
else:
154-
var_t_values = manifold.retr(
155-
var_values, -grad * coefficients["lr_t"]
156-
)
157-
158-
var.scatter_update(ops.IndexedSlices(var_t_values, indices))
152+
var_t_values = manifold.retr(var_values, -grad * coefficients["lr_t"])
153+
var_update = var.scatter_update(IndexedSlices(var_t_values, indices))
159154

160155
if self.stabilize is not None:
161156
self._stabilize(var)
157+
return var_update
162158

163159
@def_function.function(experimental_compile=True)
164160
def _stabilize(self, var):

tensorflow_riemopt/optimizers/riemannian_gradient_descent_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from tensorflow.python.ops import math_ops
1414
from tensorflow.python.ops import variables
1515
from tensorflow.python.platform import test
16+
from tensorflow.python.framework.indexed_slices import IndexedSlices
1617

1718
from tensorflow_riemopt.optimizers.riemannian_gradient_descent import (
1819
RiemannianSGD,
@@ -45,13 +46,13 @@ def testSparse(self):
4546
var0_ref = variables.Variable(var0_np)
4647
var1_ref = variables.Variable(var1_np)
4748
grads0_np_indices = np.array([0, 2], dtype=np.int32)
48-
grads0 = ops.IndexedSlices(
49+
grads0 = IndexedSlices(
4950
constant_op.constant(grads0_np[grads0_np_indices]),
5051
constant_op.constant(grads0_np_indices),
5152
constant_op.constant([3]),
5253
)
5354
grads1_np_indices = np.array([0, 2], dtype=np.int32)
54-
grads1 = ops.IndexedSlices(
55+
grads1 = IndexedSlices(
5556
constant_op.constant(grads1_np[grads1_np_indices]),
5657
constant_op.constant(grads1_np_indices),
5758
constant_op.constant([3]),

tensorflow_riemopt/variable.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
import tensorflow as tf
2-
32
from tensorflow_riemopt.manifolds import Euclidean
43

54

65
def assign_to_manifold(var, manifold):
7-
if not isinstance(var, tf.Variable):
8-
raise ValueError("var should be a TensorFlow variable")
6+
if not hasattr(var, "shape"):
7+
raise ValueError("var should be a valid variable with a 'shape' attribute")
98
if not manifold.check_shape(var):
109
raise ValueError("Invalid variable shape {}".format(var.shape))
1110
setattr(var, "manifold", manifold)
1211

1312

1413
def get_manifold(var, default_manifold=Euclidean()):
15-
if not isinstance(var, tf.Variable):
16-
raise ValueError("var should be a TensorFlow variable")
17-
if hasattr(var, "manifold"):
18-
return var.manifold
19-
else:
20-
return default_manifold
14+
if not hasattr(var, "shape"):
15+
raise ValueError("var should be a valid variable with a 'shape' attribute")
16+
return getattr(var, "manifold", default_manifold)

0 commit comments

Comments
 (0)