Skip to content

Commit f7f8384

Browse files
authored
src directory polishing (#23)
Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ``` jobuser [ ~/Liger-Kernel ]$ make checkstyle test test-convergence flake8 .; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 1 files All done! ✨ 🍰 ✨ 45 files left unchanged. pytest --disable-warnings test/ --ignore=test/convergence =============================================================================================================== test session starts ================================================================================================================ platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191 collected 111 items test/transformers/test_cross_entropy.py .......................................................... [ 52%] test/transformers/test_fused_linear_cross_entropy.py ...... [ 57%] test/transformers/test_geglu.py ........ [ 64%] test/transformers/test_rms_norm.py ................ [ 79%] test/transformers/test_rope.py ............ [ 90%] test/transformers/test_swiglu.py ........ [ 97%] test/transformers/test_transformers_monkey_patch.py . [ 98%] test/triton/test_triton_monkey_patch.py .. [100%] ========================================================================================================== 111 passed in 61.60s (0:01:01) ========================================================================================================== HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence =============================================================================================================== test session starts ================================================================================================================ platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191 collected 8 items test/convergence/test_mini_models.py ...... [ 75%] test/convergence/test_mini_models_no_logits.py .. [100%] =========================================================================================================== 8 passed in 97.43s (0:01:37) =========================================================================================================== ```
1 parent ee0ede6 commit f7f8384

File tree

6 files changed

+14
-42
lines changed

6 files changed

+14
-42
lines changed

docs/images/memory.png

-18.9 KB
Binary file not shown.

docs/images/speedup.png

-13.2 KB
Binary file not shown.

src/liger_kernel/ops/cross_entropy.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def liger_cross_entropy_kernel(
1717
BLOCK_SIZE: tl.constexpr,
1818
):
1919
"""
20-
This kernel computes both cross entropy loss and the gradient of the _input.
20+
This kernel computes both cross entropy loss and the gradient of the input.
2121
We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math.
2222
2323
Parameters:
@@ -34,7 +34,7 @@ def liger_cross_entropy_kernel(
3434
"""
3535

3636
# https://github.com/triton-lang/triton/issues/1058
37-
# Essentially if B*T*V is too large, program_id * stride will overflow out of int32
37+
# If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64
3838
program_id = tl.program_id(0).to(tl.int64)
3939

4040
# 1. Load Y_ptr first because if the target is ignore_index, we can return right away
@@ -90,13 +90,7 @@ def liger_cross_entropy_kernel(
9090
tl.debug_barrier()
9191

9292
# 5. Calculate the loss
93-
# Old Approach: Problematic LogSoftmax
94-
# min of bfloat16 and float32 is 1e-38, so we set a value larger than that but small enough
95-
# This will overflow if X_y * n_non_ignore is too small. Even if we add a tiny epsilon, it will still overflow
96-
# loss = -tl.log(X_y * n_non_ignore)
9793

98-
# New Approach: Safe LogSoftmax
99-
# Therefore, we propose to use safe logsoftmax by reordering the formula.
10094
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
10195
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
10296
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
@@ -114,7 +108,7 @@ def liger_cross_entropy_kernel(
114108
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
115109
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
116110
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
117-
MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
111+
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
118112

119113

120114
@triton.jit
@@ -184,28 +178,6 @@ def forward(ctx, _input, target, ignore_index):
184178
n_non_ignore = (target != ignore_index).sum().item()
185179

186180
# ensure _input and target are contiguous in the last dimension
187-
# there are examples that are NOT contiguous overall but contiguous in the last dimension
188-
####################################################################
189-
# tensor = torch.arange(1, 21).reshape(5, -1)
190-
# print(tensor)
191-
# tensor([[ 1, 2, 3, 4],
192-
# [ 5, 6, 7, 8],
193-
# [ 9, 10, 11, 12],
194-
# [13, 14, 15, 16],
195-
# [17, 18, 19, 20]])
196-
# print(tensor.is_contiguous())
197-
# True
198-
# slice = tensor[::2, :]
199-
# print(slice)
200-
# tensor([[ 1, 2, 3, 4],
201-
# [ 9, 10, 11, 12],
202-
# [17, 18, 19, 20]])
203-
# print(slice.is_contiguous())
204-
# False
205-
# print(slice.stride())
206-
# (8, 1)
207-
# slice is NOT a contiguous tensor but is contiguous in the last dimension, CE kernel can execute because the stride is 8, and each triton program will jump by 8
208-
####################################################################
209181
if _input.stride(-1) != 1:
210182
_input = _input.contiguous()
211183
if target.stride(-1) != 1:
@@ -252,10 +224,9 @@ def backward(ctx, grad_output):
252224
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
253225
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
254226
pass
227+
255228
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
256229
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
257-
# Although the Brew trainer should only perform backward once, it encounters this issue.
258-
# https://github.com/triton-lang/triton/issues/4004
259230
else:
260231
BT, V = _input.shape
261232
n_rows = BT

src/liger_kernel/ops/fused_linear_cross_entropy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
"""Fusing the last linear layer with cross-entropy loss
2-
3-
Reference: https://github.com/mgmalek/efficient_cross_entropy
4-
"""
5-
61
import torch
72
import triton
83

@@ -11,13 +6,16 @@
116
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
127
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
138
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
14-
MAX_FUSED_SIZE = 65536 // 2 # manual tune a bit
9+
MAX_FUSED_SIZE = 65536 // 2
1510

1611

1712
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
1813
@staticmethod
1914
def forward(ctx, _input, linear, target, ignore_index):
2015
"""
16+
Fusing the last linear layer with cross-entropy loss
17+
Reference: https://github.com/mgmalek/efficient_cross_entropy
18+
2119
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
2220
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
2321
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
@@ -54,6 +52,8 @@ def forward(ctx, _input, linear, target, ignore_index):
5452

5553
grad_linear = torch.zeros_like(linear, device=device)
5654
grad_input = torch.zeros_like(_input, device=device)
55+
56+
# we use fp32 for loss accumulator
5757
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
5858

5959
total_n_non_ignore = (target != ignore_index).sum().item()

src/liger_kernel/transformers/model/llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def lce_forward(
3737
cache_position: Optional[torch.LongTensor] = None,
3838
) -> Union[Tuple, CausalLMOutputWithPast]:
3939
r"""
40+
Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
41+
42+
4043
Args:
4144
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
4245
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def apply_liger_kernel_to_llama(
1414
) -> None:
1515
"""
1616
Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
17-
to make GPU go burrr.
1817
1918
Args:
2019
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -53,7 +52,6 @@ def apply_liger_kernel_to_mistral(
5352
) -> None:
5453
"""
5554
Apply Liger kernels to replace original implementation in HuggingFace Mistral models
56-
to make GPU go burrr.
5755
5856
Args:
5957
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
@@ -82,12 +80,12 @@ def apply_liger_kernel_to_mixtral(
8280
) -> None:
8381
"""
8482
Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
85-
to make GPU go burrr.
8683
8784
Args:
8885
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
8986
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
9087
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
88+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
9189
"""
9290

9391
from transformers.models.mixtral import modeling_mixtral

0 commit comments

Comments
 (0)