-
Notifications
You must be signed in to change notification settings - Fork 207
Add replica groups in dstack-service #3408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 10 commits
22c1410
5abbcad
abba7da
d974292
caa4283
1ec1d6d
7b4bc52
8c5589d
0a54e07
f4c9fdf
a0e13f6
24d976e
263e312
5c71f76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel): | |
|
|
||
| @root_validator | ||
| def check_image_or_commands_present(cls, values): | ||
| # If replicas is present, skip validation - commands come from replica groups | ||
| replica_groups = values.get("replicas") | ||
| if replica_groups: | ||
| return values | ||
|
|
||
| if not values.get("commands") and not values.get("image"): | ||
| raise ValueError("Either `commands` or `image` must be set") | ||
| return values | ||
|
|
@@ -714,6 +719,62 @@ def schema_extra(schema: Dict[str, Any]): | |
| ) | ||
|
|
||
|
|
||
| class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): | ||
|
||
| name: Annotated[ | ||
| Optional[str], | ||
| Field( | ||
| description="The name of the replica group. If not provided, defaults to 'replica0', 'replica1', etc. based on position." | ||
| ), | ||
| ] | ||
jvstme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| count: Annotated[ | ||
| Range[int], | ||
| Field( | ||
| description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " | ||
| "If it's a range, the `scaling` property is required" | ||
| ), | ||
| ] | ||
| scaling: Annotated[ | ||
| Optional[ScalingSpec], | ||
| Field(description="The auto-scaling rules. Required if `count` is set to a range"), | ||
| ] = None | ||
| # TODO: Extract to ConfigurationWithResourcesParams mixin | ||
| resources: Annotated[ | ||
| ResourcesSpec, | ||
| Field(description="The resources requirements for replicas in this group"), | ||
| ] = ResourcesSpec() | ||
jvstme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @validator("count") | ||
| def convert_count(cls, v: Range[int]) -> Range[int]: | ||
| if v.max is None: | ||
| raise ValueError("The maximum number of replicas is required") | ||
| if v.min is None: | ||
| v.min = 0 | ||
| if v.min < 0: | ||
| raise ValueError("The minimum number of replicas must be greater than or equal to 0") | ||
| return v | ||
|
|
||
| @root_validator() | ||
| def override_commands_validation(cls, values): | ||
| """ | ||
| Override parent validator from ConfigurationWithCommandsParams. | ||
| ReplicaGroup always requires commands (no image option). | ||
| """ | ||
| commands = values.get("commands", []) | ||
| if not commands: | ||
| raise ValueError("`commands` must be set for replica groups") | ||
| return values | ||
|
|
||
| @root_validator() | ||
| def validate_scaling(cls, values): | ||
| scaling = values.get("scaling") | ||
| count = values.get("count") | ||
| if count and count.min != count.max and not scaling: | ||
| raise ValueError("When you set `count` to a range, ensure to specify `scaling`.") | ||
| if count and count.min == count.max and scaling: | ||
| raise ValueError("To use `scaling`, `count` must be set to a range.") | ||
| return values | ||
|
|
||
|
|
||
| class ServiceConfigurationParams(CoreModel): | ||
| port: Annotated[ | ||
| # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used. | ||
|
|
@@ -755,13 +816,7 @@ class ServiceConfigurationParams(CoreModel): | |
| SERVICE_HTTPS_DEFAULT | ||
| ) | ||
| auth: Annotated[bool, Field(description="Enable the authorization")] = True | ||
| replicas: Annotated[ | ||
| Range[int], | ||
| Field( | ||
| description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " | ||
| "If it's a range, the `scaling` property is required" | ||
| ), | ||
| ] = Range[int](min=1, max=1) | ||
|
|
||
| scaling: Annotated[ | ||
| Optional[ScalingSpec], | ||
| Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), | ||
|
|
@@ -772,6 +827,19 @@ class ServiceConfigurationParams(CoreModel): | |
| Field(description="List of probes used to determine job health"), | ||
| ] = [] | ||
|
|
||
| replicas: Annotated[ | ||
| Optional[Union[Range[int], List[ReplicaGroup]]], | ||
| Field( | ||
| description=( | ||
| "List of replica groups. Each group defines replicas with shared configuration " | ||
| "(commands, port, resources, scaling, probes, rate_limits). " | ||
| "When specified, the top-level `replicas`, `commands`, `port`, `resources`, " | ||
| "`scaling`, `probes`, and `rate_limits` are ignored. " | ||
| "Each replica group must have a unique name." | ||
jvstme marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
| ), | ||
| ] = None | ||
|
|
||
| @validator("port") | ||
| def convert_port(cls, v) -> PortMapping: | ||
| if isinstance(v, int): | ||
|
|
@@ -786,26 +854,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]: | |
| return OpenAIChatModel(type="chat", name=v, format="openai") | ||
| return v | ||
|
|
||
| @validator("replicas") | ||
| def convert_replicas(cls, v: Range[int]) -> Range[int]: | ||
| if v.max is None: | ||
| raise ValueError("The maximum number of replicas is required") | ||
| if v.min is None: | ||
| v.min = 0 | ||
| if v.min < 0: | ||
| raise ValueError("The minimum number of replicas must be greater than or equal to 0") | ||
| return v | ||
|
|
||
| @root_validator() | ||
| def validate_scaling(cls, values): | ||
| scaling = values.get("scaling") | ||
| replicas = values.get("replicas") | ||
| if replicas and replicas.min != replicas.max and not scaling: | ||
| raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") | ||
| if replicas and replicas.min == replicas.max and scaling: | ||
| raise ValueError("To use `scaling`, `replicas` must be set to a range.") | ||
| return values | ||
|
|
||
| @validator("rate_limits") | ||
| def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: | ||
| counts = Counter(limit.prefix for limit in v) | ||
|
|
@@ -827,6 +875,56 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: | |
| raise ValueError("Probes must be unique") | ||
| return v | ||
|
|
||
| @validator("replicas") | ||
| def validate_replicas( | ||
| cls, v: Optional[Union[Range[int], List[ReplicaGroup]]] | ||
| ) -> Optional[Union[Range[int], List[ReplicaGroup]]]: | ||
| if v is None: | ||
| return v | ||
| if isinstance(v, Range): | ||
| if v.max is None: | ||
| raise ValueError("The maximum number of replicas is required") | ||
| if v.min is None: | ||
| v.min = 0 | ||
| if v.min < 0: | ||
| raise ValueError( | ||
| "The minimum number of replicas must be greater than or equal to 0" | ||
| ) | ||
| return v | ||
jvstme marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if isinstance(v, list): | ||
| if not v: | ||
| raise ValueError("`replicas` cannot be an empty list") | ||
|
|
||
| # Assign default names to groups without names | ||
| for index, group in enumerate(v): | ||
| if group.name is None: | ||
| group.name = f"replica{index}" | ||
|
|
||
| # Check for duplicate names | ||
| names = [group.name for group in v] | ||
| if len(names) != len(set(names)): | ||
| duplicates = [name for name in set(names) if names.count(name) > 1] | ||
| raise ValueError( | ||
| f"Duplicate replica group names found: {duplicates}. " | ||
| "Each replica group must have a unique name." | ||
| ) | ||
| return v | ||
|
|
||
| @root_validator() | ||
| def validate_scaling(cls, values): | ||
| scaling = values.get("scaling") | ||
| replicas = values.get("replicas") | ||
|
|
||
| if isinstance(replicas, Range): | ||
| if replicas and replicas.min != replicas.max and not scaling: | ||
| raise ValueError( | ||
| "When you set `replicas` to a range, ensure to specify `scaling`." | ||
| ) | ||
| if replicas and replicas.min == replicas.max and scaling: | ||
| raise ValueError("To use `scaling`, `replicas` must be set to a range.") | ||
| return values | ||
jvstme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class ServiceConfigurationConfig( | ||
| ProfileParamsConfig, | ||
|
|
@@ -849,6 +947,36 @@ class ServiceConfiguration( | |
| ): | ||
| type: Literal["service"] = "service" | ||
|
|
||
| @property | ||
| def replica_groups(self) -> Optional[List[ReplicaGroup]]: | ||
jvstme marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None. | ||
| Use this property for type-safe access in code. | ||
| """ | ||
jvstme marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.replicas is None: | ||
| return [ | ||
| ReplicaGroup( | ||
| name="default", | ||
jvstme marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| count=Range[int](min=1, max=1), | ||
| commands=self.commands or [], | ||
jvstme marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| resources=self.resources, | ||
| scaling=self.scaling, | ||
| ) | ||
| ] | ||
| if isinstance(self.replicas, list): | ||
| return self.replicas | ||
| if isinstance(self.replicas, Range): | ||
| return [ | ||
| ReplicaGroup( | ||
| name="default", | ||
| count=self.replicas, | ||
| commands=self.commands or [], | ||
| resources=self.resources, | ||
| scaling=self.scaling, | ||
| ) | ||
| ] | ||
| return None | ||
|
|
||
|
|
||
| AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration] | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.