Skip to content

Commit b828275

Browse files
mta + softmax docs (#730)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: RTX 3090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence Co-authored-by: Shao Tang <[email protected]>
1 parent e99bbb5 commit b828275

File tree

5 files changed

+23
-6
lines changed

5 files changed

+23
-6
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ loss.backward()
277277
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
278278
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
279279
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
280+
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
281+
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
280282
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
281283

282284

benchmark/scripts/benchmark_softmax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from utils import parse_benchmark_script_args
99
from utils import run_benchmarks
1010

11-
from liger_kernel.transformers.softmax import LigerKernelSoftmax
11+
from liger_kernel.transformers.softmax import LigerSoftmax
1212
from liger_kernel.utils import infer_device
1313

1414
device = infer_device()
@@ -23,7 +23,7 @@ def bench_speed_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut
2323
dtype = extra_benchmark_config["dtype"]
2424

2525
x_shape = (M, N)
26-
liger_softmax = LigerKernelSoftmax().to(device).to(dtype)
26+
liger_softmax = LigerSoftmax().to(device).to(dtype)
2727
torch_softmax = torch.nn.Softmax(dim=-1).to(device).to(dtype)
2828

2929
x = torch.randn(x_shape, dtype=dtype, device=device)
@@ -72,7 +72,7 @@ def bench_memory_softmax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
7272
dtype = extra_benchmark_config.get("dtype", torch.float32)
7373

7474
torch_softmax = torch.nn.Softmax(dim=-1)
75-
liger_softmax = LigerKernelSoftmax().to(device).to(dtype)
75+
liger_softmax = LigerSoftmax().to(device).to(dtype)
7676

7777
x = torch.randn(shape, device=device, dtype=dtype, requires_grad=True)
7878

docs/Low-Level-APIs.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
1010
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
1111
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
12+
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
13+
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
1214
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
1315

1416

@@ -51,6 +53,19 @@ This kernel combines linear transformations with cross-entropy loss calculations
5153
!!! Example "Try it out"
5254
You can experiment as shown in this example [here](https://colab.research.google.com/drive/1Z2QtvaIiLm5MWOs7X6ZPS1MN3hcIJFbj?usp=sharing)
5355

56+
### Multi Token Attention
57+
58+
The Multi Token Attention kernel implementation provides and optimized fused implementation of multi-token attention over the implemented Pytorch model baseline. This is a new attention mechanism that can operate on multiple Q and K inputs introduced by Meta Research.
59+
60+
Paper: https://arxiv.org/abs/2504.00927
61+
62+
### Softmax
63+
64+
The Softmax kernel implementation provides an optimized implementation of the softmax operation, which is a fundamental component in neural networks for converting raw scores into probability distributions.
65+
66+
The implementation shows notable speedups compared to the Softmax PyTorch implementation
67+
68+
5469
### Sparsemax
5570

5671
Sparsemax is a sparse alternative to softmax that produces sparse probability distributions. This kernel implements an efficient version of the sparsemax operation that can be used as a drop-in replacement for softmax in attention mechanisms or classification tasks.

src/liger_kernel/transformers/softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from liger_kernel.ops.softmax import LigerSoftmaxFunction
55

66

7-
class LigerKernelSoftmax(nn.Module):
7+
class LigerSoftmax(nn.Module):
88
def __init__(self):
99
super().__init__()
1010

test/transformers/test_softmax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from test.utils import supports_bfloat16
77

88
from liger_kernel.transformers.functional import liger_softmax
9-
from liger_kernel.transformers.softmax import LigerKernelSoftmax
9+
from liger_kernel.transformers.softmax import LigerSoftmax
1010
from liger_kernel.utils import infer_device
1111

1212
device = infer_device()
@@ -47,7 +47,7 @@ def test_liger_softmax(shape, dtype, atol, rtol):
4747

4848
torch_softmax = torch.nn.Softmax(dim=-1)
4949
ref_out = torch_softmax(x1)
50-
liger_softmax = LigerKernelSoftmax().to(device).to(dtype)
50+
liger_softmax = LigerSoftmax().to(device).to(dtype)
5151
liger_out = liger_softmax(x2)
5252

5353
assert_verbose_allclose(ref_out, liger_out, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)