Skip to content

Commit f3d8a42

Browse files
authored
Refactor quantization.py to use TSP (#2522)
1 parent 35547b6 commit f3d8a42

File tree

1 file changed

+66
-38
lines changed

1 file changed

+66
-38
lines changed

thunder/transforms/quantization.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Sequence
2+
from thunder.core.trace_interpreter import TraceSubstitutionProcessor
23

34
import thunder
45
from thunder.core.transform_common import Transform
@@ -217,49 +218,76 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo
217218
if psym.shape != csym.shape or psym.dtype != csym.dtype
218219
}
219220

220-
new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map)
221-
bound_symbols = new_computation_trace.bound_symbols
222-
new_computation_trace.bound_symbols = []
223-
224-
new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs)
225-
new_computation_trace.names.update(i.name for i in new_compute_inputs)
226-
new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args]
221+
# Add new compute inputs to the trace args before processing
222+
computation_trace.args = (*computation_trace.args, *new_compute_inputs)
223+
computation_trace.names.update(i.name for i in new_compute_inputs)
224+
computation_trace._siginfo.args = [(a.name, None) for a in computation_trace.args]
227225

228-
with tracectx(new_computation_trace):
226+
# Add unpack_trivial bindings for new inputs in the correct position
227+
with tracectx(computation_trace):
229228
new_bindings = [
230229
thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs
231230
]
232231

233-
for idx, bsym in enumerate(bound_symbols):
234-
if bsym.sym != prims.unpack_trivial:
235-
break
236-
new_computation_trace.bound_symbols.append(bsym.from_bsym())
237-
new_computation_trace.bound_symbols += new_bindings
238-
239-
for bsym in bound_symbols[idx:]:
240-
if bsym.sym == thunder.torch.linear and bsym.args[1].name in quantized_proxies:
241-
assert len(bsym.args) == 3 # torch.linear(input, weight, bias)
242-
n = quantized_proxies[bsym.args[1].name]
243-
qs = self.quant_states[n]
244-
# signature of the new symbol:
245-
# bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)
246-
new_args = (
247-
*bsym.args[:3],
248-
additional_proxies[f"{n}.absmax"],
249-
additional_proxies[f"{n}.code"],
250-
qs["blocksize"],
251-
qs["dtype"],
252-
qs["shape"],
253-
)
254-
mm_bsym = bsym.from_bsym(
255-
sym=bnb_matmul_nf4,
256-
subsymbols=[],
257-
args=new_args,
258-
)
259-
260-
new_computation_trace.bound_symbols.append(mm_bsym)
261-
else:
262-
new_computation_trace.bound_symbols.append(bsym.from_bsym())
232+
# Insert the new bindings after the existing unpack_trivial bindings to maintain arg order
233+
# Find the last unpack_trivial binding and insert after it
234+
insert_idx = len(computation_trace.bound_symbols)
235+
for i, bsym in enumerate(computation_trace.bound_symbols):
236+
if bsym.sym.id == prims.PrimIDs.UNPACK_TRIVIAL:
237+
insert_idx = i + 1
238+
239+
computation_trace.bound_symbols[insert_idx:insert_idx] = new_bindings
240+
241+
# Now update metadata for the complete trace
242+
new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map)
243+
244+
class QuantizationProcessor(TraceSubstitutionProcessor):
245+
def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, new_compute_inputs):
246+
super().__init__(trace)
247+
self.quantized_proxies = quantized_proxies
248+
self.additional_proxies = additional_proxies
249+
self.quant_states = quant_states
250+
self.new_compute_inputs = new_compute_inputs
251+
252+
def process_bsym(self, bsym):
253+
if bsym.sym == thunder.torch.linear and bsym.args[1].name in self.quantized_proxies:
254+
assert len(bsym.args) == 3 # torch.linear(input, weight, bias)
255+
n = self.quantized_proxies[bsym.args[1].name]
256+
qs = self.quant_states[n]
257+
# signature of the new symbol:
258+
# bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)
259+
new_args = (
260+
*bsym.args[:3],
261+
self.additional_proxies[f"{n}.absmax"],
262+
self.additional_proxies[f"{n}.code"],
263+
qs["blocksize"],
264+
qs["dtype"],
265+
qs["shape"],
266+
)
267+
mm_bsym = bsym.from_bsym(
268+
sym=bnb_matmul_nf4,
269+
subsymbols=[],
270+
args=new_args,
271+
)
272+
self.add_processed_bsyms([mm_bsym])
273+
self.set_result(bsym.output)
274+
elif bsym.sym == prims.python_return:
275+
assert len(bsym.args) == 1 and isinstance(bsym.args[0], dict)
276+
new_return_dict = bsym.args[0].copy()
277+
new_return_dict["flat_args"] = list(self.new_trace.args) # we know that the args are flat
278+
self.add_processed_bsyms([bsym.from_bsym(args=(new_return_dict,))])
279+
self.set_result(bsym.output)
280+
else:
281+
# Keep the original symbol
282+
self.add_processed_bsyms([bsym.from_bsym()])
283+
self.set_result(bsym.output)
284+
285+
# Process the trace using the QuantizationProcessor
286+
processor = QuantizationProcessor(
287+
new_computation_trace, quantized_proxies, additional_proxies, self.quant_states, new_compute_inputs
288+
)
263289

290+
# Now process the trace
291+
new_computation_trace, _ = processor()
264292
new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass"))
265293
return prologue_trace, new_computation_trace, epilogue_trace

0 commit comments

Comments
 (0)