We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 0d6432d + 56788ad commit cf7bc9fCopy full SHA for cf7bc9f
array_api_tests/test_linalg.py
@@ -961,14 +961,6 @@ def true_trace(x_stack, offset=0):
961
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
962
963
964
-def _conj(x):
965
- # XXX: replace with xp.dtype when all array libraries implement it
966
- if x.dtype in (xp.complex64, xp.complex128):
967
- return xp.conj(x)
968
- else:
969
- return x
970
-
971
972
def _test_vecdot(namespace, x1, x2, data):
973
vecdot = namespace.vecdot
974
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
@@ -994,7 +986,7 @@ def _test_vecdot(namespace, x1, x2, data):
994
986
out_shape=res.shape, expected=expected_shape)
995
987
996
988
def true_val(x, y, axis=-1):
997
- return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype)
989
+ return xp.sum(xp.multiply(xp.conj(x), y), dtype=res.dtype)
998
990
999
991
_test_stacks(vecdot, x1, x2, res=res, dims=0,
1000
992
matrix_axes=(axis,), true_val=true_val)
0 commit comments