File tree Expand file tree Collapse file tree 1 file changed +23
-1
lines changed
py/torch_tensorrt/dynamo/conversion Expand file tree Collapse file tree 1 file changed +23
-1
lines changed Original file line number Diff line number Diff line change @@ -881,8 +881,30 @@ def aten_ops_select(
881
881
)
882
882
883
883
884
+ def index_put_indices_continuity_validator (
885
+ node : Node , settings : Optional [CompilationSettings ] = None
886
+ ) -> bool :
887
+ idxs = node .args [1 ] # this is a list of indices
888
+ present_indices = [idx is not None for idx in idxs ]
889
+ if len (present_indices ) == 0 :
890
+ return True
891
+ first_present_index = next ((i for i , v in enumerate (present_indices ) if v ), None )
892
+ if first_present_index is None :
893
+ return False
894
+ rev_index = next ((i for i , v in enumerate (reversed (present_indices )) if v ), None )
895
+ if rev_index is None :
896
+ return False
897
+ last_present_index = len (present_indices ) - 1 - rev_index
898
+ for i in range (first_present_index , last_present_index + 1 ):
899
+ if not present_indices [i ]:
900
+ return False
901
+ return True
902
+
903
+
884
904
@dynamo_tensorrt_converter (
885
- torch .ops .aten .index_put .default , supports_dynamic_shapes = True
905
+ torch .ops .aten .index_put .default ,
906
+ capability_validator = index_put_indices_continuity_validator ,
907
+ supports_dynamic_shapes = True ,
886
908
)
887
909
@enforce_tensor_types (
888
910
{
You can’t perform that action at this time.
0 commit comments