Skip to content

Commit 3d05302

Browse files
authored
Handle granite 4 as MoE models in training (#669)
* Support granite 4 models as MoE models * fix ruff errors * Support granite 4 models as MoE models * fix ruff errors * address bot's comment
1 parent 781c36f commit 3d05302

File tree

5 files changed

+32
-15
lines changed

5 files changed

+32
-15
lines changed

src/instructlab/training/batch_loss_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _compute_average_loss(
174174
total_batch_loss = (
175175
accumulated_loss * self.world_size / batch_num_loss_counted_tokens
176176
)
177-
if self.model.is_gpt_oss and accumulated_aux_loss is not None:
177+
if accumulated_aux_loss is not None:
178178
total_batch_loss += accumulated_aux_loss
179179

180180
# reduce across ranks

src/instructlab/training/gpt_oss_utils_correct.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,15 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool:
398398
"""
399399
Determine if we should convert GPT-OSS format during saving.
400400
"""
401+
return is_known_model(model_path_or_config, "gpt_oss")
402+
403+
404+
def is_known_model(
405+
model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str]
406+
) -> bool:
407+
"""
408+
Determine if the model is a known model.
409+
"""
401410
if not isinstance(model_path_or_config, (PretrainedConfig, str)):
402411
raise ValueError(
403412
f"cannot detect model: received invalid argument of type {type(model_path_or_config)}"
@@ -408,7 +417,10 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool:
408417
if isinstance(model_path_or_config, str):
409418
model_config = AutoConfig.from_pretrained(model_path_or_config)
410419

411-
return getattr(model_config, "model_type", None) == "gpt_oss"
420+
known_model_types = (
421+
[known_model_type] if isinstance(known_model_type, str) else known_model_type
422+
)
423+
return getattr(model_config, "model_type", None) in known_model_types
412424

413425

414426
def add_gpt_oss_quantization_config(config):

src/instructlab/training/main_ds.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,13 @@ def main(args):
346346
# GPT-OSS specifically
347347
# We don't want to use use_orig_params for GPT-OSS models
348348
fsdp_should_use_orig_params = False
349-
if m.is_gpt_oss:
350-
logger.info("🎯 Detected GPT-OSS model - freezing router parameters")
351-
freeze_router_params(m)
352-
# For GPT-OSS, we need to use the original parameters so we can properly
353-
# freeze the router parameters.
354-
fsdp_should_use_orig_params = True
349+
if m.is_gpt_oss or m.is_granitemoehybrid:
350+
frozen_router_params = freeze_router_params(m)
351+
if frozen_router_params:
352+
logger.info("🎯 Detected an MoE model - frozen router parameters")
353+
# For an MoE model, we need to use the original parameters so we can properly
354+
# freeze the router parameters.
355+
fsdp_should_use_orig_params = True
355356

356357
# Mini_trainer approach: simplified setup
357358
# No complex calculations needed - the data loader handles everything

src/instructlab/training/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
DistributedBackend,
4444
Optimizer,
4545
)
46-
from instructlab.training.gpt_oss_utils_correct import is_gpt_oss
46+
from instructlab.training.gpt_oss_utils_correct import is_gpt_oss, is_known_model
4747
from instructlab.training.type_definitions import ModelInputs, ModelLosses
4848

4949

@@ -65,6 +65,7 @@ def __init__(
6565
quant_config = None
6666

6767
# check model type & set on the mclasss
68+
self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid")
6869
self.is_gpt_oss = is_gpt_oss(model_path)
6970
if self.is_gpt_oss:
7071
# Third Party
@@ -418,7 +419,7 @@ def compute_loss(
418419

419420
# add the MoE auxiliary loss (currently we only support this for GPT-OSS)
420421
if (
421-
self.is_gpt_oss
422+
(self.is_gpt_oss or self.is_granitemoehybrid)
422423
and hasattr(output, "aux_loss")
423424
and output.aux_loss is not None
424425
):
@@ -429,7 +430,7 @@ def compute_loss(
429430
scaled_main_loss = primary_loss * world_size / samples_in_batch
430431

431432
# For GPT-OSS: add unscaled auxiliary loss after scaling main loss
432-
if self.is_gpt_oss and aux_loss is not None:
433+
if aux_loss is not None:
433434
scaled_main_loss += aux_loss
434435

435436
raw_losses = ModelLosses(main_loss=primary_loss, aux_loss=aux_loss)

src/instructlab/training/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -902,13 +902,13 @@ def load_latest_full_state(args, accelerator) -> None:
902902

903903
def freeze_router_params(model: Model):
904904
"""
905-
Freeze router parameters for GPT-OSS models before FSDP setup.
905+
Freeze router parameters for MoE models before FSDP setup.
906906
907907
Args:
908908
model: The model to check and potentially freeze parameters
909909
910910
Returns:
911-
bool: True if this is a GPT-OSS model and parameters were frozen
911+
bool: True if this is an MoE model and parameters were frozen
912912
"""
913913

914914
# Freeze router parameters BEFORE accelerator setup
@@ -919,8 +919,11 @@ def freeze_router_params(model: Model):
919919
frozen_count += 1
920920
logger.info(f"❄️ Frozen router parameter: {name}")
921921

922-
logger.info(f"✅ Frozen {frozen_count} router parameters for GPT-OSS model")
923-
return True
922+
if frozen_count > 0:
923+
logger.info(f"✅ Frozen {frozen_count} router parameters for an MoE model")
924+
return True
925+
else:
926+
return False
924927

925928

926929
def test_model_inference_quick(model, tokenizer, stage_name):

0 commit comments

Comments
 (0)