Skip to content
Merged
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 53 additions & 32 deletions thunder/transforms/quantization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from thunder.core.trace_interpreter import TraceSubstitutionProcessor

import thunder
from thunder.core.transform_common import Transform
Expand Down Expand Up @@ -218,48 +219,68 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo
}

new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map)
bound_symbols = new_computation_trace.bound_symbols
new_computation_trace.bound_symbols = []

class QuantizationProcessor(TraceSubstitutionProcessor):
def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, new_compute_inputs):
super().__init__(trace)
self.quantized_proxies = quantized_proxies
self.additional_proxies = additional_proxies
self.quant_states = quant_states
self.new_compute_inputs = new_compute_inputs

def process_bsym(self, bsym):
if bsym.sym == thunder.torch.linear and bsym.args[1].name in self.quantized_proxies:
assert len(bsym.args) == 3 # torch.linear(input, weight, bias)
n = self.quantized_proxies[bsym.args[1].name]
qs = self.quant_states[n]
# signature of the new symbol:
# bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)
new_args = (
*bsym.args[:3],
self.additional_proxies[f"{n}.absmax"],
self.additional_proxies[f"{n}.code"],
qs["blocksize"],
qs["dtype"],
qs["shape"],
)
mm_bsym = bsym.from_bsym(
sym=bnb_matmul_nf4,
subsymbols=[],
args=new_args,
)
self.add_processed_bsyms([mm_bsym])
self.set_result(bsym.output)
elif bsym.sym == prims.python_return:
assert len(bsym.args) == 1 and isinstance(bsym.args[0], dict)
new_return_dict = bsym.args[0].copy()
new_return_dict["flat_args"] = list(self.new_trace.args) # we know that the args are flat
self.add_processed_bsyms([bsym.from_bsym(args=(new_return_dict,))])
self.set_result(bsym.output)
else:
# Keep the original symbol
self.add_processed_bsyms([bsym.from_bsym()])
self.set_result(bsym.output)

# Process the trace using the QuantizationProcessor
processor = QuantizationProcessor(
new_computation_trace, quantized_proxies, additional_proxies, self.quant_states, new_compute_inputs
)

# Add new compute inputs to the trace args before processing
new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs)
new_computation_trace.names.update(i.name for i in new_compute_inputs)
new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args]

# Add unpack_trivial bindings for new inputs
with tracectx(new_computation_trace):
new_bindings = [
thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs
]

for idx, bsym in enumerate(bound_symbols):
if bsym.sym != prims.unpack_trivial:
break
new_computation_trace.bound_symbols.append(bsym.from_bsym())
new_computation_trace.bound_symbols += new_bindings

for bsym in bound_symbols[idx:]:
if bsym.sym == thunder.torch.linear and bsym.args[1].name in quantized_proxies:
assert len(bsym.args) == 3 # torch.linear(input, weight, bias)
n = quantized_proxies[bsym.args[1].name]
qs = self.quant_states[n]
# signature of the new symbol:
# bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)
new_args = (
*bsym.args[:3],
additional_proxies[f"{n}.absmax"],
additional_proxies[f"{n}.code"],
qs["blocksize"],
qs["dtype"],
qs["shape"],
)
mm_bsym = bsym.from_bsym(
sym=bnb_matmul_nf4,
subsymbols=[],
args=new_args,
)

new_computation_trace.bound_symbols.append(mm_bsym)
else:
new_computation_trace.bound_symbols.append(bsym.from_bsym())
# Insert the new bindings at the beginning of the trace
new_computation_trace.bound_symbols = new_bindings + new_computation_trace.bound_symbols

# Now process the trace
new_computation_trace, _ = processor()
new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass"))
return prologue_trace, new_computation_trace, epilogue_trace
Loading