99
1010
1111def register_atomic_subgraph (
12- is_aten : bool = False ,
12+ is_core_aten : bool = False ,
1313) -> Callable [[torch .nn .Module ], torch .nn .Module ]:
1414
1515 def decorator (subgraph : torch .nn .Module ) -> torch .nn .Module :
16- ATOMIC_SUBGRAPHS .append ((subgraph , is_aten ))
16+ ATOMIC_SUBGRAPHS .append ((subgraph , is_core_aten ))
1717 return subgraph
1818
1919 return decorator
2020
2121
22- @register_atomic_subgraph (is_aten = True )
22+ @register_atomic_subgraph (is_core_aten = True )
2323class ConvBNReLU (torch .nn .Module ): # type: ignore[misc]
2424 def __init__ (self ) -> None :
2525 super ().__init__ ()
@@ -60,7 +60,7 @@ def forward(
6060 return x
6161
6262
63- @register_atomic_subgraph (is_aten = True )
63+ @register_atomic_subgraph (is_core_aten = True )
6464class ConvReLU (torch .nn .Module ): # type: ignore[misc]
6565 def __init__ (self ) -> None :
6666 super ().__init__ ()
@@ -92,7 +92,7 @@ def forward(
9292 return x
9393
9494
95- @register_atomic_subgraph (is_aten = True )
95+ @register_atomic_subgraph (is_core_aten = True )
9696class ConvGelu (torch .nn .Module ): # type: ignore[misc]
9797 def __init__ (self ) -> None :
9898 super ().__init__ ()
@@ -124,7 +124,7 @@ def forward(
124124 return x
125125
126126
127- @register_atomic_subgraph (is_aten = True )
127+ @register_atomic_subgraph (is_core_aten = True )
128128class ConvSilu (torch .nn .Module ): # type: ignore[misc]
129129 def __init__ (self ) -> None :
130130 super ().__init__ ()
@@ -139,7 +139,7 @@ def forward(
139139 return x
140140
141141
142- @register_atomic_subgraph (is_aten = True )
142+ @register_atomic_subgraph (is_core_aten = True )
143143class MulAdd (torch .nn .Module ): # type: ignore[misc]
144144 def __init__ (self ) -> None :
145145 super ().__init__ ()
@@ -152,7 +152,7 @@ def forward(
152152 return x
153153
154154
155- @register_atomic_subgraph (is_aten = True )
155+ @register_atomic_subgraph (is_core_aten = True )
156156class MulMul (torch .nn .Module ): # type: ignore[misc]
157157 def __init__ (self ) -> None :
158158 super ().__init__ ()
@@ -192,19 +192,30 @@ def get_node_in_fusion_pattern(
192192 return fusion_nodes
193193
194194
195- @lru_cache (maxsize = None )
196195def get_compiled_atomic_subgraphs () -> List [torch .fx .GraphModule ]:
197196 """
198197 This function gets the compiled atomic subgraphs from the graph.
199198 LRU cache the result to avoid recompiling the same pattern multiple times.
200199 """
201200 compiled_atomic_subgraphs = []
202- for pattern , is_aten in ATOMIC_SUBGRAPHS :
203- pattern_graph = torch . fx . symbolic_trace (pattern () )
204- if not is_aten :
205- # TODO: Add decomposition and lowering if is_aten is False
201+ for pattern , is_core_aten in ATOMIC_SUBGRAPHS :
202+ pattern_graph = trace_atomic_graph (pattern , is_core_aten )
203+ if not is_core_aten :
204+ # TODO: Add decomposition and lowering if is_core_aten is False
206205 raise NotImplementedError (
207206 "Atomic subgraphs are not supported for non-aten subgraphs yet."
208207 )
209208 compiled_atomic_subgraphs .append (pattern_graph )
210209 return compiled_atomic_subgraphs
210+
211+
212+ @lru_cache (maxsize = None )
213+ def trace_atomic_graph (
214+ graph : torch .nn .Module , is_core_aten : bool = True
215+ ) -> torch .fx .GraphModule :
216+ if is_core_aten :
217+ return torch .fx .symbolic_trace (graph ())
218+ else :
219+ raise NotImplementedError (
220+ "Resource partitioner currently does not support unlowered atomic subgraphs"
221+ )
0 commit comments