Skip to content

Commit ca8edb7

Browse files
authored
refactor(iris): unify worker config around WorkerConfig proto (#3077)
Unify worker bootstrap and configuration. - Replace `BootstrapConfig` with a single `WorkerConfig` proto message that carries all worker configuration from autoscaler through bootstrap to the worker process - Eliminate scattered env var encoding (`IRIS_ACCELERATOR_TYPE`, `IRIS_WORKER_ATTRIBUTES`, etc.) — autoscaler now sets proto fields directly - Replace 8 CLI flags on `serve` with a single `--worker-config` JSON path - Split `env_probe.py` into pure hardware probing (`probe_hardware()`) and config-aware metadata building (`build_worker_metadata()`) - Update all platform implementations (GCP, Manual, CoreWeave, Local) and example YAML configs
1 parent 49022ac commit ca8edb7

35 files changed

+882
-839
lines changed

lib/iris/AGENTS.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ for GHCR. See `docs/image-push.md` for full details.
232232
image tags to the AR remote repo for the VM's continent:
233233
- `ghcr.io/org/image:v1``us-docker.pkg.dev/project/ghcr-mirror/org/image:v1`
234234

235-
Set `defaults.bootstrap.docker_image` to a `ghcr.io/...` tag. Non-GHCR tags
235+
Set `defaults.worker.docker_image` to a `ghcr.io/...` tag. Non-GHCR tags
236236
(`docker.io`, existing AR tags) pass through unchanged.
237237

238238
**Bundle storage** (`controller.bundle_prefix`) is a GCS URI with no zone
@@ -251,8 +251,14 @@ Iris follows a clean layering architecture:
251251
- Owns autoscaling logic and scaling group state
252252

253253
**Platform layer** (`cluster/platform/`): Platform abstractions for managing VMs
254-
- Provides VM lifecycle management (GCP, manual, local, CoreWeave)
255254
- Does NOT depend on controller layer
255+
- Four platform implementations with independent launch/teardown paths:
256+
- `gcp.py` — GCP TPU/VM slices, SSH bootstrap
257+
- `coreweave.py` — CoreWeave CKS, Kubernetes Pods on shared NodePools
258+
- `manual.py` — Pre-existing hosts, SSH bootstrap
259+
- `local.py` — Local development, in-process workers
260+
- Changes to shared interfaces (worker CLI, bootstrap flow, proto schemas)
261+
must be applied to all four platforms
256262

257263
**Cluster layer** (`cluster/`): High-level orchestration
258264
- `connect_cluster()` and `stop_all()` free functions for cluster lifecycle

lib/iris/examples/coreweave.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ defaults:
5858
milliseconds: 300000
5959
startup_grace_period:
6060
milliseconds: 2400000 # 40 min — covers autoscaler node provisioning + Pod startup
61-
default_task_image: ghcr.io/marin-community/iris-task:latest
62-
bootstrap:
61+
worker:
6362
docker_image: ghcr.io/marin-community/iris-worker:latest
64-
worker_port: 10001
63+
port: 10001
6564
cache_dir: /mnt/local/iris-cache
6665
runtime: kubernetes
66+
default_task_image: ghcr.io/marin-community/iris-task:latest
6767

6868
scale_groups:
6969
# CPU general-purpose — used for data processing, orchestration, etc.

lib/iris/examples/demo.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ platform:
88
project_id: hai-gcp-models
99

1010
defaults:
11-
default_task_image: ghcr.io/marin-community/iris-task:latest
1211
autoscaler:
1312
evaluation_interval:
1413
milliseconds: 10000
1514
scale_up_delay:
1615
milliseconds: 60000
1716
scale_down_delay:
1817
milliseconds: 300000
19-
bootstrap:
18+
worker:
2019
docker_image: ghcr.io/marin-community/iris-worker:latest
21-
worker_port: 10001
20+
default_task_image: ghcr.io/marin-community/iris-task:latest
21+
port: 10001
2222
controller_address: "${IRIS_CONTROLLER_ADDRESS}"
2323

2424
storage:

lib/iris/examples/marin.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ platform:
77
project_id: hai-gcp-models
88

99
defaults:
10-
default_task_image: ghcr.io/marin-community/iris-task:latest
1110
autoscaler:
1211
evaluation_interval:
1312
milliseconds: 10000
1413
scale_up_delay:
1514
milliseconds: 60000
1615
scale_down_delay:
1716
milliseconds: 300000
18-
bootstrap:
17+
worker:
1918
docker_image: ghcr.io/marin-community/iris-worker:latest
20-
worker_port: 10001
19+
default_task_image: ghcr.io/marin-community/iris-task:latest
20+
port: 10001
2121

2222
storage:
2323
bundle_prefix: gs://marin-us-central2/tmp/iris/bundles

lib/iris/examples/smoke.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ platform:
77
project_id: hai-gcp-models
88

99
defaults:
10-
default_task_image: ghcr.io/marin-community/iris-task:latest
1110
autoscaler:
1211
evaluation_interval:
1312
milliseconds: 10000
1413
scale_up_delay:
1514
milliseconds: 60000
1615
scale_down_delay:
1716
milliseconds: 300000
18-
bootstrap:
17+
worker:
1918
docker_image: ghcr.io/marin-community/iris-worker:latest
20-
worker_port: 10001
19+
default_task_image: ghcr.io/marin-community/iris-task:latest
20+
port: 10001
2121

2222
storage:
2323
bundle_prefix: gs://marin-us-central2/tmp/iris/bundles

lib/iris/scripts/smoke-test.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
# Custom per-job timeout
2828
uv run python scripts/smoke-test.py --job-timeout 120
2929
30+
# Local mode: in-process controller and workers (no cloud VMs)
31+
uv run python scripts/smoke-test.py --mode local
32+
3033
# Keep cluster running on failure for debugging
3134
uv run python scripts/smoke-test.py --mode keep
3235
@@ -697,8 +700,11 @@ class SmokeTestConfig:
697700
accelerator: AcceleratorConfig
698701
boot_timeout_seconds: int = DEFAULT_BOOT_TIMEOUT
699702
job_timeout_seconds: int = DEFAULT_JOB_TIMEOUT
700-
local: bool = False # Run locally without GCP
701-
mode: Literal["full", "keep", "redeploy"] = "full"
703+
mode: Literal["full", "keep", "redeploy", "local"] = "full"
704+
705+
@property
706+
def local(self) -> bool:
707+
return self.mode == "local"
702708

703709

704710
# =============================================================================
@@ -769,7 +775,7 @@ def run(self) -> bool:
769775
controller_url: str | None = None
770776

771777
if self.config.mode != "redeploy":
772-
if self.config.mode in ("full", "keep") and not self.config.local:
778+
if self.config.mode in ("full", "keep"):
773779
_log_section("PHASE 0: Clean Start")
774780
self._cleanup_existing()
775781
if self._interrupted:
@@ -837,10 +843,10 @@ def _print_header(self):
837843
logger.info("=" * 60)
838844
logger.info("")
839845
logger.info("Config: %s", self.config.config_path)
846+
logger.info("Mode: %s", self.config.mode)
840847
logger.info("Boot timeout: %ds", self.config.boot_timeout_seconds)
841848
logger.info("Job timeout: %ds", self.config.job_timeout_seconds)
842849
logger.info("Accelerator: %s (%s)", self._accel.label(), self._accel.device_type)
843-
logger.info("Local: %s", self.config.local)
844850

845851
# ----- Cluster lifecycle via CLI -----
846852

@@ -1479,22 +1485,17 @@ def _cleanup(self):
14791485
)
14801486
@click.option(
14811487
"--mode",
1482-
type=click.Choice(["full", "keep", "redeploy"]),
1488+
type=click.Choice(["full", "keep", "redeploy", "local"]),
14831489
default="full",
14841490
show_default=True,
1485-
help="Execution mode: 'full' (clean start + teardown), 'keep' (clean start + keep VMs), 'redeploy' (reuse VMs)",
1486-
)
1487-
@click.option(
1488-
"--local",
1489-
is_flag=True,
1490-
help="Run locally without GCP (in-process controller and workers)",
1491+
help="Execution mode: 'full' (clean start + teardown), 'keep' (clean start + keep VMs), "
1492+
"'redeploy' (reuse VMs), 'local' (in-process controller and workers, no cloud VMs)",
14911493
)
14921494
def main(
14931495
config_path: Path,
14941496
boot_timeout_seconds: int,
14951497
job_timeout_seconds: int,
14961498
mode: str,
1497-
local: bool,
14981499
):
14991500
"""Run Iris cluster autoscaling smoke test.
15001501
@@ -1507,6 +1508,9 @@ def main(
15071508
# Basic smoke test (uses examples/smoke.yaml by default)
15081509
uv run python scripts/smoke-test.py
15091510
1511+
# Local mode: in-process controller and workers
1512+
uv run python scripts/smoke-test.py --mode local
1513+
15101514
# CoreWeave GPU smoke test
15111515
uv run python scripts/smoke-test.py --config examples/coreweave.yaml
15121516
@@ -1527,7 +1531,6 @@ def main(
15271531
boot_timeout_seconds=boot_timeout_seconds,
15281532
job_timeout_seconds=job_timeout_seconds,
15291533
mode=mode, # type: ignore
1530-
local=local,
15311534
)
15321535

15331536
runner = SmokeTestRunner(config)

lib/iris/src/iris/cli/cluster.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,12 @@ def _build_and_push_task_image(task_tag: str, verbose: bool = False) -> None:
150150
def _build_cluster_images(config, verbose: bool = False) -> dict[str, str]:
151151
built: dict[str, str] = {}
152152

153-
for tag, typ in [(config.defaults.bootstrap.docker_image, "worker"), (config.controller.image, "controller")]:
153+
for tag, typ in [(config.defaults.worker.docker_image, "worker"), (config.controller.image, "controller")]:
154154
if tag:
155155
_build_and_push_for_tag(tag, typ, verbose=verbose)
156156
built[typ] = tag
157157

158-
task_tag = config.defaults.default_task_image
158+
task_tag = config.defaults.worker.default_task_image
159159
if task_tag:
160160
_build_and_push_task_image(task_tag, verbose=verbose)
161161
built["task"] = task_tag
@@ -175,8 +175,8 @@ def _pin_tag(tag: str | None, git_sha: str) -> str | None:
175175

176176
tags = {
177177
"controller": config.controller.image,
178-
"worker": config.defaults.bootstrap.docker_image,
179-
"task": config.defaults.default_task_image,
178+
"worker": config.defaults.worker.docker_image,
179+
"task": config.defaults.worker.default_task_image,
180180
}
181181
needs_pin = any(tag.endswith(":latest") for tag in tags.values() if tag)
182182
if not needs_pin:
@@ -188,9 +188,9 @@ def _pin_tag(tag: str | None, git_sha: str) -> str | None:
188188
if pinned["controller"]:
189189
config.controller.image = pinned["controller"]
190190
if pinned["worker"]:
191-
config.defaults.bootstrap.docker_image = pinned["worker"]
191+
config.defaults.worker.docker_image = pinned["worker"]
192192
if pinned["task"]:
193-
config.defaults.default_task_image = pinned["task"]
193+
config.defaults.worker.default_task_image = pinned["task"]
194194

195195
click.echo("Pinning :latest image tags to git SHA for this run:")
196196
for name, tag in pinned.items():

lib/iris/src/iris/cluster/config.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@
4545
scale_up_delay=Duration.from_seconds(60).to_proto(),
4646
scale_down_delay=Duration.from_seconds(300).to_proto(),
4747
),
48-
bootstrap=config_pb2.BootstrapConfig(
49-
worker_port=10001,
48+
worker=config_pb2.WorkerConfig(
49+
port=10001,
5050
cache_dir="/var/cache/iris",
51+
host="0.0.0.0",
52+
port_range="30000-40000",
5153
),
5254
)
5355

@@ -248,33 +250,31 @@ def validate_config(config: config_pb2.IrisClusterConfig) -> None:
248250
_validate_scale_group_resources(config)
249251
_validate_slice_templates(config)
250252
_validate_worker_settings(config)
251-
_validate_bootstrap_defaults(config)
253+
_validate_worker_defaults(config)
252254

253255

254-
def _validate_bootstrap_defaults(config: config_pb2.IrisClusterConfig) -> None:
255-
"""Validate bootstrap defaults required for worker-based platforms.
256+
def _validate_worker_defaults(config: config_pb2.IrisClusterConfig) -> None:
257+
"""Validate worker defaults required for worker-based platforms.
256258
257-
Local platform runs workers in-process and does not require bootstrap image/runtime.
259+
Local platform runs workers in-process and does not require a docker image/runtime.
258260
GCP/manual/CoreWeave create remote worker processes and must provide a worker image.
259261
"""
260262
# Some unit tests validate partial proto configs directly (without load_config/apply_defaults).
261-
# Only enforce bootstrap image checks once defaults/platform are explicitly present.
263+
# Only enforce worker image checks once defaults/platform are explicitly present.
262264
if not config.HasField("defaults"):
263265
return
264266

265267
platform_kind = config.platform.WhichOneof("platform")
266268
if platform_kind in (None, "local"):
267269
return
268270

269-
docker_image = config.defaults.bootstrap.docker_image.strip()
271+
docker_image = config.defaults.worker.docker_image.strip()
270272
if not docker_image:
271-
raise ValueError(
272-
"defaults.bootstrap.docker_image is required for non-local platforms " "(gcp/manual/coreweave)."
273-
)
273+
raise ValueError("defaults.worker.docker_image is required for non-local platforms (gcp/manual/coreweave).")
274274

275-
runtime = config.defaults.bootstrap.runtime.strip()
275+
runtime = config.defaults.worker.runtime.strip()
276276
if runtime and runtime not in {"docker", "kubernetes"}:
277-
raise ValueError(f"defaults.bootstrap.runtime must be one of docker/kubernetes, got {runtime!r}.")
277+
raise ValueError(f"defaults.worker.runtime must be one of docker/kubernetes, got {runtime!r}.")
278278

279279

280280
def _scale_groups_to_config(scale_groups: dict[str, config_pb2.ScaleGroupConfig]) -> config_pb2.IrisClusterConfig:
@@ -326,16 +326,16 @@ def _merge_proto_fields(target, source) -> None:
326326
def _deep_merge_defaults(target: config_pb2.DefaultsConfig, source: config_pb2.DefaultsConfig) -> None:
327327
"""Deep merge source defaults into target, field by field.
328328
329-
Sub-messages (timeouts, ssh, autoscaler, bootstrap) are merged field-by-field
329+
Sub-messages (timeouts, ssh, autoscaler, worker) are merged field-by-field
330330
so that partially-specified user configs overlay hardcoded defaults without
331-
wiping unset siblings. Top-level scalar fields (e.g. default_task_image) are
332-
merged via _merge_proto_fields which copies any explicitly-set value.
331+
wiping unset siblings. Top-level scalar fields are merged via
332+
_merge_proto_fields which copies any explicitly-set value.
333333
334334
Args:
335335
target: DefaultsConfig to merge into (modified in place)
336336
source: DefaultsConfig to merge from
337337
"""
338-
# Merge top-level scalar fields (e.g. default_task_image).
338+
# Merge top-level scalar fields.
339339
# We skip message fields here since sub-messages need deep merging below.
340340
for field_desc in source.DESCRIPTOR.fields:
341341
if field_desc.message_type is not None:
@@ -348,11 +348,13 @@ def _deep_merge_defaults(target: config_pb2.DefaultsConfig, source: config_pb2.D
348348
_merge_proto_fields(target.ssh, source.ssh)
349349
if source.HasField("autoscaler"):
350350
_merge_proto_fields(target.autoscaler, source.autoscaler)
351-
if source.HasField("bootstrap"):
352-
_merge_proto_fields(target.bootstrap, source.bootstrap)
353-
# Merge env_vars map separately (map fields don't use HasField)
354-
for key, value in source.bootstrap.env_vars.items():
355-
target.bootstrap.env_vars[key] = value
351+
if source.HasField("worker"):
352+
_merge_proto_fields(target.worker, source.worker)
353+
# Merge map fields separately (map fields don't support HasField)
354+
for key, value in source.worker.default_task_env.items():
355+
target.worker.default_task_env[key] = value
356+
for key, value in source.worker.worker_attributes.items():
357+
target.worker.worker_attributes[key] = value
356358

357359

358360
def _validate_autoscaler_config(config: config_pb2.AutoscalerConfig, context: str = "autoscaler") -> None:
@@ -619,12 +621,10 @@ def load_config(config_path: Path | str) -> config_pb2.IrisClusterConfig:
619621
# Expand environment variables in controller_address only.
620622
# Other fields (e.g., docker_image, ssh.key_file) are used as-is.
621623
# This is intentional - controller_address often needs $IRIS_CONTROLLER_ADDRESS for dynamic discovery.
622-
if "bootstrap" in data and "controller_address" in data["bootstrap"]:
623-
data["bootstrap"]["controller_address"] = os.path.expandvars(data["bootstrap"]["controller_address"])
624-
if "defaults" in data and "bootstrap" in data["defaults"]:
625-
defaults_bootstrap = data["defaults"]["bootstrap"]
626-
if "controller_address" in defaults_bootstrap:
627-
defaults_bootstrap["controller_address"] = os.path.expandvars(defaults_bootstrap["controller_address"])
624+
if "defaults" in data and "worker" in data["defaults"]:
625+
defaults_worker = data["defaults"]["worker"]
626+
if "controller_address" in defaults_worker:
627+
defaults_worker["controller_address"] = os.path.expandvars(defaults_worker["controller_address"])
628628

629629
_normalize_scale_group_resources(data)
630630
_expand_multi_zone_groups(data)
@@ -898,15 +898,15 @@ def as_local(self) -> IrisConfig:
898898
return IrisConfig(local_proto)
899899

900900
def controller_address(self) -> str:
901-
"""Get controller address from bootstrap config, if set.
901+
"""Get controller address from worker config, if set.
902902
903903
Returns:
904904
Controller address string, or empty string if not configured
905905
"""
906906
# TODO: Derive controller address from controller.manual/local when unset.
907-
bootstrap = self._proto.defaults.bootstrap
908-
if bootstrap.HasField("controller_address"):
909-
return bootstrap.controller_address
907+
worker = self._proto.defaults.worker
908+
if worker.HasField("controller_address"):
909+
return worker.controller_address
910910
return ""
911911

912912

@@ -915,7 +915,7 @@ def create_autoscaler(
915915
autoscaler_config: config_pb2.AutoscalerConfig,
916916
scale_groups: dict[str, config_pb2.ScaleGroupConfig],
917917
label_prefix: str,
918-
bootstrap_config: config_pb2.BootstrapConfig | None = None,
918+
base_worker_config: config_pb2.WorkerConfig | None = None,
919919
threads: ThreadContainer | None = None,
920920
):
921921
"""Create autoscaler from Platform and explicit config.
@@ -925,7 +925,7 @@ def create_autoscaler(
925925
autoscaler_config: Autoscaler settings (already resolved with defaults)
926926
scale_groups: Map of scale group name to config
927927
label_prefix: Prefix for labels on managed resources
928-
bootstrap_config: Worker bootstrap settings passed through to platform.create_slice().
928+
base_worker_config: Base worker configuration passed through to platform.create_slice().
929929
None disables bootstrap (test/local mode).
930930
threads: Thread container for background threads. Uses global default if not provided.
931931
@@ -975,5 +975,5 @@ def create_autoscaler(
975975
scale_groups=scaling_groups,
976976
config=autoscaler_config,
977977
platform=platform,
978-
bootstrap_config=bootstrap_config,
978+
base_worker_config=base_worker_config,
979979
)

0 commit comments

Comments
 (0)