Skip to content

Commit f379556

Browse files
[OpenVINO Backend] update slice method (#21361)
* [OpenVINO Backend] update slice * [OpenVINO Backend] update slice
1 parent 1d0358f commit f379556

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

keras/src/backend/openvino/core.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -625,31 +625,55 @@ def scatter_update(inputs, indices, updates):
625625

626626
def slice(inputs, start_indices, shape):
627627
inputs = get_ov_output(inputs)
628+
if isinstance(start_indices, (list, np.ndarray)):
629+
start_indices = tuple(start_indices)
630+
if isinstance(shape, (list, np.ndarray)):
631+
shape = tuple(shape)
628632
assert isinstance(start_indices, tuple), (
629633
"`slice` is not supported by openvino backend"
630-
" for `start_indices` of type {}".format(type(shape))
634+
" for `start_indices` of type {}".format(type(start_indices))
631635
)
632636
assert isinstance(shape, tuple), (
633637
"`slice` is not supported by openvino backend"
634-
" for `lengths` of type {}".format(type(shape))
638+
" for `shape` of type {}".format(type(shape))
635639
)
636640

637641
axes = []
638642
start = []
639643
stop = []
644+
645+
def prepare_slice_index(val):
646+
val_type = val.get_element_type()
647+
if not val_type.is_integral():
648+
raise ValueError(
649+
"`slice` is not supported by OpenVINO backend "
650+
"for `start_indices` or `shape` with non-integer types"
651+
)
652+
if val_type != Type.i32:
653+
val = ov_opset.convert(val, Type.i32).output(0)
654+
if len(val.get_partial_shape()) == 0:
655+
val = ov_opset.unsqueeze(
656+
val, ov_opset.constant(0, Type.i32)
657+
).output(0)
658+
return val
659+
640660
for idx, length in enumerate(shape):
641661
if length is not None and length >= 0:
642662
axes.append(idx)
643-
start.append(start_indices[idx])
644-
stop.append(start_indices[idx] + length)
663+
start_val = prepare_slice_index(get_ov_output(start_indices[idx]))
664+
stop_val = prepare_slice_index(
665+
get_ov_output(start_indices[idx] + length)
666+
)
667+
start.append(start_val)
668+
stop.append(stop_val)
645669

646670
if len(axes) == 0:
647671
return inputs
648672

649673
step = [1] * len(start)
650674
step = ov_opset.constant(step, Type.i32).output(0)
651-
start = ov_opset.constant(start, Type.i32).output(0)
652-
stop = ov_opset.constant(stop, Type.i32).output(0)
675+
start = ov_opset.concat(start, axis=0).output(0)
676+
stop = ov_opset.concat(stop, axis=0).output(0)
653677
axes = ov_opset.constant(axes, Type.i32).output(0)
654678
return OpenVINOKerasTensor(
655679
ov_opset.slice(inputs, start, stop, step, axes).output(0)

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,7 @@ CoreOpsCallsTests::test_map_basic_call
161161
CoreOpsCallsTests::test_scan_basic_call
162162
CoreOpsCallsTests::test_scatter_basic_call
163163
CoreOpsCallsTests::test_scatter_update_basic_call
164-
CoreOpsCallsTests::test_slice_basic_call
165164
CoreOpsCallsTests::test_slice_update_basic_call
166-
CoreOpsCallsTests::test_slice_with_non_symbolic_tensors
167165
CoreOpsCallsTests::test_switch_basic_call
168166
CoreOpsCallsTests::test_unstack_basic_functionality
169167
CoreOpsCallsTests::test_while_loop_basic_functionality
@@ -175,7 +173,6 @@ CoreOpsCorrectnessTest::test_fori_loop
175173
CoreOpsCorrectnessTest::test_map
176174
CoreOpsCorrectnessTest::test_scan
177175
CoreOpsCorrectnessTest::test_scatter
178-
CoreOpsCorrectnessTest::test_slice
179176
CoreOpsCorrectnessTest::test_slice_update
180177
CoreOpsCorrectnessTest::test_switch
181178
CoreOpsCorrectnessTest::test_unstack

0 commit comments

Comments
 (0)