Skip to content

Commit b3f6f59

Browse files
committed
index_put validator for discontinuous cases
1 parent 4210eba commit b3f6f59

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,8 +881,30 @@ def aten_ops_select(
881881
)
882882

883883

884+
def index_put_indices_continuity_validator(
885+
node: Node, settings: Optional[CompilationSettings] = None
886+
) -> bool:
887+
idxs = node.args[1] # this is a list of indices
888+
present_indices = [idx is not None for idx in idxs]
889+
if len(present_indices) == 0:
890+
return True
891+
first_present_index = next((i for i, v in enumerate(present_indices) if v), None)
892+
if first_present_index is None:
893+
return False
894+
rev_index = next((i for i, v in enumerate(reversed(present_indices)) if v), None)
895+
if rev_index is None:
896+
return False
897+
last_present_index = len(present_indices) - 1 - rev_index
898+
for i in range(first_present_index, last_present_index + 1):
899+
if not present_indices[i]:
900+
return False
901+
return True
902+
903+
884904
@dynamo_tensorrt_converter(
885-
torch.ops.aten.index_put.default, supports_dynamic_shapes=True
905+
torch.ops.aten.index_put.default,
906+
capability_validator=index_put_indices_continuity_validator,
907+
supports_dynamic_shapes=True,
886908
)
887909
@enforce_tensor_types(
888910
{

0 commit comments

Comments
 (0)