Skip to content

Commit 4b31460

Browse files
Merge pull request #883 from IntelPython/ctor_exception
dpctl.tensor.asarray must check numpy array data-type
2 parents 948bc65 + ce39457 commit 4b31460

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

.github/workflows/os-llvm-sycl-build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
cd /home/runner/work
4242
mkdir -p sycl_bundle
4343
cd sycl_bundle
44-
export LATEST_LLVM_TAG=$(git -c 'versionsort.suffix=-' ls-remote --tags --sort='v:refname' https://github.com/intel/llvm.git | tail --lines=1)
44+
export LATEST_LLVM_TAG=$(git -c 'versionsort.suffix=-' ls-remote --tags --sort='v:refname' https://github.com/intel/llvm.git | grep sycl-nightly | tail --lines=1)
4545
export LATEST_LLVM_TAG_SHA=$(echo ${LATEST_LLVM_TAG} | awk '{print $1}')
4646
export NIGHTLY_TAG=$(python3 -c "import sys, urllib.parse as ul; print (ul.quote_plus(sys.argv[1]))" \
4747
$(echo ${LATEST_LLVM_TAG} | awk '{gsub(/^refs\/tags\//, "", $2)} {print $2}'))

dpctl/tensor/_ctors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _get_dtype(dtype, sycl_obj, ref_type=None):
4444
dtype = ti.default_device_complex_type(sycl_obj)
4545
return np.dtype(dtype)
4646
else:
47-
raise ValueError(f"Reference type {ref_type} not recognized.")
47+
raise TypeError(f"Reference type {ref_type} not recognized.")
4848
else:
4949
return np.dtype(dtype)
5050

@@ -199,6 +199,11 @@ def _asarray_from_numpy_ndarray(
199199
if usm_type is None:
200200
usm_type = "device"
201201
copy_q = normalize_queue_device(sycl_queue=None, device=sycl_queue)
202+
if ary.dtype.char not in "?bBhHiIlLqQefdFD":
203+
raise TypeError(
204+
f"Numpy array of data type {ary.dtype} is not supported. "
205+
"Please convert the input to an array with numeric data type."
206+
)
202207
if dtype is None:
203208
ary_dtype = ary.dtype
204209
dtype = _get_dtype(dtype, copy_q, ref_type=ary_dtype)

dpctl/tests/test_tensor_asarray.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,17 @@ def test_asarray_scalars():
177177
Y = dpt.asarray(5)
178178
assert Y.dtype == np.dtype(int)
179179
Y = dpt.asarray(5.2)
180-
assert Y.dtype == np.dtype(float)
180+
if Y.sycl_device.has_aspect_fp64:
181+
assert Y.dtype == np.dtype(float)
182+
else:
183+
assert Y.dtype == np.dtype(np.float32)
181184
Y = dpt.asarray(np.float32(2.3))
182185
assert Y.dtype == np.dtype(np.float32)
183186
Y = dpt.asarray(1.0j)
184-
assert Y.dtype == np.dtype(complex)
187+
if Y.sycl_device.has_aspect_fp64:
188+
assert Y.dtype == np.dtype(complex)
189+
else:
190+
assert Y.dtype == np.dtype(np.complex64)
185191
Y = dpt.asarray(ctypes.c_int(8))
186192
assert Y.dtype == np.dtype(ctypes.c_int)
187193

@@ -220,3 +226,13 @@ def test_asarray_copy_false():
220226
assert Y6 is Xf
221227
with pytest.raises(ValueError):
222228
dpt.asarray(Xf, copy=False, order="C")
229+
230+
231+
def test_asarray_invalid_dtype():
232+
try:
233+
q = dpctl.SyclQueue()
234+
except dpctl.SyclQueueCreationError:
235+
pytest.skip("Could not create a queue")
236+
Xnp = np.array([1, 2, 3], dtype=object)
237+
with pytest.raises(TypeError):
238+
dpt.asarray(Xnp, sycl_queue=q)

0 commit comments

Comments
 (0)