Skip to content

Commit 0dbc7e9

Browse files
authored
Update LinAlg functions from BLAS routine (#1919)
1 parent 8ac7f88 commit 0dbc7e9

File tree

2 files changed

+120
-218
lines changed

2 files changed

+120
-218
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,7 @@ def clip(a, a_min, a_max, *, out=None, order="K", **kwargs):
673673

674674
if kwargs:
675675
raise NotImplementedError(f"kwargs={kwargs} is currently not supported")
676+
676677
if a_min is None and a_max is None:
677678
raise ValueError("One of max or min must be given")
678679

@@ -923,11 +924,13 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
923924
if not isinstance(axis, int):
924925
raise TypeError(f"axis should be an integer but got, {type(axis)}.")
925926
axisa, axisb, axisc = (axis,) * 3
927+
926928
dpnp.check_supported_arrays_type(a, b)
927929
if a.dtype == dpnp.bool and b.dtype == dpnp.bool:
928930
raise TypeError(
929931
"Input arrays with boolean data type are not supported."
930932
)
933+
931934
# Check axisa and axisb are within bounds
932935
axisa = normalize_axis_index(axisa, a.ndim, msg_prefix="axisa")
933936
axisb = normalize_axis_index(axisb, b.ndim, msg_prefix="axisb")
@@ -944,6 +947,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
944947
# Modify the shape of input arrays if necessary
945948
a_shape = a.shape
946949
b_shape = b.shape
950+
947951
# TODO: replace with dpnp.broadcast_shapes once implemented
948952
res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1])
949953
if a_shape[:-1] != res_shape:
@@ -957,6 +961,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
957961
res_shape += (3,)
958962
# Check axisc is within bounds
959963
axisc = normalize_axis_index(axisc, len(res_shape), msg_prefix="axisc")
964+
960965
# Create the output array
961966
dtype = dpnp.result_type(a, b)
962967
res_usm_type, exec_q = get_usm_allocations([a, b])
@@ -968,7 +973,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
968973
a = a.astype(dtype, copy=False)
969974
b = b.astype(dtype, copy=False)
970975

971-
cp = dpnp_cross(a, b, cp, exec_q)
976+
cp = dpnp_cross(a, b, cp)
972977
if a_shape[-1] == 2 and b_shape[-1] == 2:
973978
return cp
974979

@@ -3184,6 +3189,9 @@ def sum(
31843189
sycl_sum = get_sum(input, output)
31853190

31863191
if sycl_sum:
3192+
# TODO: pass dep events into _get_sum_over_axis_0 to remove sync
3193+
dpnp.synchronize_array_data(input)
3194+
31873195
sycl_sum(input, output, []).wait()
31883196
result = dpnp_array._create_from_usm_ndarray(output)
31893197

0 commit comments

Comments
 (0)