@@ -625,31 +625,55 @@ def scatter_update(inputs, indices, updates):
625
625
626
626
def slice (inputs , start_indices , shape ):
627
627
inputs = get_ov_output (inputs )
628
+ if isinstance (start_indices , (list , np .ndarray )):
629
+ start_indices = tuple (start_indices )
630
+ if isinstance (shape , (list , np .ndarray )):
631
+ shape = tuple (shape )
628
632
assert isinstance (start_indices , tuple ), (
629
633
"`slice` is not supported by openvino backend"
630
- " for `start_indices` of type {}" .format (type (shape ))
634
+ " for `start_indices` of type {}" .format (type (start_indices ))
631
635
)
632
636
assert isinstance (shape , tuple ), (
633
637
"`slice` is not supported by openvino backend"
634
- " for `lengths ` of type {}" .format (type (shape ))
638
+ " for `shape ` of type {}" .format (type (shape ))
635
639
)
636
640
637
641
axes = []
638
642
start = []
639
643
stop = []
644
+
645
+ def prepare_slice_index (val ):
646
+ val_type = val .get_element_type ()
647
+ if not val_type .is_integral ():
648
+ raise ValueError (
649
+ "`slice` is not supported by OpenVINO backend "
650
+ "for `start_indices` or `shape` with non-integer types"
651
+ )
652
+ if val_type != Type .i32 :
653
+ val = ov_opset .convert (val , Type .i32 ).output (0 )
654
+ if len (val .get_partial_shape ()) == 0 :
655
+ val = ov_opset .unsqueeze (
656
+ val , ov_opset .constant (0 , Type .i32 )
657
+ ).output (0 )
658
+ return val
659
+
640
660
for idx , length in enumerate (shape ):
641
661
if length is not None and length >= 0 :
642
662
axes .append (idx )
643
- start .append (start_indices [idx ])
644
- stop .append (start_indices [idx ] + length )
663
+ start_val = prepare_slice_index (get_ov_output (start_indices [idx ]))
664
+ stop_val = prepare_slice_index (
665
+ get_ov_output (start_indices [idx ] + length )
666
+ )
667
+ start .append (start_val )
668
+ stop .append (stop_val )
645
669
646
670
if len (axes ) == 0 :
647
671
return inputs
648
672
649
673
step = [1 ] * len (start )
650
674
step = ov_opset .constant (step , Type .i32 ).output (0 )
651
- start = ov_opset .constant (start , Type . i32 ).output (0 )
652
- stop = ov_opset .constant (stop , Type . i32 ).output (0 )
675
+ start = ov_opset .concat (start , axis = 0 ).output (0 )
676
+ stop = ov_opset .concat (stop , axis = 0 ).output (0 )
653
677
axes = ov_opset .constant (axes , Type .i32 ).output (0 )
654
678
return OpenVINOKerasTensor (
655
679
ov_opset .slice (inputs , start , stop , step , axes ).output (0 )
0 commit comments