Skip to content

Commit 3d9f858

Browse files
committed
Added support for arrays for fill_falue for full() function
1 parent 2c86ef6 commit 3d9f858

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

dpctl/tensor/_ctors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,8 @@ def full(
763763
order = order[0].upper()
764764
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
765765
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
766+
if dtype is None and isinstance(fill_value, (dpt.usm_ndarray, np.ndarray)):
767+
dtype = fill_value.dtype
766768
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
767769
res = dpt.usm_ndarray(
768770
sh,

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,17 @@ def test_full_dtype_inference():
989989
assert np.issubdtype(dpt.full(10, 0.3 - 2j).dtype, np.complexfloating)
990990

991991

992+
def test_full_fill_array():
993+
q = get_queue_or_skip()
994+
995+
dtype = np.float16
996+
X = dpt.full(10, dpt.usm_ndarray(1, dtype=dtype), sycl_queue=q)
997+
assert dtype == X.dtype
998+
999+
X = dpt.full(10, np.ndarray(1, dtype=dtype), sycl_queue=q)
1000+
assert dtype == X.dtype
1001+
1002+
9921003
@pytest.mark.parametrize(
9931004
"dt",
9941005
_all_dtypes[1:],

0 commit comments

Comments
 (0)