diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index 53d8f2b51d..e3682236bd 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from thunder.core.trace_interpreter import TraceSubstitutionProcessor import thunder from thunder.core.transform_common import Transform @@ -217,49 +218,76 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo if psym.shape != csym.shape or psym.dtype != csym.dtype } - new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) - bound_symbols = new_computation_trace.bound_symbols - new_computation_trace.bound_symbols = [] - - new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs) - new_computation_trace.names.update(i.name for i in new_compute_inputs) - new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args] + # Add new compute inputs to the trace args before processing + computation_trace.args = (*computation_trace.args, *new_compute_inputs) + computation_trace.names.update(i.name for i in new_compute_inputs) + computation_trace._siginfo.args = [(a.name, None) for a in computation_trace.args] - with tracectx(new_computation_trace): + # Add unpack_trivial bindings for new inputs in the correct position + with tracectx(computation_trace): new_bindings = [ thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs ] - for idx, bsym in enumerate(bound_symbols): - if bsym.sym != prims.unpack_trivial: - break - new_computation_trace.bound_symbols.append(bsym.from_bsym()) - new_computation_trace.bound_symbols += new_bindings - - for bsym in bound_symbols[idx:]: - if bsym.sym == thunder.torch.linear and bsym.args[1].name in quantized_proxies: - assert len(bsym.args) == 3 # torch.linear(input, weight, bias) - n = quantized_proxies[bsym.args[1].name] - qs = self.quant_states[n] - # signature of the new symbol: - # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape) - new_args = ( - *bsym.args[:3], - additional_proxies[f"{n}.absmax"], - additional_proxies[f"{n}.code"], - qs["blocksize"], - qs["dtype"], - qs["shape"], - ) - mm_bsym = bsym.from_bsym( - sym=bnb_matmul_nf4, - subsymbols=[], - args=new_args, - ) - - new_computation_trace.bound_symbols.append(mm_bsym) - else: - new_computation_trace.bound_symbols.append(bsym.from_bsym()) + # Insert the new bindings after the existing unpack_trivial bindings to maintain arg order + # Find the last unpack_trivial binding and insert after it + insert_idx = len(computation_trace.bound_symbols) + for i, bsym in enumerate(computation_trace.bound_symbols): + if bsym.sym.id == prims.PrimIDs.UNPACK_TRIVIAL: + insert_idx = i + 1 + + computation_trace.bound_symbols[insert_idx:insert_idx] = new_bindings + + # Now update metadata for the complete trace + new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) + + class QuantizationProcessor(TraceSubstitutionProcessor): + def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, new_compute_inputs): + super().__init__(trace) + self.quantized_proxies = quantized_proxies + self.additional_proxies = additional_proxies + self.quant_states = quant_states + self.new_compute_inputs = new_compute_inputs + + def process_bsym(self, bsym): + if bsym.sym == thunder.torch.linear and bsym.args[1].name in self.quantized_proxies: + assert len(bsym.args) == 3 # torch.linear(input, weight, bias) + n = self.quantized_proxies[bsym.args[1].name] + qs = self.quant_states[n] + # signature of the new symbol: + # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape) + new_args = ( + *bsym.args[:3], + self.additional_proxies[f"{n}.absmax"], + self.additional_proxies[f"{n}.code"], + qs["blocksize"], + qs["dtype"], + qs["shape"], + ) + mm_bsym = bsym.from_bsym( + sym=bnb_matmul_nf4, + subsymbols=[], + args=new_args, + ) + self.add_processed_bsyms([mm_bsym]) + self.set_result(bsym.output) + elif bsym.sym == prims.python_return: + assert len(bsym.args) == 1 and isinstance(bsym.args[0], dict) + new_return_dict = bsym.args[0].copy() + new_return_dict["flat_args"] = list(self.new_trace.args) # we know that the args are flat + self.add_processed_bsyms([bsym.from_bsym(args=(new_return_dict,))]) + self.set_result(bsym.output) + else: + # Keep the original symbol + self.add_processed_bsyms([bsym.from_bsym()]) + self.set_result(bsym.output) + + # Process the trace using the QuantizationProcessor + processor = QuantizationProcessor( + new_computation_trace, quantized_proxies, additional_proxies, self.quant_states, new_compute_inputs + ) + # Now process the trace + new_computation_trace, _ = processor() new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass")) return prologue_trace, new_computation_trace, epilogue_trace