@@ -1458,10 +1458,75 @@ def take(x, indices, axis=None):
1458
1458
1459
1459
1460
1460
def take_along_axis (x , indices , axis = None ):
1461
- raise NotImplementedError (
1462
- "`take_along_axis` is not supported with openvino backend"
1461
+ x = get_ov_output (x )
1462
+ indices = get_ov_output (indices )
1463
+
1464
+ if axis is None :
1465
+ target_shape = ov_opset .constant ([- 1 ], dtype = Type .i32 ).output (0 )
1466
+ x_flat = ov_opset .reshape (x , target_shape , False ).output (0 )
1467
+ indices_flat = ov_opset .reshape (indices , target_shape , False ).output (0 )
1468
+ result = ov_opset .gather_elements (x_flat , indices_flat , 0 ).output (0 )
1469
+ return OpenVINOKerasTensor (result )
1470
+
1471
+ x_rank = len (x .get_partial_shape ())
1472
+ if axis < 0 :
1473
+ axis += x_rank
1474
+
1475
+ x_shape = ov_opset .shape_of (x , Type .i32 ).output (0 )
1476
+ indices_shape = ov_opset .shape_of (indices , Type .i32 ).output (0 )
1477
+
1478
+ zero_const = ov_opset .constant (0 , dtype = Type .i32 ).output (0 )
1479
+ axis_index = ov_opset .constant ([axis ], dtype = Type .i32 ).output (0 )
1480
+
1481
+ # Fix negative indices
1482
+ dim_size = ov_opset .squeeze (
1483
+ ov_opset .gather (x_shape , axis_index , zero_const ).output (0 ), zero_const
1484
+ ).output (0 )
1485
+ zero_scalar = ov_opset .constant (0 , indices .get_element_type ()).output (0 )
1486
+ is_neg = ov_opset .less (indices , zero_scalar ).output (0 )
1487
+ dim_size_cast = ov_opset .convert (
1488
+ dim_size , indices .get_element_type ()
1489
+ ).output (0 )
1490
+ indices = ov_opset .select (
1491
+ is_neg , ov_opset .add (indices , dim_size_cast ).output (0 ), indices
1492
+ ).output (0 )
1493
+ indices = ov_opset .convert (indices , Type .i32 ).output (0 )
1494
+
1495
+ x_target_parts , indices_target_parts = [], []
1496
+
1497
+ for i in range (x_rank ):
1498
+ dim_idx = ov_opset .constant ([i ], dtype = Type .i32 ).output (0 )
1499
+ x_dim = ov_opset .gather (x_shape , dim_idx , zero_const ).output (0 )
1500
+ indices_dim = ov_opset .gather (
1501
+ indices_shape , dim_idx , zero_const
1502
+ ).output (0 )
1503
+
1504
+ if i == axis :
1505
+ # For axis dimension: keep original dimensions
1506
+ x_target_parts .append (x_dim )
1507
+ indices_target_parts .append (indices_dim )
1508
+ else :
1509
+ # For other dimensions: use maximum for broadcasting
1510
+ max_dim = ov_opset .maximum (x_dim , indices_dim ).output (0 )
1511
+ x_target_parts .append (max_dim )
1512
+ indices_target_parts .append (max_dim )
1513
+
1514
+ x_target_shape = ov_opset .concat (x_target_parts , axis = 0 ).output (0 )
1515
+ indices_target_shape = ov_opset .concat (indices_target_parts , axis = 0 ).output (
1516
+ 0
1463
1517
)
1464
1518
1519
+ # Broadcast to target shapes and gather elements
1520
+ x_broadcasted = ov_opset .broadcast (x , x_target_shape ).output (0 )
1521
+ indices_broadcasted = ov_opset .broadcast (
1522
+ indices , indices_target_shape
1523
+ ).output (0 )
1524
+ result = ov_opset .gather_elements (
1525
+ x_broadcasted , indices_broadcasted , axis
1526
+ ).output (0 )
1527
+
1528
+ return OpenVINOKerasTensor (result )
1529
+
1465
1530
1466
1531
def tan (x ):
1467
1532
x = get_ov_output (x )
0 commit comments