Skip to content

Commit 083a0af

Browse files
authored
Backend: mock unit tests (#908)
* chore(backend): add type hints to test files * feat(backend): improved typehints and decorators for deployments * feat(backend): removed dependency on cohere api key for testing
1 parent e17777b commit 083a0af

32 files changed

+685
-591
lines changed

.github/workflows/backend_integration_tests.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,47 @@ on:
1313
jobs:
1414
pytest:
1515
permissions: write-all
16-
environment: development
16+
# environment: development
1717
runs-on: ubuntu-latest
1818

1919
steps:
2020
- name: Checkout repo
2121
uses: actions/checkout@v3
22+
2223
- uses: actions/setup-python@v5
2324
with:
2425
python-version: '3.11'
2526
cache: 'pip'
27+
2628
- name: Install poetry
2729
uses: snok/install-poetry@v1
2830
with:
2931
virtualenvs-create: true
3032
virtualenvs-in-project: true
3133
virtualenvs-path: .venv
3234
installer-parallel: true
35+
3336
- name: Load cached venv
3437
id: cached-poetry-dependencies
3538
uses: actions/cache@v4
3639
with:
3740
path: .venv
3841
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
42+
3943
- name: Install dependencies
4044
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
4145
run: poetry install --with dev --no-interaction --no-root
46+
4247
- name: Setup test DB container
4348
run: make test-db
49+
4450
- name: Test with pytest
4551
if: github.actor != 'dependabot[bot]'
4652
run: |
4753
make run-integration-tests
4854
env:
4955
PYTHONPATH: src
56+
5057
- name: Upload coverage reports to Codecov
5158
uses: codecov/codecov-action@v4.0.1
5259
with:

.github/workflows/backend_unit_tests.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,41 @@ jobs:
1919
steps:
2020
- name: Checkout repo
2121
uses: actions/checkout@v3
22+
2223
- uses: actions/setup-python@v5
2324
with:
2425
python-version: '3.11'
2526
cache: 'pip'
27+
2628
- name: Install poetry
2729
uses: snok/install-poetry@v1
2830
with:
2931
virtualenvs-create: true
3032
virtualenvs-in-project: true
3133
virtualenvs-path: .venv
3234
installer-parallel: true
35+
3336
- name: Load cached venv
3437
id: cached-poetry-dependencies
3538
uses: actions/cache@v4
3639
with:
3740
path: .venv
3841
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
42+
3943
- name: Install dependencies
4044
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
4145
run: poetry install --with dev --no-interaction --no-root
46+
4247
- name: Setup test DB container
4348
run: make test-db
49+
4450
- name: Test with pytest
4551
if: github.actor != 'dependabot[bot]'
4652
run: |
4753
make run-unit-tests
4854
env:
4955
PYTHONPATH: src
56+
5057
- name: Upload coverage reports to Codecov
5158
uses: codecov/codecov-action@v4.0.1
5259
with:

src/backend/chat/collate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def rerank_and_chunk(
6060
reranked_results[tool_call_hashable] = tool_result
6161
continue
6262

63-
chunked_outputs = []
63+
chunked_outputs: list[dict[str, Any]] = []
6464
for output in tool_result["outputs"]:
6565
text = output.get("text")
6666

@@ -78,7 +78,7 @@ async def rerank_and_chunk(
7878

7979
res = await model.invoke_rerank(
8080
query=query,
81-
documents=chunked_outputs,
81+
documents=[output["text"] for output in chunked_outputs],
8282
ctx=ctx,
8383
)
8484

@@ -102,7 +102,12 @@ async def rerank_and_chunk(
102102
return list(reranked_results.values())
103103

104104

105-
def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=300):
105+
def chunk(
106+
content: str,
107+
compact_mode: bool = False,
108+
soft_word_cut_off: int = 100,
109+
hard_word_cut_off: int = 300,
110+
) -> list[str]:
106111
if compact_mode:
107112
content = content.replace("\n", " ")
108113

@@ -139,7 +144,7 @@ def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off=
139144
return chunks
140145

141146

142-
def to_dict(obj):
147+
def to_dict(obj) -> dict:
143148
return json.loads(
144149
json.dumps(
145150
obj, default=lambda o: o.__dict__ if hasattr(o, "__dict__") else str(o)

src/backend/chat/custom/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_deployment(name: str, session: Session, ctx: Context, **kwargs: Any) ->
2424
kwargs["ctx"] = ctx
2525
try:
2626
deployment = deployment_service.get_deployment_instance_by_name(session, name, **kwargs)
27-
except DeploymentNotFoundError:
27+
except (DeploymentNotFoundError, Exception):
2828
deployment = deployment_service.get_default_deployment_instance(**kwargs)
2929

3030
return deployment

src/backend/model_deployments/azure.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, AsyncGenerator, Dict, List
1+
from typing import Any, AsyncGenerator
22

33
import cohere
44

@@ -43,33 +43,33 @@ def __init__(self, **kwargs: Any):
4343
base_url=self.chat_endpoint_url, api_key=self.api_key
4444
)
4545

46-
@classmethod
47-
def name(cls) -> str:
46+
@staticmethod
47+
def name() -> str:
4848
return "Azure"
4949

50-
@classmethod
51-
def env_vars(cls) -> List[str]:
50+
@staticmethod
51+
def env_vars() -> list[str]:
5252
return [AZURE_API_KEY_ENV_VAR, AZURE_CHAT_URL_ENV_VAR]
5353

54-
@classmethod
55-
def rerank_enabled(cls) -> bool:
54+
@staticmethod
55+
def rerank_enabled() -> bool:
5656
return False
5757

5858
@classmethod
59-
def list_models(cls) -> List[str]:
59+
def list_models(cls) -> list[str]:
6060
if not cls.is_available():
6161
return []
6262

6363
return cls.DEFAULT_MODELS
6464

65-
@classmethod
66-
def is_available(cls) -> bool:
65+
@staticmethod
66+
def is_available() -> bool:
6767
return (
6868
AzureDeployment.default_api_key is not None
6969
and AzureDeployment.default_chat_endpoint_url is not None
7070
)
7171

72-
async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
72+
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any:
7373
response = self.client.chat(
7474
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
7575
)
@@ -86,6 +86,6 @@ async def invoke_chat_stream(
8686
yield to_dict(event)
8787

8888
async def invoke_rerank(
89-
self, query: str, documents: List[Dict[str, Any]], ctx: Context
89+
self, query: str, documents: list[str], ctx: Context, **kwargs
9090
) -> Any:
9191
return None

src/backend/model_deployments/base.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, AsyncGenerator, Dict, List
2+
from typing import Any
33

44
from backend.config.settings import Settings
55
from backend.schemas.cohere_chat import CohereChatRequest
@@ -25,32 +25,32 @@ def __init__(self, db_id=None, **kwargs: Any):
2525
def id(cls) -> str:
2626
return cls.db_id if cls.db_id else cls.name().replace(" ", "_").lower()
2727

28-
@classmethod
28+
@staticmethod
2929
@abstractmethod
30-
def name(cls) -> str: ...
30+
def name() -> str: ...
3131

32-
@classmethod
32+
@staticmethod
3333
@abstractmethod
34-
def env_vars(cls) -> List[str]: ...
34+
def env_vars() -> list[str]: ...
3535

36-
@classmethod
36+
@staticmethod
3737
@abstractmethod
38-
def rerank_enabled(cls) -> bool: ...
38+
def rerank_enabled() -> bool: ...
3939

4040
@classmethod
4141
@abstractmethod
42-
def list_models(cls) -> List[str]: ...
42+
def list_models(cls) -> list[str]: ...
4343

44-
@classmethod
44+
@staticmethod
4545
@abstractmethod
46-
def is_available(cls) -> bool: ...
46+
def is_available() -> bool: ...
4747

4848
@classmethod
4949
def is_community(cls) -> bool:
5050
return False
5151

5252
@classmethod
53-
def config(cls) -> Dict[str, Any]:
53+
def config(cls) -> dict[str, Any]:
5454
config = Settings().get(f"deployments.{cls.id()}")
5555

5656
if not config:
@@ -81,9 +81,9 @@ async def invoke_chat(
8181
@abstractmethod
8282
async def invoke_chat_stream(
8383
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
84-
) -> AsyncGenerator[Any, Any]: ...
84+
) -> Any: ...
8585

8686
@abstractmethod
8787
async def invoke_rerank(
88-
self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any
88+
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
8989
) -> Any: ...

src/backend/model_deployments/bedrock.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, AsyncGenerator, Dict, List
1+
from typing import Any, AsyncGenerator
22

33
import cohere
44

@@ -42,40 +42,40 @@ def __init__(self, **kwargs: Any):
4242
),
4343
)
4444

45-
@classmethod
46-
def name(cls) -> str:
45+
@staticmethod
46+
def name() -> str:
4747
return "Bedrock"
4848

49-
@classmethod
50-
def env_vars(cls) -> List[str]:
49+
@staticmethod
50+
def env_vars() -> list[str]:
5151
return [
5252
BEDROCK_ACCESS_KEY_ENV_VAR,
5353
BEDROCK_SECRET_KEY_ENV_VAR,
5454
BEDROCK_SESSION_TOKEN_ENV_VAR,
5555
BEDROCK_REGION_NAME_ENV_VAR,
5656
]
5757

58-
@classmethod
59-
def rerank_enabled(cls) -> bool:
58+
@staticmethod
59+
def rerank_enabled() -> bool:
6060
return False
6161

6262
@classmethod
63-
def list_models(cls) -> List[str]:
63+
def list_models(cls) -> list[str]:
6464
if not cls.is_available():
6565
return []
6666

6767
return cls.DEFAULT_MODELS
6868

69-
@classmethod
70-
def is_available(cls) -> bool:
69+
@staticmethod
70+
def is_available() -> bool:
7171
return (
7272
BedrockDeployment.access_key is not None
7373
and BedrockDeployment.secret_access_key is not None
7474
and BedrockDeployment.session_token is not None
7575
and BedrockDeployment.region_name is not None
7676
)
7777

78-
async def invoke_chat(self, chat_request: CohereChatRequest) -> Any:
78+
async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
7979
# bedrock accepts a subset of the chat request fields
8080
bedrock_chat_req = chat_request.model_dump(
8181
exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True
@@ -101,6 +101,6 @@ async def invoke_chat_stream(
101101
yield to_dict(event)
102102

103103
async def invoke_rerank(
104-
self, query: str, documents: List[Dict[str, Any]], ctx: Context
104+
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
105105
) -> Any:
106106
return None

src/backend/model_deployments/cohere_platform.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List
1+
from typing import Any
22

33
import cohere
44
import requests
@@ -29,20 +29,20 @@ def __init__(self, **kwargs: Any):
2929
)
3030
self.client = cohere.Client(api_key, client_name=self.client_name)
3131

32-
@classmethod
33-
def name(cls) -> str:
32+
@staticmethod
33+
def name() -> str:
3434
return "Cohere Platform"
3535

36-
@classmethod
37-
def env_vars(cls) -> List[str]:
36+
@staticmethod
37+
def env_vars() -> list[str]:
3838
return [COHERE_API_KEY_ENV_VAR]
3939

40-
@classmethod
41-
def rerank_enabled(cls) -> bool:
40+
@staticmethod
41+
def rerank_enabled() -> bool:
4242
return True
4343

4444
@classmethod
45-
def list_models(cls) -> List[str]:
45+
def list_models(cls) -> list[str]:
4646
logger = LoggerFactory().get_logger()
4747
if not CohereDeployment.is_available():
4848
return []
@@ -64,12 +64,12 @@ def list_models(cls) -> List[str]:
6464
models = response.json()["models"]
6565
return [model["name"] for model in models if model.get("endpoints") and "chat" in model["endpoints"]]
6666

67-
@classmethod
68-
def is_available(cls) -> bool:
67+
@staticmethod
68+
def is_available() -> bool:
6969
return CohereDeployment.api_key is not None
7070

7171
async def invoke_chat(
72-
self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any
72+
self, chat_request: CohereChatRequest, **kwargs: Any
7373
) -> Any:
7474
response = self.client.chat(
7575
**chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}),
@@ -99,7 +99,7 @@ async def invoke_chat_stream(
9999
yield event_dict
100100

101101
async def invoke_rerank(
102-
self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any
102+
self, query: str, documents: list[str], ctx: Context, **kwargs: Any
103103
) -> Any:
104104
response = self.client.rerank(
105105
query=query, documents=documents, model=DEFAULT_RERANK_MODEL

0 commit comments

Comments
 (0)