diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index f0e10a89..cc6da021 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -174,7 +174,7 @@ def _compute_average_loss( total_batch_loss = ( accumulated_loss * self.world_size / batch_num_loss_counted_tokens ) - if self.model.is_gpt_oss and accumulated_aux_loss is not None: + if accumulated_aux_loss is not None: total_batch_loss += accumulated_aux_loss # reduce across ranks diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py index 430a890f..a77ee15a 100644 --- a/src/instructlab/training/gpt_oss_utils_correct.py +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -398,6 +398,15 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: """ Determine if we should convert GPT-OSS format during saving. """ + return is_known_model(model_path_or_config, "gpt_oss") + + +def is_known_model( + model_path_or_config: str | PretrainedConfig, known_model_type: str | list[str] +) -> bool: + """ + Determine if the model is a known model. + """ if not isinstance(model_path_or_config, (PretrainedConfig, str)): raise ValueError( f"cannot detect model: received invalid argument of type {type(model_path_or_config)}" @@ -408,7 +417,10 @@ def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: if isinstance(model_path_or_config, str): model_config = AutoConfig.from_pretrained(model_path_or_config) - return getattr(model_config, "model_type", None) == "gpt_oss" + known_model_types = ( + [known_model_type] if isinstance(known_model_type, str) else known_model_type + ) + return getattr(model_config, "model_type", None) in known_model_types def add_gpt_oss_quantization_config(config): diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a0931353..b73a86f3 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -346,12 +346,13 @@ def main(args): # GPT-OSS specifically # We don't want to use use_orig_params for GPT-OSS models fsdp_should_use_orig_params = False - if m.is_gpt_oss: - logger.info("🎯 Detected GPT-OSS model - freezing router parameters") - freeze_router_params(m) - # For GPT-OSS, we need to use the original parameters so we can properly - # freeze the router parameters. - fsdp_should_use_orig_params = True + if m.is_gpt_oss or m.is_granitemoehybrid: + frozen_router_params = freeze_router_params(m) + if frozen_router_params: + logger.info("🎯 Detected an MoE model - frozen router parameters") + # For an MoE model, we need to use the original parameters so we can properly + # freeze the router parameters. + fsdp_should_use_orig_params = True # Mini_trainer approach: simplified setup # No complex calculations needed - the data loader handles everything diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index de863e1d..bb89a1c4 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -43,7 +43,7 @@ DistributedBackend, Optimizer, ) -from instructlab.training.gpt_oss_utils_correct import is_gpt_oss +from instructlab.training.gpt_oss_utils_correct import is_gpt_oss, is_known_model from instructlab.training.type_definitions import ModelInputs, ModelLosses @@ -65,6 +65,7 @@ def __init__( quant_config = None # check model type & set on the mclasss + self.is_granitemoehybrid = is_known_model(model_path, "granitemoehybrid") self.is_gpt_oss = is_gpt_oss(model_path) if self.is_gpt_oss: # Third Party @@ -418,7 +419,7 @@ def compute_loss( # add the MoE auxiliary loss (currently we only support this for GPT-OSS) if ( - self.is_gpt_oss + (self.is_gpt_oss or self.is_granitemoehybrid) and hasattr(output, "aux_loss") and output.aux_loss is not None ): @@ -429,7 +430,7 @@ def compute_loss( scaled_main_loss = primary_loss * world_size / samples_in_batch # For GPT-OSS: add unscaled auxiliary loss after scaling main loss - if self.is_gpt_oss and aux_loss is not None: + if aux_loss is not None: scaled_main_loss += aux_loss raw_losses = ModelLosses(main_loss=primary_loss, aux_loss=aux_loss) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 275a4b7e..fc31858e 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -902,13 +902,13 @@ def load_latest_full_state(args, accelerator) -> None: def freeze_router_params(model: Model): """ - Freeze router parameters for GPT-OSS models before FSDP setup. + Freeze router parameters for MoE models before FSDP setup. Args: model: The model to check and potentially freeze parameters Returns: - bool: True if this is a GPT-OSS model and parameters were frozen + bool: True if this is an MoE model and parameters were frozen """ # Freeze router parameters BEFORE accelerator setup @@ -919,8 +919,11 @@ def freeze_router_params(model: Model): frozen_count += 1 logger.info(f"❄️ Frozen router parameter: {name}") - logger.info(f"✅ Frozen {frozen_count} router parameters for GPT-OSS model") - return True + if frozen_count > 0: + logger.info(f"✅ Frozen {frozen_count} router parameters for an MoE model") + return True + else: + return False def test_model_inference_quick(model, tokenizer, stage_name):