@@ -1340,7 +1340,49 @@ def sort(x, axis=-1):
13401340
13411341
13421342def split (x , indices_or_sections , axis = 0 ):
1343- raise NotImplementedError ("`split` is not supported with openvino backend" )
1343+ x = get_ov_output (x )
1344+ axis_tensor = ov_opset .constant (axis , dtype = Type .i32 ).output (0 )
1345+
1346+ shape_tensor = ov_opset .shape_of (x )
1347+ axis_i32 = ov_opset .constant ([axis ], dtype = Type .i32 )
1348+ dim_at_axis_tensor = ov_opset .gather (
1349+ shape_tensor , axis_i32 , ov_opset .constant (0 , dtype = Type .i32 )
1350+ )
1351+
1352+ if isinstance (indices_or_sections , int ):
1353+ num_splits = indices_or_sections
1354+ splits = ov_opset .split (x , axis_tensor , num_splits = num_splits )
1355+ result = []
1356+ for i in range (num_splits ):
1357+ result .append (OpenVINOKerasTensor (splits .output (i )))
1358+ return result
1359+
1360+ if isinstance (indices_or_sections , (list , tuple , np .ndarray )):
1361+ indices = list (indices_or_sections )
1362+ split_lengths = []
1363+ split_lengths .append (indices [0 ])
1364+ for i in range (1 , len (indices )):
1365+ split_lengths .append (indices [i ] - indices [i - 1 ])
1366+
1367+ last_index_tensor = ov_opset .constant (indices [- 1 ], dtype = Type .i64 )
1368+ remaining_length_tensor = ov_opset .subtract (
1369+ dim_at_axis_tensor , last_index_tensor
1370+ )
1371+
1372+ length_parts = []
1373+ length_parts .append (ov_opset .constant (split_lengths , dtype = Type .i64 ))
1374+ length_parts .append (remaining_length_tensor )
1375+ length_tensor = ov_opset .concat (length_parts , axis = 0 )
1376+
1377+ splits = ov_opset .variadic_split (x , axis_tensor , length_tensor )
1378+ result = []
1379+ for i in range (len (split_lengths ) + 1 ):
1380+ result .append (OpenVINOKerasTensor (splits .output (i )))
1381+ return result
1382+
1383+ raise TypeError (
1384+ f"unsupported type of indices_or_sections: { type (indices_or_sections )} "
1385+ )
13441386
13451387
13461388def stack (x , axis = 0 ):
0 commit comments