@@ -849,9 +849,168 @@ def prepare_slice_index(val):
849
849
850
850
851
851
def slice_update (inputs , start_indices , updates ):
852
- raise NotImplementedError (
853
- "`slice_update` is not supported with openvino backend"
854
- )
852
+ inputs = get_ov_output (inputs )
853
+ updates_tensor = get_ov_output (updates )
854
+
855
+ if isinstance (start_indices , (list , np .ndarray )):
856
+ start_indices = tuple (start_indices )
857
+ if not isinstance (start_indices , tuple ):
858
+ raise ValueError (
859
+ "`slice_update` is not supported by openvino backend"
860
+ " for `start_indices` of type {}" .format (type (start_indices ))
861
+ )
862
+
863
+ zero_scalar = ov_opset .constant (0 , Type .i32 )
864
+ one_scalar = ov_opset .constant (1 , Type .i32 )
865
+ zero_tensor = ov_opset .constant ([0 ], Type .i32 )
866
+ one_tensor = ov_opset .constant ([1 ], Type .i32 )
867
+
868
+ processed_start_indices = []
869
+ for idx in start_indices :
870
+ val = get_ov_output (idx )
871
+ if not val .get_element_type ().is_integral ():
872
+ raise ValueError ("`slice_update` requires integral start_indices" )
873
+ if val .get_element_type () != Type .i32 :
874
+ val = ov_opset .convert (val , Type .i32 ).output (0 )
875
+ if val .get_partial_shape ().rank .get_length () == 0 :
876
+ val = ov_opset .unsqueeze (val , zero_scalar ).output (0 )
877
+ processed_start_indices .append (val )
878
+
879
+ updates_shape = ov_opset .shape_of (updates_tensor , Type .i32 ).output (0 )
880
+ rank = updates_tensor .get_partial_shape ().rank .get_length ()
881
+ if rank == 0 :
882
+ # Handle scalar update
883
+ start_tensor = ov_opset .concat (processed_start_indices , axis = 0 ).output (
884
+ 0
885
+ )
886
+ # For scatter_nd_update,
887
+ # indices should be of shape [num_updates, rank_of_inputs]
888
+ # and updates should be of shape [num_updates]. Here num_updates is 1.
889
+ absolute_indices = ov_opset .unsqueeze (start_tensor , zero_scalar ).output (
890
+ 0
891
+ )
892
+ updates_flat = ov_opset .unsqueeze (updates_tensor , zero_scalar ).output (0 )
893
+ result = ov_opset .scatter_nd_update (
894
+ inputs , absolute_indices , updates_flat
895
+ ).output (0 )
896
+ return OpenVINOKerasTensor (result )
897
+
898
+ # Compute the total number of elements in the updates tensor.
899
+ # Example:
900
+ # if updates.shape = [2, 3], total_elements = 6.
901
+ total_elements = ov_opset .reduce_prod (
902
+ updates_shape , zero_tensor , keep_dims = False
903
+ ).output (0 )
904
+
905
+ # Generate a flat range [0, 1, ..., total_elements-1].
906
+ # This will be used to enumerate all positions in the updates tensor.
907
+ flat_indices = ov_opset .range (
908
+ zero_scalar , total_elements , one_scalar , output_type = Type .i32
909
+ ).output (0 )
910
+
911
+ dim_sizes = []
912
+ strides = []
913
+
914
+ # For each dimension, compute its size and the stride.
915
+ # (number of elements to skip to move to the next index in this dimension).
916
+ # Example:
917
+ # for shape [2, 3], strides = [3, 1].
918
+ for dim in range (rank ):
919
+ dim_size = ov_opset .gather (
920
+ updates_shape , ov_opset .constant ([dim ], Type .i32 ), zero_scalar
921
+ ).output (0 )
922
+ dim_size_scalar = ov_opset .squeeze (dim_size , zero_tensor ).output (0 )
923
+ dim_sizes .append (dim_size_scalar )
924
+
925
+ # Strides to convert a flat index into a multi-dimensional index.
926
+ # This allows us to map each element in the flattened updates tensor
927
+ # to its correct N-dimensional position, so we can compute the absolute
928
+ # index in the input tensor for the scatter update.
929
+ # Stride for a dimension is the product of all dimensions after it.
930
+ # For the last dimension, stride is 1.
931
+ # Example:
932
+ # For a 3D tensor with shape [2, 3, 4]:
933
+ # - stride for dim=0 (first axis) is 3*4=12
934
+ # (to move to the next "block" along axis 0)
935
+ # - stride for dim=1 is 4 (to move to the next row along axis 1)
936
+ # - stride for dim=2 is 1 (to move to the next element along axis 2)
937
+ # This is equivalent to how numpy flattens multi-dimensional arrays.
938
+ if dim < rank - 1 :
939
+ remaining_dims = ov_opset .slice (
940
+ updates_shape ,
941
+ ov_opset .constant ([dim + 1 ], Type .i32 ),
942
+ ov_opset .constant ([rank ], Type .i32 ),
943
+ one_tensor ,
944
+ zero_tensor ,
945
+ ).output (0 )
946
+ stride = ov_opset .reduce_prod (
947
+ remaining_dims , zero_tensor , keep_dims = False
948
+ ).output (0 )
949
+ else :
950
+ stride = one_scalar
951
+ strides .append (stride )
952
+
953
+ coord_tensors = []
954
+ # For each dimension, compute the coordinate for every flat index.
955
+ # Example:
956
+ # for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).
957
+ for dim in range (rank ):
958
+ coords = ov_opset .mod (
959
+ ov_opset .divide (flat_indices , strides [dim ]).output (0 ),
960
+ dim_sizes [dim ],
961
+ ).output (0 )
962
+ coord_tensors .append (coords )
963
+
964
+ coord_tensors_unsqueezed = []
965
+ for coord in coord_tensors :
966
+ # Unsqueeze to make each coordinate a column vector for concatenation.
967
+ coord_unsqueezed = ov_opset .unsqueeze (coord , one_tensor ).output (0 )
968
+ coord_tensors_unsqueezed .append (coord_unsqueezed )
969
+
970
+ # Concatenate all coordinate columns to form [total_elements, rank] matrix.
971
+ # Each row is a multi-dimensional index into the updates tensor.
972
+ # Example:
973
+ # for shape [2, 3], row 4 = [1, 1].
974
+ indices_matrix = ov_opset .concat (coord_tensors_unsqueezed , axis = 1 ).output (0 )
975
+
976
+ # Broadcast start indices to match the number of updates.
977
+ # Example:
978
+ # start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],
979
+ # start_broadcast = [[2,3],[2,3],...]
980
+ start_tensor = ov_opset .concat (processed_start_indices , axis = 0 ).output (0 )
981
+ start_reshaped = ov_opset .reshape (
982
+ start_tensor , ov_opset .constant ([1 , rank ], Type .i32 ), special_zero = False
983
+ ).output (0 )
984
+
985
+ broadcast_shape = ov_opset .concat (
986
+ [
987
+ ov_opset .unsqueeze (total_elements , zero_tensor ).output (0 ),
988
+ one_tensor ,
989
+ ],
990
+ axis = 0 ,
991
+ ).output (0 )
992
+
993
+ start_broadcast = ov_opset .tile (start_reshaped , broadcast_shape ).output (0 )
994
+
995
+ # Add the broadcasted start indices to the relative indices
996
+ # to get absolute indices in the input tensor.
997
+ # Example:
998
+ # if start=(2,3), update index [1,1] -> absolute index [3,4].
999
+ absolute_indices = ov_opset .add (indices_matrix , start_broadcast ).output (0 )
1000
+
1001
+ # Flatten the updates tensor to match the flat indices.
1002
+ updates_flat = ov_opset .reshape (
1003
+ updates_tensor ,
1004
+ ov_opset .unsqueeze (total_elements , zero_tensor ).output (0 ),
1005
+ special_zero = False ,
1006
+ ).output (0 )
1007
+
1008
+ # Perform the scatter update: for each absolute index,
1009
+ # set the corresponding value from updates_flat.
1010
+ result = ov_opset .scatter_nd_update (
1011
+ inputs , absolute_indices , updates_flat
1012
+ ).output (0 )
1013
+ return OpenVINOKerasTensor (result )
855
1014
856
1015
857
1016
def while_loop (
0 commit comments