|
| 1 | +import copy |
1 | 2 | import dataclasses |
2 | 3 | import os |
3 | 4 | from typing import List, Optional, Tuple |
|
7 | 8 | from transformers.modeling_utils import no_init_weights |
8 | 9 | from transformers.models.gemma3.modeling_gemma3 import Gemma3MultiModalProjector |
9 | 10 |
|
| 11 | +from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \ |
| 12 | + BaseWeightMapper |
| 13 | + |
10 | 14 | from ..._utils import nvtx_range |
11 | 15 | from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, |
12 | 16 | register_input_processor) |
@@ -98,13 +102,14 @@ def __init__(self, model_config: ModelConfig[Gemma3Config]): |
98 | 102 | dtype=torch.int32, |
99 | 103 | device=self._device) |
100 | 104 |
|
101 | | - self.model_config = model_config |
| 105 | + model_config_cp = copy.deepcopy(model_config) |
| 106 | + self.model_config = model_config_cp |
102 | 107 |
|
103 | | - llm_model_config = self.get_sub_model_config(model_config, |
| 108 | + llm_model_config = self.get_sub_model_config(model_config_cp, |
104 | 109 | "text_config") |
105 | 110 | self.llm = Gemma3ForCausalLM(llm_model_config) |
106 | 111 |
|
107 | | - vision_model_config = self.get_sub_model_config(model_config, |
| 112 | + vision_model_config = self.get_sub_model_config(model_config_cp, |
108 | 113 | "vision_config") |
109 | 114 | self.siglip_tower = SiglipVisionModel(vision_model_config, |
110 | 115 | use_post_layernorm=True) |
@@ -141,9 +146,9 @@ def get_sub_model_config( |
141 | 146 | sub_model_config.pretrained_config.torch_dtype = model_config.pretrained_config.torch_dtype |
142 | 147 | return sub_model_config |
143 | 148 |
|
144 | | - def load_weights(self, weights): |
| 149 | + def load_weights(self, weights, weight_mapper: BaseWeightMapper): |
145 | 150 | llm_weights = filter_weights("language_model", weights) |
146 | | - self.llm.load_weights(llm_weights) |
| 151 | + self.llm.load_weights(llm_weights, weight_mapper) |
147 | 152 |
|
148 | 153 | vit_weights = filter_weights("vision_tower", weights) |
149 | 154 | self.siglip_tower.load_weights(vit_weights) |
|
0 commit comments