Skip to content

Commit cee2b56

Browse files
Fix missing low-level api imports (#856)
## Summary Some of the imports specified in the Readme don't work, this is a small fix to make them work. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe A100 Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 3bdb36e commit cee2b56

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

src/liger_kernel/transformers/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
1111
from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
1212
from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
13+
from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401
1314
from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
1415
from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
1516
from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17+
from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
1618
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
1719
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
20+
from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
21+
from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
1822
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
1923
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
2024
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
@@ -134,6 +138,10 @@ def __getattr__(name: str):
134138
"LigerQwen3MoeSwiGLUMLP",
135139
"LigerSwiGLUMLP",
136140
"LigerTVDLoss",
141+
"LigerKLDIVLoss",
142+
"LigerMultiTokenAttention",
143+
"LigerSoftmax",
144+
"LigerSparsemax",
137145
]
138146

139147
# Add transformer-dependent symbols only if available
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401
2+
3+
__all__ = [
4+
"LigerEmbedding",
5+
]

0 commit comments

Comments
 (0)