44#
55# Federico Brancasi <[email protected] > 66
7- from typing import List , Optional , Type
7+ from functools import partial
8+ from typing import List , Optional , Type , Callable
89
10+ import torch
911import torch .nn as nn
12+ from torch ._dynamo import allow_in_graph
13+ from torch .fx .graph_module import GraphModule
14+
1015from brevitas .fx .brevitas_tracer import (
1116 Tracer ,
1217 _is_brevitas_leaf_module ,
1318 _symbolic_trace ,
1419)
15- from torch .fx .graph_module import GraphModule
1620
1721
18- class QuantTracer (Tracer ):
22+ class QuantTracer ():
1923 """Enhanced tracer with fine-grained control over module tracing."""
2024
2125 def __init__ (
@@ -24,34 +28,58 @@ def __init__(
2428 nonLeafClasses : Optional [List [Type [nn .Module ]]] = None ,
2529 debug : bool = False ,
2630 ) -> None :
27- super ().__init__ ()
2831 self .leafClasses = leafClasses if leafClasses is not None else []
2932 self .nonLeafClasses = nonLeafClasses if nonLeafClasses is not None else []
3033 self .debug = debug
3134
3235 def registerLeafModule (self , moduleCls : Type [nn .Module ]) -> None :
33- """Register a module class as a leaf module."""
3436 if moduleCls not in self .leafClasses :
3537 self .leafClasses .append (moduleCls )
3638
3739 def registerNonLeafModule (self , moduleCls : Type [nn .Module ]) -> None :
38- """Register a module class as a non-leaf module."""
3940 if moduleCls not in self .nonLeafClasses :
4041 self .nonLeafClasses .append (moduleCls )
4142
42- def is_leaf_module (self , m : nn .Module , moduleQualifiedName : str ) -> bool :
43- """Determine if a module should be treated as a leaf module."""
44- if any (isinstance (m , lc ) for lc in self .leafClasses ):
45- return True
46- if any (isinstance (m , nlc ) for nlc in self .nonLeafClasses ):
47- return False
48- return _is_brevitas_leaf_module (m , moduleQualifiedName )
49-
50-
51- def customBrevitasTrace (
52- root : nn .Module , concreteArgs = None , tracer : Optional [QuantTracer ] = None
53- ) -> GraphModule :
54- """Create an FX GraphModule using the QuantTracer (a custom Brevitas tracer)."""
55- if tracer is None :
56- tracer = QuantTracer ()
57- return _symbolic_trace (tracer , root , concreteArgs )
43+ def trace (self , model : nn .Module , exampleInput ):
44+
45+ brevitasClasses = [(m , id (m .__class__ )) for _ , m in model .named_modules () if m .__module__ .startswith ('brevitas.nn' ) or m .__module__ .startswith ('brevitas.core' ) or m .__module__ .startswith ('brevitas.proxy' )]
46+
47+ # leafClasses = (set(brevitasClasses) | set(self.leafClasses)) - set(self.nonLeafClasses)
48+ leafClasses = brevitasClasses
49+ graphs : List [torch .fx .GraphModule ] = []
50+
51+ def dynamo_graph_extract_compiler (gm : GraphModule , inputs : torch .Tensor ) -> Callable :
52+ graphs .append (gm )
53+ return gm .forward
54+
55+ torch ._dynamo .reset ()
56+ torch ._dynamo .config .verbose = True
57+
58+ # for op, _ in leafClasses:
59+ # allow_in_graph(op.forward)
60+
61+ # for _, m in model.named_modules():
62+ # if m.__module__.startswith('brevitas.nn') or m.__module__.startswith('brevitas.core') or m.__module__.startswith('brevitas.proxy'):
63+ # allow_in_graph(m)
64+
65+ allow_in_graph (model .inputQuant .__class__ )
66+ allow_in_graph (model .inputQuant .forward )
67+
68+ allow_in_graph (model .linear1 .__class__ )
69+ allow_in_graph (model .linear1 .forward )
70+
71+ model_fn = torch .compile (model , backend = dynamo_graph_extract_compiler , dynamic = False )
72+ from brevitas .export .inference import quant_inference_mode
73+ with torch .no_grad (), quant_inference_mode (model_fn ):
74+ _ = model_fn (exampleInput )
75+
76+ import IPython ; IPython .embed ()
77+ return graphs [0 ]
78+
79+ # def customBrevitasTrace(
80+ # root: nn.Module, concreteArgs=None, tracer: Optional[QuantTracer] = None
81+ # ) -> GraphModule:
82+ # """Create an FX GraphModule using the QuantTracer (a custom Brevitas tracer)."""
83+ # if tracer is None:
84+ # tracer = QuantTracer()
85+ # return _symbolic_trace(tracer, root, concreteArgs)
0 commit comments