File tree Expand file tree Collapse file tree 5 files changed +32
-15
lines changed
Expand file tree Collapse file tree 5 files changed +32
-15
lines changed Original file line number Diff line number Diff line change @@ -174,7 +174,7 @@ def _compute_average_loss(
174174 total_batch_loss = (
175175 accumulated_loss * self .world_size / batch_num_loss_counted_tokens
176176 )
177- if self . model . is_gpt_oss and accumulated_aux_loss is not None :
177+ if accumulated_aux_loss is not None :
178178 total_batch_loss += accumulated_aux_loss
179179
180180 # reduce across ranks
Original file line number Diff line number Diff line change @@ -398,6 +398,15 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool:
398398 """
399399 Determine if we should convert GPT-OSS format during saving.
400400 """
401+ return is_known_model (model_path_or_config , "gpt_oss" )
402+
403+
404+ def is_known_model (
405+ model_path_or_config : str | PretrainedConfig , known_model_type : str | list [str ]
406+ ) -> bool :
407+ """
408+ Determine if the model is a known model.
409+ """
401410 if not isinstance (model_path_or_config , (PretrainedConfig , str )):
402411 raise ValueError (
403412 f"cannot detect model: received invalid argument of type { type (model_path_or_config )} "
@@ -408,7 +417,10 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool:
408417 if isinstance (model_path_or_config , str ):
409418 model_config = AutoConfig .from_pretrained (model_path_or_config )
410419
411- return getattr (model_config , "model_type" , None ) == "gpt_oss"
420+ known_model_types = (
421+ [known_model_type ] if isinstance (known_model_type , str ) else known_model_type
422+ )
423+ return getattr (model_config , "model_type" , None ) in known_model_types
412424
413425
414426def add_gpt_oss_quantization_config (config ):
Original file line number Diff line number Diff line change @@ -346,12 +346,13 @@ def main(args):
346346 # GPT-OSS specifically
347347 # We don't want to use use_orig_params for GPT-OSS models
348348 fsdp_should_use_orig_params = False
349- if m .is_gpt_oss :
350- logger .info ("🎯 Detected GPT-OSS model - freezing router parameters" )
351- freeze_router_params (m )
352- # For GPT-OSS, we need to use the original parameters so we can properly
353- # freeze the router parameters.
354- fsdp_should_use_orig_params = True
349+ if m .is_gpt_oss or m .is_granitemoehybrid :
350+ frozen_router_params = freeze_router_params (m )
351+ if frozen_router_params :
352+ logger .info ("🎯 Detected an MoE model - frozen router parameters" )
353+ # For an MoE model, we need to use the original parameters so we can properly
354+ # freeze the router parameters.
355+ fsdp_should_use_orig_params = True
355356
356357 # Mini_trainer approach: simplified setup
357358 # No complex calculations needed - the data loader handles everything
Original file line number Diff line number Diff line change 4343 DistributedBackend ,
4444 Optimizer ,
4545)
46- from instructlab .training .gpt_oss_utils_correct import is_gpt_oss
46+ from instructlab .training .gpt_oss_utils_correct import is_gpt_oss , is_known_model
4747from instructlab .training .type_definitions import ModelInputs , ModelLosses
4848
4949
@@ -65,6 +65,7 @@ def __init__(
6565 quant_config = None
6666
6767 # check model type & set on the mclasss
68+ self .is_granitemoehybrid = is_known_model (model_path , "granitemoehybrid" )
6869 self .is_gpt_oss = is_gpt_oss (model_path )
6970 if self .is_gpt_oss :
7071 # Third Party
@@ -418,7 +419,7 @@ def compute_loss(
418419
419420 # add the MoE auxiliary loss (currently we only support this for GPT-OSS)
420421 if (
421- self .is_gpt_oss
422+ ( self .is_gpt_oss or self . is_granitemoehybrid )
422423 and hasattr (output , "aux_loss" )
423424 and output .aux_loss is not None
424425 ):
@@ -429,7 +430,7 @@ def compute_loss(
429430 scaled_main_loss = primary_loss * world_size / samples_in_batch
430431
431432 # For GPT-OSS: add unscaled auxiliary loss after scaling main loss
432- if self . is_gpt_oss and aux_loss is not None :
433+ if aux_loss is not None :
433434 scaled_main_loss += aux_loss
434435
435436 raw_losses = ModelLosses (main_loss = primary_loss , aux_loss = aux_loss )
Original file line number Diff line number Diff line change @@ -902,13 +902,13 @@ def load_latest_full_state(args, accelerator) -> None:
902902
903903def freeze_router_params (model : Model ):
904904 """
905- Freeze router parameters for GPT-OSS models before FSDP setup.
905+ Freeze router parameters for MoE models before FSDP setup.
906906
907907 Args:
908908 model: The model to check and potentially freeze parameters
909909
910910 Returns:
911- bool: True if this is a GPT-OSS model and parameters were frozen
911+ bool: True if this is an MoE model and parameters were frozen
912912 """
913913
914914 # Freeze router parameters BEFORE accelerator setup
@@ -919,8 +919,11 @@ def freeze_router_params(model: Model):
919919 frozen_count += 1
920920 logger .info (f"❄️ Frozen router parameter: { name } " )
921921
922- logger .info (f"✅ Frozen { frozen_count } router parameters for GPT-OSS model" )
923- return True
922+ if frozen_count > 0 :
923+ logger .info (f"✅ Frozen { frozen_count } router parameters for an MoE model" )
924+ return True
925+ else :
926+ return False
924927
925928
926929def test_model_inference_quick (model , tokenizer , stage_name ):
You can’t perform that action at this time.
0 commit comments