diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index abf721198b..f1a86e3f45 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -23,6 +23,7 @@ get_positive_dim, is_only_operator_on_placeholder, ) +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -2694,9 +2695,18 @@ def sort_validator(node: Node, settings: Optional[CompilationSettings] = None) - def topk_sort_validator(k: int) -> bool: + + # topk layer supports dynamic k value but we cannot determine supported dynamic topk value at + # compile time. + if k == DYNAMIC_DIM: + _LOGGER.debug( + "[top_k validator] Converter does not support k being a dynamic value. Therefore, aten::topk will run in PyTorch" + ) + return False + if k > 3840: _LOGGER.debug( - f"Currently only topk values up to 3840 are supported, got k={k}." + f"[top_k validator] Currently only topk values up to 3840 are supported, got k={k}. Therefore, aten::topk will run in PyTorch" ) return False return True @@ -3160,7 +3170,9 @@ def aten_ops_upsample_bicubic2d( @dynamo_tensorrt_converter( - torch.ops.aten.topk.default, capability_validator=topk_validator + torch.ops.aten.topk.default, + capability_validator=topk_validator, + supports_dynamic_shapes=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/topk.py b/py/torch_tensorrt/dynamo/conversion/impl/topk.py index 053a46ce2b..638cbf599e 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/topk.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/topk.py @@ -209,10 +209,6 @@ def topk( get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))), ) - # topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at - # compile time. - assert k != DYNAMIC_DIM, "k value cannot be dynamic!" - # TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements # so here no matter sorted is True or False the returned the topk Tensor object is always sorted set_layer_name(topk_layer, target, f"{name}_topk", source_ir)