Skip to content

Commit a6d0dec

Browse files
authored
Merge branch 'main' into chenhany/fix_eagle3_multi_layer_hook
2 parents 18a61e8 + 0d279f1 commit a6d0dec

File tree

5 files changed

+39
-49
lines changed

5 files changed

+39
-49
lines changed

.github/workflows/gpu_tests.yml

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,31 @@ jobs:
2222
any_changed: ${{ steps.changed-tests.outputs.any_changed }}
2323
steps:
2424
- uses: actions/checkout@v4
25+
with:
26+
fetch-depth: 0
2527
- id: get-pr-info
2628
uses: nv-gha-runners/get-pr-info@main
29+
# Get commit from main branch that is present in the PR to use as base for changed files
30+
- id: calculate-merge-base
31+
env:
32+
PR_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }}
33+
BASE_SHA: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).base.sha }}
34+
run: |
35+
(echo -n "merge-base="; git merge-base "$BASE_SHA" "$PR_SHA") | tee --append "${GITHUB_OUTPUT}"
2736
- name: Check for changes in test-relevant directories
2837
id: changed-tests
2938
uses: step-security/[email protected]
3039
with:
40+
base_sha: ${{ steps.calculate-merge-base.outputs.merge-base }}
41+
sha: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).head.sha }}
3142
files: |
3243
.github/workflows/gpu_tests.yml
3344
modelopt/**
3445
tests/gpu/**
3546
tox.ini
3647
pyproject.toml
3748
setup.py
38-
base_sha: ${{ fromJSON(steps.get-pr-info.outputs.pr-info).base.ref }}
49+
fail_on_initial_diff_error: true
3950
wait-checks:
4051
needs: [check-file-changes]
4152
if: needs.check-file-changes.outputs.any_changed == 'true'
@@ -70,3 +81,12 @@ jobs:
7081
timeout-minutes: 90
7182
container: *gpu_container
7283
steps: *gpu_steps
84+
gpu-pr-required-check:
85+
# Run even if gpu-tests-pr is skipped
86+
if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }}
87+
needs: [check-file-changes, gpu-tests-pr]
88+
runs-on: ubuntu-latest
89+
steps:
90+
- name: Required GPU tests did not succeed
91+
if: ${{ needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && needs.gpu-tests-pr.result != 'success') }}
92+
run: exit 1

.github/workflows/unit_tests.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,9 @@ jobs:
126126
python-version: "3.12"
127127
- name: Run unit tests
128128
run: pip install tox && tox -e py312-partial-unit-${{ matrix.test-env }}
129+
unit-pr-required-check:
130+
if: github.event_name == 'pull_request'
131+
needs: [linux, windows, multi-py, multi-torch, multi-transformers, partial-install]
132+
runs-on: ubuntu-latest
133+
steps:
134+
- run: echo "All PR unit test jobs completed"

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -59,38 +59,6 @@
5959
}
6060

6161

62-
def get_supported_models():
63-
"""Get the supported models for Minitron pruning.
64-
65-
NOTE: Keep inside function to avoid circular import issues.
66-
"""
67-
supported_models = set()
68-
69-
try:
70-
from megatron.core.models.gpt import GPTModel
71-
72-
supported_models.add(GPTModel)
73-
except Exception:
74-
pass
75-
76-
try:
77-
from megatron.core.models.mamba import MambaModel
78-
79-
supported_models.add(MambaModel)
80-
except Exception:
81-
pass
82-
83-
try:
84-
from nemo.collections import llm
85-
86-
# NOTE: llm.MambaModel is a subclass of llm.GPTModel
87-
supported_models.add(llm.GPTModel)
88-
except Exception:
89-
pass
90-
91-
return supported_models
92-
93-
9462
class MCoreMinitronSearcher(BaseSearcher):
9563
"""Searcher for Minitron pruning algorithm."""
9664

@@ -158,17 +126,6 @@ def before_search(self) -> None:
158126
def run_search(self) -> None:
159127
"""Run actual search."""
160128
# Run forward loop to collect activations and sort parameters
161-
model_cfg = None
162-
supported_models = get_supported_models()
163-
for m_type in supported_models:
164-
if isinstance(self.model, m_type):
165-
model_cfg = self.model.config
166-
break
167-
if model_cfg is None:
168-
raise NotImplementedError(
169-
f"Only {supported_models} models are supported! Got: {type(self.model)}"
170-
)
171-
172129
assert self.forward_loop is not None
173130
is_training = self.model.training
174131
self.model.eval()
@@ -187,6 +144,7 @@ def run_search(self) -> None:
187144
hp.active = export_config[hp_name]
188145

189146
# kv_channels can be None so we need to save original from original hidden_size and num_attention_heads
147+
model_cfg = self.model.config
190148
orig_kv_channels = getattr(model_cfg, "kv_channels")
191149
if orig_kv_channels is None:
192150
orig_kv_channels = getattr(model_cfg, "hidden_size") // getattr(

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def dict_to_config(
9090
fp16=fp16,
9191
bf16=bf16,
9292
params_dtype=getattr(torch, architecture_config["torch_dtype"]),
93-
pipeline_dtype=None,
93+
pipeline_dtype=getattr(torch, architecture_config["torch_dtype"]),
9494
num_layers=architecture_config.get("num_hidden_layers"),
9595
hidden_size=architecture_config.get("hidden_size"),
9696
ffn_hidden_size=architecture_config.get("intermediate_size"),

tests/gpu/torch/quantization/backends/test_gemm_common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
set_seed()
3030

3131

32+
@pytest.fixture(autouse=True)
33+
def setup_seed():
34+
"""Set seed before each test function."""
35+
set_seed()
36+
37+
3238
@pytest.mark.parametrize(
3339
("config", "gemm_forward", "atol", "rtol"),
3440
[
@@ -257,9 +263,9 @@ def forward_loop(model, run_backward=False):
257263

258264
# The way the compression of the weights and inputs might be different.
259265
# E.g. we may use torch.compile in the gemms.
260-
assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 3)
261-
assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 3)
266+
assert torch.allclose(output_dynamic_quant_gemm, output_dynamic_quant, atol=atol / 2)
267+
assert torch.allclose(output_calib_quant_gemm, output_calib_quant, atol=atol / 2)
262268
assert torch.allclose(
263-
output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 3
269+
output_dynamic_quant_gemm, output_dynamic_quant_compressed, atol=atol / 2
264270
)
265-
assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 3)
271+
assert torch.allclose(output_calib_quant_gemm, output_calib_quant_compressed, atol=atol / 2)

0 commit comments

Comments
 (0)