Skip to content

Commit 030cb71

Browse files
authored
Polish test/ and others (#26)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> as title ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ``` jobuser [ ~/Liger-Kernel ]$ make checkstyle && make test && make test-convergence flake8 .; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 1 files All done! ✨ 🍰 ✨ 45 files left unchanged. pytest --disable-warnings test/ --ignore=test/convergence ================================================================================================== test session starts =================================================================================================== platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191 collected 111 items test/transformers/test_cross_entropy.py .......................................................... [ 52%] test/transformers/test_fused_linear_cross_entropy.py ...... [ 57%] test/transformers/test_geglu.py ........ [ 64%] test/transformers/test_rms_norm.py ................ [ 79%] test/transformers/test_rope.py ............ [ 90%] test/transformers/test_swiglu.py ........ [ 97%] test/transformers/test_transformers_monkey_patch.py . [ 98%] test/triton/test_triton_monkey_patch.py .. [100%] ============================================================================================= 111 passed in 60.64s (0:01:00) ============================================================================================= HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence ================================================================================================== test session starts =================================================================================================== platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191 collected 8 items test/convergence/test_mini_models.py ...... [ 75%] test/convergence/test_mini_models_no_logits.py .. [100%] ============================================================================================== 8 passed in 95.88s (0:01:35) ============================================================================================== ```
1 parent f7f8384 commit 030cb71

File tree

7 files changed

+10
-12
lines changed

7 files changed

+10
-12
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with Torch Compile.
183183

184184
[CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
185185

186+
## Acknowledgement
187+
188+
- [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
189+
- [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) for convergence testing by andrej karpathy
190+
191+
186192
## License
187193

188194
[BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)

test/convergence/test_mini_models_no_logits.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class MiniModelConfig:
2626

2727
MINI_MODEL_SETUPS = {
2828
"mini_llama3": MiniModelConfig(
29+
# TODO (easy): replace with oss public path
2930
tokenizer_path="/shared/public/models/Meta-Llama-3-8B/",
3031
liger_kernel_patch_func=apply_liger_kernel_to_llama,
3132
model_class=LlamaForCausalLM,

test/transformers/test_cross_entropy.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,16 +155,8 @@ def test_correctness_with_ignore_index(
155155
@pytest.mark.parametrize(
156156
"scalar, dtype, atol, rtol",
157157
[
158-
# (0.01, torch.bfloat16, 1e-8, 5e-2),
159-
# (0.1, torch.bfloat16, 1e-8, 5e-2),
160158
(1.0, torch.bfloat16, 1e-8, 5e-2),
161-
# (10.0, torch.bfloat16, 1e-8, 5e-2),
162-
# (100.0, torch.bfloat16, 1e-8, 5e-2),
163-
# (0.01, torch.float32, 1e-8, 1e-6),
164-
# (0.1, torch.float32, 1e-8, 1e-6),
165159
(1.0, torch.float32, 1e-8, 1e-6),
166-
# (10.0, torch.float32, 1e-8, 1e-6),
167-
# (100.0, torch.float32, 1e-8, 1e-6),
168160
],
169161
)
170162
def test_correctness_not_last_layer(B, T, V, scalar, dtype, atol, rtol):

test/transformers/test_fused_linear_cross_entropy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def forward(self, x, y):
5858
[
5959
(2, 4, 512, 512),
6060
(8, 2048, 4096, 32000), # llama2, mistral
61+
# Comment out to speed up testing
6162
# (4, 2048, 4096, 128256), # llama3 8B
6263
# (4, 1024, 8192, 128256), # llama3 70B
6364
(4, 423, 8192, 32000), # random shape

test/transformers/test_geglu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
)
1313
SLEEP_SECONDS = 0.1
1414

15-
# TODO (yun dai): triton 3.0.0 breaks geglu due to tanh module issue
16-
1715

1816
@pytest.mark.parametrize(
1917
"bsz, seq_len, hidden_size, intermediate_size",

test/transformers/test_rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,5 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol):
7272
)
7373
is True
7474
)
75-
# import pdb; pdb.set_trace()
75+
7676
assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol) is True

test/transformers/test_swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# atol is for small values: they have more difference, so set atol higher
3030
# rtol is for larger values: they are very close, so set rtol lower
3131
(torch.float32, 1e-0, 1e-5),
32-
# TODO: we should find a better way to tune this lol. 1e4 is too large apparently
32+
# TODO: we should find a better way to tune this. 1e4 is too large apparently
3333
(torch.bfloat16, 1e4, 1e-2),
3434
],
3535
)

0 commit comments

Comments
 (0)