Skip to content

Commit 4adc234

Browse files
authored
Merge branch 'main' into main
2 parents 63f1d17 + d1441a0 commit 4adc234

File tree

15 files changed

+89
-46
lines changed

15 files changed

+89
-46
lines changed

.github/workflows/amd-ci.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ jobs:
6464
run: |
6565
rocm-smi
6666
python -m pip install --upgrade pip
67-
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm_version }}
67+
pip install -e .[dev]
68+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm_version }}/
6869
6970
- name: List Python Environments
7071
run: python -m pip list

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ uv.lock
2323

2424
# Benchmark images
2525
benchmark/visualizations
26-
.vscode/
26+
.vscode/
27+
.coverage

Makefile

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@ all: checkstyle test test-convergence
55

66
# Command to run pytest for correctness tests
77
test:
8-
python -m pytest --disable-warnings test/ --ignore=test/convergence
8+
python -m pytest --disable-warnings \
9+
-n auto \
10+
--dist=load \
11+
--cov=src/liger_kernel \
12+
--cov-report=term-missing \
13+
--ignore=test/convergence \
14+
test/
15+
coverage combine
16+
coverage report -m
17+
coverage html
918

1019
# Command to run ruff for linting and formatting code
1120
checkstyle:

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ y = orpo_loss(lm_head.weight, x, target)
129129
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
130130

131131
```bash
132-
# Need to pass the url when installing
133-
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
132+
pip install -e .[dev]
133+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
134134
```
135135

136136
### Optional Dependencies
@@ -164,6 +164,9 @@ pip install -e .
164164

165165
# Setup Development Dependencies
166166
pip install -e ".[dev]"
167+
168+
# NOTE -> For AMD users only
169+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3/
167170
```
168171

169172

docs/Examples.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401
239239

240240
model = AutoModelForCausalLM.from_pretrained(
241241
"meta-llama/Llama-3.2-1B-Instruct",
242-
torch_dtype=torch.bfloat16,
242+
dtype=torch.bfloat16,
243243
)
244244

245245
tokenizer = AutoTokenizer.from_pretrained(

examples/alignment/run_orpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
model = AutoModelForCausalLM.from_pretrained(
1111
"meta-llama/Llama-3.2-1B-Instruct",
12-
torch_dtype=torch.bfloat16,
12+
dtype=torch.bfloat16,
1313
)
1414

1515
tokenizer = AutoTokenizer.from_pretrained(

examples/huggingface/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def train():
4848
custom_args.model_name,
4949
trust_remote_code=True,
5050
use_cache=False,
51-
torch_dtype=torch.bfloat16,
51+
dtype=torch.bfloat16,
5252
# These args will get passed to the appropriate apply_liger_kernel_to_* function
5353
# to override the default settings
5454
# cross_entropy=True,
@@ -59,7 +59,7 @@ def train():
5959
custom_args.model_name,
6060
trust_remote_code=True,
6161
use_cache=False,
62-
torch_dtype=torch.bfloat16,
62+
dtype=torch.bfloat16,
6363
)
6464

6565
trainer = SFTTrainer(

examples/huggingface/training_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def construct_model_and_processor(model_name: str, use_liger: bool) -> torch.nn.
5656
model = Qwen2VLForConditionalGeneration.from_pretrained(
5757
pretrained_model_name_or_path=model_name,
5858
use_cache=False,
59-
torch_dtype=torch.bfloat16,
59+
dtype=torch.bfloat16,
6060
low_cpu_mem_usage=True,
6161
attn_implementation="sdpa",
6262
)

examples/medusa/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def _model_loader():
319319
model = model_builder(
320320
model_args.model_name_or_path,
321321
cache_dir=training_args.cache_dir,
322-
torch_dtype=torch.bfloat16,
322+
dtype=torch.bfloat16,
323323
)
324324

325325
# Freeze the base model

pyproject.toml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,36 @@ pythonpath = ["src", "."]
2424
asyncio_mode = "auto"
2525
log_cli = true
2626
log_cli_level = "INFO"
27+
addopts = [
28+
"-n", "auto",
29+
"--dist=load", # use "load" to distribute tests and let pytest-cov combine coverage
30+
"--cov=src/liger_kernel",
31+
"--cov-report=term-missing",
32+
"--cov-report=html",
33+
"--cov-config=pyproject.toml",
34+
"--durations=0"
35+
]
36+
python_files = "test_*.py"
37+
testpaths = ["test/"]
38+
39+
[tool.coverage.run]
40+
branch = true
41+
parallel = true
42+
source = ["src/liger_kernel"]
43+
# xdist uses subprocesses; "multiprocessing" is a safe concurrency choice
44+
concurrency = ["multiprocessing"]
45+
46+
[tool.coverage.paths]
47+
liger_kernel = [
48+
"src/liger_kernel",
49+
"*/site-packages/liger_kernel"
50+
]
51+
52+
[tool.coverage.report]
53+
omit = ["test/*"]
54+
show_missing = true
55+
skip_covered = false
56+
2757

2858
[tool.ruff]
2959
line-length = 120

0 commit comments

Comments
 (0)