Skip to content

Commit 387fbc9

Browse files
[OpenVINO backend] fix __getitem__ and convert_to_tensor issues (#21545)
* [OpenVINO backend] fix __getitem__ and convert_to_tensor issues * fix ov_type=none * fix consistancy
1 parent 8b7c90e commit 387fbc9

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

keras/src/backend/openvino/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def count_unsqueeze_before(dim):
368368
if not (0 <= actual_dim < rank):
369369
raise IndexError(
370370
f"Index {index} is out of bounds for "
371-
"axis {dim} with rank {rank}"
371+
f"axis {dim} with rank {rank}"
372372
)
373373
length = ov_opset.gather(
374374
partial_shape,
@@ -403,7 +403,7 @@ def count_unsqueeze_before(dim):
403403
if index_type == Type.boolean or not index_type.is_integral():
404404
raise ValueError(
405405
"OpenVINO backend does not "
406-
"support {index_type} indexing"
406+
f"support {index_type} indexing"
407407
)
408408
axes.append(dim)
409409
if len(index_shape) > 1:
@@ -654,13 +654,20 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
654654
if dtype and dtype != x.dtype:
655655
x = cast(x, dtype)
656656
return x
657-
if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16":
658-
return ov.Tensor(np.asarray(x).astype(dtype))
659-
if dtype is None:
660-
dtype = result_type(
661-
*[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
657+
original_type = type(x)
658+
try:
659+
if dtype is None:
660+
dtype = getattr(x, "dtype", original_type)
661+
ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]
662+
else:
663+
ov_type = OPENVINO_DTYPES[dtype]
664+
x = np.array(x)
665+
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))
666+
except Exception as e:
667+
raise TypeError(
668+
f"Cannot convert object of type {original_type} "
669+
f"to OpenVINOKerasTensor: {e}"
662670
)
663-
return ov.Tensor(np.array(x, dtype=dtype))
664671

665672

666673
def convert_to_numpy(x):

0 commit comments

Comments
 (0)