File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
py/torch_tensorrt/dynamo/lowering Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 44
55import torch
66from torch ._decomp import register_decomposition
7+ from torch ._export .utils import (
8+ _collect_all_valid_cia_ops_for_aten_namespace ,
9+ _get_decomp_for_cia ,
10+ )
711from torch ._ops import OpOverload
8- from torch .export import default_decompositions
912from torch_tensorrt .dynamo ._defaults import default_device
1013from torch_tensorrt .dynamo .conversion .converter_utils import get_positive_dim
1114from torch_tensorrt .dynamo .utils import to_torch_device
@@ -432,7 +435,10 @@ def get_decompositions(
432435 return {** CORE_ATEN_DECOMPOSITIONS_FILTERED , ** TORCH_TRT_DECOMPOSITIONS }
433436 else :
434437 # changes made here due to torch2.6 changes https://github.com/pytorch/pytorch/pull/135080
435- decomp_table = default_decompositions ()
438+ decomp_table = {}
439+ for op in _collect_all_valid_cia_ops_for_aten_namespace ():
440+ decomp_table [op ] = _get_decomp_for_cia (op )
441+
436442 DECOMP_TABLE_FILTERED : Dict [OpOverload , Callable [[Any ], Any ]] = {
437443 decomp : decomp_table [decomp ]
438444 for decomp in decomp_table
You can’t perform that action at this time.
0 commit comments