From 635d50503b638013cc746cbf7607ed6821bdd2ee Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Fri, 20 Jun 2025 17:44:22 -0700 Subject: [PATCH 01/11] added workflow file --- .github/workflows/coverage_test.yml | 42 +++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .github/workflows/coverage_test.yml diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml new file mode 100644 index 0000000000..c6315110a4 --- /dev/null +++ b/.github/workflows/coverage_test.yml @@ -0,0 +1,42 @@ +name: Coverage Test + +on: + workflow_dispatch: {} + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: false + +defaults: + run: + shell: bash + +jobs: + coverage-test: + runs-on: ubuntu-22.04 + + steps: + - name: Checkout Code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Build Thunder Package + run: | + pip install -U build + python -m build --sdist --wheel --outdir dist/ + ls -l dist/ + + - name: Install Dependencies + run: | + pip install -U pip + pip install torch transformers lightning_sdk + pip install pytest pytest-benchmark + + - name: Run Coverage Trace Test + env: + ALLOW_COVERAGE_TRACE: "1" + run: pytest thunder/tests/test_coverage_trace.py \ No newline at end of file From e0e0c98a758cee989a23b4d756db445facfd711e Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Sun, 22 Jun 2025 22:49:00 -0700 Subject: [PATCH 02/11] fixed deps with pytest --- .github/workflows/coverage_test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index c6315110a4..eba010d50d 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -35,6 +35,7 @@ jobs: pip install -U pip pip install torch transformers lightning_sdk pip install pytest pytest-benchmark + pip install looseversion - name: Run Coverage Trace Test env: From 6c81782879b870c36d5a44f3fa2837d998e03ee7 Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Mon, 23 Jun 2025 11:29:03 -0700 Subject: [PATCH 03/11] fixed deps --- .github/workflows/coverage_test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index eba010d50d..baed399e06 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -35,9 +35,9 @@ jobs: pip install -U pip pip install torch transformers lightning_sdk pip install pytest pytest-benchmark - pip install looseversion + pip install -r requirements/base.txt - name: Run Coverage Trace Test env: ALLOW_COVERAGE_TRACE: "1" - run: pytest thunder/tests/test_coverage_trace.py \ No newline at end of file + run: pytest thunder/tests/test_coverage_trace.pygit \ No newline at end of file From 48d0a396d6c28bc608971b9c80be90a05b0e712b Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Mon, 23 Jun 2025 12:22:48 -0700 Subject: [PATCH 04/11] fixed typo --- .github/workflows/coverage_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index baed399e06..27c63aa68e 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -40,4 +40,4 @@ jobs: - name: Run Coverage Trace Test env: ALLOW_COVERAGE_TRACE: "1" - run: pytest thunder/tests/test_coverage_trace.pygit \ No newline at end of file + run: pytest thunder/tests/test_coverage_trace.py \ No newline at end of file From b91aeeeda964305e2d053e953bc5da66f3399c4d Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Mon, 23 Jun 2025 15:56:19 -0700 Subject: [PATCH 05/11] updated python path --- .github/workflows/coverage_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index 27c63aa68e..66f92c264c 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -40,4 +40,4 @@ jobs: - name: Run Coverage Trace Test env: ALLOW_COVERAGE_TRACE: "1" - run: pytest thunder/tests/test_coverage_trace.py \ No newline at end of file + run: PYTHONPATH=$(pwd)/thunder/tests pytest thunder/tests/test_coverage_trace.py \ No newline at end of file From b500737a6b89611fd555697e1026e78a293a6c17 Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Mon, 22 Sep 2025 22:25:37 -0700 Subject: [PATCH 06/11] refactored to use trace substituion processor --- thunder/transforms/quantization.py | 99 +++++++++++++++++++----------- 1 file changed, 63 insertions(+), 36 deletions(-) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index 51e05d5454..ffa1ddd765 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from thunder.core.trace_interpreter import TraceSubstitutionProcessor import thunder from thunder.core.transform_common import Transform @@ -218,48 +219,74 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo } new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) - bound_symbols = new_computation_trace.bound_symbols - new_computation_trace.bound_symbols = [] - + class QuantizationProcessor(TraceSubstitutionProcessor): + def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, new_compute_inputs): + super().__init__(trace) + self.quantized_proxies = quantized_proxies + self.additional_proxies = additional_proxies + self.quant_states = quant_states + self.new_compute_inputs = new_compute_inputs + def process_bsym(self, bsym): + if bsym.sym == thunder.torch.linear and bsym.args[1].name in self.quantized_proxies: + assert len(bsym.args) == 3 # torch.linear(input, weight, bias) + n = self.quantized_proxies[bsym.args[1].name] + qs = self.quant_states[n] + # signature of the new symbol: + # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape) + new_args = ( + *bsym.args[:3], + self.additional_proxies[f"{n}.absmax"], + self.additional_proxies[f"{n}.code"], + qs["blocksize"], + qs["dtype"], + qs["shape"], + ) + mm_bsym = bsym.from_bsym( + sym=bnb_matmul_nf4, + subsymbols=[], + args=new_args, + ) + self.add_processed_bsyms([mm_bsym]) + self.set_result(bsym.output) + elif bsym.sym == prims.python_return: + assert len(bsym.args) == 1 and isinstance(bsym.args[0], dict) + new_return_dict = bsym.args[0].copy() + new_return_dict["flat_args"] = list(self.new_trace.args) # we know that the args are flat + self.add_processed_bsyms([bsym.from_bsym(args=(new_return_dict,))]) + self.set_result(bsym.output) + else: + # Keep the original symbol + self.add_processed_bsyms([bsym.from_bsym()]) + self.set_result(bsym.output) + + # Process the trace using the QuantizationProcessor + processor = QuantizationProcessor( + new_computation_trace, + quantized_proxies, + additional_proxies, + self.quant_states, + new_compute_inputs + ) + + # Add new compute inputs to the trace args before processing new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs) new_computation_trace.names.update(i.name for i in new_compute_inputs) new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args] - + + # Add unpack_trivial bindings for new inputs with tracectx(new_computation_trace): new_bindings = [ thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs ] - - for idx, bsym in enumerate(bound_symbols): - if bsym.sym != prims.unpack_trivial: - break - new_computation_trace.bound_symbols.append(bsym.from_bsym()) - new_computation_trace.bound_symbols += new_bindings - - for bsym in bound_symbols[idx:]: - if bsym.sym == thunder.torch.linear and bsym.args[1].name in quantized_proxies: - assert len(bsym.args) == 3 # torch.linear(input, weight, bias) - n = quantized_proxies[bsym.args[1].name] - qs = self.quant_states[n] - # signature of the new symbol: - # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape) - new_args = ( - *bsym.args[:3], - additional_proxies[f"{n}.absmax"], - additional_proxies[f"{n}.code"], - qs["blocksize"], - qs["dtype"], - qs["shape"], - ) - mm_bsym = bsym.from_bsym( - sym=bnb_matmul_nf4, - subsymbols=[], - args=new_args, - ) - - new_computation_trace.bound_symbols.append(mm_bsym) - else: - new_computation_trace.bound_symbols.append(bsym.from_bsym()) - + + # Insert the new bindings at the beginning of the trace + new_computation_trace.bound_symbols = new_bindings + new_computation_trace.bound_symbols + + # Now process the trace + new_computation_trace, _ = processor() new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass")) return prologue_trace, new_computation_trace, epilogue_trace + + + + From d1031cf8b1caeec6b6180f7452664b7c4b0684cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 05:30:46 +0000 Subject: [PATCH 07/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/transforms/quantization.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index ffa1ddd765..c42a4be4b5 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -219,6 +219,7 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo } new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) + class QuantizationProcessor(TraceSubstitutionProcessor): def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, new_compute_inputs): super().__init__(trace) @@ -226,6 +227,7 @@ def __init__(self, trace, quantized_proxies, additional_proxies, quant_states, n self.additional_proxies = additional_proxies self.quant_states = quant_states self.new_compute_inputs = new_compute_inputs + def process_bsym(self, bsym): if bsym.sym == thunder.torch.linear and bsym.args[1].name in self.quantized_proxies: assert len(bsym.args) == 3 # torch.linear(input, weight, bias) @@ -261,32 +263,24 @@ def process_bsym(self, bsym): # Process the trace using the QuantizationProcessor processor = QuantizationProcessor( - new_computation_trace, - quantized_proxies, - additional_proxies, - self.quant_states, - new_compute_inputs + new_computation_trace, quantized_proxies, additional_proxies, self.quant_states, new_compute_inputs ) - + # Add new compute inputs to the trace args before processing new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs) new_computation_trace.names.update(i.name for i in new_compute_inputs) new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args] - + # Add unpack_trivial bindings for new inputs with tracectx(new_computation_trace): new_bindings = [ thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs ] - + # Insert the new bindings at the beginning of the trace new_computation_trace.bound_symbols = new_bindings + new_computation_trace.bound_symbols - + # Now process the trace new_computation_trace, _ = processor() new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass")) return prologue_trace, new_computation_trace, epilogue_trace - - - - From 6c6286f26261be9de92d8472c3c994bcfaffa3cb Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Tue, 30 Sep 2025 21:53:25 -0700 Subject: [PATCH 08/11] move trace modifications before metadata replacement --- thunder/transforms/quantization.py | 31 +++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index a7e0cca5a9..698ca6c0fc 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -218,6 +218,21 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo if psym.shape != csym.shape or psym.dtype != csym.dtype } + # Add new compute inputs to the trace args before processing + computation_trace.args = (*computation_trace.args, *new_compute_inputs) + computation_trace.names.update(i.name for i in new_compute_inputs) + computation_trace._siginfo.args = [(a.name, None) for a in computation_trace.args] + + # Add unpack_trivial bindings for new inputs + with tracectx(computation_trace): + new_bindings = [ + thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs + ] + + # Insert the new bindings at the beginning of the trace + computation_trace.bound_symbols = new_bindings + computation_trace.bound_symbols + + # Now update metadata for the complete trace new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) class QuantizationProcessor(TraceSubstitutionProcessor): @@ -266,21 +281,7 @@ def process_bsym(self, bsym): new_computation_trace, quantized_proxies, additional_proxies, self.quant_states, new_compute_inputs ) - # Add new compute inputs to the trace args before processing - new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs) - new_computation_trace.names.update(i.name for i in new_compute_inputs) - new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args] - - # Add unpack_trivial bindings for new inputs - with tracectx(new_computation_trace): - new_bindings = [ - thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs - ] - - # Insert the new bindings at the beginning of the trace - new_computation_trace.bound_symbols = new_bindings + new_computation_trace.bound_symbols - # Now process the trace new_computation_trace, _ = processor() new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass")) - return prologue_trace, new_computation_trace, epilogue_trace + return prologue_trace, new_computation_trace, epilogue_trace \ No newline at end of file From 69dbcff91a49d51405969dfd901124f786c9b6f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Oct 2025 04:54:15 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/transforms/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index 698ca6c0fc..f8380547e1 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -284,4 +284,4 @@ def process_bsym(self, bsym): # Now process the trace new_computation_trace, _ = processor() new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass")) - return prologue_trace, new_computation_trace, epilogue_trace \ No newline at end of file + return prologue_trace, new_computation_trace, epilogue_trace From 37b595ba2e8b8ab019f814fbedbfb655729ba1f9 Mon Sep 17 00:00:00 2001 From: Teja Pulagam Date: Tue, 30 Sep 2025 22:09:41 -0700 Subject: [PATCH 10/11] aligned ordering --- thunder/transforms/quantization.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index 698ca6c0fc..f70f777296 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -223,14 +223,20 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo computation_trace.names.update(i.name for i in new_compute_inputs) computation_trace._siginfo.args = [(a.name, None) for a in computation_trace.args] - # Add unpack_trivial bindings for new inputs + # Add unpack_trivial bindings for new inputs in the correct position with tracectx(computation_trace): new_bindings = [ thunder.core.prims.unpack_trivial.bind(i, output=i, name=i.name) for i in new_compute_inputs ] - # Insert the new bindings at the beginning of the trace - computation_trace.bound_symbols = new_bindings + computation_trace.bound_symbols + # Insert the new bindings after the existing unpack_trivial bindings to maintain arg order + # Find the last unpack_trivial binding and insert after it + insert_idx = len(computation_trace.bound_symbols) + for i, bsym in enumerate(computation_trace.bound_symbols): + if bsym.sym.id == prims.PrimIDs.UNPACK_TRIVIAL: + insert_idx = i + 1 + + computation_trace.bound_symbols[insert_idx:insert_idx] = new_bindings # Now update metadata for the complete trace new_computation_trace = trace_with_replaced_proxy_metadata(computation_trace, computation_proxy_map) From 652af5020729f4ad42a77d333fc44670d1cd1af4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 1 Oct 2025 05:11:54 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/transforms/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index 45517282fb..e3682236bd 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -235,7 +235,7 @@ def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilo for i, bsym in enumerate(computation_trace.bound_symbols): if bsym.sym.id == prims.PrimIDs.UNPACK_TRIVIAL: insert_idx = i + 1 - + computation_trace.bound_symbols[insert_idx:insert_idx] = new_bindings # Now update metadata for the complete trace