Skip to content

Commit c80d1aa

Browse files
whisper tiny pass
1 parent 78d79da commit c80d1aa

File tree

4 files changed

+75
-22
lines changed

4 files changed

+75
-22
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# please consider a global .gitignore https://help.github.com/articles/ignoring-files
44
.gitignore
55
.egg-info
6+
.ruff_cache
67
.vagrant*
78
.hcl
89
.terraform.lock.hcl
910
.terraform
11+
pip-unpack-*
1012
__pycache__
1113
bin
1214
docker/docker

src/huggingface_inference_toolkit/utils.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _load_repository_from_hf(
140140

141141
if framework is None:
142142
framework = _get_framework()
143+
logging.info(f"Framework: {framework}")
143144

144145
if isinstance(target_dir, str):
145146
target_dir = Path(target_dir)
@@ -149,22 +150,24 @@ def _load_repository_from_hf(
149150
target_dir.mkdir(parents=True)
150151

151152
# check if safetensors weights are available
152-
if framework == "pytorch":
153-
files = HfApi().model_info(repository_id).siblings
154-
if any(f.rfilename.endswith("safetensors") for f in files):
155-
framework = "safetensors"
153+
#if framework == "pytorch":
154+
#files = HfApi().model_info(repository_id).siblings
155+
#if any(f.rfilename.endswith("safetensors") for f in files):
156+
#framework = "safetensors"
156157

157158
# create regex to only include the framework specific weights
158159
ignore_regex = create_artifact_filter(framework)
160+
logging.info(f"ignore_regex: {ignore_regex}")
161+
logging.info(f"Framework after filtering: {framework}")
159162
logger.info(f"Ignore regex pattern for files, which are not downloaded: { ', '.join(ignore_regex) }")
160163

161164
# Download the repository to the workdir and filter out non-framework specific weights
162165
snapshot_download(
163-
repository_id,
164-
revision=revision,
165-
local_dir=str(target_dir),
166-
local_dir_use_symlinks=False,
167-
ignore_patterns=ignore_regex,
166+
repo_id = repository_id,
167+
revision = revision,
168+
local_dir = str(target_dir),
169+
local_dir_use_symlinks = False,
170+
ignore_patterns = ignore_regex,
168171
)
169172

170173
return target_dir
@@ -223,7 +226,12 @@ def get_device():
223226
return -1
224227

225228

226-
def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
229+
def get_pipeline(
230+
task: str,
231+
model_dir: Path,
232+
framework = "pytorch",
233+
**kwargs,
234+
) -> Pipeline:
227235
"""
228236
create pipeline class for a specific task based on local saved model
229237
"""
@@ -244,6 +252,12 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
244252
"zero-shot-image-classification",
245253
}:
246254
kwargs["feature_extractor"] = model_dir
255+
hf_pipeline = pipeline(
256+
task=task,
257+
model=model_dir,
258+
device=device,
259+
**kwargs
260+
)
247261
elif task in {"image-to-text"}:
248262
pass
249263
else:
@@ -265,12 +279,20 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
265279
logging.info(f"Model: {model_dir}")
266280
logging.info(f"Device: {device}")
267281
logging.info(f"Args: {kwargs}")
268-
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)
282+
hf_pipeline = pipeline(
283+
task=task,
284+
model=model_dir,
285+
device=device,
286+
**kwargs
287+
)
269288

270289
# wrapp specific pipeline to support better ux
271290
if task == "conversational":
272291
hf_pipeline = wrap_conversation_pipeline(hf_pipeline)
273-
elif task == "automatic-speech-recognition" and isinstance(hf_pipeline.model, WhisperForConditionalGeneration):
292+
elif task == "automatic-speech-recognition" and isinstance(
293+
hf_pipeline.model,
294+
WhisperForConditionalGeneration
295+
):
274296
# set chunk length to 30s for whisper to enable long audio files
275297
hf_pipeline._preprocess_params["chunk_length_s"] = 30
276298
hf_pipeline._preprocess_params["ignore_warning"] = True

tests/unit/test_utils.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
wrap_conversation_pipeline,
1818
)
1919

20+
import logging
21+
2022
MODEL = "lysandre/tiny-bert-random"
2123
TASK = "text-classification"
2224
TASK_MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
@@ -112,17 +114,39 @@ def test_get_framework_tensorflow():
112114
def test_get_pipeline():
113115
with tempfile.TemporaryDirectory() as tmpdirname:
114116
storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch")
115-
pipe = get_pipeline(TASK, storage_dir.as_posix())
117+
pipe = get_pipeline(
118+
task = TASK,
119+
model_dir = storage_dir.as_posix(),
120+
framework = "pytorch"
121+
)
116122
res = pipe("Life is good, Life is bad")
117123
assert "score" in res[0]
118124

119125

120126
@require_torch
121127
def test_whisper_long_audio():
122128
with tempfile.TemporaryDirectory() as tmpdirname:
123-
storage_dir = _load_repository_from_hf("openai/whisper-tiny", tmpdirname, framework="pytorch")
124-
pipe = get_pipeline("automatic-speech-recognition", storage_dir.as_posix())
125-
res = pipe(os.path.join(os.getcwd(), "tests/resources/audio", "long_sample.mp3"))
129+
storage_dir = _load_repository_from_hf(
130+
repository_id = "openai/whisper-tiny",
131+
target_dir = tmpdirname,
132+
framework = "pytorch",
133+
revision = "be0ba7c2f24f0127b27863a23a08002af4c2c279"
134+
)
135+
logging.info(f"Temp dir: {tmpdirname}")
136+
logging.info(f"POSIX Path: {storage_dir.as_posix()}")
137+
logging.info(f"Contents: {os.listdir(tmpdirname)}")
138+
pipe = get_pipeline(
139+
task = "automatic-speech-recognition",
140+
model_dir = storage_dir.as_posix(),
141+
framework = "safetensors"
142+
)
143+
res = pipe(
144+
os.path.join(
145+
os.getcwd(),
146+
"tests/resources/audio",
147+
"long_sample.mp3"
148+
)
149+
)
126150

127151
assert len(res["text"]) > 700
128152

@@ -149,7 +173,7 @@ def test_wrap_conversation_pipeline():
149173
@require_torch
150174
def test_wrapped_pipeline():
151175
with tempfile.TemporaryDirectory() as tmpdirname:
152-
storage_dir = _load_repository_from_hf("microsoft/DialoGPT-small", tmpdirname, framework="pytorch")
176+
storage_dir = _load_repository_from_hf("hf-internal-testing/tiny-random-blenderbot", tmpdirname, framework="pytorch")
153177
conv_pipe = get_pipeline("conversational", storage_dir.as_posix())
154178
data = {
155179
"past_user_inputs": ["Which movie is the best ?"],

tox.ini

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,19 @@ commands = ruff src --fix
2020

2121
[testenv:unit-torch]
2222
install_command =
23-
pip install -e .
24-
pip install -e ".[test,dev,torch,st]"
25-
allowlist_externals = pytest
23+
pip install -e ".[test,torch,st]"
24+
allowlist_externals =
25+
pytest
2626
commands =
2727
pytest -s -v \
2828
{tty:--color=yes} \
29-
tests/unit/ {posargs} \
30-
--log-cli-level=ERROR \
29+
tests/unit/test_const.py \
30+
tests/unit/test_handler.py \
31+
tests/unit/test_sentence_transformers.py \
32+
tests/unit/test_serializer.py \
33+
tests/unit/test_utils.py \
34+
{posargs} \
35+
--log-cli-level=DEBUG \
3136
--log-format='%(asctime)s %(levelname)s %(module)s:%(lineno)d %(message)s'
3237

3338
[testenv:unit-torch-slow]

0 commit comments

Comments
 (0)