diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 7bd4a390..3d173c3b 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -113,8 +113,12 @@ def save_fsdp_optimizer( (model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer) # filter out lora state dict + # TODO: Once expert layers are supported for LoRA tuning + # remove the "router" filtering lora_state_dict = { - k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k + k: v + for k, v in model_state_dict.items() + if ("lora_A" in k or "lora_B" in k) and "router" not in k } # - save model diff --git a/plugins/accelerated-moe/tox.ini b/plugins/accelerated-moe/tox.ini index 811f1329..e17e163c 100644 --- a/plugins/accelerated-moe/tox.ini +++ b/plugins/accelerated-moe/tox.ini @@ -4,6 +4,7 @@ envlist = py, lint [testenv] deps = pytest>=7 + importlib-metadata -e {toxinidir} skip_install = true commands =