Skip to content

Commit fb0b1e6

Browse files
authored
Fix operations which had static parameter passed as dynamic parameters. (#21374)
Switched many parameters in ops from being dynamic (passed to `call` method) to being static (passed to `__init__`). These parameters should be static because they cannot be tensors. The most commonly changed parameters were `shape` and `dtype`. Standardized the way `dtype` parameters are handled, in particular when they can be `None`. Removed unused op classes: `Empty`, `Eye`, `Identity`, `Ones`, `Tri`, `Zeros`. These ops only have static parameters and no tensor inputs, they therefore cannot be part of a functional graph. They were missing the standard `if any_symbolic_tensors(...)`, which is why the class was dead code and only instantiated explicitly by unit tests. Fixed `arange` and `full` ops. The code was missing the standard `if any_symbolic_tensors(...)`, which means that adding them to a functional graph would fail before and the op class was dead code until now.
1 parent 11f737e commit fb0b1e6

File tree

6 files changed

+195
-219
lines changed

6 files changed

+195
-219
lines changed

keras/src/ops/core.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,30 @@ def map(f, xs):
7878

7979

8080
class Scan(Operation):
81-
def __init__(self, reverse=False, unroll=1):
81+
def __init__(self, length=None, reverse=False, unroll=1):
8282
super().__init__()
83+
self.length = length
8384
self.reverse = reverse
8485
self.unroll = unroll
8586

86-
def call(self, f, init, xs=None, length=None):
87+
def call(self, f, init, xs=None):
8788
return backend.core.scan(
88-
f, init, xs, length, reverse=self.reverse, unroll=self.unroll
89+
f,
90+
init,
91+
xs,
92+
length=self.length,
93+
reverse=self.reverse,
94+
unroll=self.unroll,
8995
)
9096

91-
def compute_output_spec(self, f, init, xs=None, length=None):
97+
def compute_output_spec(self, f, init, xs=None):
9298
if xs is None:
93-
n = int(length)
99+
n = int(self.length)
94100
x = None
95101
else:
96102
n = (
97-
int(length)
98-
if length is not None
103+
int(self.length)
104+
if self.length is not None
99105
else tree.flatten(xs)[0].shape[0]
100106
)
101107
x = xs[0]
@@ -176,9 +182,9 @@ def scan(f, init, xs, length=None):
176182
[1, 3, 6, 10, 15]
177183
"""
178184
if any_symbolic_tensors((init, xs)):
179-
return Scan(reverse=reverse, unroll=unroll).symbolic_call(
180-
f, init, xs, length
181-
)
185+
return Scan(
186+
length=length, reverse=reverse, unroll=unroll
187+
).symbolic_call(f, init, xs)
182188
return backend.core.scan(
183189
f, init, xs, length, reverse=reverse, unroll=unroll
184190
)
@@ -283,11 +289,15 @@ def associative_scan(f, elems, reverse=False, axis=0):
283289

284290

285291
class Scatter(Operation):
286-
def call(self, indices, values, shape):
287-
return backend.core.scatter(indices, values, shape)
292+
def __init__(self, shape):
293+
super().__init__()
294+
self.shape = shape
288295

289-
def compute_output_spec(self, indices, values, shape):
290-
return KerasTensor(shape, dtype=values.dtype)
296+
def call(self, indices, values):
297+
return backend.core.scatter(indices, values, self.shape)
298+
299+
def compute_output_spec(self, indices, values):
300+
return KerasTensor(self.shape, dtype=values.dtype)
291301

292302

293303
@keras_export("keras.ops.scatter")
@@ -316,8 +326,8 @@ def scatter(indices, values, shape):
316326
array([[0., 1.],
317327
[0., 1.]])
318328
"""
319-
if any_symbolic_tensors((indices, values, shape)):
320-
return Scatter().symbolic_call(indices, values, shape)
329+
if any_symbolic_tensors((indices, values)):
330+
return Scatter(shape=shape).symbolic_call(indices, values)
321331
return backend.core.scatter(indices, values, shape)
322332

323333

@@ -382,11 +392,15 @@ def scatter_update(inputs, indices, updates):
382392

383393

384394
class Slice(Operation):
385-
def call(self, inputs, start_indices, shape):
386-
return backend.core.slice(inputs, start_indices, shape)
395+
def __init__(self, shape):
396+
super().__init__()
397+
self.shape = shape
387398

388-
def compute_output_spec(self, inputs, start_indices, shape):
389-
return KerasTensor(shape, dtype=inputs.dtype)
399+
def call(self, inputs, start_indices):
400+
return backend.core.slice(inputs, start_indices, self.shape)
401+
402+
def compute_output_spec(self, inputs, start_indices):
403+
return KerasTensor(self.shape, dtype=inputs.dtype)
390404

391405

392406
@keras_export("keras.ops.slice")
@@ -415,8 +429,8 @@ def slice(inputs, start_indices, shape):
415429
Returns:
416430
A tensor, has the same shape and dtype as `inputs`.
417431
"""
418-
if any_symbolic_tensors((inputs, start_indices, shape)):
419-
return Slice().symbolic_call(inputs, start_indices, shape)
432+
if any_symbolic_tensors((inputs, start_indices)):
433+
return Slice(shape=shape).symbolic_call(inputs, start_indices)
420434
return backend.core.slice(inputs, start_indices, shape)
421435

422436

@@ -916,7 +930,7 @@ def get_dtype_min_max(dtype):
916930
class ConvertToTensor(Operation):
917931
def __init__(self, dtype=None, sparse=None, ragged=None):
918932
super().__init__()
919-
self.dtype = backend.standardize_dtype(dtype)
933+
self.dtype = None if dtype is None else backend.standardize_dtype(dtype)
920934
self.sparse = sparse
921935
self.ragged = ragged
922936

@@ -926,7 +940,11 @@ def call(self, x):
926940
)
927941

928942
def compute_output_spec(self, x):
929-
dtype = x.dtype if self.dtype is None else self.dtype
943+
dtype = (
944+
backend.standardize_dtype(x.dtype)
945+
if self.dtype is None
946+
else self.dtype
947+
)
930948
sparse = (
931949
False if self.sparse is not None and not self.sparse else x.sparse
932950
)

keras/src/ops/core_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,7 @@ def cumsum(carry, xs):
10101010
init = np.array(0, dtype="float32")
10111011
xs = np.array([1, 2, 3, 4, 10, 20], dtype="float32")
10121012
scan_op = core.Scan()
1013-
carry, result = scan_op.call(cumsum, init, xs, None)
1013+
carry, result = scan_op.call(cumsum, init, xs)
10141014
self.assertAllClose(carry, 40.0)
10151015
self.assertAllClose(result, ops.cumsum(xs))
10161016

@@ -1025,8 +1025,8 @@ def test_scatter_basic_call(self):
10251025
indices = np.array([[1, 0], [0, 1]])
10261026
values = np.array([10, 20])
10271027
shape = (2, 2)
1028-
scatter = core.Scatter()
1029-
result = scatter.call(indices, values, shape)
1028+
scatter = core.Scatter(shape)
1029+
result = scatter.call(indices, values)
10301030
expected_output = np.array([[0, 20], [10, 0]])
10311031
self.assertAllClose(core.convert_to_numpy(result), expected_output)
10321032

@@ -1043,17 +1043,17 @@ def test_slice_basic_call(self):
10431043
inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
10441044
start_indices = np.array([1, 1])
10451045
shape = (2, 2)
1046-
slice_op = core.Slice()
1047-
result = slice_op.call(inputs, start_indices, shape)
1046+
slice_op = core.Slice(shape)
1047+
result = slice_op.call(inputs, start_indices)
10481048
expected_output = np.array([[5, 6], [8, 9]])
10491049
self.assertAllClose(core.convert_to_numpy(result), expected_output)
10501050

10511051
def test_slice_compute_output_spec(self):
10521052
inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
10531053
start_indices = np.array([1, 1])
10541054
shape = (2, 2)
1055-
slice_op = core.Slice()
1056-
output_spec = slice_op.compute_output_spec(inputs, start_indices, shape)
1055+
slice_op = core.Slice(shape)
1056+
output_spec = slice_op.compute_output_spec(inputs, start_indices)
10571057
self.assertEqual(output_spec.shape, shape)
10581058
self.assertEqual(output_spec.dtype, inputs.dtype)
10591059

keras/src/ops/nn.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ def __init__(self, num_classes, axis=-1, dtype=None, sparse=False):
16931693
super().__init__()
16941694
self.num_classes = num_classes
16951695
self.axis = axis
1696-
self.dtype = dtype or backend.floatx()
1696+
self.dtype = backend.standardize_dtype(dtype)
16971697
self.sparse = sparse
16981698

16991699
def call(self, x):
@@ -2580,9 +2580,13 @@ def psnr(
25802580

25812581

25822582
class DotProductAttention(Operation):
2583-
def __init__(self, is_causal=False):
2583+
def __init__(
2584+
self, is_causal=False, flash_attention=None, attn_logits_soft_cap=None
2585+
):
25842586
super().__init__()
25852587
self.is_causal = is_causal
2588+
self.flash_attention = flash_attention
2589+
self.attn_logits_soft_cap = attn_logits_soft_cap
25862590

25872591
def call(
25882592
self,
@@ -2592,8 +2596,6 @@ def call(
25922596
bias=None,
25932597
mask=None,
25942598
scale=None,
2595-
flash_attention=None,
2596-
attn_logits_soft_cap=None,
25972599
):
25982600
return backend.nn.dot_product_attention(
25992601
query,
@@ -2603,8 +2605,8 @@ def call(
26032605
mask=mask,
26042606
scale=scale,
26052607
is_causal=self.is_causal,
2606-
flash_attention=flash_attention,
2607-
attn_logits_soft_cap=attn_logits_soft_cap,
2608+
flash_attention=self.flash_attention,
2609+
attn_logits_soft_cap=self.attn_logits_soft_cap,
26082610
)
26092611

26102612
def compute_output_spec(
@@ -2615,8 +2617,6 @@ def compute_output_spec(
26152617
bias=None,
26162618
mask=None,
26172619
scale=None,
2618-
flash_attention=None,
2619-
attn_logits_soft_cap=None,
26202620
):
26212621
return KerasTensor(query.shape, dtype=query.dtype)
26222622

@@ -2703,15 +2703,17 @@ def dot_product_attention(
27032703
)
27042704

27052705
if any_symbolic_tensors((query, key, value)):
2706-
return DotProductAttention(is_causal=is_causal).symbolic_call(
2706+
return DotProductAttention(
2707+
is_causal=is_causal,
2708+
flash_attention=flash_attention,
2709+
attn_logits_soft_cap=attn_logits_soft_cap,
2710+
).symbolic_call(
27072711
query,
27082712
key,
27092713
value,
27102714
bias=bias,
27112715
mask=mask,
27122716
scale=scale,
2713-
flash_attention=flash_attention,
2714-
attn_logits_soft_cap=attn_logits_soft_cap,
27152717
)
27162718
return backend.nn.dot_product_attention(
27172719
query,

0 commit comments

Comments
 (0)