diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index fb8ab1bc..7c5c54b1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -235,7 +235,7 @@ def get_state_dict_from_dcp_checkpoint( planner=_EmptyStateDictLoadPlanner(), no_dist=True, ) - return [KEY_MODEL] + return sd[KEY_MODEL] # function to get state dict from regular checkoint