|
84 | 84 | ] |
85 | 85 |
|
86 | 86 |
|
| 87 | +def _isclose_scalar_tol(a, b, rtol, atol, equal_nan): |
| 88 | + """ |
| 89 | + Specialized implementation of dpnp.isclose() for scalar rtol and atol |
| 90 | + using a dedicated SYCL kernel. |
| 91 | + """ |
| 92 | + dt = dpnp.result_type(a, b, 1.0) |
| 93 | + |
| 94 | + if dpnp.isscalar(a): |
| 95 | + usm_type = b.usm_type |
| 96 | + exec_q = b.sycl_queue |
| 97 | + a = dpnp.array( |
| 98 | + a, |
| 99 | + dt, |
| 100 | + usm_type=usm_type, |
| 101 | + sycl_queue=exec_q, |
| 102 | + ) |
| 103 | + elif dpnp.isscalar(b): |
| 104 | + usm_type = a.usm_type |
| 105 | + exec_q = a.sycl_queue |
| 106 | + b = dpnp.array( |
| 107 | + b, |
| 108 | + dt, |
| 109 | + usm_type=usm_type, |
| 110 | + sycl_queue=exec_q, |
| 111 | + ) |
| 112 | + else: |
| 113 | + usm_type, exec_q = get_usm_allocations([a, b]) |
| 114 | + |
| 115 | + a = dpnp.astype(a, dt, casting="same_kind", copy=False) |
| 116 | + b = dpnp.astype(b, dt, casting="same_kind", copy=False) |
| 117 | + |
| 118 | + # Convert complex rtol/atol to to their real parts |
| 119 | + # to avoid pybind11 cast errors and match NumPy behavior |
| 120 | + if isinstance(rtol, complex): |
| 121 | + rtol = rtol.real |
| 122 | + if isinstance(atol, complex): |
| 123 | + atol = atol.real |
| 124 | + |
| 125 | + # pylint: disable=W0707 |
| 126 | + try: |
| 127 | + res_shape = dpnp.broadcast_shapes(a.shape, b.shape) |
| 128 | + except ValueError: |
| 129 | + raise ValueError( |
| 130 | + "operands could not be broadcast together with shapes " |
| 131 | + f"{a.shape} and {b.shape}" |
| 132 | + ) |
| 133 | + |
| 134 | + if a.shape != res_shape: |
| 135 | + a = dpnp.broadcast_to(a, res_shape) |
| 136 | + if b.shape != res_shape: |
| 137 | + b = dpnp.broadcast_to(b, res_shape) |
| 138 | + |
| 139 | + out_dtype = dpnp.bool |
| 140 | + output = dpnp.empty( |
| 141 | + res_shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type |
| 142 | + ) |
| 143 | + |
| 144 | + _manager = dpu.SequentialOrderManager[exec_q] |
| 145 | + mem_ev, ht_ev = ufi._isclose_scalar( |
| 146 | + a.get_array(), |
| 147 | + b.get_array(), |
| 148 | + rtol, |
| 149 | + atol, |
| 150 | + equal_nan, |
| 151 | + output.get_array(), |
| 152 | + exec_q, |
| 153 | + depends=_manager.submitted_events, |
| 154 | + ) |
| 155 | + _manager.add_event_pair(mem_ev, ht_ev) |
| 156 | + |
| 157 | + return output |
| 158 | + |
| 159 | + |
87 | 160 | def all(a, /, axis=None, out=None, keepdims=False, *, where=True): |
88 | 161 | """ |
89 | 162 | Test whether all array elements along a given axis evaluate to ``True``. |
@@ -874,72 +947,7 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): |
874 | 947 |
|
875 | 948 | # Use own SYCL kernel for scalar rtol/atol |
876 | 949 | if dpnp.isscalar(rtol) and dpnp.isscalar(atol): |
877 | | - dt = dpnp.result_type(a, b, 1.0) |
878 | | - |
879 | | - if dpnp.isscalar(a): |
880 | | - usm_type = b.usm_type |
881 | | - exec_q = b.sycl_queue |
882 | | - a = dpnp.array( |
883 | | - a, |
884 | | - dt, |
885 | | - usm_type=usm_type, |
886 | | - sycl_queue=exec_q, |
887 | | - ) |
888 | | - elif dpnp.isscalar(b): |
889 | | - usm_type = a.usm_type |
890 | | - exec_q = a.sycl_queue |
891 | | - b = dpnp.array( |
892 | | - b, |
893 | | - dt, |
894 | | - usm_type=usm_type, |
895 | | - sycl_queue=exec_q, |
896 | | - ) |
897 | | - else: |
898 | | - usm_type, exec_q = get_usm_allocations([a, b]) |
899 | | - |
900 | | - a = dpnp.astype(a, dt, casting="same_kind", copy=False) |
901 | | - b = dpnp.astype(b, dt, casting="same_kind", copy=False) |
902 | | - |
903 | | - # Convert complex rtol/atol to to their real parts |
904 | | - # to avoid pybind11 cast errors and match NumPy behavior |
905 | | - if isinstance(rtol, complex): |
906 | | - rtol = rtol.real |
907 | | - if isinstance(atol, complex): |
908 | | - atol = atol.real |
909 | | - |
910 | | - # pylint: disable=W0707 |
911 | | - try: |
912 | | - res_shape = dpnp.broadcast_shapes(a.shape, b.shape) |
913 | | - except ValueError: |
914 | | - raise ValueError( |
915 | | - "operands could not be broadcast together with shapes " |
916 | | - f"{a.shape} and {b.shape}" |
917 | | - ) |
918 | | - |
919 | | - if a.shape != res_shape: |
920 | | - a = dpnp.broadcast_to(a, res_shape) |
921 | | - if b.shape != res_shape: |
922 | | - b = dpnp.broadcast_to(b, res_shape) |
923 | | - |
924 | | - out_dtype = dpnp.bool |
925 | | - output = dpnp.empty( |
926 | | - res_shape, dtype=out_dtype, sycl_queue=exec_q, usm_type=usm_type |
927 | | - ) |
928 | | - |
929 | | - _manager = dpu.SequentialOrderManager[exec_q] |
930 | | - mem_ev, ht_ev = ufi._isclose_scalar( |
931 | | - a.get_array(), |
932 | | - b.get_array(), |
933 | | - rtol, |
934 | | - atol, |
935 | | - equal_nan, |
936 | | - output.get_array(), |
937 | | - exec_q, |
938 | | - depends=_manager.submitted_events, |
939 | | - ) |
940 | | - _manager.add_event_pair(mem_ev, ht_ev) |
941 | | - |
942 | | - return output |
| 950 | + return _isclose_scalar_tol(a, b, rtol, atol, equal_nan) |
943 | 951 |
|
944 | 952 | # make sure b is an inexact type to avoid bad behavior on abs(MIN_INT) |
945 | 953 | if dpnp.isscalar(b): |
|
0 commit comments