|
20 | 20 | from tensorrt_llm.quantization import QuantAlgo |
21 | 21 | from tensorrt_llm.sampling_params import SamplingParams |
22 | 22 |
|
| 23 | +from ..conftest import get_device_count, llm_models_root |
23 | 24 | from .accuracy_core import GSM8K, MMLU, CnnDailymail, LlmapiAccuracyTestHarness |
24 | 25 |
|
25 | 26 |
|
@@ -226,6 +227,9 @@ class TestNemotronSuperV3(LlmapiAccuracyTestHarness): |
226 | 227 |
|
227 | 228 | MODEL_NAME = "nvidia/Nemotron-Super-V3" |
228 | 229 | MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev" |
| 230 | + MODEL_PATH_FP8 = f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-fp8-fp8kv" |
| 231 | + MODEL_PATH_FP4 = f"{llm_models_root()}/Nemotron-SuperV3-phase1-mtp-nvfp4-fp8kv" |
| 232 | + |
229 | 233 | # Set minimum possible seq len + small buffer, for test speed & memory usage |
230 | 234 | MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN, |
231 | 235 | GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN) |
@@ -271,3 +275,45 @@ def test_bf16(self): |
271 | 275 | task.evaluate(llm, sampling_params=sampling_params) |
272 | 276 | task = GSM8K(self.MODEL_NAME) |
273 | 277 | task.evaluate(llm) |
| 278 | + |
| 279 | + @pytest.mark.skip_less_device_memory(180000) |
| 280 | + @pytest.mark.skip_less_device(4) |
| 281 | + @pytest.mark.parametrize("world_size", [4, 8]) |
| 282 | + def test_fp8(self, world_size): |
| 283 | + if get_device_count() < world_size: |
| 284 | + pytest.skip("Not enough devices for world size, skipping test") |
| 285 | + kwargs = self.get_default_kwargs() |
| 286 | + sampling_params = self.get_default_sampling_params() |
| 287 | + with AutoDeployLLM(model=self.MODEL_PATH_FP8, |
| 288 | + tokenizer=self.MODEL_PATH_FP8, |
| 289 | + world_size=world_size, |
| 290 | + **kwargs) as llm: |
| 291 | + # Manually set quant_config for FP8 model to get the accuracy threshold |
| 292 | + llm.args.quant_config.quant_algo = QuantAlgo.FP8 |
| 293 | + llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8 |
| 294 | + |
| 295 | + task = MMLU(self.MODEL_NAME) |
| 296 | + task.evaluate(llm, sampling_params=sampling_params) |
| 297 | + task = GSM8K(self.MODEL_NAME) |
| 298 | + task.evaluate(llm) |
| 299 | + |
| 300 | + @pytest.mark.skip("Skipping FP4 test until it is supported") |
| 301 | + @pytest.mark.skip_less_device_memory(180000) |
| 302 | + @pytest.mark.parametrize("world_size", [1, 4, 8]) |
| 303 | + def test_fp4(self, world_size): |
| 304 | + if get_device_count() < world_size: |
| 305 | + pytest.skip("Not enough devices for world size, skipping test") |
| 306 | + kwargs = self.get_default_kwargs() |
| 307 | + sampling_params = self.get_default_sampling_params() |
| 308 | + with AutoDeployLLM(model=self.MODEL_PATH_FP4, |
| 309 | + tokenizer=self.MODEL_PATH_FP4, |
| 310 | + world_size=world_size, |
| 311 | + **kwargs) as llm: |
| 312 | + # Manually set quant_config for FP4 model to get the accuracy threshold |
| 313 | + llm.args.quant_config.quant_algo = QuantAlgo.NVFP4 |
| 314 | + llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.NVFP4 |
| 315 | + |
| 316 | + task = MMLU(self.MODEL_NAME) |
| 317 | + task.evaluate(llm, sampling_params=sampling_params) |
| 318 | + task = GSM8K(self.MODEL_NAME) |
| 319 | + task.evaluate(llm) |
0 commit comments