Skip to content
Open
66 changes: 62 additions & 4 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,16 +281,38 @@ def _format_job_name(
show_deployment_num: bool,
show_replica: bool,
show_job: bool,
group_index: Optional[int] = None,
last_shown_group_index: Optional[int] = None,
) -> str:
name_parts = []
prefix = ""
if show_replica:
name_parts.append(f"replica={job.job_spec.replica_num}")
# Show group information if replica groups are used
if group_index is not None:
# Show group=X replica=Y when group changes, or just replica=Y when same group
if group_index != last_shown_group_index:
# First job in group: use 3 spaces indent
prefix = " "
name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}")
else:
# Subsequent job in same group: align "replica=" with first job's "replica="
# Calculate padding: width of " group={last_shown_group_index} "
padding_width = 3 + len(f"group={last_shown_group_index}") + 1
prefix = " " * padding_width
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
# Legacy behavior: no replica groups
prefix = " "
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
prefix = " "

if show_job:
name_parts.append(f"job={job.job_spec.job_num}")
name_suffix = (
f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else ""
)
name_value = " " + (" ".join(name_parts) if name_parts else "")
name_value = prefix + (" ".join(name_parts) if name_parts else "")
name_value += name_suffix
return name_value

Expand Down Expand Up @@ -359,6 +381,17 @@ def get_runs_table(
)
merge_job_rows = len(run.jobs) == 1 and not show_deployment_num

group_name_to_index: Dict[str, int] = {}
if run.run_spec.configuration.type == "service" and hasattr(
run.run_spec.configuration, "replica_groups"
):
replica_groups = run.run_spec.configuration.replica_groups
if replica_groups:
for idx, group in enumerate(replica_groups):
# Use group name or default to "replica{idx}" if name is None
group_name = group.name or f"replica{idx}"
group_name_to_index[group_name] = idx

run_row: Dict[Union[str, int], Any] = {
"NAME": _format_run_name(run, show_deployment_num),
"SUBMITTED": format_date(run.submitted_at),
Expand All @@ -372,13 +405,35 @@ def get_runs_table(
if not merge_job_rows:
add_row_from_dict(table, run_row)

for job in run.jobs:
# Sort jobs by group index first, then by replica_num within each group
def get_job_sort_key(job: Job) -> tuple:
group_index = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)
# Use a large number for jobs without groups to put them at the end
return (group_index if group_index is not None else 999999, job.job_spec.replica_num)

sorted_jobs = sorted(run.jobs, key=get_job_sort_key)

last_shown_group_index: Optional[int] = None
for job in sorted_jobs:
latest_job_submission = job.job_submissions[-1]
status_formatted = _format_job_submission_status(latest_job_submission, verbose)

# Get group index for this job
group_index: Optional[int] = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)

job_row: Dict[Union[str, int], Any] = {
"NAME": _format_job_name(
job, latest_job_submission, show_deployment_num, show_replica, show_job
job,
latest_job_submission,
show_deployment_num,
show_replica,
show_job,
group_index=group_index,
last_shown_group_index=last_shown_group_index,
),
"STATUS": status_formatted,
"PROBES": _format_job_probes(
Expand All @@ -390,6 +445,9 @@ def get_runs_table(
"GPU": "-",
"PRICE": "-",
}
# Update last shown group index for next iteration
if group_index is not None:
last_shown_group_index = group_index
jpd = latest_job_submission.job_provisioning_data
if jpd is not None:
shared_offer: Optional[InstanceOfferWithAvailability] = None
Expand Down
182 changes: 155 additions & 27 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel):

@root_validator
def check_image_or_commands_present(cls, values):
# If replicas is present, skip validation - commands come from replica groups
replica_groups = values.get("replicas")
if replica_groups:
return values

if not values.get("commands") and not values.get("image"):
raise ValueError("Either `commands` or `image` must be set")
return values
Expand Down Expand Up @@ -714,6 +719,62 @@ def schema_extra(schema: Dict[str, Any]):
)


class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I think we could allow to set many more properties per replica group. If the user can set commands, they may also want to set entrypoint, working_dir, image, volumes, repos, etc. And if the user can set resources, they may also want to set instance_types, spot_policy, reservation, etc.

Although it may be a good idea to leave this to a future iteration, because some properties may be non-trivial to support correctly

name: Annotated[
Optional[str],
Field(
description="The name of the replica group. If not provided, defaults to 'replica0', 'replica1', etc. based on position."
),
]
count: Annotated[
Range[int],
Field(
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
"If it's a range, the `scaling` property is required"
),
]
scaling: Annotated[
Optional[ScalingSpec],
Field(description="The auto-scaling rules. Required if `count` is set to a range"),
] = None
# TODO: Extract to ConfigurationWithResourcesParams mixin
resources: Annotated[
ResourcesSpec,
Field(description="The resources requirements for replicas in this group"),
] = ResourcesSpec()

@validator("count")
def convert_count(cls, v: Range[int]) -> Range[int]:
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
return v

@root_validator()
def override_commands_validation(cls, values):
"""
Override parent validator from ConfigurationWithCommandsParams.
ReplicaGroup always requires commands (no image option).
"""
commands = values.get("commands", [])
if not commands:
raise ValueError("`commands` must be set for replica groups")
return values

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
count = values.get("count")
if count and count.min != count.max and not scaling:
raise ValueError("When you set `count` to a range, ensure to specify `scaling`.")
if count and count.min == count.max and scaling:
raise ValueError("To use `scaling`, `count` must be set to a range.")
return values


class ServiceConfigurationParams(CoreModel):
port: Annotated[
# NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used.
Expand Down Expand Up @@ -755,13 +816,7 @@ class ServiceConfigurationParams(CoreModel):
SERVICE_HTTPS_DEFAULT
)
auth: Annotated[bool, Field(description="Enable the authorization")] = True
replicas: Annotated[
Range[int],
Field(
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
"If it's a range, the `scaling` property is required"
),
] = Range[int](min=1, max=1)

scaling: Annotated[
Optional[ScalingSpec],
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
Expand All @@ -772,6 +827,19 @@ class ServiceConfigurationParams(CoreModel):
Field(description="List of probes used to determine job health"),
] = []

replicas: Annotated[
Optional[Union[Range[int], List[ReplicaGroup]]],
Field(
description=(
"List of replica groups. Each group defines replicas with shared configuration "
"(commands, port, resources, scaling, probes, rate_limits). "
"When specified, the top-level `replicas`, `commands`, `port`, `resources`, "
"`scaling`, `probes`, and `rate_limits` are ignored. "
"Each replica group must have a unique name."
)
),
] = None

@validator("port")
def convert_port(cls, v) -> PortMapping:
if isinstance(v, int):
Expand All @@ -786,26 +854,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]:
return OpenAIChatModel(type="chat", name=v, format="openai")
return v

@validator("replicas")
def convert_replicas(cls, v: Range[int]) -> Range[int]:
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
return v

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
replicas = values.get("replicas")
if replicas and replicas.min != replicas.max and not scaling:
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
if replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
return values

@validator("rate_limits")
def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
counts = Counter(limit.prefix for limit in v)
Expand All @@ -827,6 +875,56 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
raise ValueError("Probes must be unique")
return v

@validator("replicas")
def validate_replicas(
cls, v: Optional[Union[Range[int], List[ReplicaGroup]]]
) -> Optional[Union[Range[int], List[ReplicaGroup]]]:
if v is None:
return v
if isinstance(v, Range):
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError(
"The minimum number of replicas must be greater than or equal to 0"
)
return v

if isinstance(v, list):
if not v:
raise ValueError("`replicas` cannot be an empty list")

# Assign default names to groups without names
for index, group in enumerate(v):
if group.name is None:
group.name = f"replica{index}"

# Check for duplicate names
names = [group.name for group in v]
if len(names) != len(set(names)):
duplicates = [name for name in set(names) if names.count(name) > 1]
raise ValueError(
f"Duplicate replica group names found: {duplicates}. "
"Each replica group must have a unique name."
)
return v

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
replicas = values.get("replicas")

if isinstance(replicas, Range):
if replicas and replicas.min != replicas.max and not scaling:
raise ValueError(
"When you set `replicas` to a range, ensure to specify `scaling`."
)
if replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
return values


class ServiceConfigurationConfig(
ProfileParamsConfig,
Expand All @@ -849,6 +947,36 @@ class ServiceConfiguration(
):
type: Literal["service"] = "service"

@property
def replica_groups(self) -> Optional[List[ReplicaGroup]]:
"""
Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None.
Use this property for type-safe access in code.
"""
if self.replicas is None:
return [
ReplicaGroup(
name="default",
count=Range[int](min=1, max=1),
commands=self.commands or [],
resources=self.resources,
scaling=self.scaling,
)
]
if isinstance(self.replicas, list):
return self.replicas
if isinstance(self.replicas, Range):
return [
ReplicaGroup(
name="default",
count=self.replicas,
commands=self.commands or [],
resources=self.resources,
scaling=self.scaling,
)
]
return None


AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class JobSpec(CoreModel):
job_num: int
job_name: str
jobs_per_replica: int = 1 # default value for backward compatibility
replica_group: str = "default"
app_specs: Optional[List[AppSpec]]
user: Optional[UnixUser] = None # default value for backward compatibility
commands: List[str]
Expand Down
Loading