Skip to content

Conversation

@fabianlim
Copy link
Contributor

@fabianlim fabianlim commented Nov 3, 2024

This PR will supercede #69.

This PR will add a new accelerated-moe plugin that is triton-only

TODO:

Performance

For ibm-granite/granite-3.0-3b-a800m-instruct and Mixtral-8x7B-Instruct-v0.1

  • effective batch size 128
  • bfloat16 no mixed precision
  • we disabled the torch memory logging to get more competitive runtimes
  • framework_config = None: FSDP
  • moe-scattermoe-granite-ep1: MoE world_size = 1
  • moe-scattermoe-granite-ep2: MoE world_size = 2
model_name_or_path num_gpus framework_config mem_nvidia_mem_reserved train_runtime mem util speedup
ibm-granite/granite-3.0-3b-a800m-instruct 1 none 71199 2371.93 baseline baseline
ibm-granite/granite-3.0-3b-a800m-instruct 1 moe-scattermoe-granite-ep1 71187 742.739 1.0 3.19
ibm-granite/granite-3.0-3b-a800m-instruct 1 moe-scattermoe-granite-ep1-padding-free 48401 631.976 0.68 3.75
ibm-granite/granite-3.0-3b-a800m-instruct 1 moe-scattermoe-granite-ep1-padding-free-foak 42651 615.453 0.6 3.85
ibm-granite/granite-3.0-3b-a800m-instruct 2 none 46829 1355.71 baseline baseline
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep1 52503 485.51 1.12 2.79
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep1-padding-free 42452 454.344 0.91 2.98
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep1-padding-free-foak 37743 433.481 0.81 3.13
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep2 40193 577.216 0.86 2.35
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep2-padding-free 31012 546.507 0.66 2.48
ibm-granite/granite-3.0-3b-a800m-instruct 2 moe-scattermoe-granite-ep2-padding-free-foak 26075 524.775 0.56 2.58
ibm-granite/granite-3.0-3b-a800m-instruct 4 none 37996 708.391 baseline baseline
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep1 51145 262.957 1.35 2.69
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep1-padding-free 38560 241.297 1.01 2.94
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep1-padding-free-foak 35153 232.043 0.93 3.05
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep2 40878.5 300.285 1.08 2.36
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep2-padding-free 28133 283.544 0.74 2.5
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep2-padding-free-foak 24665.5 274.126 0.65 2.58
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep4 31777.5 307.126 0.84 2.31
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep4-padding-free 21585.5 284.608 0.57 2.49
ibm-granite/granite-3.0-3b-a800m-instruct 4 moe-scattermoe-granite-ep4-padding-free-foak 18368 278.125 0.48 2.55
mistralai/Mixtral-8x7B-Instruct-v0.1 8 none 65607.2 4180.95 baseline baseline
mistralai/Mixtral-8x7B-Instruct-v0.1 8 moe-scattermoe-granite-ep8 52004.8 1071.2 0.79 3.9
mistralai/Mixtral-8x7B-Instruct-v0.1 8 moe-scattermoe-granite-ep8-foak 51961.2 1043.67 0.79 4.01

Resumption

non-sharded checkpoints: Tested resumption on 2 devices for expert size 1 and 2

reader = dcp.FileSystemReader("tmp3/checkpoint-10/pytorch_model_fsdp_0")
metadata.state_dict_metadata['model.model.layers.1.block_sparse_moe.w1.weight']
TensorStorageMetadata(properties=TensorProperties(dtype=torch.bfloat16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False), size=torch.Size([40, 1536, 512]), chunks=[ChunkStorageMetadata(offsets=torch.Size([0, 0, 0]), sizes=torch.Size([20, 1536, 512])), ChunkStorageMetadata(offsets=torch.Size([20, 0, 0]), sizes=torch.Size([20, 1536, 512]))])

image

Also for sharded checkpoints (mixtral): tested resumption

reader = dcp.FileSystemReader("tmp3/checkpoint-10/pytorch_model_fsdp_0")
metadata.state_dict_metadata['model.model.layers.1.block_sparse_moe.w1.weight']

TensorStorageMetadata(properties=TensorProperties(dtype=torch.bfloat16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False), size=torch.Size([8, 4096, 14336]), chunks=[ChunkStorageMetadata(offsets=torch.Size([0, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([1, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([2, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([3, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([4, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([5, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([6, 0, 0]), sizes=torch.Size([1, 4096, 14336])), ChunkStorageMetadata(offsets=torch.Size([7, 0, 0]), sizes=torch.Size([1, 4096, 14336]))])

Handling the State Dict

we have a convinience function restore_scattermoe_checkpoint_to_orig to load the DCP checkpoint, and optionally convert back to original if the pretrained checkpoint is provided.

from fms_acceleration_moe.utils.checkpoint_utils import restore_scattermoe_checkpoint_to_orig
from fms_acceleration_moe.utils import prepare_scattermoe
from transformers import AutoModelForCausalLM

MODEL = 'ibm-granite/granite-3.0-3b-a800m-instruct'
CKPT = "tmp2/checkpoint-50/pytorch_model_fsdp_0"

# load the model, convert to scattermoe
model = AutoModelForCausalLM.from_pretrained(MODEL, device_map='cuda')
prepare_scattermoe(
    model,
    checkpoint_name_or_path=MODEL,
    rank=0,
    world_size=1,
    ep_degree=1,
    mixed_precision=False,  # Currently this is hardcoded to OFF
)

# dcp checkpoint
sd = restore_scattermoe_checkpoint_to_orig(CKPT)
model.load_state_dict(sd)

# load the original model
model2 = AutoModelForCausalLM.from_pretrained(MODEL, device_map='cuda')
# use the utility to convert the original checkpoint
sd = restore_scattermoe_checkpoint_to_orig(CKPT, pretrained_model_name_or_path=MODEL)
model2.load_state_dict(sd)


Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim marked this pull request as draft November 3, 2024 07:46
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim force-pushed the refactor/moe branch 3 times, most recently from f412899 to bccd967 Compare November 5, 2024 01:52
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim marked this pull request as ready for review November 6, 2024 10:30
@fabianlim fabianlim requested a review from anhuong November 7, 2024 02:28
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants