|
21 | 21 | import numpy as np |
22 | 22 | import torch |
23 | 23 | import wrapt |
| 24 | +from datasets import load_dataset |
24 | 25 | from transformers import AutoModel, AutoTokenizer |
25 | 26 |
|
26 | 27 | from nemo_deploy import ITritonDeployable |
|
31 | 32 | ) |
32 | 33 | from nemo_export_deploy_common.import_utils import ( |
33 | 34 | MISSING_MODELOPT_MSG, |
34 | | - MISSING_NEMO_MSG, |
35 | 35 | MISSING_TENSORRT_MSG, |
36 | 36 | UnavailableError, |
37 | 37 | ) |
|
59 | 59 | trt = MagicMock() |
60 | 60 | HAVE_TENSORRT = False |
61 | 61 |
|
62 | | -try: |
63 | | - from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import ( |
64 | | - get_quant_cfg_choices, |
65 | | - ) |
66 | | - |
67 | | - QUANT_CFG_CHOICES = get_quant_cfg_choices() |
68 | | - |
69 | | - HAVE_NEMO = True |
70 | | -except (ImportError, ModuleNotFoundError): |
71 | | - HAVE_NEMO = False |
72 | | - |
73 | 62 |
|
74 | 63 | @wrapt.decorator |
75 | 64 | def noop_decorator(func): |
@@ -250,6 +239,7 @@ def _export_to_onnx( |
250 | 239 | dynamic_axes={**dynamic_axes_input, **dynamic_axes_output}, |
251 | 240 | verbose=verbose, |
252 | 241 | opset_version=opset, |
| 242 | + dynamo=False, |
253 | 243 | ) |
254 | 244 | logger.info(f"Successfully exported PyTorch model to ONNX model {self.onnx_model_path}") |
255 | 245 |
|
@@ -490,17 +480,15 @@ def quantize( |
490 | 480 | forward_loop (callable): A function that accepts the model as a single parameter |
491 | 481 | and runs sample data through it. This is used for calibration during quantization. |
492 | 482 | """ |
493 | | - if not HAVE_NEMO: |
494 | | - raise UnavailableError(MISSING_NEMO_MSG) |
495 | | - |
496 | 483 | if not HAVE_MODELOPT: |
497 | 484 | raise UnavailableError(MISSING_MODELOPT_MSG) |
498 | 485 |
|
| 486 | + quant_cfg_choices = get_quant_cfg_choices() |
499 | 487 | if isinstance(quant_cfg, str): |
500 | | - assert quant_cfg in QUANT_CFG_CHOICES, ( |
501 | | - f"Quantization config {quant_cfg} is not supported. Supported configs: {list(QUANT_CFG_CHOICES)}" |
| 488 | + assert quant_cfg in quant_cfg_choices, ( |
| 489 | + f"Quantization config {quant_cfg} is not supported. Supported configs: {list(quant_cfg_choices)}" |
502 | 490 | ) |
503 | | - quant_cfg = QUANT_CFG_CHOICES[quant_cfg] |
| 491 | + quant_cfg = quant_cfg_choices[quant_cfg] |
504 | 492 |
|
505 | 493 | logger.info("Starting quantization...") |
506 | 494 | mtq.quantize(self.model, quant_cfg, forward_loop=forward_loop) |
@@ -535,3 +523,57 @@ def get_triton_output(self): |
535 | 523 | def triton_infer_fn(self, **inputs: np.ndarray): |
536 | 524 | """PyTriton inference function.""" |
537 | 525 | raise NotImplementedError("This function will be implemented later.") |
| 526 | + |
| 527 | + |
| 528 | +def get_calib_data_iter( |
| 529 | + data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 |
| 530 | +): |
| 531 | + """Creates a sample data iterator for calibration.""" |
| 532 | + if data == "wikitext": |
| 533 | + dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") |
| 534 | + text_column = "text" |
| 535 | + elif data == "cnn_dailymail": |
| 536 | + dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") |
| 537 | + text_column = "article" |
| 538 | + else: |
| 539 | + # Assume a local JSON dataset with a column named "text" |
| 540 | + dataset = load_dataset("json", data_files=data, split="train") |
| 541 | + text_column = "text" |
| 542 | + calib_size = max(min(len(dataset), calib_size), batch_size) |
| 543 | + for i in range(calib_size // batch_size): |
| 544 | + batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] |
| 545 | + for j in range(len(batch)): |
| 546 | + batch[j] = batch[j][:max_sequence_length] |
| 547 | + yield batch |
| 548 | + |
| 549 | + |
| 550 | +def get_quant_cfg_choices() -> Dict[str, Dict[str, Any]]: |
| 551 | + """ |
| 552 | + Retrieve a dictionary of modelopt quantization configuration choices. |
| 553 | +
|
| 554 | + This function checks for the availability of specific quantization configurations defined in |
| 555 | + the modelopt.torch.quantization (mtq) module and returns a dictionary mapping short names to |
| 556 | + their corresponding configurations. The function is intended to work for different modelopt |
| 557 | + library versions that come with variable configuration choices. |
| 558 | +
|
| 559 | + Returns: |
| 560 | + dict: A dictionary where keys are short names (e.g., "fp8") and values are the |
| 561 | + corresponding modelopt quantization configuration objects. |
| 562 | + """ |
| 563 | + quant_cfg_names = [ |
| 564 | + ("int8", "INT8_DEFAULT_CFG"), |
| 565 | + ("int8_sq", "INT8_SMOOTHQUANT_CFG"), |
| 566 | + ("fp8", "FP8_DEFAULT_CFG"), |
| 567 | + ("block_fp8", "FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG"), |
| 568 | + ("int4_awq", "INT4_AWQ_CFG"), |
| 569 | + ("w4a8_awq", "W4A8_AWQ_BETA_CFG"), |
| 570 | + ("int4", "INT4_BLOCKWISE_WEIGHT_ONLY_CFG"), |
| 571 | + ("nvfp4", "NVFP4_DEFAULT_CFG"), |
| 572 | + ] |
| 573 | + |
| 574 | + quant_cfg_choices = {} |
| 575 | + for short_name, full_name in quant_cfg_names: |
| 576 | + if config := getattr(mtq, full_name, None): |
| 577 | + quant_cfg_choices[short_name] = config |
| 578 | + |
| 579 | + return quant_cfg_choices |
0 commit comments