Skip to content

Commit 252251c

Browse files
committed
chore(ci): fix test path for scripts
1 parent 8fb2764 commit 252251c

File tree

6 files changed

+128
-169
lines changed

6 files changed

+128
-169
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,4 @@ jobs:
4949
retention-days: 1
5050

5151
- name: Check code quality
52-
run: |
53-
poetry run black --check babeltron/
54-
poetry run isort --check-only babeltron/
52+
run: make lint

.github/workflows/test.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,18 @@ jobs:
3333
poetry config virtualenvs.in-project true --local
3434
3535
- name: Install dependencies
36-
run: poetry install
36+
run: |
37+
python -m pip install --upgrade pip
38+
pip install poetry
39+
# Force remove any existing nvidia package to avoid conflicts
40+
pip uninstall -y nvidia-cuda-runtime-cu11 || true
41+
poetry install
3742
3843
- name: Run tests
3944
run: make test
4045

4146
- name: Check code quality
42-
run: |
43-
make lint
47+
run: make lint
4448

4549
- name: Upload coverage reports to Codecov
4650
uses: codecov/codecov-action@v3

babeltron/scripts/__init__.py

Whitespace-only changes.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Script to download M2M-100 models for Babeltron.
4+
"""
5+
import argparse
6+
import os
7+
from pathlib import Path
8+
from typing import List, Optional, Union
9+
10+
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
11+
12+
VALID_MODEL_SIZES: List[str] = os.environ.get(
13+
"BABELTRON_MODEL_SIZES", "418M,1.2B,12B"
14+
).split(",")
15+
DEFAULT_MODEL_SIZE: str = os.environ.get("BABELTRON_DEFAULT_MODEL_SIZE", "418M")
16+
DEFAULT_OUTPUT_DIR: Path = Path.home() / "models"
17+
18+
19+
def parse_args():
20+
parser = argparse.ArgumentParser(description="Download M2M100 translation models")
21+
parser.add_argument(
22+
"--size",
23+
choices=VALID_MODEL_SIZES,
24+
default=DEFAULT_MODEL_SIZE,
25+
help="Model size to download (418M, 1.2B, or 12B)",
26+
)
27+
parser.add_argument(
28+
"--output-dir", default=DEFAULT_OUTPUT_DIR, help="Directory to save the model"
29+
)
30+
return parser.parse_args()
31+
32+
33+
def download_model(
34+
model_size: str = DEFAULT_MODEL_SIZE,
35+
output_dir: Optional[Union[str, Path]] = None,
36+
show_progress: bool = True,
37+
) -> str:
38+
"""
39+
Download M2M-100 model and tokenizer.
40+
41+
Args:
42+
model_size (str): Size of the model to download (418M, 1.2B, or 12B)
43+
output_dir (str or Path, optional): Directory to save the model to
44+
show_progress (bool): Whether to show a progress bar
45+
46+
Returns:
47+
str: Path to the downloaded model directory
48+
"""
49+
if model_size not in VALID_MODEL_SIZES:
50+
raise ValueError(f"Model size must be one of {VALID_MODEL_SIZES}")
51+
52+
model_name = f"facebook/m2m100_{model_size}"
53+
54+
if output_dir is None:
55+
output_dir = DEFAULT_OUTPUT_DIR
56+
else:
57+
output_dir = Path(output_dir)
58+
59+
output_dir.mkdir(parents=True, exist_ok=True)
60+
61+
print(f"Downloading {model_name} model and tokenizer to {output_dir}...")
62+
63+
print("Downloading tokenizer...")
64+
tokenizer = M2M100Tokenizer.from_pretrained(model_name)
65+
tokenizer.save_pretrained(output_dir)
66+
67+
print("Downloading model (this may take a while)...")
68+
model = M2M100ForConditionalGeneration.from_pretrained(model_name)
69+
model.save_pretrained(output_dir)
70+
71+
print(f"Model and tokenizer successfully saved to {output_dir}")
72+
73+
return str(output_dir)
74+
75+
76+
def main():
77+
args = parse_args()
78+
79+
try:
80+
model_map = {"418M": "418M", "1.2B": "1.2B", "12B": "12B"}
81+
82+
model_size = model_map[args.size]
83+
output_dir = args.output_dir
84+
85+
print(f"Downloading {args.size} model...")
86+
print(
87+
"This may take a while depending on your internet connection and the model size."
88+
)
89+
90+
download_model(model_size=model_size, output_dir=output_dir)
91+
92+
print(f"Model successfully downloaded and saved to {output_dir}")
93+
94+
except Exception as e:
95+
print(f"Error downloading model: {e}")
96+
return 1
97+
98+
return 0
99+
100+
101+
if __name__ == "__main__":
102+
main()

0 commit comments

Comments
 (0)