Skip to content

Commit fa31fa7

Browse files
authored
Add name parameter to the __init__ of all op classes. (#21376)
Follow up of #21373 This is to solve an inconsistency in the saving / reloading of ops. Some ops behave differently. Consider this code: ```python input = keras.layers.Input(shape=(4,), dtype="float32") output = keras.ops.abs(input) model = keras.models.Model(input, output) json = model.to_json() reloaded_model = keras.models.model_from_json(json) reloaded_json = reloaded_model.to_json() ``` The reloaded model JSON is the same as the original model JSON. The `abs` op is serialized as - `{"module": "keras.src.ops.numpy", "class_name": "Absolute", "config": {"name": "absolute"}, "registered_name": "Absolute", "name": "absolute", ...}` Consider the same code with `abs` replaced with `sum`: ```python input = keras.layers.Input(shape=(4,), dtype="float32") output = keras.ops.abs(input) model = keras.models.Model(input, output) json = model.to_json() reloaded_model = keras.models.model_from_json(json) reloaded_json = reloaded_model.to_json() ``` The reloaded model JSON is different from the original JSON. - `{"module": "keras.src.ops.numpy", "class_name": "Sum", "config": {"axis": null, "keepdims": false}, "registered_name": "Sum", "name": "sum", ...}` - `{"module": "keras.src.ops.numpy", "class_name": "Sum", "config": {"axis": null, "keepdims": false}, "registered_name": "Sum", "name": "sum_1", ...}` The reloaded `sum` op now has a name `"sum_1"` instead of `"sum"`. This is because: - `Abs` does not define `__init__` and inherits the `Operation.__init__` that has a `name` parameter. - `Sum` defines `__init__` without a `name` parameter. We want saving / reloading to be idempotent. Even though users cannot control the name of ops (although it would be easy to add), the auto-assigned names should be saved and reloaded. For this, `name` has to be supported in `__init__`. This PR adds a `name` parameter to all existing `__init__` of op classes, which is passed to `super().__init__()`. Note that it is defined as a keyword only argument for forward compatibility in case more parameters get added in the feature. Empty `__init__`s were removed. Fix `UnravelIndex.__init__` which was not calling `super().__init__()`.
1 parent fb0b1e6 commit fa31fa7

File tree

7 files changed

+380
-357
lines changed

7 files changed

+380
-357
lines changed

keras/src/ops/core.py

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212

1313

1414
class Map(Operation):
15-
def __init__(self):
16-
super().__init__()
17-
1815
def call(self, f, xs):
1916
return backend.core.map(f, xs)
2017

@@ -78,8 +75,8 @@ def map(f, xs):
7875

7976

8077
class Scan(Operation):
81-
def __init__(self, length=None, reverse=False, unroll=1):
82-
super().__init__()
78+
def __init__(self, length=None, reverse=False, unroll=1, *, name=None):
79+
super().__init__(name=name)
8380
self.length = length
8481
self.reverse = reverse
8582
self.unroll = unroll
@@ -191,8 +188,8 @@ def scan(f, init, xs, length=None):
191188

192189

193190
class AssociativeScan(Operation):
194-
def __init__(self, reverse=False, axis=0):
195-
super().__init__()
191+
def __init__(self, reverse=False, axis=0, *, name=None):
192+
super().__init__(name=name)
196193
self.reverse = reverse
197194
self.axis = axis
198195

@@ -289,8 +286,8 @@ def associative_scan(f, elems, reverse=False, axis=0):
289286

290287

291288
class Scatter(Operation):
292-
def __init__(self, shape):
293-
super().__init__()
289+
def __init__(self, shape, *, name=None):
290+
super().__init__(name=name)
294291
self.shape = shape
295292

296293
def call(self, indices, values):
@@ -392,8 +389,8 @@ def scatter_update(inputs, indices, updates):
392389

393390

394391
class Slice(Operation):
395-
def __init__(self, shape):
396-
super().__init__()
392+
def __init__(self, shape, *, name=None):
393+
super().__init__(name=name)
397394
self.shape = shape
398395

399396
def call(self, inputs, start_indices):
@@ -530,8 +527,8 @@ def switch(index, branches, *operands):
530527

531528

532529
class WhileLoop(Operation):
533-
def __init__(self, cond, body, maximum_iterations=None):
534-
super().__init__()
530+
def __init__(self, cond, body, maximum_iterations=None, *, name=None):
531+
super().__init__(name=name)
535532
self.cond = cond
536533
self.body = body
537534
self.maximum_iterations = maximum_iterations
@@ -599,9 +596,6 @@ def while_loop(
599596

600597

601598
class StopGradient(Operation):
602-
def __init__(self):
603-
super().__init__()
604-
605599
def call(self, variable):
606600
return backend.core.stop_gradient(variable)
607601

@@ -634,8 +628,8 @@ def stop_gradient(variable):
634628

635629

636630
class ForiLoop(Operation):
637-
def __init__(self, lower, upper, body_fun):
638-
super().__init__()
631+
def __init__(self, lower, upper, body_fun, *, name=None):
632+
super().__init__(name=name)
639633
self.lower = lower
640634
self.upper = upper
641635
self.body_fun = body_fun
@@ -682,8 +676,8 @@ def fori_loop(lower, upper, body_fun, init_val):
682676

683677

684678
class Unstack(Operation):
685-
def __init__(self, num=None, axis=0):
686-
super().__init__()
679+
def __init__(self, num=None, axis=0, *, name=None):
680+
super().__init__(name=name)
687681
self.num = num
688682
self.axis = axis
689683

@@ -787,8 +781,8 @@ def dtype(x):
787781

788782

789783
class Cast(Operation):
790-
def __init__(self, dtype):
791-
super().__init__()
784+
def __init__(self, dtype, *, name=None):
785+
super().__init__(name=name)
792786
self.dtype = backend.standardize_dtype(dtype)
793787

794788
def call(self, x):
@@ -822,8 +816,8 @@ def cast(x, dtype):
822816

823817

824818
class SaturateCast(Operation):
825-
def __init__(self, dtype):
826-
super().__init__()
819+
def __init__(self, dtype, *, name=None):
820+
super().__init__(name=name)
827821
self.dtype = backend.standardize_dtype(dtype)
828822

829823
def call(self, x):
@@ -928,8 +922,8 @@ def get_dtype_min_max(dtype):
928922

929923

930924
class ConvertToTensor(Operation):
931-
def __init__(self, dtype=None, sparse=None, ragged=None):
932-
super().__init__()
925+
def __init__(self, dtype=None, sparse=None, ragged=None, *, name=None):
926+
super().__init__(name=name)
933927
self.dtype = None if dtype is None else backend.standardize_dtype(dtype)
934928
self.sparse = sparse
935929
self.ragged = ragged

keras/src/ops/image.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99

1010
class RGBToGrayscale(Operation):
11-
def __init__(self, data_format=None):
12-
super().__init__()
11+
def __init__(self, data_format=None, *, name=None):
12+
super().__init__(name=name)
1313
self.data_format = backend.standardize_data_format(data_format)
1414

1515
def call(self, images):
@@ -77,8 +77,8 @@ def rgb_to_grayscale(images, data_format=None):
7777

7878

7979
class RGBToHSV(Operation):
80-
def __init__(self, data_format=None):
81-
super().__init__()
80+
def __init__(self, data_format=None, *, name=None):
81+
super().__init__(name=name)
8282
self.data_format = backend.standardize_data_format(data_format)
8383

8484
def call(self, images):
@@ -149,8 +149,8 @@ def rgb_to_hsv(images, data_format=None):
149149

150150

151151
class HSVToRGB(Operation):
152-
def __init__(self, data_format=None):
153-
super().__init__()
152+
def __init__(self, data_format=None, *, name=None):
153+
super().__init__(name=name)
154154
self.data_format = backend.standardize_data_format(data_format)
155155

156156
def call(self, images):
@@ -228,8 +228,10 @@ def __init__(
228228
fill_mode="constant",
229229
fill_value=0.0,
230230
data_format=None,
231+
*,
232+
name=None,
231233
):
232-
super().__init__()
234+
super().__init__(name=name)
233235
self.size = tuple(size)
234236
self.interpolation = interpolation
235237
self.antialias = antialias
@@ -413,8 +415,10 @@ def __init__(
413415
fill_mode="constant",
414416
fill_value=0,
415417
data_format=None,
418+
*,
419+
name=None,
416420
):
417-
super().__init__()
421+
super().__init__(name=name)
418422
self.interpolation = interpolation
419423
self.fill_mode = fill_mode
420424
self.fill_value = fill_value
@@ -554,8 +558,10 @@ def __init__(
554558
dilation_rate=1,
555559
padding="valid",
556560
data_format=None,
561+
*,
562+
name=None,
557563
):
558-
super().__init__()
564+
super().__init__(name=name)
559565
if isinstance(size, int):
560566
size = (size, size)
561567
self.size = size
@@ -707,8 +713,8 @@ def _extract_patches(
707713

708714

709715
class MapCoordinates(Operation):
710-
def __init__(self, order, fill_mode="constant", fill_value=0):
711-
super().__init__()
716+
def __init__(self, order, fill_mode="constant", fill_value=0, *, name=None):
717+
super().__init__(name=name)
712718
self.order = order
713719
self.fill_mode = fill_mode
714720
self.fill_value = fill_value
@@ -803,8 +809,10 @@ def __init__(
803809
target_height=None,
804810
target_width=None,
805811
data_format=None,
812+
*,
813+
name=None,
806814
):
807-
super().__init__()
815+
super().__init__(name=name)
808816
self.top_padding = top_padding
809817
self.left_padding = left_padding
810818
self.bottom_padding = bottom_padding
@@ -1014,8 +1022,10 @@ def __init__(
10141022
target_height=None,
10151023
target_width=None,
10161024
data_format=None,
1025+
*,
1026+
name=None,
10171027
):
1018-
super().__init__()
1028+
super().__init__(name=name)
10191029
self.top_cropping = top_cropping
10201030
self.bottom_cropping = bottom_cropping
10211031
self.left_cropping = left_cropping
@@ -1238,8 +1248,10 @@ def __init__(
12381248
interpolation="bilinear",
12391249
fill_value=0,
12401250
data_format=None,
1251+
*,
1252+
name=None,
12411253
):
1242-
super().__init__()
1254+
super().__init__(name=name)
12431255
self.interpolation = interpolation
12441256
self.fill_value = fill_value
12451257
self.data_format = backend.standardize_data_format(data_format)
@@ -1381,8 +1393,10 @@ def __init__(
13811393
kernel_size=(3, 3),
13821394
sigma=(1.0, 1.0),
13831395
data_format=None,
1396+
*,
1397+
name=None,
13841398
):
1385-
super().__init__()
1399+
super().__init__(name=name)
13861400
self.kernel_size = kernel_size
13871401
self.sigma = sigma
13881402
self.data_format = backend.standardize_data_format(data_format)
@@ -1470,8 +1484,10 @@ def __init__(
14701484
fill_value=0.0,
14711485
seed=None,
14721486
data_format=None,
1487+
*,
1488+
name=None,
14731489
):
1474-
super().__init__()
1490+
super().__init__(name=name)
14751491
self.alpha = alpha
14761492
self.sigma = sigma
14771493
self.interpolation = interpolation

keras/src/ops/linalg.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77

88

99
class Cholesky(Operation):
10-
def __init__(self):
11-
super().__init__()
12-
1310
def call(self, x):
1411
return _cholesky(x)
1512

@@ -47,9 +44,6 @@ def _cholesky(x):
4744

4845

4946
class Det(Operation):
50-
def __init__(self):
51-
super().__init__()
52-
5347
def call(self, x):
5448
return _det(x)
5549

@@ -83,9 +77,6 @@ def _det(x):
8377

8478

8579
class Eig(Operation):
86-
def __init__(self):
87-
super().__init__()
88-
8980
def call(self, x):
9081
return _eig(x)
9182

@@ -122,9 +113,6 @@ def _eig(x):
122113

123114

124115
class Eigh(Operation):
125-
def __init__(self):
126-
super().__init__()
127-
128116
def call(self, x):
129117
return _eigh(x)
130118

@@ -162,9 +150,6 @@ def _eigh(x):
162150

163151

164152
class Inv(Operation):
165-
def __init__(self):
166-
super().__init__()
167-
168153
def call(self, x):
169154
return _inv(x)
170155

@@ -198,9 +183,6 @@ def _inv(x):
198183

199184

200185
class LuFactor(Operation):
201-
def __init__(self):
202-
super().__init__()
203-
204186
def call(self, x):
205187
return _lu_factor(x)
206188

@@ -248,8 +230,8 @@ def _lu_factor(x):
248230

249231

250232
class Norm(Operation):
251-
def __init__(self, ord=None, axis=None, keepdims=False):
252-
super().__init__()
233+
def __init__(self, ord=None, axis=None, keepdims=False, *, name=None):
234+
super().__init__(name=name)
253235
if isinstance(ord, str):
254236
if ord not in ("fro", "nuc"):
255237
raise ValueError(
@@ -367,8 +349,8 @@ def norm(x, ord=None, axis=None, keepdims=False):
367349

368350

369351
class Qr(Operation):
370-
def __init__(self, mode="reduced"):
371-
super().__init__()
352+
def __init__(self, mode="reduced", *, name=None):
353+
super().__init__(name=name)
372354
if mode not in {"reduced", "complete"}:
373355
raise ValueError(
374356
"`mode` argument value not supported. "
@@ -440,9 +422,6 @@ def qr(x, mode="reduced"):
440422

441423

442424
class Solve(Operation):
443-
def __init__(self):
444-
super().__init__()
445-
446425
def call(self, a, b):
447426
return _solve(a, b)
448427

@@ -484,8 +463,8 @@ def _solve(a, b):
484463

485464

486465
class SolveTriangular(Operation):
487-
def __init__(self, lower=False):
488-
super().__init__()
466+
def __init__(self, lower=False, *, name=None):
467+
super().__init__(name=name)
489468
self.lower = lower
490469

491470
def call(self, a, b):
@@ -531,8 +510,8 @@ def _solve_triangular(a, b, lower=False):
531510

532511

533512
class SVD(Operation):
534-
def __init__(self, full_matrices=True, compute_uv=True):
535-
super().__init__()
513+
def __init__(self, full_matrices=True, compute_uv=True, *, name=None):
514+
super().__init__(name=name)
536515
self.full_matrices = full_matrices
537516
self.compute_uv = compute_uv
538517

@@ -586,8 +565,8 @@ def _svd(x, full_matrices=True, compute_uv=True):
586565

587566

588567
class Lstsq(Operation):
589-
def __init__(self, rcond=None):
590-
super().__init__()
568+
def __init__(self, rcond=None, *, name=None):
569+
super().__init__(name=name)
591570
self.rcond = rcond
592571

593572
def call(self, a, b):

0 commit comments

Comments
 (0)