Skip to content

Commit 8a33160

Browse files
committed
Added the dynamic check in the validator
1 parent 40d4d41 commit 8a33160

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
get_positive_dim,
2424
is_only_operator_on_placeholder,
2525
)
26+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2627

2728
_LOGGER: logging.Logger = logging.getLogger(__name__)
2829

@@ -2694,6 +2695,13 @@ def sort_validator(node: Node, settings: Optional[CompilationSettings] = None) -
26942695

26952696

26962697
def topk_sort_validator(k: int) -> bool:
2698+
2699+
# topk layer supports dynamic k value but we cannot determine supported dynamic topk value at
2700+
# compile time.
2701+
if k == DYNAMIC_DIM:
2702+
_LOGGER.debug("k value cannot be dynamic!")
2703+
return False
2704+
26972705
if k > 3840:
26982706
_LOGGER.debug(
26992707
f"Currently only topk values up to 3840 are supported, got k={k}."
@@ -3160,7 +3168,9 @@ def aten_ops_upsample_bicubic2d(
31603168

31613169

31623170
@dynamo_tensorrt_converter(
3163-
torch.ops.aten.topk.default, capability_validator=topk_validator
3171+
torch.ops.aten.topk.default,
3172+
capability_validator=topk_validator,
3173+
supports_dynamic_shapes=True,
31643174
)
31653175
@enforce_tensor_types(
31663176
{

0 commit comments

Comments
 (0)