Skip to content

Commit 802c232

Browse files
malexwtianjing-li
andauthored
backend: Deployments refactor; Add deployment service and fix deployment config setting (#831)
* Deployments refactor; Add deployment service and fix deployment config setting * Changes for code review * Fix a number of integration and unit tests * Fix failing chat tests * Move some tests from unit/routers to integration/routers * Fix a few more tests * Fix a few more tests * Fix remainder of broken integration tests * Fix lint issues * Run prettier on Coral * Remove old, unused model crud helper * Fix failing deployments unit tests * Coral fix to account for agent.tools possibly being null * Fix TS styling * Provide a dummy Cohere API key during testing * Update Coral to align with latest version of the backend API * Fix lint issues in Coral * Last few changes for code review * Update generated API in assistants_web * Fix assistants_web build * Fix backend lint issues * Simplify validate_deployment_header * Don't seed the DB with deployment data, and fix a DeploymentDefinition serialization issue * Fix backend lint issues * Fix broken unit tests * Skip cohere deployments tests since they're breaking other tests * Fix deployment integration tests * More fixes to deployments integration tests * Fix deployment integration tests * What API key are we using to call Cohere in the tests? * Mock list_models of CoherePlatform model to avoid Cohere API calls * Mask all deployment config values when looking up deployment info * Fix integration tests * Fix typecheck issue * wip * Minor changes * use defaults * resolve lint * Update generated client --------- Co-authored-by: Tianjing Li <tianjinglimail@gmail.com>
1 parent 5ad48e6 commit 802c232

File tree

88 files changed

+6130
-3691
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+6130
-3691
lines changed

src/backend/alembic/versions/2024_08_01_117f0d9b1d3d_seed_deployments_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from alembic import op
1212

13-
from backend.database_models.seeders.deplyments_models_seed import (
13+
from backend.database_models.seeders.deployments_models_seed import (
1414
delete_default_models,
1515
deployments_models_seed,
1616
)

src/backend/chat/custom/utils.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from typing import Any
22

3-
from backend.config.deployments import (
4-
AVAILABLE_MODEL_DEPLOYMENTS,
5-
get_default_deployment,
6-
)
3+
from backend.database_models.database import get_session
4+
from backend.exceptions import DeploymentNotFoundError
75
from backend.model_deployments.base import BaseDeployment
86
from backend.schemas.context import Context
7+
from backend.services import deployment as deployment_service
98

109

1110
def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
@@ -16,22 +15,12 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
1615
1716
Returns:
1817
BaseDeployment: Deployment implementation instance based on the deployment name.
19-
20-
Raises:
21-
ValueError: If the deployment is not supported.
2218
"""
2319
kwargs["ctx"] = ctx
24-
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)
25-
26-
# Check provided deployment against config const
27-
if deployment is not None:
28-
return deployment.deployment_class(**kwargs, **deployment.kwargs)
29-
30-
# Fallback to first available deployment
31-
default = get_default_deployment(**kwargs)
32-
if default is not None:
33-
return default
20+
try:
21+
session = next(get_session())
22+
deployment = deployment_service.get_deployment_by_name(session, name, **kwargs)
23+
except DeploymentNotFoundError:
24+
deployment = deployment_service.get_default_deployment(**kwargs)
3425

35-
raise ValueError(
36-
f"Deployment {name} is not supported, and no available deployments were found."
37-
)
26+
return deployment

src/backend/config/default_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import datetime
22

3-
from backend.config.deployments import ModelDeploymentName
43
from backend.config.tools import Tool
4+
from backend.model_deployments.cohere_platform import CohereDeployment
55
from backend.schemas.agent import AgentPublic
66

77
DEFAULT_AGENT_ID = "default"
8-
DEFAULT_DEPLOYMENT = ModelDeploymentName.CoherePlatform
8+
DEFAULT_DEPLOYMENT = CohereDeployment.name()
99
DEFAULT_MODEL = "command-r-plus"
1010

1111
def get_default_agent() -> AgentPublic:

src/backend/config/deployments.py

Lines changed: 15 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,140 +1,35 @@
1-
from enum import StrEnum
2-
31
from backend.config.settings import Settings
4-
from backend.model_deployments import (
5-
AzureDeployment,
6-
BedrockDeployment,
7-
CohereDeployment,
8-
SageMakerDeployment,
9-
SingleContainerDeployment,
10-
)
11-
from backend.model_deployments.azure import AZURE_ENV_VARS
122
from backend.model_deployments.base import BaseDeployment
13-
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
14-
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
15-
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
16-
from backend.model_deployments.single_container import SC_ENV_VARS
17-
from backend.schemas.deployment import Deployment
183
from backend.services.logger.utils import LoggerFactory
194

205
logger = LoggerFactory().get_logger()
216

227

23-
class ModelDeploymentName(StrEnum):
24-
CoherePlatform = "Cohere Platform"
25-
SageMaker = "SageMaker"
26-
Azure = "Azure"
27-
Bedrock = "Bedrock"
28-
SingleContainer = "Single Container"
29-
30-
31-
use_community_features = Settings().get('feature_flags.use_community_features')
8+
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }
329

33-
# TODO names in the map below should not be the display names but ids
34-
ALL_MODEL_DEPLOYMENTS = {
35-
ModelDeploymentName.CoherePlatform: Deployment(
36-
id="cohere_platform",
37-
name=ModelDeploymentName.CoherePlatform,
38-
deployment_class=CohereDeployment,
39-
models=CohereDeployment.list_models(),
40-
is_available=CohereDeployment.is_available(),
41-
env_vars=COHERE_ENV_VARS,
42-
),
43-
ModelDeploymentName.SingleContainer: Deployment(
44-
id="single_container",
45-
name=ModelDeploymentName.SingleContainer,
46-
deployment_class=SingleContainerDeployment,
47-
models=SingleContainerDeployment.list_models(),
48-
is_available=SingleContainerDeployment.is_available(),
49-
env_vars=SC_ENV_VARS,
50-
),
51-
ModelDeploymentName.SageMaker: Deployment(
52-
id="sagemaker",
53-
name=ModelDeploymentName.SageMaker,
54-
deployment_class=SageMakerDeployment,
55-
models=SageMakerDeployment.list_models(),
56-
is_available=SageMakerDeployment.is_available(),
57-
env_vars=SAGE_MAKER_ENV_VARS,
58-
),
59-
ModelDeploymentName.Azure: Deployment(
60-
id="azure",
61-
name=ModelDeploymentName.Azure,
62-
deployment_class=AzureDeployment,
63-
models=AzureDeployment.list_models(),
64-
is_available=AzureDeployment.is_available(),
65-
env_vars=AZURE_ENV_VARS,
66-
),
67-
ModelDeploymentName.Bedrock: Deployment(
68-
id="bedrock",
69-
name=ModelDeploymentName.Bedrock,
70-
deployment_class=BedrockDeployment,
71-
models=BedrockDeployment.list_models(),
72-
is_available=BedrockDeployment.is_available(),
73-
env_vars=BEDROCK_ENV_VARS,
74-
),
75-
}
7610

11+
def get_available_deployments() -> list[type[BaseDeployment]]:
12+
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())
7713

78-
def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
79-
if use_community_features:
14+
if Settings().get("feature_flags.use_community_features"):
8015
try:
8116
from community.config.deployments import (
8217
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
8318
)
84-
85-
model_deployments = ALL_MODEL_DEPLOYMENTS.copy()
86-
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
87-
return model_deployments
88-
except ImportError:
19+
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())
20+
except ImportError as e:
8921
logger.warning(
90-
event="[Deployments] No available community deployments have been configured"
22+
event="[Deployments] No available community deployments have been configured", ex=e
9123
)
9224

93-
deployments = Settings().get('deployments.enabled_deployments')
94-
if deployments is not None and len(deployments) > 0:
95-
return {
96-
key: value
97-
for key, value in ALL_MODEL_DEPLOYMENTS.items()
98-
if value.id in Settings().get('deployments.enabled_deployments')
99-
}
100-
101-
return ALL_MODEL_DEPLOYMENTS
102-
103-
104-
def get_default_deployment(**kwargs) -> BaseDeployment:
105-
# Fallback to the first available deployment
106-
fallback = None
107-
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
108-
if deployment.is_available:
109-
fallback = deployment.deployment_class(**kwargs)
110-
break
111-
112-
default = Settings().get('deployments.default_deployment')
113-
if default:
114-
return next(
115-
(
116-
v.deployment_class(**kwargs)
117-
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items()
118-
if v.id == default
119-
),
120-
fallback,
121-
)
122-
else:
123-
return fallback
124-
125-
126-
def find_config_by_deployment_id(deployment_id: str) -> Deployment:
127-
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
128-
if deployment.id == deployment_id:
129-
return deployment
130-
return None
131-
132-
133-
def find_config_by_deployment_name(deployment_name: str) -> Deployment:
134-
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
135-
if deployment.name == deployment_name:
136-
return deployment
137-
return None
25+
enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
26+
if enabled_deployment_ids:
27+
return [
28+
deployment
29+
for deployment in installed_deployments
30+
if deployment.id() in enabled_deployment_ids
31+
]
13832

33+
return installed_deployments
13934

14035
AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()

src/backend/crud/deployment.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
import os
21

32
from sqlalchemy.orm import Session
43

54
from backend.database_models import Deployment
65
from backend.model_deployments.utils import class_name_validator
7-
from backend.schemas.deployment import Deployment as DeploymentSchema
8-
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
9-
from backend.services.transaction import validate_transaction
10-
from community.config.deployments import (
11-
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
6+
from backend.schemas.deployment import (
7+
DeploymentCreate,
8+
DeploymentDefinition,
9+
DeploymentUpdate,
1210
)
11+
from backend.services.transaction import validate_transaction
1312

1413

1514
@validate_transaction
@@ -19,7 +18,7 @@ def create_deployment(db: Session, deployment: DeploymentCreate) -> Deployment:
1918
2019
Args:
2120
db (Session): Database session.
22-
deployment (DeploymentSchema): Deployment data to be created.
21+
deployment (DeploymentDefinition): Deployment data to be created.
2322
2423
Returns:
2524
Deployment: Created deployment.
@@ -132,27 +131,24 @@ def delete_deployment(db: Session, deployment_id: str) -> None:
132131

133132

134133
@validate_transaction
135-
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
134+
def create_deployment_by_config(db: Session, deployment_config: DeploymentDefinition) -> Deployment:
136135
"""
137136
Create a new deployment by config.
138137
139138
Args:
140139
db (Session): Database session.
141140
deployment (str): Deployment data to be created.
142-
deployment_config (DeploymentSchema): Deployment config.
141+
deployment_config (DeploymentDefinition): Deployment config.
143142
144143
Returns:
145144
Deployment: Created deployment.
146145
"""
147146
deployment = Deployment(
148147
name=deployment_config.name,
149148
description="",
150-
default_deployment_config= {
151-
env_var: os.environ.get(env_var, "")
152-
for env_var in deployment_config.env_vars
153-
},
154-
deployment_class_name=deployment_config.deployment_class.__name__,
155-
is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
149+
default_deployment_config=deployment_config.config,
150+
deployment_class_name=deployment_config.class_name,
151+
is_community=deployment_config.is_community,
156152
)
157153
db.add(deployment)
158154
db.commit()

src/backend/crud/model.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from sqlalchemy.orm import Session
22

3-
from backend.database_models import Deployment
43
from backend.database_models.model import Model
5-
from backend.schemas.deployment import Deployment as DeploymentSchema
4+
from backend.schemas.deployment import DeploymentDefinition
65
from backend.schemas.model import ModelCreate, ModelUpdate
6+
from backend.services.logger.utils import LoggerFactory
77
from backend.services.transaction import validate_transaction
88

9+
logger = LoggerFactory().get_logger()
10+
911

1012
@validate_transaction
1113
def create_model(db: Session, model: ModelCreate) -> Model:
@@ -127,29 +129,29 @@ def delete_model(db: Session, model_id: str) -> None:
127129
db.commit()
128130

129131

130-
def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
132+
def create_model_by_config(db: Session, deployment_config: DeploymentDefinition, deployment_id: str, model: str | None) -> Model:
131133
"""
132134
Create a new model by config if present
133135
134136
Args:
135137
db (Session): Database session.
136-
deployment (Deployment): Deployment data.
137-
deployment_config (DeploymentSchema): Deployment config data.
138-
model (str): Model data.
138+
deployment_config (DeploymentDefinition): A deployment definition for any kind of deployment.
139+
deployment_id (DeploymentDefinition): Deployment ID for a deployment from the DB.
140+
model (str): Optional model name that should have its data returned from this call.
139141
140142
Returns:
141143
Model: Created model.
142144
"""
143-
deployment_config_models = deployment_config.models
144-
deployment_db_models = get_models_by_deployment_id(db, deployment.id)
145+
logger.debug(event="create_model_by_config", deployment_models=deployment_config.models, deployment_id=deployment_id, model=model)
146+
deployment_db_models = get_models_by_deployment_id(db, deployment_id)
145147
model_to_return = None
146-
for deployment_config_model in deployment_config_models:
148+
for deployment_config_model in deployment_config.models:
147149
model_in_db = any(record.name == deployment_config_model for record in deployment_db_models)
148150
if not model_in_db:
149151
new_model = Model(
150152
name=deployment_config_model,
151153
cohere_name=deployment_config_model,
152-
deployment_id=deployment.id,
154+
deployment_id=deployment_id,
153155
)
154156
db.add(new_model)
155157
db.commit()
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from sqlalchemy.orm import Session
2+
3+
from backend.database_models import Deployment, Model, Organization
4+
5+
6+
def deployments_models_seed(op):
7+
"""
8+
Seed default deployments, models, organization, user and agent.
9+
"""
10+
# Previously we would seed the default deployments and models here. We've changed this
11+
# behaviour during a refactor of the deployments module so that deployments and models
12+
# are inserted when they're first used. This solves an issue where seed data would
13+
# sometimes be inserted with invalid config data.
14+
pass
15+
16+
17+
def delete_default_models(op):
18+
"""
19+
Delete deployments and models.
20+
"""
21+
session = Session(op.get_bind())
22+
session.query(Deployment).delete()
23+
session.query(Model).delete()
24+
session.query(Organization).filter_by(id="default").delete()
25+
session.commit()

0 commit comments

Comments
 (0)