Skip to content

Commit a51e0f9

Browse files
committed
WIP: Use Dynamo Tracer
1 parent 71977cb commit a51e0f9

File tree

4 files changed

+95
-33
lines changed

4 files changed

+95
-33
lines changed

DeepQuant/Pipeline/Injection.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
MHATransformation,
1717
)
1818
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
19-
from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace
19+
from DeepQuant.Utils.CustomTracer import QuantTracer
2020
from DeepQuant.Utils.GraphPrinter import GraphModulePrinter
2121

2222

@@ -40,14 +40,17 @@ def injectCustomForwards(
4040
executor = TransformationExecutor(transformations, debug=debug, tracer=tracer)
4141
transformedModel = executor.execute(model, exampleInput)
4242

43-
fxModel = customBrevitasTrace(
44-
root=transformedModel,
45-
tracer=tracer,
46-
)
47-
fxModel.recompile()
43+
# fxModel = customBrevitasTrace(
44+
# root=transformedModel,
45+
# tracer=tracer,
46+
# )
47+
# fxModel.recompile()
48+
import IPython; IPython.embed()
49+
fxModel = tracer.trace(transformedModel, exampleInput)
50+
import IPython; IPython.embed()
4851

49-
with torch.no_grad():
50-
output = fxModel(exampleInput)
52+
# with torch.no_grad():
53+
output = fxModel(exampleInput)
5154

5255
if torch.allclose(referenceOutput, output, atol=1e-5):
5356
if debug:

DeepQuant/Pipeline/OriginalTracing.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,44 @@
1414
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
1515
from DeepQuant.Utils.GraphPrinter import GraphModulePrinter
1616

17+
from torch._dynamo import allow_in_graph
1718

1819
def traceOriginalModel(
1920
model: nn.Module, exampleInput: torch.Tensor, debug: bool = False
2021
) -> Tuple[nn.Module, torch.Tensor]:
2122
"""Symbolically trace the original model using Brevitas."""
2223
printer = GraphModulePrinter()
2324

24-
tracedModel = brevitas_symbolic_trace(model)
25+
# tracedModel = brevitas_symbolic_trace(model)
26+
graphs = []
27+
28+
def dynamo_graph_extract_compiler(gm, inputs: torch.Tensor):
29+
graphs.append(gm)
30+
return gm.forward
31+
32+
33+
torch._dynamo.reset()
34+
torch._dynamo.config.verbose = True
35+
36+
allow_in_graph(model.inputQuant)
37+
allow_in_graph(model.inputQuant.__class__)
38+
allow_in_graph(model.inputQuant.forward)
39+
40+
allow_in_graph(model.inputQuant)
41+
allow_in_graph(model.linear1.__class__)
42+
allow_in_graph(model.linear1.forward)
43+
# JUNGVI: For Philip, dynamo uses the id of the thing passed in allow_in_graph to filter them. But it does not seems to work at least for brevitas layers, IDK if they have smth special...
44+
45+
import IPython; IPython.embed()
46+
47+
model_fn = torch.compile(model, backend = dynamo_graph_extract_compiler, dynamic = False)
48+
49+
with torch.no_grad():
50+
_ = model_fn(exampleInput)
51+
52+
import IPython; IPython.embed()
53+
54+
tracedModel = graphs[0]
2555

2656
if debug:
2757
print(cc.header("1. Original Network"))

DeepQuant/QuantManipulation/QuantNodesDivider.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def insertQuantDequantPair(
4343
Dequant(originalModule, scaleVal, zpVal, bwVal, signed=signedVal),
4444
)
4545

46-
with graph.inserting_after(node):
47-
quantNode = graph.call_module(quantName, args=(mainArg,))
46+
with fxModel.graph.inserting_after(node):
47+
import IPython; IPython.embed()
48+
quantNode = fxModel.graph.call_module(quantName, args=(mainArg,))
4849

4950
with graph.inserting_after(quantNode):
5051
dequantNode = graph.call_module(dequantName, args=(quantNode,))

DeepQuant/Utils/CustomTracer.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,22 @@
44
#
55
# Federico Brancasi <[email protected]>
66

7-
from typing import List, Optional, Type
7+
from functools import partial
8+
from typing import List, Optional, Type, Callable
89

10+
import torch
911
import torch.nn as nn
12+
from torch._dynamo import allow_in_graph
13+
from torch.fx.graph_module import GraphModule
14+
1015
from brevitas.fx.brevitas_tracer import (
1116
Tracer,
1217
_is_brevitas_leaf_module,
1318
_symbolic_trace,
1419
)
15-
from torch.fx.graph_module import GraphModule
1620

1721

18-
class QuantTracer(Tracer):
22+
class QuantTracer():
1923
"""Enhanced tracer with fine-grained control over module tracing."""
2024

2125
def __init__(
@@ -24,34 +28,58 @@ def __init__(
2428
nonLeafClasses: Optional[List[Type[nn.Module]]] = None,
2529
debug: bool = False,
2630
) -> None:
27-
super().__init__()
2831
self.leafClasses = leafClasses if leafClasses is not None else []
2932
self.nonLeafClasses = nonLeafClasses if nonLeafClasses is not None else []
3033
self.debug = debug
3134

3235
def registerLeafModule(self, moduleCls: Type[nn.Module]) -> None:
33-
"""Register a module class as a leaf module."""
3436
if moduleCls not in self.leafClasses:
3537
self.leafClasses.append(moduleCls)
3638

3739
def registerNonLeafModule(self, moduleCls: Type[nn.Module]) -> None:
38-
"""Register a module class as a non-leaf module."""
3940
if moduleCls not in self.nonLeafClasses:
4041
self.nonLeafClasses.append(moduleCls)
4142

42-
def is_leaf_module(self, m: nn.Module, moduleQualifiedName: str) -> bool:
43-
"""Determine if a module should be treated as a leaf module."""
44-
if any(isinstance(m, lc) for lc in self.leafClasses):
45-
return True
46-
if any(isinstance(m, nlc) for nlc in self.nonLeafClasses):
47-
return False
48-
return _is_brevitas_leaf_module(m, moduleQualifiedName)
49-
50-
51-
def customBrevitasTrace(
52-
root: nn.Module, concreteArgs=None, tracer: Optional[QuantTracer] = None
53-
) -> GraphModule:
54-
"""Create an FX GraphModule using the QuantTracer (a custom Brevitas tracer)."""
55-
if tracer is None:
56-
tracer = QuantTracer()
57-
return _symbolic_trace(tracer, root, concreteArgs)
43+
def trace(self, model: nn.Module, exampleInput):
44+
45+
brevitasClasses = [(m, id(m.__class__)) for _, m in model.named_modules() if m.__module__.startswith('brevitas.nn') or m.__module__.startswith('brevitas.core') or m.__module__.startswith('brevitas.proxy')]
46+
47+
# leafClasses = (set(brevitasClasses) | set(self.leafClasses)) - set(self.nonLeafClasses)
48+
leafClasses = brevitasClasses
49+
graphs: List[torch.fx.GraphModule] = []
50+
51+
def dynamo_graph_extract_compiler(gm: GraphModule, inputs: torch.Tensor) -> Callable:
52+
graphs.append(gm)
53+
return gm.forward
54+
55+
torch._dynamo.reset()
56+
torch._dynamo.config.verbose = True
57+
58+
# for op, _ in leafClasses:
59+
# allow_in_graph(op.forward)
60+
61+
# for _, m in model.named_modules():
62+
# if m.__module__.startswith('brevitas.nn') or m.__module__.startswith('brevitas.core') or m.__module__.startswith('brevitas.proxy'):
63+
# allow_in_graph(m)
64+
65+
allow_in_graph(model.inputQuant.__class__)
66+
allow_in_graph(model.inputQuant.forward)
67+
68+
allow_in_graph(model.linear1.__class__)
69+
allow_in_graph(model.linear1.forward)
70+
71+
model_fn = torch.compile(model, backend = dynamo_graph_extract_compiler, dynamic = False)
72+
from brevitas.export.inference import quant_inference_mode
73+
with torch.no_grad(), quant_inference_mode(model_fn):
74+
_ = model_fn(exampleInput)
75+
76+
import IPython; IPython.embed()
77+
return graphs[0]
78+
79+
# def customBrevitasTrace(
80+
# root: nn.Module, concreteArgs=None, tracer: Optional[QuantTracer] = None
81+
# ) -> GraphModule:
82+
# """Create an FX GraphModule using the QuantTracer (a custom Brevitas tracer)."""
83+
# if tracer is None:
84+
# tracer = QuantTracer()
85+
# return _symbolic_trace(tracer, root, concreteArgs)

0 commit comments

Comments
 (0)