|
1 | 1 | from collections.abc import Sequence
|
| 2 | +from thunder.core.trace_interpreter import TraceSubstitutionProcessor |
2 | 3 |
|
3 | 4 | import thunder
|
4 | 5 | from thunder.core.transform_common import Transform
|
@@ -217,49 +218,76 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo
|
217 | 218 | if psym.shape != csym.shape or psym.dtype != csym.dtype
|
218 | 219 | }
|
219 | 220 |
|
220 |
| - new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) |
221 |
| - bound_symbols = new_computation_trace.bound_symbols |
222 |
| - new_computation_trace.bound_symbols = [] |
223 |
| - |
224 |
| - new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs) |
225 |
| - new_computation_trace.names.update(i.name for i in new_compute_inputs) |
226 |
| - new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args] |
| 221 | + # Add new compute inputs to the trace args before processing |
| 222 | + computation_trace.args = (*computation_trace.args, *new_compute_inputs) |
| 223 | + computation_trace.names.update(i.name for i in new_compute_inputs) |
| 224 | + computation_trace._siginfo.args = [(a.name, None) for a in computation_trace.args] |
227 | 225 |
|
228 |
| - with tracectx(new_computation_trace): |
| 226 | + # Add unpack_trivial bindings for new inputs in the correct position |
| 227 | + with tracectx(computation_trace): |
229 | 228 | new_bindings = [
|
230 | 229 | thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs
|
231 | 230 | ]
|
232 | 231 |
|
233 |
| - for idx, bsym in enumerate(bound_symbols): |
234 |
| - if bsym.sym != prims.unpack_trivial: |
235 |
| - break |
236 |
| - new_computation_trace.bound_symbols.append(bsym.from_bsym()) |
237 |
| - new_computation_trace.bound_symbols += new_bindings |
238 |
| - |
239 |
| - for bsym in bound_symbols[idx:]: |
240 |
| - if bsym.sym == thunder.torch.linear and bsym.args[1].name in quantized_proxies: |
241 |
| - assert len(bsym.args) == 3 # torch.linear(input, weight, bias) |
242 |
| - n = quantized_proxies[bsym.args[1].name] |
243 |
| - qs = self.quant_states[n] |
244 |
| - # signature of the new symbol: |
245 |
| - # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape) |
246 |
| - new_args = ( |
247 |
| - *bsym.args[:3], |
248 |
| - additional_proxies[f"{n}.absmax"], |
249 |
| - additional_proxies[f"{n}.code"], |
250 |
| - qs["blocksize"], |
251 |
| - qs["dtype"], |
252 |
| - qs["shape"], |
253 |
| - ) |
254 |
| - mm_bsym = bsym.from_bsym( |
255 |
| - sym=bnb_matmul_nf4, |
256 |
| - subsymbols=[], |
257 |
| - args=new_args, |
258 |
| - ) |
259 |
| - |
260 |
| - new_computation_trace.bound_symbols.append(mm_bsym) |
261 |
| - else: |
262 |
| - new_computation_trace.bound_symbols.append(bsym.from_bsym()) |
| 232 | + # Insert the new bindings after the existing unpack_trivial bindings to maintain arg order |
| 233 | + # Find the last unpack_trivial binding and insert after it |
| 234 | + insert_idx = len(computation_trace.bound_symbols) |
| 235 | + for i, bsym in enumerate(computation_trace.bound_symbols): |
| 236 | + if bsym.sym.id == prims.PrimIDs.UNPACK_TRIVIAL: |
| 237 | + insert_idx = i + 1 |
| 238 | + |
| 239 | + computation_trace.bound_symbols[insert_idx:insert_idx] = new_bindings |
| 240 | + |
| 241 | + # Now update metadata for the complete trace |
| 242 | + new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) |
| 243 | + |
| 244 | + class QuantizationProcessor(TraceSubstitutionProcessor): |
| 245 | + def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, new_compute_inputs): |
| 246 | + super().__init__(trace) |
| 247 | + self.quantized_proxies = quantized_proxies |
| 248 | + self.additional_proxies = additional_proxies |
| 249 | + self.quant_states = quant_states |
| 250 | + self.new_compute_inputs = new_compute_inputs |
| 251 | + |
| 252 | + def process_bsym(self, bsym): |
| 253 | + if bsym.sym == thunder.torch.linear and bsym.args[1].name in self.quantized_proxies: |
| 254 | + assert len(bsym.args) == 3 # torch.linear(input, weight, bias) |
| 255 | + n = self.quantized_proxies[bsym.args[1].name] |
| 256 | + qs = self.quant_states[n] |
| 257 | + # signature of the new symbol: |
| 258 | + # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape) |
| 259 | + new_args = ( |
| 260 | + *bsym.args[:3], |
| 261 | + self.additional_proxies[f"{n}.absmax"], |
| 262 | + self.additional_proxies[f"{n}.code"], |
| 263 | + qs["blocksize"], |
| 264 | + qs["dtype"], |
| 265 | + qs["shape"], |
| 266 | + ) |
| 267 | + mm_bsym = bsym.from_bsym( |
| 268 | + sym=bnb_matmul_nf4, |
| 269 | + subsymbols=[], |
| 270 | + args=new_args, |
| 271 | + ) |
| 272 | + self.add_processed_bsyms([mm_bsym]) |
| 273 | + self.set_result(bsym.output) |
| 274 | + elif bsym.sym == prims.python_return: |
| 275 | + assert len(bsym.args) == 1 and isinstance(bsym.args[0], dict) |
| 276 | + new_return_dict = bsym.args[0].copy() |
| 277 | + new_return_dict["flat_args"] = list(self.new_trace.args) # we know that the args are flat |
| 278 | + self.add_processed_bsyms([bsym.from_bsym(args=(new_return_dict,))]) |
| 279 | + self.set_result(bsym.output) |
| 280 | + else: |
| 281 | + # Keep the original symbol |
| 282 | + self.add_processed_bsyms([bsym.from_bsym()]) |
| 283 | + self.set_result(bsym.output) |
| 284 | + |
| 285 | + # Process the trace using the QuantizationProcessor |
| 286 | + processor = QuantizationProcessor( |
| 287 | + new_computation_trace, quantized_proxies, additional_proxies, self.quant_states, new_compute_inputs |
| 288 | + ) |
263 | 289 |
|
| 290 | + # Now process the trace |
| 291 | + new_computation_trace, _ = processor() |
264 | 292 | new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass"))
|
265 | 293 | return prologue_trace, new_computation_trace, epilogue_trace
|
0 commit comments