Skip to content

Commit 0c6c363

Browse files
[OpenVINO backend] support slice_update (#21549)
* [OpenVINO backend] support slice_update * update core.py
1 parent fa499b4 commit 0c6c363

File tree

2 files changed

+162
-6
lines changed

2 files changed

+162
-6
lines changed

keras/src/backend/openvino/core.py

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,168 @@ def prepare_slice_index(val):
849849

850850

851851
def slice_update(inputs, start_indices, updates):
852-
raise NotImplementedError(
853-
"`slice_update` is not supported with openvino backend"
854-
)
852+
inputs = get_ov_output(inputs)
853+
updates_tensor = get_ov_output(updates)
854+
855+
if isinstance(start_indices, (list, np.ndarray)):
856+
start_indices = tuple(start_indices)
857+
if not isinstance(start_indices, tuple):
858+
raise ValueError(
859+
"`slice_update` is not supported by openvino backend"
860+
" for `start_indices` of type {}".format(type(start_indices))
861+
)
862+
863+
zero_scalar = ov_opset.constant(0, Type.i32)
864+
one_scalar = ov_opset.constant(1, Type.i32)
865+
zero_tensor = ov_opset.constant([0], Type.i32)
866+
one_tensor = ov_opset.constant([1], Type.i32)
867+
868+
processed_start_indices = []
869+
for idx in start_indices:
870+
val = get_ov_output(idx)
871+
if not val.get_element_type().is_integral():
872+
raise ValueError("`slice_update` requires integral start_indices")
873+
if val.get_element_type() != Type.i32:
874+
val = ov_opset.convert(val, Type.i32).output(0)
875+
if val.get_partial_shape().rank.get_length() == 0:
876+
val = ov_opset.unsqueeze(val, zero_scalar).output(0)
877+
processed_start_indices.append(val)
878+
879+
updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)
880+
rank = updates_tensor.get_partial_shape().rank.get_length()
881+
if rank == 0:
882+
# Handle scalar update
883+
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(
884+
0
885+
)
886+
# For scatter_nd_update,
887+
# indices should be of shape [num_updates, rank_of_inputs]
888+
# and updates should be of shape [num_updates]. Here num_updates is 1.
889+
absolute_indices = ov_opset.unsqueeze(start_tensor, zero_scalar).output(
890+
0
891+
)
892+
updates_flat = ov_opset.unsqueeze(updates_tensor, zero_scalar).output(0)
893+
result = ov_opset.scatter_nd_update(
894+
inputs, absolute_indices, updates_flat
895+
).output(0)
896+
return OpenVINOKerasTensor(result)
897+
898+
# Compute the total number of elements in the updates tensor.
899+
# Example:
900+
# if updates.shape = [2, 3], total_elements = 6.
901+
total_elements = ov_opset.reduce_prod(
902+
updates_shape, zero_tensor, keep_dims=False
903+
).output(0)
904+
905+
# Generate a flat range [0, 1, ..., total_elements-1].
906+
# This will be used to enumerate all positions in the updates tensor.
907+
flat_indices = ov_opset.range(
908+
zero_scalar, total_elements, one_scalar, output_type=Type.i32
909+
).output(0)
910+
911+
dim_sizes = []
912+
strides = []
913+
914+
# For each dimension, compute its size and the stride.
915+
# (number of elements to skip to move to the next index in this dimension).
916+
# Example:
917+
# for shape [2, 3], strides = [3, 1].
918+
for dim in range(rank):
919+
dim_size = ov_opset.gather(
920+
updates_shape, ov_opset.constant([dim], Type.i32), zero_scalar
921+
).output(0)
922+
dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0)
923+
dim_sizes.append(dim_size_scalar)
924+
925+
# Strides to convert a flat index into a multi-dimensional index.
926+
# This allows us to map each element in the flattened updates tensor
927+
# to its correct N-dimensional position, so we can compute the absolute
928+
# index in the input tensor for the scatter update.
929+
# Stride for a dimension is the product of all dimensions after it.
930+
# For the last dimension, stride is 1.
931+
# Example:
932+
# For a 3D tensor with shape [2, 3, 4]:
933+
# - stride for dim=0 (first axis) is 3*4=12
934+
# (to move to the next "block" along axis 0)
935+
# - stride for dim=1 is 4 (to move to the next row along axis 1)
936+
# - stride for dim=2 is 1 (to move to the next element along axis 2)
937+
# This is equivalent to how numpy flattens multi-dimensional arrays.
938+
if dim < rank - 1:
939+
remaining_dims = ov_opset.slice(
940+
updates_shape,
941+
ov_opset.constant([dim + 1], Type.i32),
942+
ov_opset.constant([rank], Type.i32),
943+
one_tensor,
944+
zero_tensor,
945+
).output(0)
946+
stride = ov_opset.reduce_prod(
947+
remaining_dims, zero_tensor, keep_dims=False
948+
).output(0)
949+
else:
950+
stride = one_scalar
951+
strides.append(stride)
952+
953+
coord_tensors = []
954+
# For each dimension, compute the coordinate for every flat index.
955+
# Example:
956+
# for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
957+
for dim in range(rank):
958+
coords = ov_opset.mod(
959+
ov_opset.divide(flat_indices, strides[dim]).output(0),
960+
dim_sizes[dim],
961+
).output(0)
962+
coord_tensors.append(coords)
963+
964+
coord_tensors_unsqueezed = []
965+
for coord in coord_tensors:
966+
# Unsqueeze to make each coordinate a column vector for concatenation.
967+
coord_unsqueezed = ov_opset.unsqueeze(coord, one_tensor).output(0)
968+
coord_tensors_unsqueezed.append(coord_unsqueezed)
969+
970+
# Concatenate all coordinate columns to form [total_elements, rank] matrix.
971+
# Each row is a multi-dimensional index into the updates tensor.
972+
# Example:
973+
# for shape [2, 3], row 4 = [1, 1].
974+
indices_matrix = ov_opset.concat(coord_tensors_unsqueezed, axis=1).output(0)
975+
976+
# Broadcast start indices to match the number of updates.
977+
# Example:
978+
# start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
979+
# start_broadcast = [[2,3],[2,3],...]
980+
start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0)
981+
start_reshaped = ov_opset.reshape(
982+
start_tensor, ov_opset.constant([1, rank], Type.i32), special_zero=False
983+
).output(0)
984+
985+
broadcast_shape = ov_opset.concat(
986+
[
987+
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
988+
one_tensor,
989+
],
990+
axis=0,
991+
).output(0)
992+
993+
start_broadcast = ov_opset.tile(start_reshaped, broadcast_shape).output(0)
994+
995+
# Add the broadcasted start indices to the relative indices
996+
# to get absolute indices in the input tensor.
997+
# Example:
998+
# if start=(2,3), update index [1,1] -> absolute index [3,4].
999+
absolute_indices = ov_opset.add(indices_matrix, start_broadcast).output(0)
1000+
1001+
# Flatten the updates tensor to match the flat indices.
1002+
updates_flat = ov_opset.reshape(
1003+
updates_tensor,
1004+
ov_opset.unsqueeze(total_elements, zero_tensor).output(0),
1005+
special_zero=False,
1006+
).output(0)
1007+
1008+
# Perform the scatter update: for each absolute index,
1009+
# set the corresponding value from updates_flat.
1010+
result = ov_opset.scatter_nd_update(
1011+
inputs, absolute_indices, updates_flat
1012+
).output(0)
1013+
return OpenVINOKerasTensor(result)
8551014

8561015

8571016
def while_loop(

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,14 @@ CoreOpsCallsTests::test_map_basic_call
175175
CoreOpsCallsTests::test_scan_basic_call
176176
CoreOpsCallsTests::test_scatter_basic_call
177177
CoreOpsCallsTests::test_scatter_update_basic_call
178-
CoreOpsCallsTests::test_slice_update_basic_call
179178
CoreOpsCallsTests::test_switch_basic_call
180179
CoreOpsCallsTests::test_unstack_basic_functionality
181180
CoreOpsCorrectnessTest::test_associative_scan
182181
CoreOpsCorrectnessTest::test_cond
183-
CoreOpsCorrectnessTest::test_dynamic_slice
184182
CoreOpsCorrectnessTest::test_fori_loop
185183
CoreOpsCorrectnessTest::test_map
186184
CoreOpsCorrectnessTest::test_scan
187185
CoreOpsCorrectnessTest::test_scatter
188-
CoreOpsCorrectnessTest::test_slice_update
189186
CoreOpsCorrectnessTest::test_switch
190187
CoreOpsCorrectnessTest::test_unstack
191188
CoreOpsCorrectnessTest::test_vectorized_map

0 commit comments

Comments
 (0)