|
20 | 20 | import numpy as np |
21 | 21 | import torch |
22 | 22 | import wrapt |
| 23 | +from datasets import load_dataset |
23 | 24 | from transformers import AutoModel, AutoTokenizer |
24 | 25 |
|
25 | 26 | from nemo_deploy import ITritonDeployable |
|
30 | 31 | ) |
31 | 32 | from nemo_export_deploy_common.import_utils import ( |
32 | 33 | MISSING_MODELOPT_MSG, |
33 | | - MISSING_NEMO_MSG, |
34 | 34 | MISSING_TENSORRT_MSG, |
35 | 35 | UnavailableError, |
36 | 36 | ) |
|
63 | 63 | trt = MagicMock() |
64 | 64 | HAVE_TENSORRT = False |
65 | 65 |
|
66 | | -try: |
67 | | - from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import ( |
68 | | - get_quant_cfg_choices, |
69 | | - ) |
70 | | - |
71 | | - QUANT_CFG_CHOICES = get_quant_cfg_choices() |
72 | | - |
73 | | - HAVE_NEMO = True |
74 | | -except (ImportError, ModuleNotFoundError): |
75 | | - HAVE_NEMO = False |
76 | | - |
77 | 66 |
|
78 | 67 | @wrapt.decorator |
79 | 68 | def noop_decorator(func): |
@@ -254,6 +243,7 @@ def _export_to_onnx( |
254 | 243 | dynamic_axes={**dynamic_axes_input, **dynamic_axes_output}, |
255 | 244 | verbose=verbose, |
256 | 245 | opset_version=opset, |
| 246 | + dynamo=False, |
257 | 247 | ) |
258 | 248 | logging.info(f"Successfully exported PyTorch model to ONNX model {self.onnx_model_path}") |
259 | 249 |
|
@@ -494,17 +484,15 @@ def quantize( |
494 | 484 | forward_loop (callable): A function that accepts the model as a single parameter |
495 | 485 | and runs sample data through it. This is used for calibration during quantization. |
496 | 486 | """ |
497 | | - if not HAVE_NEMO: |
498 | | - raise UnavailableError(MISSING_NEMO_MSG) |
499 | | - |
500 | 487 | if not HAVE_MODELOPT: |
501 | 488 | raise UnavailableError(MISSING_MODELOPT_MSG) |
502 | 489 |
|
| 490 | + quant_cfg_choices = get_quant_cfg_choices() |
503 | 491 | if isinstance(quant_cfg, str): |
504 | | - assert quant_cfg in QUANT_CFG_CHOICES, ( |
505 | | - f"Quantization config {quant_cfg} is not supported. Supported configs: {list(QUANT_CFG_CHOICES)}" |
| 492 | + assert quant_cfg in quant_cfg_choices, ( |
| 493 | + f"Quantization config {quant_cfg} is not supported. Supported configs: {list(quant_cfg_choices)}" |
506 | 494 | ) |
507 | | - quant_cfg = QUANT_CFG_CHOICES[quant_cfg] |
| 495 | + quant_cfg = quant_cfg_choices[quant_cfg] |
508 | 496 |
|
509 | 497 | logging.info("Starting quantization...") |
510 | 498 | mtq.quantize(self.model, quant_cfg, forward_loop=forward_loop) |
@@ -539,3 +527,57 @@ def get_triton_output(self): |
539 | 527 | def triton_infer_fn(self, **inputs: np.ndarray): |
540 | 528 | """PyTriton inference function.""" |
541 | 529 | raise NotImplementedError("This function will be implemented later.") |
| 530 | + |
| 531 | + |
| 532 | +def get_calib_data_iter( |
| 533 | + data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 |
| 534 | +): |
| 535 | + """Creates a sample data iterator for calibration.""" |
| 536 | + if data == "wikitext": |
| 537 | + dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") |
| 538 | + text_column = "text" |
| 539 | + elif data == "cnn_dailymail": |
| 540 | + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") |
| 541 | + text_column = "article" |
| 542 | + else: |
| 543 | + # Assume a local JSON dataset with a column named "text" |
| 544 | + dataset = load_dataset("json", data_files=data, split="train") |
| 545 | + text_column = "text" |
| 546 | + calib_size = max(min(len(dataset), calib_size), batch_size) |
| 547 | + for i in range(calib_size // batch_size): |
| 548 | + batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] |
| 549 | + for j in range(len(batch)): |
| 550 | + batch[j] = batch[j][:max_sequence_length] |
| 551 | + yield batch |
| 552 | + |
| 553 | + |
| 554 | +def get_quant_cfg_choices() -> Dict[str, Dict[str, Any]]: |
| 555 | + """ |
| 556 | + Retrieve a dictionary of modelopt quantization configuration choices. |
| 557 | +
|
| 558 | + This function checks for the availability of specific quantization configurations defined in |
| 559 | + the modelopt.torch.quantization (mtq) module and returns a dictionary mapping short names to |
| 560 | + their corresponding configurations. The function is intended to work for different modelopt |
| 561 | + library versions that come with variable configuration choices. |
| 562 | +
|
| 563 | + Returns: |
| 564 | + dict: A dictionary where keys are short names (e.g., "fp8") and values are the |
| 565 | + corresponding modelopt quantization configuration objects. |
| 566 | + """ |
| 567 | + quant_cfg_names = [ |
| 568 | + ("int8", "INT8_DEFAULT_CFG"), |
| 569 | + ("int8_sq", "INT8_SMOOTHQUANT_CFG"), |
| 570 | + ("fp8", "FP8_DEFAULT_CFG"), |
| 571 | + ("block_fp8", "FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG"), |
| 572 | + ("int4_awq", "INT4_AWQ_CFG"), |
| 573 | + ("w4a8_awq", "W4A8_AWQ_BETA_CFG"), |
| 574 | + ("int4", "INT4_BLOCKWISE_WEIGHT_ONLY_CFG"), |
| 575 | + ("nvfp4", "NVFP4_DEFAULT_CFG"), |
| 576 | + ] |
| 577 | + |
| 578 | + quant_cfg_choices = {} |
| 579 | + for short_name, full_name in quant_cfg_names: |
| 580 | + if config := getattr(mtq, full_name, None): |
| 581 | + quant_cfg_choices[short_name] = config |
| 582 | + |
| 583 | + return quant_cfg_choices |
0 commit comments