Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/nemo_run/qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar

You can run the example either locally or on a [Slurm cluster](ADVANCED.md).

To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.
To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container.

- `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git`
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a`
- `git clone https://github.com/NVIDIA-NeMo/NeMo.git`
- `git clone https://github.com/NVIDIA/Megatron-LM.git`

Example docker command:

```bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash
docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt -v /home/user/Megatron-LM:/opt/megatron-lm --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash
```

You will also need to set your Huggingface token with `export HF_TOKEN=<your-token>`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written.
Expand Down Expand Up @@ -92,7 +93,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p
To perform QAD training, run:

```bash
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment
python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4
```

## Supported models
Expand Down
20 changes: 17 additions & 3 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
return

def sync_quantizer_amax_across_dp(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel group."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
sync_quantizer_amax_across_dp(_q, parallel_state)
Expand All @@ -94,7 +95,6 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp(child, module.parallel_state)

# TP sync:
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

Expand All @@ -116,6 +116,7 @@ def sync_quantizer_amax_across_tp(
):
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
"Syncing amax across TP for sequential quantizer"
sync_quantizer_amax_across_tp(
_q, linear_name, quantizer_type, axes_for_sync, parallel_state
)
Expand Down Expand Up @@ -598,19 +599,32 @@ def forward(self, input, *args, **kwargs):
# This will also perform distributed amax sync for input_quantizers
max_calibrate(model, lambda model: None)

def sync_act_scale_across_dp(module, data_parallel_group):
"""Sync activation scale across Data Parallel (DP)."""
if data_parallel_group.is_initialized():
dist.all_reduce(
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
)

for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps > 0
):
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps

if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
torch.isnan(module.awq_lite.weight_scale)
):
module.awq_lite.is_enabled = False
# Hack: MoEs forward all tokens through all experts if _if_calib is True
module._if_calib = True
else:
sync_act_scale_across_dp(
module,
module.parallel_state.data_parallel_group,
)
Comment on lines +615 to +627
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify: NaN check placement may cause collective operation mismatch.

The current implementation places the sync call in the else block (lines 624-627), which means:

  • Rank A (has NaN): sets is_enabled = False, skips else block, does NOT call sync_act_scale_across_dp
  • Rank B (no NaN): enters else block, CALLS sync_act_scale_across_dp

Since dist.all_reduce (inside sync_act_scale_across_dp at line 606) is a collective operation requiring all ranks to participate, this creates a deadlock when ranks disagree on NaN presence.

Past review comments suggest this should be resolved by either:

  1. Performing a collective NaN check first, then all ranks sync (or skip) together
  2. Moving NaN detection after sync

Consider adding a collective NaN check before the conditional:

# Check if ANY rank has NaN
has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any(
    torch.isnan(module.awq_lite.weight_scale)
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
    dist.all_reduce(has_nan, op=dist.ReduceOp.MAX, 
                    group=module.parallel_state.data_parallel_group.group)

if has_nan.item() > 0:
    module.awq_lite.is_enabled = False
else:
    sync_act_scale_across_dp(module, module.parallel_state.data_parallel_group)

Based on past review comments.


🏁 Script executed:

#!/bin/bash
rg -n "sync_act_scale_across_dp" -C3

Length of output: 1217


🏁 Script executed:

#!/bin/bash
sed -n '580,650p' modelopt/torch/quantization/model_calib.py

Length of output: 3165


Collectively sync NaN detection before calling sync_act_scale_across_dp
Perform a global NaN check across all ranks, then either disable calibration everywhere or call the collective sync together. For example, replace the per-rank branch (model_calib.py ~621–627) with:

has_nan_local = (
    torch.any(torch.isnan(module.awq_lite.act_scale))
    or torch.any(torch.isnan(module.awq_lite.weight_scale))
)
has_nan = torch.tensor(int(has_nan_local), device=module.weight.device)
if module.parallel_state.data_parallel_group.is_initialized():
    dist.all_reduce(has_nan, op=dist.ReduceOp.MAX,
                    group=module.parallel_state.data_parallel_group.group)

if has_nan.item() > 0:
    module.awq_lite.is_enabled = False
else:
    sync_act_scale_across_dp(
        module,
        module.parallel_state.data_parallel_group,
    )

This ensures every rank participates in the collective operation and prevents deadlock.

🤖 Prompt for AI Agents
In modelopt/torch/quantization/model_calib.py around lines 615 to 627, replace
the current per-rank NaN check and conditional call to sync_act_scale_across_dp
with a collective NaN detection: compute a has_nan_local boolean from act_scale
or weight_scale NaNs, create a tensor on the module weight/device with that
value, perform an all_reduce (MAX) across
module.parallel_state.data_parallel_group if it's initialized, then if the
reduced has_nan is >0 set module.awq_lite.is_enabled = False on all ranks,
otherwise call sync_act_scale_across_dp; ensure the tensor is on the correct
device and the collective uses the data_parallel_group to avoid deadlocks.


AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
Expand Down
12 changes: 11 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

"""Support quantization for megatron linear layers."""

import logging
import warnings
from typing import Any

import megatron.core.parallel_state as mcore_parallel
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand All @@ -38,6 +40,8 @@
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear

logger = logging.getLogger(__name__)

__all__ = []


Expand Down Expand Up @@ -217,8 +221,14 @@ class _MegatronParallelLinear(_ParallelLinear):
]

def _setup(self):
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning("Context parallel group is not initialized, using data parallel group")
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(),
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)
super()._setup()
Expand Down
5 changes: 4 additions & 1 deletion modelopt/torch/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def __init__(
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)

def __repr__(self) -> str:
return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}"
return (
f"data_parallel_group: {self.data_parallel_group}, "
f"tensor_parallel_group: {self.tensor_parallel_group}, "
)


def get_group(ranks: list[int]):
Expand Down
20 changes: 16 additions & 4 deletions tests/_test_utils/torch_dist/plugins/megatron_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@


class MegatronModel(MegatronModule):
def __init__(self, tp_size: int = 1, use_te_norm: bool = False):
def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False):
config = TransformerConfig(
tensor_model_parallel_size=tp_size,
context_parallel_size=cp_size,
pipeline_model_parallel_size=1,
normalization="LayerNorm",
# Unused parameters below are set to avoid ZeroDivisionError in __post_init__
Expand Down Expand Up @@ -126,7 +127,11 @@ def forward(self, x):
x = x[0]
return x

def get_dummy_input(self) -> torch.Tensor:
def get_dummy_input(self, seed: int | None = None) -> torch.Tensor:
if seed is not None:
gen = torch.Generator()
gen.manual_seed(seed)
return torch.randn(1, 4, 32, generator=gen)
return torch.randn(1, 4, 32)


Expand Down Expand Up @@ -383,13 +388,20 @@ def run_mcore_inference_with_dummy_input(


def initialize_for_megatron(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
seed=1234,
context_parallel_size=1,
):
"""Initialize Megatron model parallelism.

NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`.
"""
initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
initialize_model_parallel(
tensor_model_parallel_size,
pipeline_model_parallel_size,
context_parallel_size=context_parallel_size,
)
model_parallel_cuda_manual_seed(seed)


Expand Down
149 changes: 136 additions & 13 deletions tests/_test_utils/torch_quantization/quantize_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from unittest.mock import patch

import pytest
import torch
Expand All @@ -22,7 +23,9 @@

import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite
from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm
from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer
from modelopt.torch.quantization.utils import is_quantized_linear
from modelopt.torch.utils import torch_to

Expand Down Expand Up @@ -116,40 +119,160 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
mto.restore_from_modelopt_state(model_ref, state_dict)


def tensor_parallel_test_helper(model, config, tp_group, dp_group):
# The input to fist layer, the column parallel should be the same across all tp ranks
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
quantizer_attr = getattr(quantizer, attr).clone()
print("quantizer.attr before reduce", getattr(quantizer, attr))
dist.all_reduce(quantizer_attr, op=op, group=group)
print("quantizer.attr after reduce", getattr(quantizer, attr))
print("quantizer_attr after reduce", quantizer_attr)
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))


original_awq_lite = model_calib_module.awq_lite


def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
"""Function to mock awq_lite function to always use debug=True for testing"""
return original_awq_lite(model, forward_loop, alpha_step, debug=True)

Comment on lines +134 to +137
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Forward the AWQ-Lite kwargs in the patch

The _debug_awq_lite wrapper drops every extra keyword argument that callers pass to awq_lite (e.g., tensor_parallel_group, data_parallel_group, max_calib_steps). The upstream API explicitly accepts **kwargs, so the first call that includes one of those options will now raise a TypeError, breaking AWQ-Lite calibration in the very paths this PR is exercising. Please mirror the original signature and forward **kwargs to original_awq_lite while forcing debug=True.

-def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
-    """Function to mock awq_lite function to always use debug=True for testing"""
-    return original_awq_lite(model, forward_loop, alpha_step, debug=True)
+def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
+    """Force awq_lite debug mode during tests without dropping optional args."""
+    return original_awq_lite(
+        model,
+        forward_loop,
+        alpha_step=alpha_step,
+        debug=True,
+        **kwargs,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True):
"""Function to mock awq_lite function to always use debug=True for testing"""
return original_awq_lite(model, forward_loop, alpha_step, debug=True)
def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs):
"""Force awq_lite debug mode during tests without dropping optional args."""
return original_awq_lite(
model,
forward_loop,
alpha_step=alpha_step,
debug=True,
**kwargs,
)
🤖 Prompt for AI Agents
In tests/_test_utils/torch_quantization/quantize_common.py around lines 134 to
137, the _debug_awq_lite wrapper drops any extra keyword arguments callers pass
to awq_lite which causes TypeError when upstream calls include options like
tensor_parallel_group or max_calib_steps; update the wrapper to mirror the
original awq_lite signature by accepting *args and **kwargs (or the same
explicit params plus **kwargs) and forward them to original_awq_lite while
forcing debug=True (i.e., call original_awq_lite(..., debug=True, **kwargs) so
all upstream options are preserved).


@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite):
# The input to first layer, the column parallel should be the same across all tp ranks
calib_data = model.get_dummy_input().cuda()
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Sanity check
forward_loop(model)

if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]:
# Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks
activation_amax = model.fc2.input_quantizer.amax.clone()
dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax)

_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group)
# Lets check the row parallel weight amax; it should be the same across all tp ranks
weight_amax = model.fc2.weight_quantizer.amax.clone()
dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax)
_reduce_quantizer_attr(
model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group
)

if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks
input_quantizer = model.fc1.input_quantizer
pre_quant_scale = input_quantizer.pre_quant_scale.clone()
dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group)
assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale)
_reduce_quantizer_attr(
input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group
)

if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Check activation scale for AWQ lite
_reduce_quantizer_attr(
Copy link
Contributor Author

@jenchen13 jenchen13 Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@realAsma For TP, I only test fc1 (column parallel) act scale during awq lite, because fc2 row parallel will fail. For DP/CP I can test both column + row parallel act scale. I'm assuming row parallel fails because it's split across the c_in dimension in activation ... is this right?

model.fc1.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
group=tp_group,
)

dist.destroy_process_group()


@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite):
calib_data = model.get_dummy_input().cuda()

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

# Sanity check
forward_loop(model)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)

# Weight quantizer amax
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
else:
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group)
else:
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group)

if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
# Check act scale
_reduce_quantizer_attr(
model.fc1.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
group=group,
)
_reduce_quantizer_attr(
model.fc2.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
group=group,
)


@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite):
# Calib data should be same across each DP rank
dp_rank = dist.get_rank(group=dp_group)
calib_data = model.get_dummy_input(seed=dp_rank).cuda()

def forward_loop(model):
model(calib_data)

model = mtq.quantize(model, config, forward_loop)

def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
quantizer_attr = getattr(quantizer, attr).clone()

# Perform all-reduce operations
dist.all_reduce(quantizer_attr, op=op, group=tp_group)

dist.all_reduce(quantizer_attr, op=op, group=dp_group)

assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr)

# Input quantizer amax
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX)
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)

# Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
# Channel-wise (INT8) only expects same amax across row parallel ranks
# Block-wise quantization does not expect same amax across row and column parallel ranks
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]:
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc1.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
else:
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX)

if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]:
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
for quantizer in model.fc2.weight_quantizer:
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
else:
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX)

# Check act scale
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
_reduce_quantizer_attr(
model.fc1.awq_lite,
"act_scale",
dist.ReduceOp.AVG,
)


def auto_quantize_helper(model):
model, search_state = mtq.auto_quantize(
model,
Expand Down
6 changes: 6 additions & 0 deletions tests/gpu/torch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def need_2_gpus():
pytest.skip("Need at least 2 GPUs to run this test")


@pytest.fixture
def need_8_gpus():
if torch.cuda.device_count() < 8:
pytest.skip("Need at least 8 GPUs to run this test")


@pytest.fixture(scope="module")
def set_torch_dtype(request):
orig_dtype = torch.get_default_dtype()
Expand Down
Loading