Skip to content

Commit f732fd3

Browse files
committed
chore(ci): fix test path for scripts
1 parent 8fb2764 commit f732fd3

File tree

6 files changed

+142
-171
lines changed

6 files changed

+142
-171
lines changed

.github/workflows/build.yml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,26 @@ jobs:
3535
poetry config virtualenvs.create true --local
3636
poetry config virtualenvs.in-project true --local
3737
38+
# Add caching for Poetry dependencies
39+
- name: Cache Poetry dependencies
40+
uses: actions/cache@v3
41+
with:
42+
path: .venv
43+
key: ${{ runner.os }}-poetry-3.9-${{ hashFiles('**/poetry.lock') }}
44+
restore-keys: |
45+
${{ runner.os }}-poetry-3.9-
46+
3847
- name: Install dependencies
3948
run: poetry install
4049

4150
- name: Build package
4251
run: poetry build
4352

44-
- name: Upload artifact
45-
uses: actions/upload-artifact@v4
53+
- name: Upload build artifacts
54+
uses: actions/upload-artifact@v3
4655
with:
4756
name: dist
4857
path: dist/
49-
retention-days: 1
5058

5159
- name: Check code quality
52-
run: |
53-
poetry run black --check babeltron/
54-
poetry run isort --check-only babeltron/
60+
run: make lint

.github/workflows/test.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,25 @@ jobs:
3232
poetry config virtualenvs.create true --local
3333
poetry config virtualenvs.in-project true --local
3434
35+
- name: Cache Poetry dependencies
36+
uses: actions/cache@v3
37+
with:
38+
path: .venv
39+
key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
40+
restore-keys: |
41+
${{ runner.os }}-poetry-${{ matrix.python-version }}-
42+
3543
- name: Install dependencies
3644
run: poetry install
3745

3846
- name: Run tests
3947
run: make test
4048

4149
- name: Check code quality
42-
run: |
43-
make lint
50+
run: make lint
4451

4552
- name: Upload coverage reports to Codecov
53+
if: matrix.python-version == '3.9'
4654
uses: codecov/codecov-action@v3
4755
env:
4856
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

babeltron/scripts/__init__.py

Whitespace-only changes.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Script to download M2M-100 models for Babeltron.
4+
"""
5+
import argparse
6+
import os
7+
from pathlib import Path
8+
from typing import List, Optional, Union
9+
10+
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
11+
12+
VALID_MODEL_SIZES: List[str] = os.environ.get(
13+
"BABELTRON_MODEL_SIZES", "418M,1.2B,12B"
14+
).split(",")
15+
DEFAULT_MODEL_SIZE: str = os.environ.get("BABELTRON_DEFAULT_MODEL_SIZE", "418M")
16+
DEFAULT_OUTPUT_DIR: Path = Path.home() / "models"
17+
18+
19+
def parse_args():
20+
parser = argparse.ArgumentParser(description="Download M2M100 translation models")
21+
parser.add_argument(
22+
"--size",
23+
choices=VALID_MODEL_SIZES,
24+
default=DEFAULT_MODEL_SIZE,
25+
help="Model size to download (418M, 1.2B, or 12B)",
26+
)
27+
parser.add_argument(
28+
"--output-dir", default=DEFAULT_OUTPUT_DIR, help="Directory to save the model"
29+
)
30+
return parser.parse_args()
31+
32+
33+
def download_model(
34+
model_size: str = DEFAULT_MODEL_SIZE,
35+
output_dir: Optional[Union[str, Path]] = None,
36+
show_progress: bool = True,
37+
) -> str:
38+
"""
39+
Download M2M-100 model and tokenizer.
40+
41+
Args:
42+
model_size (str): Size of the model to download (418M, 1.2B, or 12B)
43+
output_dir (str or Path, optional): Directory to save the model to
44+
show_progress (bool): Whether to show a progress bar
45+
46+
Returns:
47+
str: Path to the downloaded model directory
48+
"""
49+
if model_size not in VALID_MODEL_SIZES:
50+
raise ValueError(f"Model size must be one of {VALID_MODEL_SIZES}")
51+
52+
model_name = f"facebook/m2m100_{model_size}"
53+
54+
if output_dir is None:
55+
output_dir = DEFAULT_OUTPUT_DIR
56+
else:
57+
output_dir = Path(output_dir)
58+
59+
output_dir.mkdir(parents=True, exist_ok=True)
60+
61+
print(f"Downloading {model_name} model and tokenizer to {output_dir}...")
62+
63+
print("Downloading tokenizer...")
64+
tokenizer = M2M100Tokenizer.from_pretrained(model_name)
65+
tokenizer.save_pretrained(output_dir)
66+
67+
print("Downloading model (this may take a while)...")
68+
model = M2M100ForConditionalGeneration.from_pretrained(model_name)
69+
model.save_pretrained(output_dir)
70+
71+
print(f"Model and tokenizer successfully saved to {output_dir}")
72+
73+
return str(output_dir)
74+
75+
76+
def main():
77+
args = parse_args()
78+
79+
try:
80+
model_map = {"418M": "418M", "1.2B": "1.2B", "12B": "12B"}
81+
82+
model_size = model_map[args.size]
83+
output_dir = args.output_dir
84+
85+
print(f"Downloading {args.size} model...")
86+
print(
87+
"This may take a while depending on your internet connection and the model size."
88+
)
89+
90+
download_model(model_size=model_size, output_dir=output_dir)
91+
92+
print(f"Model successfully downloaded and saved to {output_dir}")
93+
94+
except Exception as e:
95+
print(f"Error downloading model: {e}")
96+
return 1
97+
98+
return 0
99+
100+
101+
if __name__ == "__main__":
102+
main()

0 commit comments

Comments
 (0)