-
Notifications
You must be signed in to change notification settings - Fork 428
Expand file tree
/
Copy pathconftest.py
More file actions
232 lines (174 loc) · 6.32 KB
/
conftest.py
File metadata and controls
232 lines (174 loc) · 6.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
from typing import Any, Generator
from unittest.mock import patch
import pytest
from alembic.command import upgrade
from alembic.config import Config
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from backend.chat.enums import FinishReason
from backend.database_models import get_session
from backend.database_models.agent import Agent
from backend.database_models.deployment import Deployment
from backend.database_models.model import Model
from backend.main import app, create_app
from backend.schemas.chat import StreamEvent
from backend.schemas.organization import Organization
from backend.schemas.user import User
from backend.tests.unit.factories import get_factory
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://postgres:postgres@localhost:5433")
@pytest.fixture
def client() -> Generator[TestClient, None, None]:
yield TestClient(app)
@pytest.fixture(scope="function")
def engine() -> Generator[Any, None, None]:
"""
Yields a SQLAlchemy engine which is disposed of after the test session
"""
engine = create_engine(DATABASE_URL, echo=True)
yield engine
engine.dispose()
@pytest.fixture(scope="function")
def session(engine: Any) -> Generator[Session, None, None]:
"""
Yields a SQLAlchemy session within a transaction
that is rolled back after every function
"""
connection = engine.connect()
# Begin the nested transaction
transaction = connection.begin()
# Use connection within the started transaction
session = Session(bind=connection)
# Run Alembic migrations
alembic_cfg = Config("src/backend/alembic.ini")
upgrade(alembic_cfg, "head")
yield session
session.close()
# Roll back the transaction
transaction.rollback()
# Close connection so it returns to the connection pool
connection.close()
@pytest.fixture(scope="function")
def session_client(session: Session) -> Generator[TestClient, None, None]:
"""
Fixture to inject the session into the API client
"""
def override_get_session() -> Generator[Session, Any, None]:
yield session
app = create_app()
app.dependency_overrides[get_session] = override_get_session
print("Session at fixture " + str(session))
with TestClient(app) as client:
yield client
app.dependency_overrides = {}
@pytest.fixture(scope="session")
def engine_chat() -> Generator[Any, None, None]:
"""
Yields a SQLAlchemy engine which is disposed of after the test session
"""
engine = create_engine(DATABASE_URL, echo=True)
yield engine
engine.dispose()
@pytest.fixture(scope="session")
def session_chat(engine_chat: Any) -> Generator[Session, None, None]:
"""
Yields a SQLAlchemy session within a transaction
that is rolled back after every session
We need to use the fixture in the session scope because the chat
endpoint is asynchronous and needs to be open for the entire session
"""
connection = engine_chat.connect()
# Begin the nested transaction
transaction = connection.begin()
# Use connection within the started transaction
session = Session(bind=connection)
# Run Alembic migrations
alembic_cfg = Config("src/backend/alembic.ini")
upgrade(alembic_cfg, "head")
yield session
session.close()
# Roll back the transaction
transaction.rollback()
# Close connection so it returns to the connection pool
connection.close()
@pytest.fixture(scope="session")
def session_client_chat(session_chat: Session) -> Generator[TestClient, None, None]:
"""
Fixture to inject the session into the API client
We need to use the fixture in the session scope because the chat
router uses a WebSocket connection that needs to be open for the
entire session
"""
def override_get_session() -> Generator[Session, Any, None]:
yield session_chat
app = create_app()
app.dependency_overrides[get_session] = override_get_session
print("Session at fixture " + str(session_chat))
with TestClient(app) as client:
yield client
app.dependency_overrides = {}
@pytest.fixture
def user(session: Session) -> User:
return get_factory("User", session).create()
@pytest.fixture
def organization(session: Session) -> Organization:
return get_factory("Organization", session).create()
@pytest.fixture
def deployment(session: Session) -> Deployment:
return get_factory("Deployment", session).create(
deployment_class_name="CohereDeployment"
)
@pytest.fixture
def model(session: Session) -> Model:
return get_factory("Model", session).create()
@pytest.fixture
def agent(session: Session) -> Agent:
return get_factory("Agent", session).create()
@pytest.fixture
def inject_events() -> list[dict]:
return []
@pytest.fixture
def mock_event_stream(inject_events: list[dict]) -> list[dict]:
events = [
{
"event_type": StreamEvent.STREAM_START,
"generation_id": "ca0f398e-f8c8-48f0-b093-12d1754d00ed",
},
]
if inject_events:
events.extend(inject_events)
events.extend([
{
"event_type": StreamEvent.TEXT_GENERATION,
"text": "This is a test.",
},
{
"event_type": StreamEvent.STREAM_END,
"response": {
"generation_id": "ca0f398e-f8c8-48f0-b093-12d1754d00ed",
"citations": [],
"documents": [],
"search_results": [],
"search_queries": [],
},
"finish_reason": FinishReason.COMPLETE,
}
])
return events
@pytest.fixture
def mock_available_model_deployments(mock_event_stream: list[dict]):
from backend.tests.unit.model_deployments.mock_deployments.mock_base import (
MockDeployment,
)
MockDeployment.event_stream = mock_event_stream
MOCKED_DEPLOYMENTS = { d.name(): d for d in MockDeployment.__subclasses__() }
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", MOCKED_DEPLOYMENTS) as mock:
yield mock
@pytest.fixture
def mock_cohere_list_models():
with patch(
"backend.model_deployments.cohere_platform.CohereDeployment.list_models",
return_value=["command", "command-r", "command-r-plus", "command-light-nightly"]
) as mock:
yield mock