Skip to content

Commit cf0b5b3

Browse files
authored
Extract model download logic to utils (#146)
* Extract ensure_model to utils * migrate existing usage_model refs to utils * Fix ruff errors
1 parent 9bb68cf commit cf0b5b3

File tree

4 files changed

+51
-90
lines changed

4 files changed

+51
-90
lines changed

agents-core/vision_agents/core/utils/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import io
23
import logging
34
import re
@@ -6,6 +7,7 @@
67
from dataclasses import dataclass
78
from typing import Dict, Optional
89
from PIL import Image
10+
import httpx
911

1012

1113
# Type alias for markdown file contents: maps filename to file content
@@ -108,3 +110,48 @@ def get_vision_agents_version() -> str:
108110
return importlib.metadata.version("vision-agents")
109111
except importlib.metadata.PackageNotFoundError:
110112
return "unknown"
113+
114+
115+
async def ensure_model(path: str, url: str) -> str:
116+
"""
117+
Download a model file asynchronously if it doesn't exist.
118+
119+
Args:
120+
path: Local path where the model should be saved
121+
url: URL to download the model from
122+
123+
Returns:
124+
The path to the model file
125+
"""
126+
127+
logger = logging.getLogger(__name__)
128+
if not os.path.exists(path):
129+
model_name = os.path.basename(path)
130+
logger.info(f"Downloading {model_name}...")
131+
132+
try:
133+
async with httpx.AsyncClient(timeout=300.0, follow_redirects=True) as client:
134+
async with client.stream("GET", url) as response:
135+
response.raise_for_status()
136+
137+
# Write file in chunks to avoid loading entire file in memory
138+
chunks = []
139+
async for chunk in response.aiter_bytes(chunk_size=8192):
140+
chunks.append(chunk)
141+
142+
# Write all chunks to file in thread to avoid blocking event loop
143+
def write_file():
144+
with open(path, "wb") as f:
145+
for chunk in chunks:
146+
f.write(chunk)
147+
148+
await asyncio.to_thread(write_file)
149+
150+
logger.info(f"{model_name} downloaded.")
151+
except httpx.HTTPError as e:
152+
# Clean up partial download on error
153+
if os.path.exists(path):
154+
os.remove(path)
155+
raise RuntimeError(f"Failed to download {model_name}: {e}")
156+
157+
return path

plugins/smart_turn/tests/test_smart_turn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from vision_agents.plugins.smart_turn.smart_turn_detection import (
1010
SileroVAD,
1111
SmartTurnDetection,
12-
ensure_model,
1312
SILERO_ONNX_URL,
1413
SILERO_ONNX_FILENAME,
1514
)
15+
from vision_agents.core.utils.utils import ensure_model
16+
1617
import logging
1718

1819
logger = logging.getLogger(__name__)

plugins/smart_turn/vision_agents/plugins/smart_turn/smart_turn_detection.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from dataclasses import dataclass
55
from typing import Optional, Any
66

7-
import httpx
87
from getstream.video.rtc.track_util import PcmData, AudioFormat
98
import numpy as np
109
import onnxruntime as ort
@@ -19,6 +18,7 @@
1918
TurnStartedEvent,
2019
TurnEndedEvent,
2120
)
21+
from vision_agents.core.utils.utils import ensure_model
2222

2323
import logging
2424

@@ -462,47 +462,3 @@ def _predict_speech(self, chunk_f32: np.ndarray) -> float:
462462

463463
# out shape is (1, 1) -> return scalar
464464
return float(out[0][0])
465-
466-
467-
async def ensure_model(path: str, url: str) -> str:
468-
"""
469-
Download a model file asynchronously if it doesn't exist.
470-
471-
Args:
472-
path: Local path where the model should be saved
473-
url: URL to download the model from
474-
475-
Returns:
476-
The path to the model file
477-
"""
478-
if not os.path.exists(path):
479-
model_name = os.path.basename(path)
480-
481-
try:
482-
async with httpx.AsyncClient(
483-
timeout=300.0, follow_redirects=True
484-
) as client:
485-
async with client.stream("GET", url) as response:
486-
response.raise_for_status()
487-
488-
# Write file in chunks to avoid loading entire file in memory
489-
# Use asyncio.to_thread for blocking file I/O operations
490-
chunks = []
491-
async for chunk in response.aiter_bytes(chunk_size=8192):
492-
chunks.append(chunk)
493-
494-
# Write all chunks to file in thread to avoid blocking event loop
495-
def write_file():
496-
with open(path, "wb") as f:
497-
for chunk in chunks:
498-
f.write(chunk)
499-
500-
await asyncio.to_thread(write_file)
501-
502-
except httpx.HTTPError as e:
503-
# Clean up partial download on error
504-
if os.path.exists(path):
505-
os.remove(path)
506-
raise RuntimeError(f"Failed to download {model_name}: {e}")
507-
508-
return path

plugins/vogent/vision_agents/plugins/vogent/vogent_turn_detection.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from dataclasses import dataclass
55
from typing import Optional, Any
66

7-
import httpx
87
import numpy as np
98
from faster_whisper import WhisperModel
109
from getstream.video.rtc.track_util import PcmData, AudioFormat
@@ -17,6 +16,7 @@
1716
TurnStartedEvent,
1817
TurnEndedEvent,
1918
)
19+
from vision_agents.core.utils.utils import ensure_model
2020

2121
import logging
2222

@@ -509,46 +509,3 @@ def predict_speech(self, chunk_f32: np.ndarray) -> float:
509509
return float(out[0][0])
510510

511511

512-
async def ensure_model(path: str, url: str) -> str:
513-
"""
514-
Download a model file asynchronously if it doesn't exist.
515-
516-
Args:
517-
path: Local path where the model should be saved
518-
url: URL to download the model from
519-
520-
Returns:
521-
The path to the model file
522-
"""
523-
if not os.path.exists(path):
524-
model_name = os.path.basename(path)
525-
logger.info(f"Downloading {model_name}...")
526-
527-
try:
528-
async with httpx.AsyncClient(timeout=300.0, follow_redirects=True) as client:
529-
async with client.stream("GET", url) as response:
530-
response.raise_for_status()
531-
532-
# Write file in chunks to avoid loading entire file in memory
533-
chunks = []
534-
async for chunk in response.aiter_bytes(chunk_size=8192):
535-
chunks.append(chunk)
536-
537-
# Write all chunks to file in thread to avoid blocking event loop
538-
def write_file():
539-
with open(path, "wb") as f:
540-
for chunk in chunks:
541-
f.write(chunk)
542-
543-
await asyncio.to_thread(write_file)
544-
545-
logger.info(f"{model_name} downloaded.")
546-
except httpx.HTTPError as e:
547-
# Clean up partial download on error
548-
if os.path.exists(path):
549-
os.remove(path)
550-
raise RuntimeError(f"Failed to download {model_name}: {e}")
551-
552-
return path
553-
554-

0 commit comments

Comments
 (0)