|
6 | 6 | import torch.nn.functional as F |
7 | 7 | from torch import nn |
8 | 8 |
|
9 | | -import tensorrt_llm.logger as trtllm_logger |
10 | 9 | from tensorrt_llm._utils import get_sm_version, is_sm_100f |
| 10 | +from tensorrt_llm.logger import logger |
11 | 11 | from tensorrt_llm.quantization.functional import \ |
12 | 12 | preprocess_weights_for_mixed_gemm |
13 | 13 | from tensorrt_llm.quantization.utils.fp4_utils import ( |
@@ -271,8 +271,6 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], |
271 | 271 | module.w2_bias.data if module.bias else None) |
272 | 272 |
|
273 | 273 | self.load_quant_scales(module, weights) |
274 | | - # Re-setup quant scales after loading weights as the tensors may have been modified. |
275 | | - self.setup_quant_scales(module) |
276 | 274 |
|
277 | 275 | if self.need_load_shared_weights(module): |
278 | 276 | local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( |
@@ -323,7 +321,8 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], |
323 | 321 | module.initial_global_assignments) |
324 | 322 |
|
325 | 323 | def post_load_weights(self, module: torch.nn.Module): |
326 | | - pass |
| 324 | + # Re-setup quant scales after loading weights as the tensors may have been modified. |
| 325 | + self.setup_quant_scales(module) |
327 | 326 |
|
328 | 327 | def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]): |
329 | 328 | pass |
@@ -722,14 +721,15 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], |
722 | 721 | if int(name.split(".")[0]) not in expert_ids: |
723 | 722 | continue |
724 | 723 | weight_name = name.replace("weight_scale_inv", "weight") |
725 | | - trtllm_logger.logger.debug(f"Resmoothing {weight_name}") |
| 724 | + logger.debug(f"Resmoothing {weight_name}") |
726 | 725 | weight = weights[weight_name][:] |
727 | 726 | scale = weights[name][:] |
728 | 727 | weights[weight_name], weights[name] = resmooth_to_fp8_e8m0( |
729 | 728 | weight, scale) |
730 | 729 | super().load_weights(module, weights, weight_loading_mode) |
731 | 730 |
|
732 | 731 | def post_load_weights(self, module: torch.nn.Module): |
| 732 | + super().post_load_weights(module) |
733 | 733 | if is_sm_100f(): |
734 | 734 | transfromed_w3_w1_scale = transform_sf_into_required_layout( |
735 | 735 | module.quant_scales[0], |
|
0 commit comments