Skip to content

Commit 175cdf0

Browse files
Tcc0403daiyunwei1998vaibhavjindalyukiu00Mecoli1219
authored
Support transfomers v5 (#994)
>[!IMPORTANT] >Do not merge this PR before all issues are resolved! > >Testing with three versions: `4.52.0` , `4.57.6` and the latest stable version >[!NOTE] >nvi-ci is split into correctness test ci and convergence test ci to speed up testing in this PR >, and more jobs for testing bc with transformers v4 (4.49.0 and 4.57.6) > >Whether keeping this change or not is yet to be discussed ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This is a dev branch for aggregating PRs related to transformers v5 changes. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Co-authored-by: Yunwei Dai <daiyunwei1998@gmail.com> Co-authored-by: Vaibhav Jindal <vaibhav.jndl@gmail.com> Co-authored-by: Yuki Uehara <74698040+yukiu00@users.noreply.github.com> Co-authored-by: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Co-authored-by: Vaibhav Jindal <vjindal@linkedin.com> Co-authored-by: Michael Lai <michaellai901026@gmail.com>
1 parent 7d82f73 commit 175cdf0

36 files changed

+1140
-1857
lines changed

.github/workflows/nvi-ci.yml

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
- name: Run checkstyle
4747
run: make checkstyle
4848

49-
tests:
49+
correctness-tests:
5050
runs-on: ubuntu-latest
5151
needs: [checkstyle]
5252
env:
@@ -69,15 +69,14 @@ jobs:
6969
7070
- name: Run tests
7171
run: |
72-
modal run dev.modal.tests
72+
modal run dev.modal.tests::liger_correctness_tests
7373
74-
tests-bwd:
74+
convergence-tests:
7575
runs-on: ubuntu-latest
7676
needs: [checkstyle]
7777
env:
7878
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
7979
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
80-
REBUILD_IMAGE: ${{ github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' }}
8180

8281
steps:
8382
- name: Checkout code
@@ -95,4 +94,109 @@ jobs:
9594
9695
- name: Run tests
9796
run: |
98-
modal run dev.modal.tests_bwd
97+
modal run dev.modal.tests::liger_convergence_tests
98+
99+
100+
correctness-tests-with-transformers-4-52-0:
101+
runs-on: ubuntu-latest
102+
needs: [checkstyle]
103+
env:
104+
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
105+
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
106+
107+
steps:
108+
- name: Checkout code
109+
uses: actions/checkout@v6
110+
111+
- name: Set up Python
112+
uses: actions/setup-python@v6
113+
with:
114+
python-version: '3.10'
115+
116+
- name: Install dependencies
117+
run: |
118+
python -m pip install --upgrade pip
119+
pip install modal
120+
121+
- name: Run tests
122+
run: |
123+
modal run dev.modal.tests::liger_oldest_v4_correctness_tests
124+
125+
126+
127+
convergence-tests-with-transformers-4-52-0:
128+
runs-on: ubuntu-latest
129+
needs: [checkstyle]
130+
env:
131+
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
132+
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
133+
134+
steps:
135+
- name: Checkout code
136+
uses: actions/checkout@v6
137+
138+
- name: Set up Python
139+
uses: actions/setup-python@v6
140+
with:
141+
python-version: '3.10'
142+
143+
- name: Install dependencies
144+
run: |
145+
python -m pip install --upgrade pip
146+
pip install modal
147+
148+
- name: Run tests
149+
run: |
150+
modal run dev.modal.tests::liger_oldest_v4_convergence_tests
151+
152+
correctness-tests-with-transformers-4-57-6:
153+
runs-on: ubuntu-latest
154+
needs: [checkstyle]
155+
env:
156+
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
157+
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
158+
159+
steps:
160+
- name: Checkout code
161+
uses: actions/checkout@v6
162+
163+
- name: Set up Python
164+
uses: actions/setup-python@v6
165+
with:
166+
python-version: '3.10'
167+
168+
- name: Install dependencies
169+
run: |
170+
python -m pip install --upgrade pip
171+
pip install modal
172+
173+
- name: Run tests
174+
run: |
175+
modal run dev.modal.tests::liger_latest_v4_correctness_tests
176+
177+
178+
179+
convergence-tests-with-transformers-4-57-6:
180+
runs-on: ubuntu-latest
181+
needs: [checkstyle]
182+
env:
183+
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
184+
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
185+
186+
steps:
187+
- name: Checkout code
188+
uses: actions/checkout@v6
189+
190+
- name: Set up Python
191+
uses: actions/setup-python@v6
192+
with:
193+
python-version: '3.10'
194+
195+
- name: Install dependencies
196+
run: |
197+
python -m pip install --upgrade pip
198+
pip install modal
199+
200+
- name: Run tests
201+
run: |
202+
modal run dev.modal.tests::liger_latest_v4_convergence_tests

benchmark/scripts/benchmark_llama4_rope.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu
4040
num_key_value_heads=num_kv_heads,
4141
head_dim=head_dim,
4242
max_position_embeddings=seq_len,
43-
rope_theta=10000.0,
44-
rope_scaling=None, # Use default rope type
4543
)
4644

4745
rotary_emb = transformers_version_dispatch(
@@ -134,8 +132,6 @@ def bench_memory_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkR
134132
num_key_value_heads=num_kv_heads,
135133
head_dim=head_dim,
136134
max_position_embeddings=seq_len,
137-
rope_theta=10000.0,
138-
rope_scaling=None, # Use default rope type
139135
)
140136

141137
rotary_emb = transformers_version_dispatch(

dev/modal/tests.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
REMOTE_ROOT_PATH = "/root/liger-kernel"
77
PYTHON_VERSION = "3.12"
88

9+
OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION = "4.52.0"
10+
911
image = modal.Image.debian_slim(python_version=PYTHON_VERSION).pip_install("uv")
1012

1113
app = modal.App("liger_tests", image=image)
@@ -15,7 +17,77 @@
1517

1618

1719
@app.function(gpu="H100!", image=repo, timeout=90 * 60)
18-
def liger_tests():
20+
def liger_correctness_tests():
21+
import subprocess
22+
23+
subprocess.run(
24+
["uv pip install -e '.[dev]' --system"],
25+
check=True,
26+
shell=True,
27+
cwd=REMOTE_ROOT_PATH,
28+
)
29+
subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)
30+
31+
32+
@app.function(gpu="H100!", image=repo, timeout=90 * 60)
33+
def liger_convergence_tests():
34+
import subprocess
35+
36+
subprocess.run(
37+
["uv pip install -e '.[dev]' --system"],
38+
check=True,
39+
shell=True,
40+
cwd=REMOTE_ROOT_PATH,
41+
)
42+
subprocess.run(["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)
43+
44+
45+
oldest_v4_app = modal.App("liger_oldest_v4_tests", image=image) # 4.52.0
46+
47+
48+
@oldest_v4_app.function(gpu="H100!", image=repo, timeout=90 * 60)
49+
def liger_oldest_v4_correctness_tests():
50+
import subprocess
51+
52+
subprocess.run(
53+
["uv pip install -e '.[dev]' --system"],
54+
check=True,
55+
shell=True,
56+
cwd=REMOTE_ROOT_PATH,
57+
)
58+
subprocess.run(
59+
[f"uv pip install 'transformers=={OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION}' --system"],
60+
check=True,
61+
shell=True,
62+
cwd=REMOTE_ROOT_PATH,
63+
)
64+
subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)
65+
66+
67+
@oldest_v4_app.function(gpu="H100!", image=repo, timeout=90 * 60)
68+
def liger_oldest_v4_convergence_tests():
69+
import subprocess
70+
71+
subprocess.run(
72+
["uv pip install -e '.[dev]' --system"],
73+
check=True,
74+
shell=True,
75+
cwd=REMOTE_ROOT_PATH,
76+
)
77+
subprocess.run(
78+
[f"uv pip install 'transformers=={OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION}' --system"],
79+
check=True,
80+
shell=True,
81+
cwd=REMOTE_ROOT_PATH,
82+
)
83+
subprocess.run(["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)
84+
85+
86+
latest_v4_app = modal.App("liger_latest_v4_tests", image=image) # 4.57.6
87+
88+
89+
@latest_v4_app.function(gpu="H100!", image=repo, timeout=90 * 60)
90+
def liger_latest_v4_correctness_tests():
1991
import subprocess
2092

2193
subprocess.run(
@@ -24,5 +96,29 @@ def liger_tests():
2496
shell=True,
2597
cwd=REMOTE_ROOT_PATH,
2698
)
99+
subprocess.run(
100+
[f"uv pip install 'transformers>={OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION}, <5.0.0' --system"],
101+
check=True,
102+
shell=True,
103+
cwd=REMOTE_ROOT_PATH,
104+
)
27105
subprocess.run(["make test"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)
106+
107+
108+
@latest_v4_app.function(gpu="H100!", image=repo, timeout=90 * 60)
109+
def liger_latest_v4_convergence_tests():
110+
import subprocess
111+
112+
subprocess.run(
113+
["uv pip install -e '.[dev]' --system"],
114+
check=True,
115+
shell=True,
116+
cwd=REMOTE_ROOT_PATH,
117+
)
118+
subprocess.run(
119+
[f"uv pip install 'transformers>={OLDEST_SUPPORTED_TRANSFORMERS_V4_VERSION}, <5.0.0' --system"],
120+
check=True,
121+
shell=True,
122+
cwd=REMOTE_ROOT_PATH,
123+
)
28124
subprocess.run(["make test-convergence"], check=True, shell=True, cwd=REMOTE_ROOT_PATH)

dev/modal/tests_bwd.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get_optional_dependencies():
3333
"""Get optional dependency groups."""
3434
return {
3535
"dev": [
36-
"transformers>=4.52.0, <5.0.0",
36+
"transformers>=4.52.0",
3737
"matplotlib>=3.7.2",
3838
"ruff>=0.12.0",
3939
"pytest>=7.1.2",

src/liger_kernel/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
2222
from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
2323
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
24+
from liger_kernel.transformers.swiglu import LigerExperts # noqa: F401
2425
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
2526
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
2627
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401

0 commit comments

Comments
 (0)