Skip to content

Commit 48c8051

Browse files
committed
Replace branching for fill scalar type with _cast_fill_value
1 parent 4ba7186 commit 48c8051

File tree

1 file changed

+2
-9
lines changed

1 file changed

+2
-9
lines changed

dpnp/dpnp_algo/dpnp_fill.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,11 @@ def dpnp_fill(arr, val):
6464
raise TypeError(
6565
f"array cannot be filled with `val` of type {type(val)}"
6666
)
67-
68-
dt = arr.dtype
69-
val_type = type(val)
70-
if val_type in [float, complex] and dpnp.issubdtype(dt, dpnp.integer):
71-
val = int(val.real)
72-
elif val_type is complex and dpnp.issubdtype(dt, dpnp.floating):
73-
val = val.real
74-
elif val_type is int and dpnp.issubdtype(dt, dpnp.integer):
75-
val = _cast_fill_val(val, dt)
67+
val = _cast_fill_val(val, arr.dtype)
7668

7769
_manager = dpu.SequentialOrderManager[exec_q]
7870
dep_evs = _manager.submitted_events
71+
7972
# can leverage efficient memset when val is 0
8073
if arr.flags["FORC"] and val == 0:
8174
h_ev, zeros_ev = _zeros_usm_ndarray(arr, exec_q, depends=dep_evs)

0 commit comments

Comments
 (0)