Skip to content
Open
63 changes: 59 additions & 4 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,16 +281,38 @@ def _format_job_name(
show_deployment_num: bool,
show_replica: bool,
show_job: bool,
group_index: Optional[int] = None,
last_shown_group_index: Optional[int] = None,
) -> str:
name_parts = []
prefix = ""
if show_replica:
name_parts.append(f"replica={job.job_spec.replica_num}")
# Show group information if replica groups are used
if group_index is not None:
# Show group=X replica=Y when group changes, or just replica=Y when same group
if group_index != last_shown_group_index:
# First job in group: use 3 spaces indent
prefix = " "
name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}")
else:
# Subsequent job in same group: align "replica=" with first job's "replica="
# Calculate padding: width of " group={last_shown_group_index} "
padding_width = 3 + len(f"group={last_shown_group_index}") + 1
prefix = " " * padding_width
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
# Legacy behavior: no replica groups
prefix = " "
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
prefix = " "

if show_job:
name_parts.append(f"job={job.job_spec.job_num}")
name_suffix = (
f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else ""
)
name_value = " " + (" ".join(name_parts) if name_parts else "")
name_value = prefix + (" ".join(name_parts) if name_parts else "")
name_value += name_suffix
return name_value

Expand Down Expand Up @@ -359,6 +381,14 @@ def get_runs_table(
)
merge_job_rows = len(run.jobs) == 1 and not show_deployment_num

# Replica Group Changes: Build mapping from replica group names to indices
group_name_to_index: Dict[str, int] = {}
# Replica Group Changes: Check if replicas attribute exists (only available for ServiceConfiguration)
replicas = getattr(run.run_spec.configuration, "replicas", None)
if replicas:
for idx, group in enumerate(replicas):
group_name_to_index[group.name] = idx

run_row: Dict[Union[str, int], Any] = {
"NAME": _format_run_name(run, show_deployment_num),
"SUBMITTED": format_date(run.submitted_at),
Expand All @@ -372,13 +402,35 @@ def get_runs_table(
if not merge_job_rows:
add_row_from_dict(table, run_row)

for job in run.jobs:
# Sort jobs by group index first, then by replica_num within each group
def get_job_sort_key(job: Job) -> tuple:
group_index = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)
# Use a large number for jobs without groups to put them at the end
return (group_index if group_index is not None else 999999, job.job_spec.replica_num)

sorted_jobs = sorted(run.jobs, key=get_job_sort_key)

last_shown_group_index: Optional[int] = None
for job in sorted_jobs:
latest_job_submission = job.job_submissions[-1]
status_formatted = _format_job_submission_status(latest_job_submission, verbose)

# Get group index for this job
group_index: Optional[int] = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)

job_row: Dict[Union[str, int], Any] = {
"NAME": _format_job_name(
job, latest_job_submission, show_deployment_num, show_replica, show_job
job,
latest_job_submission,
show_deployment_num,
show_replica,
show_job,
group_index=group_index,
last_shown_group_index=last_shown_group_index,
),
"STATUS": status_formatted,
"PROBES": _format_job_probes(
Expand All @@ -390,6 +442,9 @@ def get_runs_table(
"GPU": "-",
"PRICE": "-",
}
# Update last shown group index for next iteration
if group_index is not None:
last_shown_group_index = group_index
jpd = latest_job_submission.job_provisioning_data
if jpd is not None:
shared_offer: Optional[InstanceOfferWithAvailability] = None
Expand Down
177 changes: 154 additions & 23 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel):

@root_validator
def check_image_or_commands_present(cls, values):
# If replicas is present, skip validation - commands come from replica groups
replica_groups = values.get("replicas")
if replica_groups:
return values

if not values.get("commands") and not values.get("image"):
raise ValueError("Either `commands` or `image` must be set")
return values
Expand Down Expand Up @@ -714,6 +719,85 @@ def schema_extra(schema: Dict[str, Any]):
)


class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I think we could allow to set many more properties per replica group. If the user can set commands, they may also want to set entrypoint, working_dir, image, volumes, repos, etc. And if the user can set resources, they may also want to set instance_types, spot_policy, reservation, etc.

Although it may be a good idea to leave this to a future iteration, because some properties may be non-trivial to support correctly

name: Annotated[
str,
Field(description="The name of the replica group"),
]
replicas: Annotated[
Range[int],
Field(
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
"If it's a range, the `scaling` property is required"
),
]
scaling: Annotated[
Optional[ScalingSpec],
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
] = None
probes: Annotated[
list[ProbeConfig],
Field(description="List of probes used to determine job health for this replica group"),
] = []
rate_limits: Annotated[
list[RateLimit],
Field(description="Rate limiting rules for this replica group"),
] = []
# TODO: Extract to ConfigurationWithResourcesParams mixin
resources: Annotated[
ResourcesSpec,
Field(description="The resources requirements for replicas in this group"),
] = ResourcesSpec()

@validator("replicas")
def convert_replicas(cls, v: Range[int]) -> Range[int]:
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
return v

@root_validator()
def override_commands_validation(cls, values):
"""
Override parent validator from ConfigurationWithCommandsParams.
ReplicaGroup always requires commands (no image option).
"""
commands = values.get("commands", [])
if not commands:
raise ValueError("`commands` must be set for replica groups")
return values

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
replicas = values.get("replicas")
if replicas and replicas.min != replicas.max and not scaling:
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
if replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
return values

@validator("rate_limits")
def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
counts = Counter(limit.prefix for limit in v)
duplicates = [prefix for prefix, count in counts.items() if count > 1]
if duplicates:
raise ValueError(
f"Prefixes {duplicates} are used more than once."
" Each rate limit should have a unique path prefix"
)
return v

@validator("probes")
def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
if has_duplicates(v):
raise ValueError("Probes must be unique")
return v


class ServiceConfigurationParams(CoreModel):
port: Annotated[
# NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used.
Expand Down Expand Up @@ -754,13 +838,7 @@ class ServiceConfigurationParams(CoreModel):
SERVICE_HTTPS_DEFAULT
)
auth: Annotated[bool, Field(description="Enable the authorization")] = True
replicas: Annotated[
Range[int],
Field(
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
"If it's a range, the `scaling` property is required"
),
] = Range[int](min=1, max=1)

scaling: Annotated[
Optional[ScalingSpec],
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
Expand All @@ -771,6 +849,19 @@ class ServiceConfigurationParams(CoreModel):
Field(description="List of probes used to determine job health"),
] = []

replicas: Annotated[
Optional[Union[Range[int], List[ReplicaGroup]]],
Field(
description=(
"List of replica groups. Each group defines replicas with shared configuration "
"(commands, port, resources, scaling, probes, rate_limits). "
"When specified, the top-level `replicas`, `commands`, `port`, `resources`, "
"`scaling`, `probes`, and `rate_limits` are ignored. "
"Each replica group must have a unique name."
)
),
] = None

@validator("port")
def convert_port(cls, v) -> PortMapping:
if isinstance(v, int):
Expand All @@ -785,16 +876,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]:
return OpenAIChatModel(type="chat", name=v, format="openai")
return v

@validator("replicas")
def convert_replicas(cls, v: Range[int]) -> Range[int]:
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
return v

@validator("gateway")
def validate_gateway(
cls, v: Optional[Union[bool, str]]
Expand All @@ -806,13 +887,29 @@ def validate_gateway(
return v

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
def normalize_replicas(cls, values):
replicas = values.get("replicas")
if replicas and replicas.min != replicas.max and not scaling:
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
if replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
if isinstance(replicas, list) and len(replicas) > 0:
if all(isinstance(item, ReplicaGroup) for item in replicas):
return values

# Handle backward compatibility: convert old-style replica config to groups
old_replicas = values.get("replicas")
if isinstance(old_replicas, Range):
replica_count = old_replicas
else:
replica_count = Range[int](min=1, max=1)
values["replicas"] = [
ReplicaGroup(
name="default",
replicas=replica_count,
commands=values.get("commands", []),
resources=values.get("resources"),
scaling=values.get("scaling"),
probes=values.get("probes", []),
rate_limits=values.get("rate_limits"),
)
]
return values

@validator("rate_limits")
Expand All @@ -836,6 +933,28 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
raise ValueError("Probes must be unique")
return v

@validator("replicas")
def validate_replicas(
cls, v: Optional[Union[Range[int], List[ReplicaGroup]]]
) -> Optional[Union[Range[int], List[ReplicaGroup]]]:
if v is None:
return v
if isinstance(v, Range):
return v

if isinstance(v, list):
if not v:
raise ValueError("`replicas` cannot be an empty list")
# Check for duplicate names
names = [group.name for group in v]
if len(names) != len(set(names)):
duplicates = [name for name in set(names) if names.count(name) > 1]
raise ValueError(
f"Duplicate replica group names found: {duplicates}. "
"Each replica group must have a unique name."
)
return v


class ServiceConfigurationConfig(
ProfileParamsConfig,
Expand All @@ -858,6 +977,18 @@ class ServiceConfiguration(
):
type: Literal["service"] = "service"

@property
def replica_groups(self) -> Optional[List[ReplicaGroup]]:
"""
Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None.
Use this property for type-safe access in code.
"""
if self.replicas is None:
return None
if isinstance(self.replicas, list):
return self.replicas
return None


AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class JobSpec(CoreModel):
job_num: int
job_name: str
jobs_per_replica: int = 1 # default value for backward compatibility
replica_group: Optional[str] = "default"
app_specs: Optional[List[AppSpec]]
user: Optional[UnixUser] = None # default value for backward compatibility
commands: List[str]
Expand Down
Loading
Loading