Skip to content

Commit fe2c83a

Browse files
authored
Get decompositions only for CIA ops (#3297)
1 parent 04d68bd commit fe2c83a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44

55
import torch
66
from torch._decomp import register_decomposition
7+
from torch._export.utils import (
8+
_collect_all_valid_cia_ops_for_aten_namespace,
9+
_get_decomp_for_cia,
10+
)
711
from torch._ops import OpOverload
8-
from torch.export import default_decompositions
912
from torch_tensorrt.dynamo._defaults import default_device
1013
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
1114
from torch_tensorrt.dynamo.utils import to_torch_device
@@ -432,7 +435,10 @@ def get_decompositions(
432435
return {**CORE_ATEN_DECOMPOSITIONS_FILTERED, **TORCH_TRT_DECOMPOSITIONS}
433436
else:
434437
# changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080
435-
decomp_table = default_decompositions()
438+
decomp_table = {}
439+
for op in _collect_all_valid_cia_ops_for_aten_namespace():
440+
decomp_table[op] = _get_decomp_for_cia(op)
441+
436442
DECOMP_TABLE_FILTERED: Dict[OpOverload, Callable[[Any], Any]] = {
437443
decomp: decomp_table[decomp]
438444
for decomp in decomp_table

0 commit comments

Comments
 (0)