Skip to content

Commit 4686f52

Browse files
committed
Add tests
1 parent 204423d commit 4686f52

File tree

2 files changed

+247
-15
lines changed

2 files changed

+247
-15
lines changed

label_studio_ml/examples/deepgram/test_api.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,38 +10,90 @@
1010
- Change the `request` and `expected_response` variables to match the input and output of your model.
1111
"""
1212

13-
import pytest
1413
import json
15-
from model import NewModel
14+
15+
import pytest
16+
from label_studio_ml.response import ModelResponse
17+
from label_studio_sdk.label_interface.objects import PredictionValue
18+
from model import DeepgramModel
1619

1720

1821
@pytest.fixture
1922
def client():
2023
from _wsgi import init_app
21-
app = init_app(model_class=NewModel)
24+
app = init_app(model_class=DeepgramModel)
2225
app.config['TESTING'] = True
2326
with app.test_client() as client:
2427
yield client
2528

2629

27-
def test_predict(client):
30+
def test_predict(client, monkeypatch):
31+
"""
32+
Scenario: exercise the /predict endpoint with minimal payload.
33+
Steps : patch DeepgramModel.setup to avoid env var requirements, POST minimal request.
34+
Checks : ensure HTTP 200 is returned with empty results when no context is provided.
35+
"""
36+
# Patch setup to avoid requiring DEEPGRAM_API_KEY during model instantiation
37+
monkeypatch.setattr(DeepgramModel, 'setup', lambda self: None)
38+
2839
request = {
2940
'tasks': [{
30-
'data': {
31-
# Your input test data here
32-
}
41+
'id': 1,
42+
'data': {}
3343
}],
34-
# Your labeling configuration here
35-
'label_config': '<View></View>'
44+
'label_config': '<View></View>',
45+
'project': '1.1234567890'
3646
}
3747

38-
expected_response = {
39-
'results': [{
40-
# Your expected result here
48+
response = client.post('/predict', data=json.dumps(request), content_type='application/json')
49+
assert response.status_code == 200
50+
body = json.loads(response.data)
51+
assert 'results' in body
52+
# When no context is provided, predict returns empty predictions
53+
assert body['results'] == []
54+
55+
56+
def test_predict_endpoint_returns_stubbed_predictions(client, monkeypatch):
57+
"""
58+
Scenario: exercise the /predict endpoint without hitting external services.
59+
Steps : patch DeepgramModel.setup and predict to avoid env vars and return stubbed data,
60+
POST realistic payload to /predict, parse the JSON.
61+
Checks : ensure HTTP 200 is returned and the payload's `results` field matches the stub.
62+
"""
63+
# Create a proper PredictionValue object with result structure
64+
stub_prediction = PredictionValue(
65+
result=[{
66+
'from_name': 'text',
67+
'to_name': 'audio',
68+
'type': 'textarea',
69+
'value': {'text': ['Hello from stub']}
4170
}]
71+
)
72+
73+
# Patch setup to avoid requiring DEEPGRAM_API_KEY during model instantiation
74+
monkeypatch.setattr(DeepgramModel, 'setup', lambda self: None)
75+
76+
def fake_predict(self, tasks, context=None, **params):
77+
return ModelResponse(predictions=[stub_prediction])
78+
79+
monkeypatch.setattr(DeepgramModel, 'predict', fake_predict)
80+
81+
request_payload = {
82+
'tasks': [{
83+
'id': 42,
84+
'data': {'text': 'Sample request text'}
85+
}],
86+
'label_config': '<View><TextArea name="text" toName="audio"/></View>',
87+
'project': '1.1234567890',
88+
'params': {'context': {'result': []}}
4289
}
4390

44-
response = client.post('/predict', data=json.dumps(request), content_type='application/json')
91+
response = client.post('/predict', data=json.dumps(request_payload), content_type='application/json')
92+
4593
assert response.status_code == 200
46-
response = json.loads(response.data)
47-
assert response == expected_response
94+
body = json.loads(response.data)
95+
# The API returns results which should contain the prediction's result
96+
assert 'results' in body
97+
assert len(body['results']) == 1
98+
# Verify the structure matches what we stubbed
99+
assert body['results'][0]['result'][0]['value']['text'] == ['Hello from stub']
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import os
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
from label_studio_ml.response import ModelResponse
6+
7+
# Ensure the Label Studio SDK inside the Deepgram example sees harmless defaults.
8+
os.environ.setdefault('LABEL_STUDIO_URL', 'http://localhost')
9+
os.environ.setdefault('LABEL_STUDIO_API_KEY', 'test-token')
10+
11+
from label_studio_ml.examples.deepgram import model as deepgram_model # noqa: E402
12+
13+
14+
@pytest.fixture
15+
def env_settings(monkeypatch):
16+
"""Provide the environment variables required by the Deepgram example."""
17+
settings = {
18+
'DEEPGRAM_API_KEY': 'dg-key',
19+
'AWS_DEFAULT_REGION': 'us-east-1',
20+
'S3_BUCKET': 'test-bucket',
21+
'S3_FOLDER': 'tts',
22+
}
23+
for key, value in settings.items():
24+
monkeypatch.setenv(key, value)
25+
return settings
26+
27+
28+
@pytest.fixture
29+
def patched_clients(monkeypatch):
30+
"""Patch the Deepgram SDK, boto3 client, and Label Studio SDK with mocks."""
31+
mock_deepgram_client = MagicMock(name='DeepgramClientInstance')
32+
mock_deepgram_ctor = MagicMock(return_value=mock_deepgram_client)
33+
monkeypatch.setattr(deepgram_model, 'DeepgramClient', mock_deepgram_ctor)
34+
35+
mock_s3_client = MagicMock(name='S3Client')
36+
monkeypatch.setattr(deepgram_model.boto3, 'client', MagicMock(return_value=mock_s3_client))
37+
38+
mock_ls = MagicMock(name='LabelStudio')
39+
monkeypatch.setattr(deepgram_model, 'ls', mock_ls)
40+
41+
return {
42+
'deepgram_client': mock_deepgram_client,
43+
'deepgram_ctor': mock_deepgram_ctor,
44+
's3_client': mock_s3_client,
45+
'ls': mock_ls,
46+
}
47+
48+
49+
def test_setup_raises_without_api_key(monkeypatch):
50+
"""
51+
Scenario: setup is called without DEEPGRAM_API_KEY.
52+
Steps : remove the env var and instantiate the model (setup runs in __init__).
53+
Checks : verify ValueError is raised mentioning the missing key.
54+
"""
55+
monkeypatch.delenv('DEEPGRAM_API_KEY', raising=False)
56+
57+
with pytest.raises(ValueError, match='DEEPGRAM_API_KEY'):
58+
deepgram_model.DeepgramModel()
59+
60+
61+
def test_setup_initializes_clients_with_api_key(env_settings, patched_clients):
62+
"""
63+
Scenario: setup receives valid env vars.
64+
Steps : call setup after patching external clients.
65+
Checks : ensure Deepgram & S3 clients plus region/bucket/folder are stored.
66+
"""
67+
model = deepgram_model.DeepgramModel()
68+
model.setup()
69+
70+
assert patched_clients['deepgram_ctor'].called
71+
assert model.deepgram_client is patched_clients['deepgram_client']
72+
assert model.s3_client is patched_clients['s3_client']
73+
assert model.s3_region == env_settings['AWS_DEFAULT_REGION']
74+
assert model.s3_bucket == env_settings['S3_BUCKET']
75+
assert model.s3_folder == env_settings['S3_FOLDER']
76+
77+
78+
def test_setup_falls_back_to_access_token(env_settings, patched_clients):
79+
"""
80+
Scenario: the Deepgram SDK rejects the api_key kwarg.
81+
Steps : make the first constructor call raise TypeError, then succeed on retry.
82+
Checks : setup retries using access_token and keeps the final client (setup runs in __init__).
83+
"""
84+
patched_clients['deepgram_ctor'].side_effect = [
85+
TypeError('unexpected kwarg'),
86+
patched_clients['deepgram_client'],
87+
]
88+
model = deepgram_model.DeepgramModel()
89+
90+
assert patched_clients['deepgram_ctor'].call_count == 2
91+
first_call_kwargs = patched_clients['deepgram_ctor'].call_args_list[0].kwargs
92+
second_call_kwargs = patched_clients['deepgram_ctor'].call_args_list[1].kwargs
93+
assert 'api_key' in first_call_kwargs
94+
assert 'access_token' in second_call_kwargs
95+
assert model.deepgram_client is patched_clients['deepgram_client']
96+
97+
98+
def test_predict_no_context_returns_empty_modelresponse(env_settings, patched_clients):
99+
"""
100+
Scenario: predict is invoked before the user submits any text.
101+
Steps : set up env vars and mocks, then call predict with empty context/result payloads.
102+
Checks : confirm an empty ModelResponse is returned immediately without calling external services.
103+
"""
104+
model = deepgram_model.DeepgramModel()
105+
tasks = [{'id': 1}]
106+
107+
response = model.predict(tasks=tasks, context=None)
108+
109+
assert isinstance(response, ModelResponse)
110+
assert response.predictions == []
111+
# Verify no external calls were made
112+
patched_clients['deepgram_client'].speak.v1.audio.generate.assert_not_called()
113+
patched_clients['s3_client'].upload_file.assert_not_called()
114+
115+
116+
def test_predict_generates_audio_uploads_to_s3_and_updates_task(env_settings, patched_clients):
117+
"""
118+
Scenario: predict handles a happy path request.
119+
Steps : mock Deepgram audio chunks, S3 upload, and Label Studio update.
120+
Checks : verify Deepgram is called, S3 upload args are correct, ls.tasks.update
121+
receives the S3 URL, and the temporary file is deleted.
122+
"""
123+
patched_clients['deepgram_client'].speak.v1.audio.generate.return_value = [b'chunk-a', b'chunk-b']
124+
model = deepgram_model.DeepgramModel()
125+
model.setup()
126+
127+
tasks = [{'id': 123}]
128+
context = {
129+
'user_id': 'user-7',
130+
'result': [{'value': {'text': ['Hello Deepgram']}}],
131+
}
132+
133+
model.predict(tasks=tasks, context=context)
134+
135+
patched_clients['deepgram_client'].speak.v1.audio.generate.assert_called_once_with(text='Hello Deepgram')
136+
assert patched_clients['s3_client'].upload_file.call_count == 1
137+
138+
upload_args = patched_clients['s3_client'].upload_file.call_args.kwargs
139+
local_path = patched_clients['s3_client'].upload_file.call_args.args[0]
140+
assert upload_args['ExtraArgs']['ContentType'] == 'audio/mpeg'
141+
assert upload_args['ExtraArgs']['ACL'] == 'public-read'
142+
assert upload_args['ExtraArgs']['CacheControl'].startswith('public')
143+
144+
expected_key = f"{env_settings['S3_FOLDER']}/123_user-7.mp3"
145+
assert patched_clients['s3_client'].upload_file.call_args.args[2] == expected_key
146+
147+
expected_url = f"https://{env_settings['S3_BUCKET']}.s3.{env_settings['AWS_DEFAULT_REGION']}.amazonaws.com/{expected_key}"
148+
patched_clients['ls'].tasks.update.assert_called_once_with(
149+
id=123,
150+
data={'text': 'Hello Deepgram', 'audio': expected_url},
151+
)
152+
153+
assert not os.path.exists(local_path)
154+
155+
156+
def test_predict_s3_failure_raises_and_cleans_up_temp_file(env_settings, patched_clients):
157+
"""
158+
Scenario: the S3 upload raises an exception.
159+
Steps : let Deepgram produce chunks, force upload_file to fail.
160+
Checks : ensure the exception bubbles up, temp file is removed, and Label Studio
161+
is never updated.
162+
"""
163+
patched_clients['deepgram_client'].speak.v1.audio.generate.return_value = [b'chunk']
164+
patched_clients['s3_client'].upload_file.side_effect = RuntimeError('s3 boom')
165+
model = deepgram_model.DeepgramModel()
166+
model.setup()
167+
168+
tasks = [{'id': 999}]
169+
context = {
170+
'user_id': 'user-1',
171+
'result': [{'value': {'text': ['Explode']}}],
172+
}
173+
174+
with pytest.raises(RuntimeError, match='s3 boom'):
175+
model.predict(tasks=tasks, context=context)
176+
177+
local_path = patched_clients['s3_client'].upload_file.call_args.args[0]
178+
assert not os.path.exists(local_path)
179+
patched_clients['ls'].tasks.update.assert_not_called()
180+

0 commit comments

Comments
 (0)