-
Notifications
You must be signed in to change notification settings - Fork 19
Open
Labels
Description
Description
Consider adding additional FusedCrossEntropyLoss kernel to FOAK set of kernels given the additional improvement seen using it in earlier tests (See Background below).
Considerations:
- Enable usage of FusedCrossEntropyLoss in FOAK plugin
- Enable FOAK kernels to be selectively activated for use in full finetuning and regular PEFT
- (KIV) consider enabling chunked CE loss for current plugin
Background
A comparison of the current FOAK kernels against the kernels from Liger using Liger's full FT benchmark script with the following parameters;
- model: "meta-llama/Meta-Llama-3-8B"
- dataset: 10000 sample subset of "tatsu-lab/alpaca"
- max_seq_length: 512
- epoch: 1
- num devices: 4 x A100-80GB GPUs
4 triton kernels are activated in the comparison against the FOAK equivalents,
- Fused MLP (SwiGLU) (MLP fused with Activation)
- RoPE Embeddings
- RMSNorm
- CrossEntropyLoss / FusedCrossEntropyLoss (Last linear layer fused with loss)
The benchmarks report the following metrics
avg_tokens_per_sec: Total input tokens seen by the model divided by the total runtime (secs) of each run.total_peak_allocated_memory: Total peak allocated gpu memory in MB
We observe that the FOAK kernels matches Liger in both speed and memory consumption with all 4 kernels (using the unfused CrossEntropyLoss kernel) but Liger performs better with FusedCrossEntropyLoss for
- speed (up to 20% improvement)
- memory (up to 36% improvement)
Additional Notes
- We also noticed that Liger's CrossEntropyLoss kernel doesn't support chunking of the LM vocab unlike the current FOAK kernels from Unsloth. Chunking allows the loss computation to be performed quicker in smaller chunks of the vocab before doing a final reduction over all the chunk losses. This could be a potential limitation/slowdown when the LM head of the model has a large vocab dimension (e.g. 256k)
- Considering that the Liger Kernels appear to be drop-in replacements for FOAK kernels, we would expect that a mix of FOAK and Liger Kernels to be compatible in the current FOAK plugin for QPEFT.
Extracted from fms-acceleration FOAK slides
| model_name_or_path | framework_config | num_gpus | batch_size | tokens_per_second | % Increase in throughput | peak_mem_alloc_in_GIB |
|---|---|---|---|---|---|---|
| llama3/hf/70b_pre_trained | accelerated-peft-bnb | 2 | 2 | 398 | 0 | 49.0 |
| llama3/hf/70b_pre_trained | accelerated-peft-bnb-foak | 2 | 2 | 434 | 9 | 48.7 |
| llama3/hf/70b_pre_trained | accelerated-peft-bnb-liger | 2 | 2 | ? | ? | ? |
