-
Notifications
You must be signed in to change notification settings - Fork 19
feat: add liger kernel with fused cross entropy loss #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
fabianlim
merged 15 commits into
foundation-model-stack:main
from
anhuong:fused-cross-entropyloss-simplified
Dec 2, 2024
Merged
Changes from 6 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
fe14d28
initial implementation of fused-linear-loss on llama
achew010 a554ac8
syntax fixes and remove unused code
achew010 33ee02a
add new num_logits_to_keep arg for llama.forward()
anhuong 4322843
add mixtral model patch
anhuong 05cdbe6
add mistral and granite model patch
anhuong b898968
add benchmark
anhuong 9454927
add new liger benchmarks
anhuong 0aab2c6
Merge remote-tracking branch 'origin/main' into fused-cross-entropylo…
fabianlim 2c202ef
some fixes
fabianlim 12ebdb9
revise benches
fabianlim 45f1a89
refactor to fused_ops
fabianlim dc075e3
fix fmt + lint
fabianlim a02a0a0
update full benches and readme
fabianlim 1a69314
fix fast foak configs
fabianlim 4595137
docs: update foak readme benchmarks
anhuong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
222 changes: 222 additions & 0 deletions
222
plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,222 @@ | ||
| # Copyright 2024 Byron Hsu & Linkedin team. All rights reserved. | ||
| # | ||
| # BSD 2-CLAUSE LICENSE | ||
| # Copyright 2024 LinkedIn Corporation | ||
| # All Rights Reserved. | ||
| # Redistribution and use in source and binary forms, with or | ||
| # without modification, are permitted provided that the following | ||
| # conditions are met: | ||
| # 1. Redistributions of source code must retain the above copyright | ||
| # notice, this list of conditions and the following disclaimer. | ||
| # 2. Redistributions in binary form must reproduce the above | ||
| # copyright notice, this list of conditions and the following | ||
| # disclaimer in the documentation and/or other materials provided | ||
| # with the distribution. | ||
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | ||
| # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | ||
| # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | ||
| # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | ||
| # HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | ||
| # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | ||
| # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | ||
| # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | ||
| # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
| # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
|
||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| @triton.jit | ||
| def liger_cross_entropy_kernel( | ||
| X_ptr, | ||
| X_stride, | ||
| Y_ptr, | ||
| Y_stride, | ||
| loss_ptr, | ||
| loss_stride, | ||
| n_cols, | ||
| n_non_ignore, | ||
| ignore_index, | ||
| label_smoothing: tl.constexpr, | ||
| reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| """ | ||
| This kernel computes both cross entropy loss and the gradient of the input. | ||
| We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. | ||
|
|
||
| Parameters: | ||
| X_ptr: Pointer to input tensor. | ||
| X_stride (int): The stride of the input tensor. | ||
| Y_ptr: Pointer to target tensor. | ||
| Y_stride (int): The stride of the target tensor. | ||
| loss_ptr: Pointer to tensor to store the loss. | ||
| loss_stride (int): The stride of the loss tensor. | ||
| n_cols (int): The number of columns in the input tensor. | ||
| n_non_ignore (int): The number of non-ignored elements in the batch. | ||
| ignore_index (int): The index to ignore in the target. | ||
| label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. | ||
| reduction (str): The string for the reduction to apply | ||
| BLOCK_SIZE (int): The block size for Triton operations. | ||
| """ | ||
|
|
||
| # https://github.com/triton-lang/triton/issues/1058 | ||
| # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 | ||
| program_id = tl.program_id(0).to(tl.int64) | ||
|
|
||
| # 1. Load Y_ptr first because if the target is ignore_index, we can return right away | ||
| Y_ptr += program_id * Y_stride | ||
| y = tl.load(Y_ptr) | ||
|
|
||
| # 2. locate the start index | ||
| X_ptr += program_id * X_stride | ||
|
|
||
| if y == ignore_index: | ||
| # set all X_ptr as 0 | ||
| for i in range(0, n_cols, BLOCK_SIZE): | ||
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
| tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) | ||
| return | ||
|
|
||
| loss_ptr += program_id * loss_stride | ||
|
|
||
| # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) | ||
| # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 | ||
|
|
||
| # 3. [Online softmax] first pass: find max + sum | ||
| m = float("-inf") # m is the max value. use the notation from the paper | ||
| d = 0.0 # d is the sum. use the notation from the paper | ||
| ori_X_y = tl.load( | ||
| X_ptr + y | ||
| ) # we need to store the original value of X_y for the loss calculation | ||
|
|
||
| # Label smoothing is a general case of normal cross entropy | ||
| # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 | ||
| scaled_x_sum = 0.0 | ||
| eps = label_smoothing / n_cols | ||
|
|
||
| for i in range(0, n_cols, BLOCK_SIZE): | ||
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
| X_block = tl.load( | ||
| X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") | ||
| ) | ||
| block_max = tl.max(X_block) | ||
| if label_smoothing > 0: | ||
| # scale X beforehand to avoid overflow | ||
| scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) | ||
| m_new = tl.maximum(m, block_max) | ||
| d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) | ||
| m = m_new | ||
|
|
||
| # 4. [Online Softmax] Second pass: compute gradients | ||
| # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) | ||
| # dx_y = (softmax(x_y) - 1) / N | ||
| # dx_i = softmax(x_i) / N, i != y | ||
| # For label smoothing: | ||
| # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y | ||
| # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N | ||
| # = dx_i - (1 - label_smoothing) / N | ||
| # | ||
| # For 'sum' reduction, no normalization is applied: | ||
| # dx_y = softmax(x_y) - 1 | ||
| # dx_i = softmax(x_i), for i ≠ y | ||
| # For label smoothing: | ||
| # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y | ||
| # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) | ||
| # = dx_i - (1 - label_smoothing) | ||
|
|
||
| for i in range(0, n_cols, BLOCK_SIZE): | ||
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
| X_block = tl.load( | ||
| X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") | ||
| ) | ||
| if reduction == "mean": | ||
| X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) | ||
| else: | ||
| X_block = tl.exp(X_block - m) / d - eps | ||
|
|
||
| tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) | ||
|
|
||
| # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in | ||
| # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 | ||
| tl.debug_barrier() | ||
|
|
||
| # 5. Calculate the loss | ||
|
|
||
| # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) | ||
| # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) | ||
| # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 | ||
| # So we can safely calculate log (softmax(X_y)) without overflow | ||
| loss = -(ori_X_y - m - tl.log(d)) | ||
|
|
||
| # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps | ||
| # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) | ||
| # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) | ||
| # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: | ||
| # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) | ||
| # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 | ||
| # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 | ||
| # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 | ||
| if label_smoothing > 0: | ||
| smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) | ||
| loss = loss * (1 - label_smoothing) + smooth_loss | ||
|
|
||
| # Normalize the loss by the number of non-ignored elements if reduction is "mean" | ||
| if reduction == "mean": | ||
| loss = loss / n_non_ignore | ||
|
|
||
| # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` | ||
| X_y = tl.load(X_ptr + y) | ||
| if reduction == "mean": | ||
| X_y += -(1 - label_smoothing) / (n_non_ignore) | ||
| else: | ||
| X_y += -(1 - label_smoothing) | ||
|
|
||
| tl.store(loss_ptr, loss) | ||
| tl.store(X_ptr + y, X_y) | ||
|
|
||
|
|
||
| # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 | ||
| # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling | ||
| # The optimal maximum block size depends on your hardware, your kernel, and your dtype | ||
| MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning | ||
|
|
||
|
|
||
| @triton.jit | ||
| def element_mul_kernel( | ||
| X_ptr, | ||
| X_stride, | ||
| grad_output_ptr, | ||
| n_cols, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| """ | ||
| This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. | ||
| The multiplication is performed in-place on the tensor pointed by X_ptr. | ||
|
|
||
| Parameters: | ||
| X_ptr: Pointer to the input tensor. | ||
| X_stride (int): The stride of the input tensor. | ||
| grad_output_ptr: Pointer to the gradient output value. | ||
| n_cols (int): The number of columns in the input tensor. | ||
| BLOCK_SIZE (int): The block size for Triton operations. | ||
| """ | ||
|
|
||
| # Get the program ID and convert it to int64 to avoid overflow | ||
| program_id = tl.program_id(0).to(tl.int64) | ||
|
|
||
| # Locate the start index | ||
| X_ptr += program_id * X_stride | ||
|
|
||
| # Load the gradient output value | ||
| grad_output = tl.load(grad_output_ptr) | ||
|
|
||
| # Perform the element-wise multiplication | ||
| for i in range(0, n_cols, BLOCK_SIZE): | ||
| X_offsets = i + tl.arange(0, BLOCK_SIZE) | ||
| X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) | ||
| tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) | ||
|
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.