@@ -571,22 +571,40 @@ def index_put_converter(
571
571
K = len (I )
572
572
# Determine the maximum size 'N' among the index tensors
573
573
if K > 0 :
574
- index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
574
+ index_shapes = (
575
+ []
576
+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
577
+ for idx_tensor in indices :
578
+ if idx_tensor is not None :
579
+ if idx_tensor .shape [0 ] != DYNAMIC_DIM :
580
+ index_shapes .append (idx_tensor .shape [0 ])
581
+ else :
582
+ index_shapes .append (
583
+ get_shape (
584
+ ctx ,
585
+ target ,
586
+ source_ir ,
587
+ name + "idx_shape_dim_0" ,
588
+ idx_tensor ,
589
+ 0 ,
590
+ )
591
+ )
575
592
N = max (index_shapes ) if index_shapes else 1
576
593
else :
577
594
N = 1
578
595
579
596
# Compute shapes and volume for the free dimensions
580
597
F_shapes = [input_tensor .shape [i ] for i in F ]
598
+ assert - 1 not in F_shapes , "Dynamic shape in free dimensions is not supported"
581
599
F_volume = trt .volume (F_shapes ) if F_shapes else 1
582
600
583
601
# Process indexed dimensions (I)
584
602
I_tensors = []
585
603
for i in I :
586
604
idx = indices [i ]
587
605
assert idx is not None
588
- idx_reshaped = impl .shuffle . reshape (
589
- ctx , target , source_ir , f"{ name } _reshape_idx_I_ { i } " , idx , ( idx . shape [ 0 ], 1 )
606
+ idx_reshaped = impl .unsqueeze . unsqueeze (
607
+ ctx , target , source_ir , f"{ name } _unsqueeze_idx_I_ { i } " , idx , 1
590
608
)
591
609
expanded_idx = impl .slice .expand (
592
610
ctx ,
@@ -608,46 +626,50 @@ def index_put_converter(
608
626
)
609
627
arange_tensors .append (arange_tensor )
610
628
611
- meshgrid_tensors = []
612
- for i , arange in enumerate (arange_tensors ):
613
- reshape_shape = [1 ] * len (F )
614
- reshape_shape [i ] = F_shapes [i ]
615
- arange_reshaped = impl .shuffle .reshape (
616
- ctx ,
617
- target ,
618
- source_ir ,
619
- f"{ name } _reshape_arange_F_{ F [i ]} " ,
620
- arange ,
621
- tuple (reshape_shape ),
622
- )
623
- expanded_arange = impl .slice .expand (
624
- ctx ,
625
- target ,
626
- source_ir ,
627
- f"{ name } _expand_arange_F_{ F [i ]} " ,
628
- arange_reshaped ,
629
- tuple (F_shapes ),
630
- )
631
- meshgrid_tensors .append (expanded_arange )
632
-
633
- meshgrid_stacked = impl .cat .cat (
634
- ctx ,
635
- target ,
636
- source_ir ,
637
- f"{ name } _stack_meshgrid" ,
638
- [
639
- impl .shuffle .reshape (
629
+ if len (arange_tensors ) == 1 :
630
+ # No need to stack
631
+ meshgrid_stacked = arange_tensors [0 ]
632
+ else :
633
+ meshgrid_tensors = []
634
+ for i , arange in enumerate (arange_tensors ):
635
+ reshape_shape = [1 ] * len (F )
636
+ reshape_shape [i ] = F_shapes [i ]
637
+ arange_reshaped = impl .shuffle .reshape (
640
638
ctx ,
641
639
target ,
642
640
source_ir ,
643
- f"{ name } _reshape_mesh_ { i } " ,
644
- t ,
645
- ( * F_shapes , 1 ),
641
+ f"{ name } _reshape_arange_F_ { F [ i ] } " ,
642
+ arange ,
643
+ tuple ( reshape_shape ),
646
644
)
647
- for i , t in enumerate (meshgrid_tensors )
648
- ],
649
- dim = - 1 ,
650
- )
645
+ expanded_arange = impl .slice .expand (
646
+ ctx ,
647
+ target ,
648
+ source_ir ,
649
+ f"{ name } _expand_arange_F_{ F [i ]} " ,
650
+ arange_reshaped ,
651
+ tuple (F_shapes ),
652
+ )
653
+ meshgrid_tensors .append (expanded_arange )
654
+
655
+ meshgrid_stacked = impl .cat .cat (
656
+ ctx ,
657
+ target ,
658
+ source_ir ,
659
+ f"{ name } _stack_meshgrid" ,
660
+ [
661
+ impl .shuffle .reshape (
662
+ ctx ,
663
+ target ,
664
+ source_ir ,
665
+ f"{ name } _reshape_mesh_{ i } " ,
666
+ t ,
667
+ (* F_shapes , 1 ),
668
+ )
669
+ for i , t in enumerate (meshgrid_tensors )
670
+ ],
671
+ dim = - 1 ,
672
+ )
651
673
meshgrid_reshaped = impl .shuffle .reshape (
652
674
ctx ,
653
675
target ,
@@ -672,21 +694,15 @@ def index_put_converter(
672
694
673
695
# Combine all indexed dimensions (I)
674
696
if K > 0 :
675
- I_combined = impl .cat .cat (
676
- ctx ,
677
- target ,
678
- source_ir ,
679
- f"{ name } _cat_I" ,
680
- [
681
- impl .shuffle .reshape (
682
- ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
683
- )
684
- for i , t in enumerate (I_tensors )
685
- ],
686
- dim = 2 ,
687
- )
697
+
698
+ I_combined = [
699
+ impl .shuffle .reshape (
700
+ ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
701
+ )
702
+ for i , t in enumerate (I_tensors )
703
+ ]
688
704
else :
689
- I_combined = None
705
+ I_combined = []
690
706
691
707
# Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
692
708
ii_list = []
@@ -695,24 +711,12 @@ def index_put_converter(
695
711
for dim in range (rank ):
696
712
unique_suffix = f"{ dim } _{ i_idx if dim in I else f_idx } "
697
713
if dim in I :
698
- start = [0 , 0 , i_idx ]
699
- shape = [N , F_volume , 1 ]
700
- stride = [1 , 1 , 1 ]
701
- idx_tensor = impl .slice .slice (
702
- ctx ,
703
- target ,
704
- source_ir ,
705
- f"{ name } _slice_I_dim_{ unique_suffix } " ,
706
- I_combined ,
707
- start ,
708
- shape ,
709
- stride ,
710
- )
714
+ idx_tensor = I_combined [i ]
711
715
ii_list .append (idx_tensor )
712
716
i_idx += 1
713
717
else :
714
718
start = [0 , 0 , f_idx ]
715
- shape = [N , F_volume , 1 ]
719
+ shape = [- 1 , F_volume , 1 ] if isinstance ( N , TRTTensor ) else [ N , F_volume , 1 ]
716
720
stride = [1 , 1 , 1 ]
717
721
mesh_tensor = impl .slice .slice (
718
722
ctx ,
@@ -731,20 +735,24 @@ def index_put_converter(
731
735
indices_cat = impl .cat .cat (
732
736
ctx , target , source_ir , f"{ name } _cat_indices" , ii_list , dim = 2
733
737
)
738
+
739
+ # Flatten the indices_cat to (N * F_volume, rank)
734
740
indices_cat = impl .shuffle .reshape (
735
741
ctx ,
736
742
target ,
737
743
source_ir ,
738
744
f"{ name } _reshape_indices_cat" ,
739
745
indices_cat ,
740
- (N * F_volume , rank ),
746
+ (- 1 , rank ),
741
747
)
742
748
743
749
if not isinstance (values , TRTTensor ):
744
750
values = get_trt_tensor (ctx , values , f"{ name } _values" , min_rank = 0 )
745
751
746
752
# Define the expected shape based on (N,) + F_shapes
747
- expected_shape = (N ,) + tuple (F_shapes )
753
+ expected_shape = (
754
+ (- 1 ,) + tuple (F_shapes ) if isinstance (N , TRTTensor ) else (N ,) + tuple (F_shapes )
755
+ )
748
756
749
757
# Broadcast 'values' to match the expected shape
750
758
if len (values .shape ) == 0 or values .shape == (1 ,): # Scalar case
@@ -842,16 +850,51 @@ def index_put_converter(
842
850
source_ir ,
843
851
f"{ name } _flatten_values" ,
844
852
values_expanded ,
845
- (N * F_volume ,),
853
+ (- 1 ,),
846
854
)
847
-
848
855
indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
849
- # Perform Scatter ND operation
850
- scatter_layer = ctx .net .add_scatter (
851
- input_tensor ,
852
- indices_cat ,
853
- flattened_values ,
854
- trt .ScatterMode .ND if not accumulate else trt .ScatterMode .ND_ELEMENTWISE_ADD ,
855
- )
856
- set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
857
- return scatter_layer .get_output (0 )
856
+ if accumulate :
857
+ zero_tensor = impl .full .full (
858
+ ctx ,
859
+ target ,
860
+ source_ir ,
861
+ f"{ name } _zero_tensor" ,
862
+ [
863
+ get_shape (
864
+ ctx ,
865
+ target ,
866
+ source_ir ,
867
+ name + f"input_tensor_shape_dim_{ i } " ,
868
+ input_tensor ,
869
+ i ,
870
+ )
871
+ for i in range (len (input_tensor .shape ))
872
+ ],
873
+ 0.0 ,
874
+ dtype = input_tensor .dtype ,
875
+ )
876
+ # Perform Scatter ND operation
877
+ scatter_layer = ctx .net .add_scatter (
878
+ zero_tensor ,
879
+ indices_cat ,
880
+ flattened_values ,
881
+ trt .ScatterMode .ND ,
882
+ )
883
+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
884
+
885
+ scatter_out = scatter_layer .get_output (0 )
886
+ result = impl .elementwise .add (
887
+ ctx , target , source_ir , f"{ name } _add" , scatter_out , input_tensor
888
+ )
889
+ return result
890
+
891
+ else :
892
+ scatter_layer = ctx .net .add_scatter (
893
+ input_tensor ,
894
+ indices_cat ,
895
+ flattened_values ,
896
+ trt .ScatterMode .ND ,
897
+ )
898
+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
899
+ scatter_out = scatter_layer .get_output (0 )
900
+ return scatter_out
0 commit comments