Skip to content

Commit d8704f0

Browse files
Merge pull request #1002 from IntelPython/fixed_dtype_full
Fixed error in cast dtype for full() function.
2 parents a2edadb + e7e8508 commit d8704f0

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

dpctl/tensor/_ctors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,14 +774,20 @@ def full(
774774

775775
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
776776
usm_type = usm_type if usm_type is not None else "device"
777-
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
777+
fill_value_type = type(fill_value)
778+
dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type)
778779
res = dpt.usm_ndarray(
779780
sh,
780781
dtype=dtype,
781782
buffer=usm_type,
782783
order=order,
783784
buffer_ctor_kwargs={"queue": sycl_queue},
784785
)
786+
if fill_value_type in [float, complex] and np.issubdtype(dtype, np.integer):
787+
fill_value = int(fill_value.real)
788+
elif fill_value_type is complex and np.issubdtype(dtype, np.floating):
789+
fill_value = fill_value.real
790+
785791
hev, _ = ti._full_usm_ndarray(fill_value, res, sycl_queue)
786792
hev.wait()
787793
return res

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,10 @@ def test_full_dtype_inference():
10011001
assert np.issubdtype(dpt.full(10, 12.3).dtype, np.floating)
10021002
assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating)
10031003

1004+
assert np.issubdtype(dpt.full(10, 12.3, dtype=int).dtype, np.integer)
1005+
assert np.issubdtype(dpt.full(10, 0.3 - 2j, dtype=int).dtype, np.integer)
1006+
assert np.issubdtype(dpt.full(10, 0.3 - 2j, dtype=float).dtype, np.floating)
1007+
10041008

10051009
def test_full_fill_array():
10061010
q = get_queue_or_skip()

0 commit comments

Comments
 (0)