Skip to content

Commit 503bcf5

Browse files
authored
Implement heaviside function in keras.ops (#21474)
* Add heaviside for numpy.py * Add heaviside for numpy.py * Add heaviside for ops * Add test case * Add test cases * Correct code by gemini assist * Correct code by gemini assist * update excluded_concrete_tests.txt * Update description to use backticks
1 parent a9688b4 commit 503bcf5

File tree

12 files changed

+157
-0
lines changed

12 files changed

+157
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
from keras.src.ops.numpy import greater_equal as greater_equal
192192
from keras.src.ops.numpy import hamming as hamming
193193
from keras.src.ops.numpy import hanning as hanning
194+
from keras.src.ops.numpy import heaviside as heaviside
194195
from keras.src.ops.numpy import histogram as histogram
195196
from keras.src.ops.numpy import hstack as hstack
196197
from keras.src.ops.numpy import identity as identity

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from keras.src.ops.numpy import greater_equal as greater_equal
8181
from keras.src.ops.numpy import hamming as hamming
8282
from keras.src.ops.numpy import hanning as hanning
83+
from keras.src.ops.numpy import heaviside as heaviside
8384
from keras.src.ops.numpy import histogram as histogram
8485
from keras.src.ops.numpy import hstack as hstack
8586
from keras.src.ops.numpy import identity as identity

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
from keras.src.ops.numpy import greater_equal as greater_equal
192192
from keras.src.ops.numpy import hamming as hamming
193193
from keras.src.ops.numpy import hanning as hanning
194+
from keras.src.ops.numpy import heaviside as heaviside
194195
from keras.src.ops.numpy import histogram as histogram
195196
from keras.src.ops.numpy import hstack as hstack
196197
from keras.src.ops.numpy import identity as identity

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from keras.src.ops.numpy import greater_equal as greater_equal
8181
from keras.src.ops.numpy import hamming as hamming
8282
from keras.src.ops.numpy import hanning as hanning
83+
from keras.src.ops.numpy import heaviside as heaviside
8384
from keras.src.ops.numpy import histogram as histogram
8485
from keras.src.ops.numpy import hstack as hstack
8586
from keras.src.ops.numpy import identity as identity

keras/src/backend/jax/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ def hanning(x):
5252
return jnp.hanning(x)
5353

5454

55+
def heaviside(x1, x2):
56+
x1 = convert_to_tensor(x1)
57+
x2 = convert_to_tensor(x2)
58+
return jnp.heaviside(x1, x2)
59+
60+
5561
def kaiser(x, beta):
5662
x = convert_to_tensor(x)
5763
return jnp.kaiser(x, beta)

keras/src/backend/numpy/numpy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,19 @@ def hanning(x):
320320
return np.hanning(x).astype(config.floatx())
321321

322322

323+
def heaviside(x1, x2):
324+
x1 = convert_to_tensor(x1)
325+
x2 = convert_to_tensor(x2)
326+
327+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
328+
if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
329+
dtype = config.floatx()
330+
elif dtype in ["int64"]:
331+
dtype = "float64"
332+
333+
return np.heaviside(x1, x2).astype(dtype)
334+
335+
323336
def kaiser(x, beta):
324337
x = convert_to_tensor(x)
325338
return np.kaiser(x, beta).astype(config.floatx())

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ NumpyDtypeTest::test_bartlett
1212
NumpyDtypeTest::test_blackman
1313
NumpyDtypeTest::test_hamming
1414
NumpyDtypeTest::test_hanning
15+
NumpyDtypeTest::test_heaviside
1516
NumpyDtypeTest::test_kaiser
1617
NumpyDtypeTest::test_bitwise
1718
NumpyDtypeTest::test_cbrt
@@ -145,6 +146,7 @@ NumpyTwoInputOpsCorrectnessTest::test_cross
145146
NumpyTwoInputOpsCorrectnessTest::test_digitize
146147
NumpyTwoInputOpsCorrectnessTest::test_divide_no_nan
147148
NumpyTwoInputOpsCorrectnessTest::test_einsum
149+
NumpyTwoInputOpsCorrectnessTest::test_heaviside
148150
NumpyTwoInputOpsCorrectnessTest::test_inner
149151
NumpyTwoInputOpsCorrectnessTest::test_linspace
150152
NumpyTwoInputOpsCorrectnessTest::test_logspace
@@ -164,6 +166,8 @@ NumpyOneInputOpsDynamicShapeTest::test_kaiser
164166
NumpyOneInputOpsStaticShapeTest::test_angle
165167
NumpyOneInputOpsStaticShapeTest::test_cbrt
166168
NumpyOneInputOpsStaticShapeTest::test_deg2rad
169+
NumpyTwoInputOpsDynamicShapeTest::test_heaviside
170+
NumpyTwoInputOpsStaticShapeTest::test_heaviside
167171
CoreOpsBehaviorTests::test_associative_scan_invalid_arguments
168172
CoreOpsBehaviorTests::test_scan_invalid_arguments
169173
CoreOpsCallsTests::test_associative_scan_basic_call

keras/src/backend/openvino/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ def hamming(x):
481481
)
482482

483483

484+
def heaviside(x1, x2):
485+
raise NotImplementedError(
486+
"`heaviside` is not supported with openvino backend"
487+
)
488+
489+
484490
def kaiser(x, beta):
485491
raise NotImplementedError("`kaiser` is not supported with openvino backend")
486492

keras/src/backend/tensorflow/numpy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,26 @@ def hanning(x):
156156
return tf.signal.hann_window(x, periodic=False)
157157

158158

159+
def heaviside(x1, x2):
160+
x1 = convert_to_tensor(x1)
161+
x2 = convert_to_tensor(x2)
162+
163+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
164+
if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
165+
dtype = config.floatx()
166+
elif dtype in ["int64"]:
167+
dtype = "float64"
168+
169+
x1 = tf.cast(x1, dtype)
170+
x2 = tf.cast(x2, dtype)
171+
172+
return tf.where(
173+
x1 < 0,
174+
tf.zeros_like(x1),
175+
tf.where(x1 > 0, tf.ones_like(x1), x2),
176+
)
177+
178+
159179
def kaiser(x, beta):
160180
x = convert_to_tensor(x, dtype=tf.int32)
161181
return tf.signal.kaiser_window(x, beta=beta)

keras/src/backend/torch/numpy.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,22 @@ def hanning(x):
445445
return torch.signal.windows.hann(x)
446446

447447

448+
def heaviside(x1, x2):
449+
x1 = convert_to_tensor(x1)
450+
x2 = convert_to_tensor(x2)
451+
452+
dtype = dtypes.result_type(x1.dtype, x2.dtype)
453+
if dtype in ["int8", "int16", "int32", "uint8", "uint16", "uint32"]:
454+
dtype = config.floatx()
455+
elif dtype == "int64":
456+
dtype = "float64"
457+
458+
x1 = cast(x1, dtype)
459+
x2 = cast(x2, dtype)
460+
461+
return torch.heaviside(x1, x2)
462+
463+
448464
def kaiser(x, beta):
449465
x = convert_to_tensor(x)
450466
return torch.signal.windows.kaiser(x, beta=beta)

0 commit comments

Comments
 (0)