Skip to content

Commit c7c7011

Browse files
Refactor isclose() by using separate _isclose_scalar_tol() for scalar rtol/atol
1 parent 725891c commit c7c7011

File tree

1 file changed

+74
-66
lines changed

1 file changed

+74
-66
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 74 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,79 @@
8484
]
8585

8686

87+
def _isclose_scalar_tol(a, b, rtol, atol, equal_nan):
88+
"""
89+
Specialized implementation of dpnp.isclose() for scalar rtol and atol
90+
using a dedicated SYCL kernel.
91+
"""
92+
dt = dpnp.result_type(a, b, 1.0)
93+
94+
if dpnp.isscalar(a):
95+
usm_type = b.usm_type
96+
exec_q = b.sycl_queue
97+
a = dpnp.array(
98+
a,
99+
dt,
100+
usm_type=usm_type,
101+
sycl_queue=exec_q,
102+
)
103+
elif dpnp.isscalar(b):
104+
usm_type = a.usm_type
105+
exec_q = a.sycl_queue
106+
b = dpnp.array(
107+
b,
108+
dt,
109+
usm_type=usm_type,
110+
sycl_queue=exec_q,
111+
)
112+
else:
113+
usm_type, exec_q = get_usm_allocations([a, b])
114+
115+
a = dpnp.astype(a, dt, casting="same_kind", copy=False)
116+
b = dpnp.astype(b, dt, casting="same_kind", copy=False)
117+
118+
# Convert complex rtol/atol to to their real parts
119+
# to avoid pybind11 cast errors and match NumPy behavior
120+
if isinstance(rtol, complex):
121+
rtol = rtol.real
122+
if isinstance(atol, complex):
123+
atol = atol.real
124+
125+
# pylint: disable=W0707
126+
try:
127+
res_shape = dpnp.broadcast_shapes(a.shape, b.shape)
128+
except ValueError:
129+
raise ValueError(
130+
"operands could not be broadcast together with shapes "
131+
f"{a.shape} and {b.shape}"
132+
)
133+
134+
if a.shape != res_shape:
135+
a = dpnp.broadcast_to(a, res_shape)
136+
if b.shape != res_shape:
137+
b = dpnp.broadcast_to(b, res_shape)
138+
139+
out_dtype = dpnp.bool
140+
output = dpnp.empty(
141+
res_shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type
142+
)
143+
144+
_manager = dpu.SequentialOrderManager[exec_q]
145+
mem_ev, ht_ev = ufi._isclose_scalar(
146+
a.get_array(),
147+
b.get_array(),
148+
rtol,
149+
atol,
150+
equal_nan,
151+
output.get_array(),
152+
exec_q,
153+
depends=_manager.submitted_events,
154+
)
155+
_manager.add_event_pair(mem_ev, ht_ev)
156+
157+
return output
158+
159+
87160
def all(a, /, axis=None, out=None, keepdims=False, *, where=True):
88161
"""
89162
Test whether all array elements along a given axis evaluate to ``True``.
@@ -874,72 +947,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
874947

875948
# Use own SYCL kernel for scalar rtol/atol
876949
if dpnp.isscalar(rtol) and dpnp.isscalar(atol):
877-
dt = dpnp.result_type(a, b, 1.0)
878-
879-
if dpnp.isscalar(a):
880-
usm_type = b.usm_type
881-
exec_q = b.sycl_queue
882-
a = dpnp.array(
883-
a,
884-
dt,
885-
usm_type=usm_type,
886-
sycl_queue=exec_q,
887-
)
888-
elif dpnp.isscalar(b):
889-
usm_type = a.usm_type
890-
exec_q = a.sycl_queue
891-
b = dpnp.array(
892-
b,
893-
dt,
894-
usm_type=usm_type,
895-
sycl_queue=exec_q,
896-
)
897-
else:
898-
usm_type, exec_q = get_usm_allocations([a, b])
899-
900-
a = dpnp.astype(a, dt, casting="same_kind", copy=False)
901-
b = dpnp.astype(b, dt, casting="same_kind", copy=False)
902-
903-
# Convert complex rtol/atol to to their real parts
904-
# to avoid pybind11 cast errors and match NumPy behavior
905-
if isinstance(rtol, complex):
906-
rtol = rtol.real
907-
if isinstance(atol, complex):
908-
atol = atol.real
909-
910-
# pylint: disable=W0707
911-
try:
912-
res_shape = dpnp.broadcast_shapes(a.shape, b.shape)
913-
except ValueError:
914-
raise ValueError(
915-
"operands could not be broadcast together with shapes "
916-
f"{a.shape} and {b.shape}"
917-
)
918-
919-
if a.shape != res_shape:
920-
a = dpnp.broadcast_to(a, res_shape)
921-
if b.shape != res_shape:
922-
b = dpnp.broadcast_to(b, res_shape)
923-
924-
out_dtype = dpnp.bool
925-
output = dpnp.empty(
926-
res_shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type
927-
)
928-
929-
_manager = dpu.SequentialOrderManager[exec_q]
930-
mem_ev, ht_ev = ufi._isclose_scalar(
931-
a.get_array(),
932-
b.get_array(),
933-
rtol,
934-
atol,
935-
equal_nan,
936-
output.get_array(),
937-
exec_q,
938-
depends=_manager.submitted_events,
939-
)
940-
_manager.add_event_pair(mem_ev, ht_ev)
941-
942-
return output
950+
return _isclose_scalar_tol(a, b, rtol, atol, equal_nan)
943951

944952
# make sure b is an inexact type to avoid bad behavior on abs(MIN_INT)
945953
if dpnp.isscalar(b):

0 commit comments

Comments
 (0)