Skip to content

Commit 07d8786

Browse files
committed
Fixed the comments
1 parent 2d6053e commit 07d8786

File tree

2 files changed

+50
-13
lines changed

2 files changed

+50
-13
lines changed

examples/dynamo/low_cpu_memory_compilation.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,29 @@ def forward(self, x):
8282

8383
# Expect two back-to-back TensorRT engines due to partitioning under the memory budget.
8484
print(trt_gm)
85+
86+
87+
"""
88+
You should be able to see two back-to-back TensorRT engines in the graph
89+
Graph Structure:
90+
91+
Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
92+
...
93+
TRT Engine #1 - Submodule name: _run_on_acc_0
94+
Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
95+
Number of Operators in Engine: 9
96+
Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32]
97+
...
98+
TRT Engine #2 - Submodule name: _run_on_acc_1
99+
Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32]
100+
Number of Operators in Engine: 3
101+
Engine Outputs: List[Tensor: (1, 10)@float32]
102+
...
103+
Outputs: List[Tensor: (1, 10)@float32]
104+
105+
106+
GraphModule(
107+
(_run_on_acc_0): TorchTensorRTModule()
108+
(_run_on_acc_1): TorchTensorRTModule()
109+
)
110+
"""

py/torch_tensorrt/dynamo/partitioning/_atomic_subgraphs.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,17 @@
99

1010

1111
def register_atomic_subgraph(
12-
is_aten: bool = False,
12+
is_core_aten: bool = False,
1313
) -> Callable[[torch.nn.Module], torch.nn.Module]:
1414

1515
def decorator(subgraph: torch.nn.Module) -> torch.nn.Module:
16-
ATOMIC_SUBGRAPHS.append((subgraph, is_aten))
16+
ATOMIC_SUBGRAPHS.append((subgraph, is_core_aten))
1717
return subgraph
1818

1919
return decorator
2020

2121

22-
@register_atomic_subgraph(is_aten=True)
22+
@register_atomic_subgraph(is_core_aten=True)
2323
class ConvBNReLU(torch.nn.Module): # type: ignore[misc]
2424
def __init__(self) -> None:
2525
super().__init__()
@@ -60,7 +60,7 @@ def forward(
6060
return x
6161

6262

63-
@register_atomic_subgraph(is_aten=True)
63+
@register_atomic_subgraph(is_core_aten=True)
6464
class ConvReLU(torch.nn.Module): # type: ignore[misc]
6565
def __init__(self) -> None:
6666
super().__init__()
@@ -92,7 +92,7 @@ def forward(
9292
return x
9393

9494

95-
@register_atomic_subgraph(is_aten=True)
95+
@register_atomic_subgraph(is_core_aten=True)
9696
class ConvGelu(torch.nn.Module): # type: ignore[misc]
9797
def __init__(self) -> None:
9898
super().__init__()
@@ -124,7 +124,7 @@ def forward(
124124
return x
125125

126126

127-
@register_atomic_subgraph(is_aten=True)
127+
@register_atomic_subgraph(is_core_aten=True)
128128
class ConvSilu(torch.nn.Module): # type: ignore[misc]
129129
def __init__(self) -> None:
130130
super().__init__()
@@ -139,7 +139,7 @@ def forward(
139139
return x
140140

141141

142-
@register_atomic_subgraph(is_aten=True)
142+
@register_atomic_subgraph(is_core_aten=True)
143143
class MulAdd(torch.nn.Module): # type: ignore[misc]
144144
def __init__(self) -> None:
145145
super().__init__()
@@ -152,7 +152,7 @@ def forward(
152152
return x
153153

154154

155-
@register_atomic_subgraph(is_aten=True)
155+
@register_atomic_subgraph(is_core_aten=True)
156156
class MulMul(torch.nn.Module): # type: ignore[misc]
157157
def __init__(self) -> None:
158158
super().__init__()
@@ -192,19 +192,30 @@ def get_node_in_fusion_pattern(
192192
return fusion_nodes
193193

194194

195-
@lru_cache(maxsize=None)
196195
def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]:
197196
"""
198197
This function gets the compiled atomic subgraphs from the graph.
199198
LRU cache the result to avoid recompiling the same pattern multiple times.
200199
"""
201200
compiled_atomic_subgraphs = []
202-
for pattern, is_aten in ATOMIC_SUBGRAPHS:
203-
pattern_graph = torch.fx.symbolic_trace(pattern())
204-
if not is_aten:
205-
# TODO: Add decomposition and lowering if is_aten is False
201+
for pattern, is_core_aten in ATOMIC_SUBGRAPHS:
202+
pattern_graph = trace_atomic_graph(pattern, is_core_aten)
203+
if not is_core_aten:
204+
# TODO: Add decomposition and lowering if is_core_aten is False
206205
raise NotImplementedError(
207206
"Atomic subgraphs are not supported for non-aten subgraphs yet."
208207
)
209208
compiled_atomic_subgraphs.append(pattern_graph)
210209
return compiled_atomic_subgraphs
210+
211+
212+
@lru_cache(maxsize=None)
213+
def trace_atomic_graph(
214+
graph: torch.nn.Module, is_core_aten: bool = True
215+
) -> torch.fx.GraphModule:
216+
if is_core_aten:
217+
return torch.fx.symbolic_trace(graph())
218+
else:
219+
raise NotImplementedError(
220+
"Resource partitioner currently does not support unlowered atomic subgraphs"
221+
)

0 commit comments

Comments
 (0)