Skip to content

Commit 0a5d527

Browse files
author
ochafik
committed
Update fetch_server_test_models.py
1 parent 0e87ae2 commit 0a5d527

File tree

1 file changed

+72
-52
lines changed

1 file changed

+72
-52
lines changed

scripts/fetch_server_test_models.py

100644100755
Lines changed: 72 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python
12
'''
23
This script fetches all the models used in the server tests.
34
@@ -7,13 +8,14 @@
78
89
Example:
910
python scripts/fetch_server_test_models.py
10-
( cd examples/server/tests && ./tests.sh --tags=slow )
11+
( cd examples/server/tests && ./tests.sh -v -x -m slow )
1112
'''
12-
from behave.parser import Parser
13+
import ast
1314
import glob
15+
import logging
1416
import os
17+
from typing import Generator
1518
from pydantic import BaseModel
16-
import re
1719
import subprocess
1820
import sys
1921

@@ -26,53 +28,71 @@ class Config:
2628
frozen = True
2729

2830

29-
models = set()
30-
31-
model_file_re = re.compile(r'a model file ([^\s\n\r]+) from HF repo ([^\s\n\r]+)')
32-
33-
34-
def process_step(step):
35-
if (match := model_file_re.search(step.name)):
36-
(hf_file, hf_repo) = match.groups()
37-
models.add(HuggingFaceModel(hf_repo=hf_repo, hf_file=hf_file))
38-
39-
40-
feature_files = glob.glob(
41-
os.path.join(
42-
os.path.dirname(__file__),
43-
'../examples/server/tests/features/*.feature'))
44-
45-
for feature_file in feature_files:
46-
with open(feature_file, 'r') as file:
47-
feature = Parser().parse(file.read())
48-
if not feature: continue
49-
50-
if feature.background:
51-
for step in feature.background.steps:
52-
process_step(step)
53-
54-
for scenario in feature.walk_scenarios(with_outlines=True):
55-
for step in scenario.steps:
56-
process_step(step)
57-
58-
cli_path = os.environ.get(
59-
'LLAMA_SERVER_BIN_PATH',
60-
os.path.join(
61-
os.path.dirname(__file__),
62-
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
63-
64-
for m in sorted(list(models), key=lambda m: m.hf_repo):
65-
if '<' in m.hf_repo or '<' in m.hf_file:
66-
continue
67-
if '-of-' in m.hf_file:
68-
print(f'# Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file', file=sys.stderr)
69-
continue
70-
print(f'# Ensuring model at {m.hf_repo} / {m.hf_file} is fetched')
71-
cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable']
72-
if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'):
73-
cmd.append('-fa')
31+
def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
7432
try:
75-
subprocess.check_call(cmd)
76-
except subprocess.CalledProcessError:
77-
print(f'# Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}', file=sys.stderr)
78-
exit(1)
33+
with open(test_file) as f:
34+
tree = ast.parse(f.read())
35+
except Exception as e:
36+
logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
37+
return
38+
39+
for node in ast.walk(tree):
40+
if isinstance(node, ast.FunctionDef):
41+
for dec in node.decorator_list:
42+
if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
43+
param_names = ast.literal_eval(dec.args[0]).split(",")
44+
if not "hf_repo" in param_names or not "hf_file" in param_names:
45+
continue
46+
47+
raw_param_values = dec.args[1]
48+
if not isinstance(raw_param_values, ast.List):
49+
logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
50+
continue
51+
52+
hf_repo_idx = param_names.index("hf_repo")
53+
hf_file_idx = param_names.index("hf_file")
54+
55+
for t in raw_param_values.elts:
56+
if not isinstance(t, ast.Tuple):
57+
logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
58+
continue
59+
yield HuggingFaceModel(
60+
hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
61+
hf_file=ast.literal_eval(t.elts[hf_file_idx]))
62+
63+
64+
if __name__ == '__main__':
65+
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
66+
67+
models = sorted(list(set([
68+
model
69+
for test_file in glob.glob('examples/server/tests/unit/test_*.py')
70+
for model in collect_hf_model_test_parameters(test_file)
71+
])), key=lambda m: (m.hf_repo, m.hf_file))
72+
73+
logging.info(f'Found {len(models)} models in parameterized tests:')
74+
for m in models:
75+
logging.info(f' - {m.hf_repo} / {m.hf_file}')
76+
77+
cli_path = os.environ.get(
78+
'LLAMA_SERVER_BIN_PATH',
79+
os.path.join(
80+
os.path.dirname(__file__),
81+
'../build/bin/Release/llama-cli.exe' if os.name == 'nt' \
82+
else '../build/bin/llama-cli'))
83+
84+
for m in models:
85+
if '<' in m.hf_repo or '<' in m.hf_file:
86+
continue
87+
if '-of-' in m.hf_file:
88+
logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
89+
continue
90+
logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
91+
cmd = [cli_path, '-hfr', m.hf_repo, '-hff', m.hf_file, '-n', '1', '-p', 'Hey', '--no-warmup', '--log-disable']
92+
if m.hf_file != 'tinyllamas/stories260K.gguf' and not m.hf_file.startswith('Mistral-Nemo'):
93+
cmd.append('-fa')
94+
try:
95+
subprocess.check_call(cmd)
96+
except subprocess.CalledProcessError:
97+
logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
98+
exit(1)

0 commit comments

Comments
 (0)