Skip to content

Commit 525c65b

Browse files
authored
Implement angle function in keras.ops (#21200)
* Add first version of angle operation on numpy * Skip test with bfloat16 on numpy * Remove bfloat16 checking on Angle * Fix test case for float16 on torch cuda * exclude openvino test case * exclude openvino test case * exclude openvino test case * Update init files
1 parent 55adc57 commit 525c65b

File tree

11 files changed

+117
-0
lines changed

11 files changed

+117
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
from keras.src.ops.numpy import all as all
119119
from keras.src.ops.numpy import amax as amax
120120
from keras.src.ops.numpy import amin as amin
121+
from keras.src.ops.numpy import angle as angle
121122
from keras.src.ops.numpy import any as any
122123
from keras.src.ops.numpy import append as append
123124
from keras.src.ops.numpy import arange as arange

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.ops.numpy import all as all
1111
from keras.src.ops.numpy import amax as amax
1212
from keras.src.ops.numpy import amin as amin
13+
from keras.src.ops.numpy import angle as angle
1314
from keras.src.ops.numpy import any as any
1415
from keras.src.ops.numpy import append as append
1516
from keras.src.ops.numpy import arange as arange

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
from keras.src.ops.numpy import all as all
119119
from keras.src.ops.numpy import amax as amax
120120
from keras.src.ops.numpy import amin as amin
121+
from keras.src.ops.numpy import angle as angle
121122
from keras.src.ops.numpy import any as any
122123
from keras.src.ops.numpy import append as append
123124
from keras.src.ops.numpy import arange as arange

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.ops.numpy import all as all
1111
from keras.src.ops.numpy import amax as amax
1212
from keras.src.ops.numpy import amin as amin
13+
from keras.src.ops.numpy import angle as angle
1314
from keras.src.ops.numpy import any as any
1415
from keras.src.ops.numpy import append as append
1516
from keras.src.ops.numpy import arange as arange

keras/src/backend/jax/numpy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,16 @@ def all(x, axis=None, keepdims=False):
246246
return jnp.all(x, axis=axis, keepdims=keepdims)
247247

248248

249+
def angle(x):
250+
x = convert_to_tensor(x)
251+
if standardize_dtype(x.dtype) == "int64":
252+
dtype = config.floatx()
253+
else:
254+
dtype = dtypes.result_type(x.dtype, float)
255+
x = cast(x, dtype)
256+
return jnp.angle(x)
257+
258+
249259
def any(x, axis=None, keepdims=False):
250260
return jnp.any(x, axis=axis, keepdims=keepdims)
251261

keras/src/backend/numpy/numpy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,16 @@ def all(x, axis=None, keepdims=False):
138138
return np.all(x, axis=axis, keepdims=keepdims)
139139

140140

141+
def angle(x):
142+
x = convert_to_tensor(x)
143+
if standardize_dtype(x.dtype) == "int64":
144+
dtype = config.floatx()
145+
else:
146+
dtype = dtypes.result_type(x.dtype, float)
147+
x = x.astype(dtype)
148+
return np.angle(x)
149+
150+
141151
def any(x, axis=None, keepdims=False):
142152
axis = standardize_axis_for_numpy(axis)
143153
return np.any(x, axis=axis, keepdims=keepdims)

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ NumpyArrayCreateOpsCorrectnessTest::test_tri
44
NumpyDtypeTest::test_absolute_bool
55
NumpyDtypeTest::test_add_
66
NumpyDtypeTest::test_all
7+
NumpyDtypeTest::test_angle
78
NumpyDtypeTest::test_any
89
NumpyDtypeTest::test_argpartition
910
NumpyDtypeTest::test_array
@@ -71,6 +72,7 @@ NumpyDtypeTest::test_clip_bool
7172
NumpyDtypeTest::test_square_bool
7273
HistogramTest
7374
NumpyOneInputOpsCorrectnessTest::test_all
75+
NumpyOneInputOpsCorrectnessTest::test_angle
7476
NumpyOneInputOpsCorrectnessTest::test_any
7577
NumpyOneInputOpsCorrectnessTest::test_argpartition
7678
NumpyOneInputOpsCorrectnessTest::test_array
@@ -150,3 +152,5 @@ NumpyTwoInputOpsCorrectnessTest::test_take_along_axis
150152
NumpyTwoInputOpsCorrectnessTest::test_tensordot
151153
NumpyTwoInputOpsCorrectnessTest::test_vdot
152154
NumpyTwoInputOpsCorrectnessTest::test_where
155+
NumpyOneInputOpsDynamicShapeTest::test_angle
156+
NumpyOneInputOpsStaticShapeTest::test_angle

keras/src/backend/tensorflow/numpy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,16 @@ def all(x, axis=None, keepdims=False):
728728
return tf.reduce_all(x, axis=axis, keepdims=keepdims)
729729

730730

731+
def angle(x):
732+
x = convert_to_tensor(x)
733+
if standardize_dtype(x.dtype) == "int64":
734+
dtype = config.floatx()
735+
else:
736+
dtype = dtypes.result_type(x.dtype, float)
737+
x = tf.cast(x, dtype)
738+
return tf.math.angle(x)
739+
740+
731741
def any(x, axis=None, keepdims=False):
732742
x = tf.cast(x, "bool")
733743
return tf.reduce_any(x, axis=axis, keepdims=keepdims)

keras/src/backend/torch/numpy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,17 @@ def all(x, axis=None, keepdims=False):
264264
return cast(x, "bool")
265265

266266

267+
def angle(x):
268+
x = convert_to_tensor(x)
269+
ori_dtype = standardize_dtype(x.dtype)
270+
271+
# torch.angle doesn't support float16 with cuda
272+
if get_device() != "cpu" and ori_dtype == "float16":
273+
x = cast(x, "float32")
274+
return cast(torch.angle(x), "float16")
275+
return torch.angle(x)
276+
277+
267278
def any(x, axis=None, keepdims=False):
268279
x = convert_to_tensor(x)
269280
if axis is None:

keras/src/ops/numpy.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,41 @@ def compute_output_spec(self, x):
328328
)
329329

330330

331+
class Angle(Operation):
332+
def call(self, x):
333+
return backend.numpy.angle(x)
334+
335+
def compute_output_spec(self, x):
336+
dtype = backend.standardize_dtype(getattr(x, "dtype", backend.floatx()))
337+
if dtype == "int64":
338+
dtype = backend.floatx()
339+
else:
340+
dtype = dtypes.result_type(dtype, float)
341+
return KerasTensor(x.shape, dtype=dtype)
342+
343+
344+
@keras_export(["keras.ops.angle", "keras.ops.numpy.angle"])
345+
def angle(x):
346+
"""Element-wise angle of a complex tensor.
347+
348+
Arguments:
349+
x: Input tensor. Can be real or complex.
350+
351+
Returns:
352+
Output tensor of same shape as x. containing the angle of each element
353+
(in radians).
354+
355+
Example:
356+
>>> x = keras.ops.convert_to_tensor([[1 + 3j, 2 - 5j], [4 - 3j, 3 + 2j]])
357+
>>> keras.ops.angle(x)
358+
array([[ 1.2490457, -1.19029 ],
359+
[-0.6435011, 0.5880026]], dtype=float32)
360+
"""
361+
if any_symbolic_tensors((x,)):
362+
return Angle().symbolic_call(x)
363+
return backend.numpy.angle(x)
364+
365+
331366
@keras_export(["keras.ops.any", "keras.ops.numpy.any"])
332367
def any(x, axis=None, keepdims=False):
333368
"""Test whether any array element along a given axis evaluates to `True`.

0 commit comments

Comments
 (0)