@@ -502,22 +502,40 @@ def index_put_converter(
502
502
K = len (I )
503
503
# Determine the maximum size 'N' among the index tensors
504
504
if K > 0 :
505
- index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
505
+ index_shapes = (
506
+ []
507
+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
508
+ for idx_tensor in indices :
509
+ if idx_tensor is not None :
510
+ if idx_tensor .shape [0 ] != DYNAMIC_DIM :
511
+ index_shapes .append (idx_tensor .shape [0 ])
512
+ else :
513
+ index_shapes .append (
514
+ get_shape (
515
+ ctx ,
516
+ target ,
517
+ source_ir ,
518
+ name + "idx_shape_dim_0" ,
519
+ idx_tensor ,
520
+ 0 ,
521
+ )
522
+ )
506
523
N = max (index_shapes ) if index_shapes else 1
507
524
else :
508
525
N = 1
509
526
510
527
# Compute shapes and volume for the free dimensions
511
528
F_shapes = [input_tensor .shape [i ] for i in F ]
529
+ assert - 1 not in F_shapes , "Dynamic shape in free dimensions is not supported"
512
530
F_volume = trt .volume (F_shapes ) if F_shapes else 1
513
531
514
532
# Process indexed dimensions (I)
515
533
I_tensors = []
516
534
for i in I :
517
535
idx = indices [i ]
518
536
assert idx is not None
519
- idx_reshaped = impl .shuffle . reshape (
520
- ctx , target , source_ir , f"{ name } _reshape_idx_I_ { i } " , idx , ( idx . shape [ 0 ], 1 )
537
+ idx_reshaped = impl .unsqueeze . unsqueeze (
538
+ ctx , target , source_ir , f"{ name } _unsqueeze_idx_I_ { i } " , idx , 1
521
539
)
522
540
expanded_idx = impl .slice .expand (
523
541
ctx ,
@@ -539,46 +557,50 @@ def index_put_converter(
539
557
)
540
558
arange_tensors .append (arange_tensor )
541
559
542
- meshgrid_tensors = []
543
- for i , arange in enumerate (arange_tensors ):
544
- reshape_shape = [1 ] * len (F )
545
- reshape_shape [i ] = F_shapes [i ]
546
- arange_reshaped = impl .shuffle .reshape (
547
- ctx ,
548
- target ,
549
- source_ir ,
550
- f"{ name } _reshape_arange_F_{ F [i ]} " ,
551
- arange ,
552
- tuple (reshape_shape ),
553
- )
554
- expanded_arange = impl .slice .expand (
555
- ctx ,
556
- target ,
557
- source_ir ,
558
- f"{ name } _expand_arange_F_{ F [i ]} " ,
559
- arange_reshaped ,
560
- tuple (F_shapes ),
561
- )
562
- meshgrid_tensors .append (expanded_arange )
563
-
564
- meshgrid_stacked = impl .cat .cat (
565
- ctx ,
566
- target ,
567
- source_ir ,
568
- f"{ name } _stack_meshgrid" ,
569
- [
570
- impl .shuffle .reshape (
560
+ if len (arange_tensors ) == 1 :
561
+ # No need to stack
562
+ meshgrid_stacked = arange_tensors [0 ]
563
+ else :
564
+ meshgrid_tensors = []
565
+ for i , arange in enumerate (arange_tensors ):
566
+ reshape_shape = [1 ] * len (F )
567
+ reshape_shape [i ] = F_shapes [i ]
568
+ arange_reshaped = impl .shuffle .reshape (
571
569
ctx ,
572
570
target ,
573
571
source_ir ,
574
- f"{ name } _reshape_mesh_ { i } " ,
575
- t ,
576
- ( * F_shapes , 1 ),
572
+ f"{ name } _reshape_arange_F_ { F [ i ] } " ,
573
+ arange ,
574
+ tuple ( reshape_shape ),
577
575
)
578
- for i , t in enumerate (meshgrid_tensors )
579
- ],
580
- dim = - 1 ,
581
- )
576
+ expanded_arange = impl .slice .expand (
577
+ ctx ,
578
+ target ,
579
+ source_ir ,
580
+ f"{ name } _expand_arange_F_{ F [i ]} " ,
581
+ arange_reshaped ,
582
+ tuple (F_shapes ),
583
+ )
584
+ meshgrid_tensors .append (expanded_arange )
585
+
586
+ meshgrid_stacked = impl .cat .cat (
587
+ ctx ,
588
+ target ,
589
+ source_ir ,
590
+ f"{ name } _stack_meshgrid" ,
591
+ [
592
+ impl .shuffle .reshape (
593
+ ctx ,
594
+ target ,
595
+ source_ir ,
596
+ f"{ name } _reshape_mesh_{ i } " ,
597
+ t ,
598
+ (* F_shapes , 1 ),
599
+ )
600
+ for i , t in enumerate (meshgrid_tensors )
601
+ ],
602
+ dim = - 1 ,
603
+ )
582
604
meshgrid_reshaped = impl .shuffle .reshape (
583
605
ctx ,
584
606
target ,
@@ -603,21 +625,15 @@ def index_put_converter(
603
625
604
626
# Combine all indexed dimensions (I)
605
627
if K > 0 :
606
- I_combined = impl .cat .cat (
607
- ctx ,
608
- target ,
609
- source_ir ,
610
- f"{ name } _cat_I" ,
611
- [
612
- impl .shuffle .reshape (
613
- ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
614
- )
615
- for i , t in enumerate (I_tensors )
616
- ],
617
- dim = 2 ,
618
- )
628
+
629
+ I_combined = [
630
+ impl .shuffle .reshape (
631
+ ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
632
+ )
633
+ for i , t in enumerate (I_tensors )
634
+ ]
619
635
else :
620
- I_combined = None
636
+ I_combined = []
621
637
622
638
# Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
623
639
ii_list = []
@@ -626,24 +642,12 @@ def index_put_converter(
626
642
for dim in range (rank ):
627
643
unique_suffix = f"{ dim } _{ i_idx if dim in I else f_idx } "
628
644
if dim in I :
629
- start = [0 , 0 , i_idx ]
630
- shape = [N , F_volume , 1 ]
631
- stride = [1 , 1 , 1 ]
632
- idx_tensor = impl .slice .slice (
633
- ctx ,
634
- target ,
635
- source_ir ,
636
- f"{ name } _slice_I_dim_{ unique_suffix } " ,
637
- I_combined ,
638
- start ,
639
- shape ,
640
- stride ,
641
- )
645
+ idx_tensor = I_combined [i ]
642
646
ii_list .append (idx_tensor )
643
647
i_idx += 1
644
648
else :
645
649
start = [0 , 0 , f_idx ]
646
- shape = [N , F_volume , 1 ]
650
+ shape = [- 1 , F_volume , 1 ] if isinstance ( N , TRTTensor ) else [ N , F_volume , 1 ]
647
651
stride = [1 , 1 , 1 ]
648
652
mesh_tensor = impl .slice .slice (
649
653
ctx ,
@@ -662,20 +666,24 @@ def index_put_converter(
662
666
indices_cat = impl .cat .cat (
663
667
ctx , target , source_ir , f"{ name } _cat_indices" , ii_list , dim = 2
664
668
)
669
+
670
+ # Flatten the indices_cat to (N * F_volume, rank)
665
671
indices_cat = impl .shuffle .reshape (
666
672
ctx ,
667
673
target ,
668
674
source_ir ,
669
675
f"{ name } _reshape_indices_cat" ,
670
676
indices_cat ,
671
- (N * F_volume , rank ),
677
+ (- 1 , rank ),
672
678
)
673
679
674
680
if not isinstance (values , TRTTensor ):
675
681
values = get_trt_tensor (ctx , values , f"{ name } _values" , min_rank = 0 )
676
682
677
683
# Define the expected shape based on (N,) + F_shapes
678
- expected_shape = (N ,) + tuple (F_shapes )
684
+ expected_shape = (
685
+ (- 1 ,) + tuple (F_shapes ) if isinstance (N , TRTTensor ) else (N ,) + tuple (F_shapes )
686
+ )
679
687
680
688
# Broadcast 'values' to match the expected shape
681
689
if len (values .shape ) == 0 or values .shape == (1 ,): # Scalar case
@@ -773,16 +781,51 @@ def index_put_converter(
773
781
source_ir ,
774
782
f"{ name } _flatten_values" ,
775
783
values_expanded ,
776
- (N * F_volume ,),
784
+ (- 1 ,),
777
785
)
778
-
779
786
indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
780
- # Perform Scatter ND operation
781
- scatter_layer = ctx .net .add_scatter (
782
- input_tensor ,
783
- indices_cat ,
784
- flattened_values ,
785
- trt .ScatterMode .ND if not accumulate else trt .ScatterMode .ND_ELEMENTWISE_ADD ,
786
- )
787
- set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
788
- return scatter_layer .get_output (0 )
787
+ if accumulate :
788
+ zero_tensor = impl .full .full (
789
+ ctx ,
790
+ target ,
791
+ source_ir ,
792
+ f"{ name } _zero_tensor" ,
793
+ [
794
+ get_shape (
795
+ ctx ,
796
+ target ,
797
+ source_ir ,
798
+ name + f"input_tensor_shape_dim_{ i } " ,
799
+ input_tensor ,
800
+ i ,
801
+ )
802
+ for i in range (len (input_tensor .shape ))
803
+ ],
804
+ 0.0 ,
805
+ dtype = input_tensor .dtype ,
806
+ )
807
+ # Perform Scatter ND operation
808
+ scatter_layer = ctx .net .add_scatter (
809
+ zero_tensor ,
810
+ indices_cat ,
811
+ flattened_values ,
812
+ trt .ScatterMode .ND ,
813
+ )
814
+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
815
+
816
+ scatter_out = scatter_layer .get_output (0 )
817
+ result = impl .elementwise .add (
818
+ ctx , target , source_ir , f"{ name } _add" , scatter_out , input_tensor
819
+ )
820
+ return result
821
+
822
+ else :
823
+ scatter_layer = ctx .net .add_scatter (
824
+ input_tensor ,
825
+ indices_cat ,
826
+ flattened_values ,
827
+ trt .ScatterMode .ND ,
828
+ )
829
+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
830
+ scatter_out = scatter_layer .get_output (0 )
831
+ return scatter_out
0 commit comments