You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Our `ScatterMoe` implementation is a module-swap; to add new models we need to update the specifications in [scattermoe_constants.py](./src/fms_acceleration_moe/utils/scattermoe_constants.py).
17
17
- See the code documentation within to understand how to add new models.
18
18
19
-
### Code Extracted from Megablocks
19
+
### Conversion of ScatterMoE
20
+
21
+
`ScatterMoE` checkpoints are saved using `torch.distributed.checkpoint` (DCP) and which is by default `StateDictType.SHARDED_STATE_DICT`:
22
+
-`DTensors` limited support for full state dicts.
23
+
- sharded state dicts are the extremely efficient, and require little comms overhead when saving.
24
+
25
+
We provide a script to recover back the original checkpoint:
26
+
- currently the script be used if the DSP saves a single `pytorch_model_fsdp_0` folder
27
+
- say the checkpoint is stored in `hf/checkpoint-10`, then call
- we have only extracted two `autograd` functions [GatherOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/gather.py) and [ScatterOp](https://github.com/databricks/megablocks/blob/main/megablocks/ops/scatter.py),
@@ -71,6 +89,5 @@ These are currently some known issues not yet resolved:
71
89
- The design currently does a swap for the mixture-of-expert module with [ScatterMoE](./src/fms_acceleration_moe/utils/scattermoe.py). This affects the `state_dict` of the model, so any saved checkpoint may need to be converted back to original.
72
90
- should eventually remove the dependency on an external `kernel-hyperdrive` repository.
73
91
- now support only loading *sharded* `safetensor` non-GGUF MoE checkpoints. This is a reasonable assumption since MoE checkpoints are typically above the size limit that prevents it being saved into a single checkpoint filed.
74
-
- currently only supports `StateDictType.SHARDED_STATE_DICT` because the implementation uses `DTensors` which have limited support for full state dicts. However for efficiency considerations, sharded state dicts are the most efficient.
0 commit comments