Skip to content

Commit 79685b6

Browse files
Refactor ops.core tests and update the backend. (#21466)
1 parent 5aef6e4 commit 79685b6

File tree

5 files changed

+976
-932
lines changed

5 files changed

+976
-932
lines changed

keras/src/backend/openvino/core.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -584,17 +584,6 @@ def _is_scalar(elem):
584584
return not isinstance(elem, (list, tuple, set, dict))
585585

586586

587-
def _get_first_element(x):
588-
if isinstance(x, (tuple, list)):
589-
for elem_in_x in x:
590-
elem = _get_first_element(elem_in_x)
591-
if elem is not None:
592-
return elem
593-
elif _is_scalar(x):
594-
return x
595-
return None
596-
597-
598587
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
599588
if sparse:
600589
raise ValueError("`sparse=True` is not supported with openvino backend")
@@ -603,24 +592,29 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
603592
if dtype is not None:
604593
dtype = standardize_dtype(dtype)
605594
if isinstance(x, OpenVINOKerasTensor):
595+
if dtype and dtype != standardize_dtype(x.dtype):
596+
x = cast(x, dtype)
606597
return x
607598
elif isinstance(x, np.ndarray):
608599
if dtype is not None:
609600
ov_type = OPENVINO_DTYPES[dtype]
610-
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))
611-
return OpenVINOKerasTensor(ov_opset.constant(x).output(0))
601+
else:
602+
ov_type = OPENVINO_DTYPES[standardize_dtype(x.dtype)]
603+
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))
612604
elif isinstance(x, (list, tuple)):
613605
if dtype is None:
614-
# try to properly deduce element type
615-
elem = _get_first_element(x)
616-
if isinstance(elem, float):
617-
dtype = "float32"
618-
elif isinstance(elem, int):
619-
dtype = "int32"
606+
dtype = result_type(
607+
*[
608+
getattr(item, "dtype", type(item))
609+
for item in tree.flatten(x)
610+
]
611+
)
620612
x = np.array(x, dtype=dtype)
621-
return OpenVINOKerasTensor(ov_opset.constant(x).output(0), x)
613+
ov_type = OPENVINO_DTYPES[dtype]
614+
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)
622615
elif isinstance(x, (float, int, bool)):
623-
dtype = standardize_dtype(dtype)
616+
if dtype is None:
617+
dtype = standardize_dtype(type(x))
624618
ov_type = OPENVINO_DTYPES[dtype]
625619
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)
626620
if isinstance(x, Variable):

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -186,13 +186,3 @@ CoreOpsCorrectnessTest::test_slice_update
186186
CoreOpsCorrectnessTest::test_switch
187187
CoreOpsCorrectnessTest::test_unstack
188188
CoreOpsCorrectnessTest::test_vectorized_map
189-
CoreOpsDtypeTest::test_convert_to_tensor0
190-
CoreOpsDtypeTest::test_convert_to_tensor1
191-
CoreOpsDtypeTest::test_convert_to_tensor2
192-
CoreOpsDtypeTest::test_convert_to_tensor3
193-
CoreOpsDtypeTest::test_convert_to_tensor8
194-
CoreOpsDtypeTest::test_convert_to_tensor11
195-
CoreOpsDtypeTest::test_convert_to_tensor12
196-
CoreOpsDtypeTest::test_convert_to_tensor14
197-
CoreOpsDtypeTest::test_convert_to_tensor25
198-
CoreOpsDtypeTest::test_convert_to_tensor37

keras/src/backend/torch/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def scatter(indices, values, shape):
575575
def scatter_update(inputs, indices, updates):
576576
inputs = convert_to_tensor(inputs)
577577
indices = convert_to_tensor(indices, dtype="int64")
578-
updates = convert_to_tensor(updates)
578+
updates = convert_to_tensor(updates, dtype=inputs.dtype)
579579
indices = torch.transpose(indices, 0, 1)
580580

581581
outputs = torch.clone(inputs)

keras/src/ops/core.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,9 @@ def call(self, loop_vars):
542542
)
543543

544544
def compute_output_spec(self, loop_vars):
545-
return [KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars]
545+
return tree.map_structure(
546+
lambda v: KerasTensor(v.shape, dtype=v.dtype), loop_vars
547+
)
546548

547549

548550
@keras_export("keras.ops.while_loop")
@@ -587,6 +589,10 @@ def while_loop(
587589
>>> keras.ops.while_loop(cond, body, (x, y))
588590
10, 11
589591
"""
592+
if any_symbolic_tensors((loop_vars,)):
593+
return WhileLoop(
594+
cond, body, maximum_iterations=maximum_iterations
595+
).symbolic_call(loop_vars)
590596
return backend.core.while_loop(
591597
cond,
592598
body,
@@ -808,8 +814,6 @@ def cast(x, dtype):
808814
>>> x = keras.ops.arange(4)
809815
>>> x = keras.ops.cast(x, dtype="float16")
810816
"""
811-
dtype = backend.standardize_dtype(dtype)
812-
813817
if any_symbolic_tensors((x,)):
814818
return Cast(dtype=dtype)(x)
815819
return backend.core.cast(x, dtype)
@@ -874,8 +878,6 @@ def saturate_cast(x, dtype):
874878
>>> # [255 255 255 255]]
875879
876880
"""
877-
dtype = backend.standardize_dtype(dtype)
878-
879881
if any_symbolic_tensors((x,)):
880882
return SaturateCast(dtype=dtype)(x)
881883
return _saturate_cast(x, dtype)

0 commit comments

Comments
 (0)