Skip to content

Commit fe5cfbd

Browse files
committed
fix linter
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
1 parent 08a8405 commit fe5cfbd

File tree

13 files changed

+49
-426
lines changed

13 files changed

+49
-426
lines changed

lm_engine/hf_models/config/sequence_mixer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,8 @@ class _SoftmaxAttentionArgs(BaseArgs):
1616
add_bias: bool = False
1717
attention_multiplier: float | None = None
1818
sliding_window: int | None = None
19-
# needed for Qwen 2 MoE
20-
qkv_bias: bool = None
2119

2220
def model_post_init(self, __context: Any) -> None:
23-
if self.qkv_bias is None:
24-
self.qkv_bias = self.add_bias
25-
2621
assert self.sequence_mixer_type == "softmax_attention"
2722

2823

lm_engine/hf_models/mixins/dense/main.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
class CausalLMModelMixin(PreTrainedModelMixin):
2323
base_model_class = None
24-
model_parallel_state_dict_function = None
2524

2625
def __init__(self, config: CommonConfig, **kwargs) -> CausalLMModelMixin:
2726
super().__init__(config, **kwargs)
@@ -46,9 +45,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
4645
self.m_width = config.m_width
4746

4847
self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled()
49-
50-
if self.is_tp_enabled:
51-
self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh()
48+
self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() if self.is_tp_enabled else None
5249

5350
def forward(
5451
self,
@@ -339,19 +336,3 @@ def _get_dummy_intermediate_tensor(
339336
)
340337

341338
return tensor
342-
343-
def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None:
344-
with torch.device(torch.cuda.current_device()):
345-
position_embedding_type = self.config.position_embedding_type
346-
347-
if position_embedding_type == "rope":
348-
self.transformer.rope.reset_parameters()
349-
350-
state_dict = self.__class__.model_parallel_state_dict_function(
351-
config=self.config,
352-
safetensors_weights_manager=safetensors_weights_manager,
353-
num_pipeline_stages=self.num_pipeline_stages,
354-
pipeline_stage_id=self.pipeline_stage_id,
355-
)
356-
357-
self.load_state_dict(state_dict)

lm_engine/hf_models/mixins/dense_TP/main.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333

3434
class CausalLMModelMixin_TP(CausalLMModelMixin):
35+
model_parallel_state_dict_function = None
36+
3537
def forward(
3638
self,
3739
input_ids: torch.Tensor | list[list[int]] | None = None,
@@ -177,3 +179,19 @@ def from_pretrained(
177179
model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path))
178180

179181
return model
182+
183+
def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None:
184+
with torch.device(torch.cuda.current_device()):
185+
position_embedding_type = self.config.position_embedding_type
186+
187+
if position_embedding_type == "rope":
188+
self.transformer.rope.reset_parameters()
189+
190+
state_dict = self.__class__.model_parallel_state_dict_function(
191+
config=self.config,
192+
safetensors_weights_manager=safetensors_weights_manager,
193+
num_pipeline_stages=self.num_pipeline_stages,
194+
pipeline_stage_id=self.pipeline_stage_id,
195+
)
196+
197+
self.load_state_dict(state_dict)

lm_engine/hf_models/model_conversion/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@
2222
_import_granitemoeshared_state_dict,
2323
)
2424
from .llama import _export_llama_config, _export_llama_state_dict, _import_llama_config, _import_llama_state_dict
25-
from .qwen2_moe import (
26-
_export_qwen2_moe_config,
27-
_export_qwen2_moe_state_dict,
28-
_import_qwen2_moe_config,
29-
_import_qwen2_moe_state_dict,
30-
)
3125

3226

3327
_MODEL_IMPORT_FUNCTIONS = {
@@ -36,7 +30,6 @@
3630
"granitemoeshared": (_import_granitemoeshared_config, _import_granitemoeshared_state_dict),
3731
"granitemoehybrid": (_import_granitemoehybrid_config, _import_granitemoehybrid_state_dict),
3832
"llama": (_import_llama_config, _import_llama_state_dict),
39-
"qwen2_moe": (_import_qwen2_moe_config, _import_qwen2_moe_state_dict),
4033
}
4134

4235

@@ -77,7 +70,6 @@ def import_from_huggingface(
7770
"granitemoeshared": (_export_granitemoeshared_config, _export_granitemoeshared_state_dict),
7871
"granitemoehybrid": (_export_granitemoehybrid_config, _export_granitemoehybrid_state_dict),
7972
"llama": (_export_llama_config, _export_llama_state_dict),
80-
"qwen2_moe": (_export_qwen2_moe_config, _export_qwen2_moe_state_dict),
8173
}
8274

8375

0 commit comments

Comments
 (0)