Skip to content

Commit ad7bbb8

Browse files
authored
Make ops consistent and test ops consistency. (#21334)
Added tests to verify that: - All the parameters in the op class `call` (dynamic / tensor parameters) and `__init__` (static / non-tensor parameters) match the parameters in the op function - the names should match - default values should match - the order should be consistent, usually with the static parameters all at the end - Ops are exported in `keras.ops` - Ops are exported in their module, e.g. `keras.ops.numpy` - Ops are implemented in every backend - The backend implementation has the same signature as the op wrapper (same parameters in the same order with the same default values) Note about default values: only the default values declared in the common op function are used by users, which is why those values were not changed and were applied to the class and backend implementations. - The default values for the class `__init__` are not used because the op function passes all parameters when constructing the class. - The default values for the backend implementation are not used, because the common op function passes all parameters when calling the backend implementation. The exception to this is when an op implementation calls another op directly. Therefore: - The defaults on the backend implementations are just a reminder for the implementer - The defaults on the class `__init__` are mostly a reminder, but can also serve as a forward compatibility mechanism when deserializing a model after a parameter is added. Fixed all inconsistencies, mostly inconsistent default values.
1 parent 4d33a9a commit ad7bbb8

File tree

25 files changed

+417
-207
lines changed

25 files changed

+417
-207
lines changed

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
from keras.src.ops.numpy import roll as roll
134134
from keras.src.ops.numpy import rot90 as rot90
135135
from keras.src.ops.numpy import round as round
136+
from keras.src.ops.numpy import searchsorted as searchsorted
136137
from keras.src.ops.numpy import select as select
137138
from keras.src.ops.numpy import sign as sign
138139
from keras.src.ops.numpy import signbit as signbit

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
from keras.src.ops.numpy import roll as roll
134134
from keras.src.ops.numpy import rot90 as rot90
135135
from keras.src.ops.numpy import round as round
136+
from keras.src.ops.numpy import searchsorted as searchsorted
136137
from keras.src.ops.numpy import select as select
137138
from keras.src.ops.numpy import sign as sign
138139
from keras.src.ops.numpy import signbit as signbit

keras/src/backend/jax/nn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def log_softmax(x, axis=-1):
155155
return jnn.log_softmax(x, axis=axis)
156156

157157

158-
def sparsemax(logits, axis=-1):
158+
def sparsemax(x, axis=-1):
159159
# Sort logits along the specified axis in descending order
160-
logits = convert_to_tensor(logits)
160+
logits = convert_to_tensor(x)
161161
logits_sorted = -1.0 * jnp.sort(logits * -1.0, axis=axis)
162162
logits_cumsum = jnp.cumsum(logits_sorted, axis=axis) # find cumulative sum
163163
r = jnp.arange(1, logits.shape[axis] + 1) # Determine the sparsity
@@ -250,8 +250,8 @@ def max_pool(
250250
def average_pool(
251251
inputs,
252252
pool_size,
253-
strides,
254-
padding,
253+
strides=None,
254+
padding="valid",
255255
data_format=None,
256256
):
257257
data_format = backend.standardize_data_format(data_format)
@@ -485,7 +485,7 @@ def conv_transpose(
485485
)
486486

487487

488-
def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
488+
def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
489489
x = convert_to_tensor(x)
490490
if sparse:
491491
if axis < 0:
@@ -513,7 +513,7 @@ def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
513513
return jnn.one_hot(x, num_classes, axis=axis, dtype=dtype)
514514

515515

516-
def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
516+
def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
517517
x = convert_to_tensor(x)
518518
reduction_axis = 1 if len(x.shape) > 1 else 0
519519
if sparse:

keras/src/backend/jax/numpy.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,10 @@ def digitize(x, bins):
630630
return jnp.digitize(x, bins)
631631

632632

633-
def dot(x, y):
634-
x = convert_to_tensor(x)
635-
y = convert_to_tensor(y)
636-
return jnp.dot(x, y)
633+
def dot(x1, x2):
634+
x1 = convert_to_tensor(x1)
635+
x2 = convert_to_tensor(x2)
636+
return jnp.dot(x1, x2)
637637

638638

639639
def empty(shape, dtype=None):
@@ -982,9 +982,9 @@ def ravel(x):
982982
return jnp.ravel(x)
983983

984984

985-
def unravel_index(x, shape):
986-
x = convert_to_tensor(x)
987-
return jnp.unravel_index(x, shape)
985+
def unravel_index(indices, shape):
986+
indices = convert_to_tensor(indices)
987+
return jnp.unravel_index(indices, shape)
988988

989989

990990
@sparse.elementwise_unary(linear=True)
@@ -1212,7 +1212,7 @@ def vectorize(pyfunc, *, excluded=None, signature=None):
12121212
return jnp.vectorize(pyfunc, excluded=excluded, signature=signature)
12131213

12141214

1215-
def where(condition, x1, x2):
1215+
def where(condition, x1=None, x2=None):
12161216
return jnp.where(condition, x1, x2)
12171217

12181218

@@ -1357,5 +1357,5 @@ def argpartition(x, kth, axis=-1):
13571357
return jnp.argpartition(x, kth, axis)
13581358

13591359

1360-
def histogram(x, bins, range):
1360+
def histogram(x, bins=10, range=None):
13611361
return jnp.histogram(x, bins=bins, range=range)

keras/src/backend/numpy/core.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,14 @@ def scatter_update(inputs, indices, updates):
343343
return inputs
344344

345345

346-
def slice(inputs, start_indices, lengths):
346+
def slice(inputs, start_indices, shape):
347347
# Validate inputs
348-
assert len(start_indices) == len(lengths)
348+
assert len(start_indices) == len(shape)
349349

350350
# Generate list of indices arrays for each dimension
351351
indices = [
352352
np.arange(start, start + length)
353-
for start, length in zip(start_indices, lengths)
353+
for start, length in zip(start_indices, shape)
354354
]
355355

356356
# Use np.ix_ to create a multidimensional index array
@@ -407,8 +407,8 @@ def fori_loop(lower, upper, body_fun, init_val):
407407
return val
408408

409409

410-
def stop_gradient(x):
411-
return x
410+
def stop_gradient(variable):
411+
return variable
412412

413413

414414
def unstack(x, num=None, axis=0):

keras/src/backend/numpy/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False):
5252
)
5353

5454

55-
def top_k(x, k, sorted=False):
55+
def top_k(x, k, sorted=True):
5656
if sorted:
5757
# Take the k largest values.
5858
sorted_indices = np.argsort(x, axis=-1)[..., ::-1]

keras/src/backend/numpy/nn.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,9 @@ def elu(x, alpha=1.0):
125125
)
126126

127127

128-
def selu(
129-
x,
130-
alpha=1.6732632423543772848170429916717,
131-
scale=1.0507009873554804934193349852946,
132-
):
128+
def selu(x):
129+
alpha = 1.6732632423543772848170429916717
130+
scale = 1.0507009873554804934193349852946
133131
x = convert_to_tensor(x)
134132
return np.array(scale, x.dtype) * elu(x, alpha)
135133

@@ -196,20 +194,20 @@ def threshold(x, threshold, default_value):
196194
return np.where(x > threshold, x, np.array(default_value, dtype=x.dtype))
197195

198196

199-
def softmax(x, axis=None):
197+
def softmax(x, axis=-1):
200198
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
201199
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
202200

203201

204-
def log_softmax(x, axis=None):
202+
def log_softmax(x, axis=-1):
205203
max_x = np.max(x, axis=axis, keepdims=True)
206204
logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True))
207205
return x - max_x - logsumexp
208206

209207

210-
def sparsemax(logits, axis=-1):
208+
def sparsemax(x, axis=-1):
211209
# Sort logits along the specified axis in descending order
212-
logits = convert_to_tensor(logits)
210+
logits = convert_to_tensor(x)
213211
logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis)
214212
logits_cumsum = np.cumsum(logits_sorted, axis=axis)
215213
r = np.arange(1, logits.shape[axis] + 1)
@@ -304,8 +302,8 @@ def max_pool(
304302
def average_pool(
305303
inputs,
306304
pool_size,
307-
strides,
308-
padding,
305+
strides=None,
306+
padding="valid",
309307
data_format=None,
310308
):
311309
data_format = backend.standardize_data_format(data_format)
@@ -543,9 +541,11 @@ def conv_transpose(
543541
)
544542

545543

546-
def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
544+
def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
547545
if sparse:
548546
raise ValueError("Unsupported value `sparse=True` with numpy backend")
547+
if dtype is None:
548+
dtype = "float32"
549549
x = convert_to_tensor(x)
550550
input_shape = x.shape
551551

@@ -569,7 +569,7 @@ def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
569569
return categorical
570570

571571

572-
def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
572+
def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
573573
if sparse:
574574
raise ValueError("Unsupported value `sparse=True` with numpy backend")
575575
x = convert_to_tensor(x)

keras/src/backend/numpy/numpy.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def append(x1, x2, axis=None):
173173
return np.append(x1, x2, axis=axis)
174174

175175

176-
def arange(start, stop=None, step=None, dtype=None):
176+
def arange(start, stop=None, step=1, dtype=None):
177177
if dtype is None:
178178
dtypes_to_resolve = [
179179
getattr(start, "dtype", type(start)),
@@ -537,13 +537,13 @@ def digitize(x, bins):
537537
return np.digitize(x, bins).astype(np.int32)
538538

539539

540-
def dot(x, y):
541-
x = convert_to_tensor(x)
542-
y = convert_to_tensor(y)
543-
dtype = dtypes.result_type(x.dtype, y.dtype)
544-
x = x.astype(dtype)
545-
y = y.astype(dtype)
546-
return np.dot(x, y)
540+
def dot(x1, x2):
541+
x1 = convert_to_tensor(x1)
542+
x2 = convert_to_tensor(x2)
543+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
544+
x1 = x1.astype(dtype)
545+
x2 = x2.astype(dtype)
546+
return np.dot(x1, x2)
547547

548548

549549
def empty(shape, dtype=None):
@@ -898,10 +898,10 @@ def ravel(x):
898898
return np.ravel(x)
899899

900900

901-
def unravel_index(x, shape):
902-
dtype = dtypes.result_type(x.dtype)
901+
def unravel_index(indices, shape):
902+
dtype = dtypes.result_type(indices.dtype)
903903
return tuple(
904-
indices.astype(dtype) for indices in np.unravel_index(x, shape)
904+
indices.astype(dtype) for indices in np.unravel_index(indices, shape)
905905
)
906906

907907

@@ -1114,7 +1114,7 @@ def vectorize(pyfunc, *, excluded=None, signature=None):
11141114
return np.vectorize(pyfunc, excluded=excluded, signature=signature)
11151115

11161116

1117-
def where(condition, x1, x2):
1117+
def where(condition, x1=None, x2=None):
11181118
if x1 is not None and x2 is not None:
11191119
if not isinstance(x1, (int, float)):
11201120
x1 = convert_to_tensor(x1)
@@ -1283,5 +1283,5 @@ def argpartition(x, kth, axis=-1):
12831283
return np.argpartition(x, kth, axis).astype("int32")
12841284

12851285

1286-
def histogram(x, bins, range):
1286+
def histogram(x, bins=10, range=None):
12871287
return np.histogram(x, bins=bins, range=range)

keras/src/backend/openvino/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -565,21 +565,21 @@ def scatter_update(inputs, indices, updates):
565565
)
566566

567567

568-
def slice(inputs, start_indices, lengths):
568+
def slice(inputs, start_indices, shape):
569569
inputs = get_ov_output(inputs)
570570
assert isinstance(start_indices, tuple), (
571571
"`slice` is not supported by openvino backend"
572-
" for `start_indices` of type {}".format(type(lengths))
572+
" for `start_indices` of type {}".format(type(shape))
573573
)
574-
assert isinstance(lengths, tuple), (
574+
assert isinstance(shape, tuple), (
575575
"`slice` is not supported by openvino backend"
576-
" for `lengths` of type {}".format(type(lengths))
576+
" for `lengths` of type {}".format(type(shape))
577577
)
578578

579579
axes = []
580580
start = []
581581
stop = []
582-
for idx, length in enumerate(lengths):
582+
for idx, length in enumerate(shape):
583583
if length is not None and length >= 0:
584584
axes.append(idx)
585585
start.append(start_indices[idx])
@@ -621,8 +621,8 @@ def fori_loop(lower, upper, body_fun, init_val):
621621
)
622622

623623

624-
def stop_gradient(x):
625-
return x
624+
def stop_gradient(variable):
625+
return variable
626626

627627

628628
def unstack(x, num=None, axis=0):

keras/src/backend/openvino/image.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
def rgb_to_grayscale(image, data_format="channels_last"):
1+
def rgb_to_grayscale(images, data_format=None):
22
raise NotImplementedError(
33
"`rgb_to_grayscale` is not supported with openvino backend"
44
)
@@ -19,20 +19,20 @@ def resize(
1919

2020

2121
def affine_transform(
22-
image,
22+
images,
2323
transform,
2424
interpolation="bilinear",
2525
fill_mode="constant",
2626
fill_value=0,
27-
data_format="channels_last",
27+
data_format=None,
2828
):
2929
raise NotImplementedError(
3030
"`affine_transform` is not supported with openvino backend"
3131
)
3232

3333

3434
def map_coordinates(
35-
input, coordinates, order, fill_mode="constant", fill_value=0.0
35+
inputs, coordinates, order, fill_mode="constant", fill_value=0
3636
):
3737
raise NotImplementedError(
3838
"`map_coordinates` is not supported with openvino backend"

0 commit comments

Comments
 (0)