diff --git a/kubeflow/trainer/backends/container/adapters/base.py b/kubeflow/trainer/backends/container/adapters/base.py index 3e38a6b89..15bb3fb2a 100644 --- a/kubeflow/trainer/backends/container/adapters/base.py +++ b/kubeflow/trainer/backends/container/adapters/base.py @@ -193,3 +193,22 @@ def get_network(self, network_id: str) -> Optional[dict]: Dictionary with network info including labels, or None if not found """ raise NotImplementedError() + + @abc.abstractmethod + def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int: + """ + Wait for a container to exit and return its exit code. + + This is a blocking call that waits until the container stops. + + Args: + container_id: Container ID + timeout: Maximum time to wait in seconds, or None to wait indefinitely + + Returns: + Container exit code + + Raises: + TimeoutError: If timeout is reached before container exits + """ + raise NotImplementedError() diff --git a/kubeflow/trainer/backends/container/adapters/docker.py b/kubeflow/trainer/backends/container/adapters/docker.py index ada446bcc..d6e5c45f9 100644 --- a/kubeflow/trainer/backends/container/adapters/docker.py +++ b/kubeflow/trainer/backends/container/adapters/docker.py @@ -227,3 +227,31 @@ def get_network(self, network_id: str) -> Optional[dict]: } except Exception: return None + + def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int: + """ + Wait for a Docker container to exit and return its exit code. + + Args: + container_id: Container ID + timeout: Maximum time to wait in seconds, or None to wait indefinitely + + Returns: + Container exit code + + Raises: + TimeoutError: If timeout is reached before container exits + """ + try: + container = self.get_container(container_id) + result = container.wait(timeout=timeout) + # Docker wait() returns a dict with 'StatusCode' key + if isinstance(result, dict): + return result.get("StatusCode", 0) + return int(result) + except Exception as e: + if "timeout" in str(e).lower(): + raise TimeoutError( + f"Container {container_id} did not exit within {timeout} seconds" + ) from e + raise diff --git a/kubeflow/trainer/backends/container/adapters/podman.py b/kubeflow/trainer/backends/container/adapters/podman.py index 585c41343..bb00c4834 100644 --- a/kubeflow/trainer/backends/container/adapters/podman.py +++ b/kubeflow/trainer/backends/container/adapters/podman.py @@ -254,3 +254,29 @@ def get_network(self, network_id: str) -> Optional[dict]: } except Exception: return None + + def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int: + """ + Wait for a Podman container to exit and return its exit code. + + Args: + container_id: Container ID + timeout: Maximum time to wait in seconds, or None to wait indefinitely + + Returns: + Container exit code + + Raises: + TimeoutError: If timeout is reached before container exits + """ + try: + container = self.get_container(container_id) + result = container.wait(timeout=timeout) + # Podman wait() returns exit code directly + return int(result) + except Exception as e: + if "timeout" in str(e).lower(): + raise TimeoutError( + f"Container {container_id} did not exit within {timeout} seconds" + ) from e + raise diff --git a/kubeflow/trainer/backends/container/backend.py b/kubeflow/trainer/backends/container/backend.py index 3a3176e09..45c6511ba 100644 --- a/kubeflow/trainer/backends/container/backend.py +++ b/kubeflow/trainer/backends/container/backend.py @@ -196,6 +196,35 @@ def _runtime_type(self) -> str: """Get the runtime type for debugging/logging.""" return self._adapter._runtime_type + def _cleanup_container_resources( + self, + container_ids: Optional[list[str]] = None, + network_id: Optional[str] = None, + stop_timeout: int = 5, + ): + """ + Clean up container resources in a best-effort manner. + + Args: + container_ids: List of container IDs to stop and remove. + network_id: Network ID to delete. + stop_timeout: Timeout in seconds for stopping containers. + """ + from contextlib import suppress + + # Stop and remove containers + if container_ids: + for container_id in container_ids: + with suppress(Exception): + self._adapter.stop_container(container_id, timeout=stop_timeout) + with suppress(Exception): + self._adapter.remove_container(container_id, force=True) + + # Delete network + if network_id: + with suppress(Exception): + self._adapter.delete_network(network_id) + # ---- Runtime APIs ---- def list_runtimes(self) -> list[types.Runtime]: return list_training_runtimes_from_sources(self.cfg.runtime_source.sources) @@ -257,6 +286,35 @@ def train( workdir = container_utils.create_workdir(trainjob_name) logger.debug(f"Created working directory: {workdir}") + # Create network for multi-node communication and initializers + num_nodes = trainer.num_nodes or runtime.trainer.num_nodes or 1 + logger.debug(f"Creating network for {num_nodes} nodes") + + network_id = self._adapter.create_network( + name=f"{trainjob_name}-net", + labels={ + f"{self.label_prefix}/trainjob-name": trainjob_name, + f"{self.label_prefix}/runtime-name": runtime.name, + f"{self.label_prefix}/workdir": workdir, + }, + ) + logger.debug(f"Created network: {network_id}") + + # Run initializers if configured + if initializer: + logger.debug("Running initializers") + try: + self._run_initializers(trainjob_name, initializer, workdir, network_id) + logger.debug("Initializers completed successfully") + except Exception as e: + # Clean up network if initializers fail + logger.error(f"Initializer failed, cleaning up network: {e}") + from contextlib import suppress + + with suppress(Exception): + self._adapter.delete_network(network_id) + raise + # Generate training script code (inline, not written to disk) training_script_code = container_utils.get_training_script_code(trainer) logger.debug("Generated training script code") @@ -275,10 +333,6 @@ def train( # Construct pre-run command to install packages pre_install_cmd = container_utils.build_pip_install_cmd(trainer) - # Create network for multi-node communication - num_nodes = trainer.num_nodes or runtime.trainer.num_nodes or 1 - logger.debug(f"Creating network for {num_nodes} nodes") - # Determine number of processes per node from GPU count # For GPU training: spawn one process per GPU for optimal utilization # For CPU training: use single process (PyTorch parallelizes internally via threads) @@ -295,16 +349,6 @@ def train( else: logger.debug("No GPU specified, using 1 process per node") - network_id = self._adapter.create_network( - name=f"{trainjob_name}-net", - labels={ - f"{self.label_prefix}/trainjob-name": trainjob_name, - f"{self.label_prefix}/runtime-name": runtime.name, - f"{self.label_prefix}/workdir": workdir, - }, - ) - logger.debug(f"Created network: {network_id}") - # Create N containers (one per node) container_ids: list[str] = [] master_container_id = None @@ -421,20 +465,14 @@ def train( logger.exception("Full traceback:") # Try to clean up any resources that were created - from contextlib import suppress - try: # Stop and remove any containers that were created if "container_ids" in locals(): - for container_id in container_ids: - with suppress(Exception): - self._adapter.stop_container(container_id, timeout=5) - self._adapter.remove_container(container_id, force=True) - - # Remove network if it was created - if "network_id" in locals(): - with suppress(Exception): - self._adapter.delete_network(network_id) + self._cleanup_container_resources( + container_ids=container_ids, + network_id=network_id if "network_id" in locals() else None, + stop_timeout=5, + ) # Remove working directory if it was created if "workdir" in locals() and os.path.isdir(workdir): @@ -467,6 +505,153 @@ def _get_job_containers(self, name: str) -> list[dict]: return containers + def _run_initializers( + self, + job_name: str, + initializer: types.Initializer, + workdir: str, + network_id: str, + ): + """ + Run dataset and model initializers before training starts. + + Args: + job_name: Name of the training job. + initializer: Initializer configuration. + workdir: Working directory path on host. + network_id: Network ID for containers. + + Raises: + RuntimeError: If initializer fails to complete successfully. + """ + # Get initializer image + init_image = container_utils.get_initializer_image(self.cfg) + + # Pull initializer image if needed + container_utils.maybe_pull_image(self._adapter, init_image, self.cfg.pull_policy) + + # Run dataset initializer if configured + if initializer.dataset: + logger.debug("Running dataset initializer") + self._run_single_initializer( + job_name=job_name, + initializer_config=initializer.dataset, + init_type="dataset", + image=init_image, + workdir=workdir, + network_id=network_id, + ) + logger.debug("Dataset initializer completed") + + # Run model initializer if configured + if initializer.model: + logger.debug("Running model initializer") + self._run_single_initializer( + job_name=job_name, + initializer_config=initializer.model, + init_type="model", + image=init_image, + workdir=workdir, + network_id=network_id, + ) + logger.debug("Model initializer completed") + + def _run_single_initializer( + self, + job_name: str, + initializer_config: types.BaseInitializer, + init_type: str, + image: str, + workdir: str, + network_id: str, + ): + """ + Run a single initializer container and wait for completion. + + Args: + job_name: Name of the training job. + initializer_config: Initializer configuration. + init_type: Type of initializer ("dataset" or "model"). + image: Container image to use. + workdir: Working directory path on host. + network_id: Network ID for containers. + + Raises: + RuntimeError: If initializer fails. + """ + container_name = f"{job_name}-{init_type}-initializer" + + # Build command and environment + command = container_utils.build_initializer_command(initializer_config, init_type) + env = container_utils.build_initializer_env(initializer_config, init_type) + + # Create labels for tracking + labels = { + f"{self.label_prefix}/trainjob-name": job_name, + f"{self.label_prefix}/step": f"{init_type}-initializer", + f"{self.label_prefix}/network-id": network_id, + } + + # Mount the shared volume + volumes = { + workdir: { + "bind": constants.WORKSPACE_PATH, + "mode": "rw", + } + } + + logger.debug(f"Starting {init_type} initializer container: {container_name}") + + # Create and start the initializer container + container_id = self._adapter.create_and_start_container( + image=image, + command=command, + name=container_name, + network_id=network_id, + environment=env, + labels=labels, + volumes=volumes, + working_dir=constants.WORKSPACE_PATH, + ) + + logger.debug(f"Initializer container started: {container_id[:12]}") + + # Wait for the initializer to complete + try: + # Use the wait API for efficient waiting + exit_code = self._adapter.wait_for_container( + container_id, timeout=self.cfg.initializer_timeout + ) + + if exit_code == 0: + logger.debug(f"{init_type} initializer completed successfully") + # Clean up the successful container + self._cleanup_container_resources(container_ids=[container_id], stop_timeout=0) + return + else: + # Get logs for debugging + logs = list(self._adapter.container_logs(container_id, follow=False)) + error_msg = ( + f"{init_type} initializer failed with exit code {exit_code}. " + f"Logs: {' '.join(logs[-10:]) if logs else 'No logs available'}" + ) + raise RuntimeError(error_msg) + + except TimeoutError: + logger.error( + f"{init_type} initializer did not complete within " + f"{self.cfg.initializer_timeout} seconds" + ) + # Clean up the timed-out container + self._cleanup_container_resources(container_ids=[container_id], stop_timeout=5) + raise + + except Exception as e: + logger.error(f"Error running {init_type} initializer: {e}") + # Clean up the failed container + self._cleanup_container_resources(container_ids=[container_id], stop_timeout=5) + raise + def __get_trainjob_from_containers( self, job_name: str, containers: list[dict] ) -> types.TrainJob: @@ -528,8 +713,8 @@ def __get_trainjob_from_containers( ) ) - # Get num_nodes from container count - num_nodes = len(containers) + # Count only training nodes (not initializers) for num_nodes + num_nodes = sum(1 for step in steps if step.name.startswith(constants.NODE)) return types.TrainJob( name=job_name, @@ -595,11 +780,20 @@ def get_job_logs( """Get logs for a training job by querying container runtime.""" containers = self._get_job_containers(name) - want_all = step == constants.NODE + "-0" + # Check if requesting logs from all node containers (default behavior) + want_all_nodes = step == constants.NODE + "-0" + for container in sorted(containers, key=lambda c: c["name"]): container_step = container["labels"].get(f"{self.label_prefix}/step", "") - if not want_all and container_step != step: + + # If want_all_nodes, only show node containers, not initializers + if want_all_nodes: + if not container_step.startswith(constants.NODE): + continue + # Otherwise, match the specific step (could be initializer or node) + elif container_step != step: continue + try: yield from self._adapter.container_logs(container["id"], follow) except Exception as e: @@ -642,18 +836,12 @@ def delete_job(self, name: str): workdir_host = network_labels.get(f"{self.label_prefix}/workdir") # Stop containers and remove - from contextlib import suppress - - for container in containers: - with suppress(Exception): - self._adapter.stop_container(container["id"], timeout=10) - with suppress(Exception): - self._adapter.remove_container(container["id"], force=True) - - # Remove network (best-effort) - if network_id: - with suppress(Exception): - self._adapter.delete_network(network_id) + container_ids = [c["id"] for c in containers] + self._cleanup_container_resources( + container_ids=container_ids, + network_id=network_id, + stop_timeout=10, + ) # Remove working directory if configured if self.cfg.auto_remove and workdir_host and os.path.isdir(workdir_host): diff --git a/kubeflow/trainer/backends/container/backend_test.py b/kubeflow/trainer/backends/container/backend_test.py index ac13ca0f5..f41c9c305 100644 --- a/kubeflow/trainer/backends/container/backend_test.py +++ b/kubeflow/trainer/backends/container/backend_test.py @@ -197,6 +197,31 @@ def get_network(self, network_id: str) -> Optional[dict]: } return None + def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int: + """ + Wait for a container to exit and return its exit code. + + For testing, immediately returns the container's exit code if it has exited, + or raises TimeoutError if the container is still running. + + Args: + container_id: Container ID + timeout: Maximum time to wait in seconds (not used in mock) + + Returns: + Container exit code + + Raises: + TimeoutError: If container is still running + """ + for container in self.containers_created: + if container["id"] == container_id: + if container["status"] == "exited": + return container.get("exit_code", 0) + # In mock, if not exited, simulate timeout + raise TimeoutError(f"Container {container_id} did not exit within timeout") + raise RuntimeError(f"Container {container_id} not found") + # Fixtures @pytest.fixture @@ -861,3 +886,333 @@ def test_create_adapter_error_message_format(): error_msg = str(exc_info.value) assert "Could not connect" in error_msg assert "tried:" in error_msg + + +# Tests for Initializer Support +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="train with HuggingFace dataset initializer", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "initializer": types.Initializer( + dataset=types.HuggingFaceDatasetInitializer( + storage_uri="hf://username/dataset-repo", + access_token="hf_token_123", + ) + ), + "expected_containers": 2, # 1 dataset-initializer + 1 training node + "expected_initializer_type": "dataset", + }, + ), + TestCase( + name="train with S3 dataset initializer", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "initializer": types.Initializer( + dataset=types.S3DatasetInitializer( + storage_uri="s3://my-bucket/dataset", + endpoint="https://s3.amazonaws.com", + region="us-west-2", + ) + ), + "expected_containers": 2, # 1 dataset-initializer + 1 training node + "expected_initializer_type": "dataset", + }, + ), + TestCase( + name="train with HuggingFace model initializer", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "initializer": types.Initializer( + model=types.HuggingFaceModelInitializer( + storage_uri="hf://username/model-repo", + access_token="hf_token_456", + ignore_patterns=["*.bin", "*.h5"], + ) + ), + "expected_containers": 2, # 1 model-initializer + 1 training node + "expected_initializer_type": "model", + }, + ), + TestCase( + name="train with S3 model initializer", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "initializer": types.Initializer( + model=types.S3ModelInitializer( + storage_uri="s3://my-bucket/model", + endpoint="https://s3.amazonaws.com", + access_key_id="my_access_key", + secret_access_key="my_secret_key", + ) + ), + "expected_containers": 2, # 1 model-initializer + 1 training node + "expected_initializer_type": "model", + }, + ), + TestCase( + name="train with both dataset and model initializers", + expected_status=SUCCESS, + config={ + "num_nodes": 2, + "initializer": types.Initializer( + dataset=types.HuggingFaceDatasetInitializer( + storage_uri="hf://username/dataset-repo" + ), + model=types.HuggingFaceModelInitializer(storage_uri="hf://username/model-repo"), + ), + "expected_containers": 4, # 1 dataset + 1 model + 2 training nodes + }, + ), + TestCase( + name="train with DataCache initializer", + expected_status=SUCCESS, + config={ + "num_nodes": 1, + "initializer": types.Initializer( + dataset=types.DataCacheInitializer( + storage_uri="cache://schema/table", + metadata_loc="s3://bucket/metadata.json", + num_data_nodes=3, + head_cpu="2", + head_mem="4Gi", + ) + ), + "expected_containers": 2, # 1 datacache-initializer + 1 training node + "expected_initializer_type": "dataset", + }, + ), + ], +) +def test_train_with_initializers(container_backend, test_case): + """Test training job creation with dataset and model initializers.""" + print("Executing test:", test_case.name) + try: + trainer = types.CustomTrainer( + func=simple_train_func, + num_nodes=test_case.config.get("num_nodes", 1), + ) + runtime = container_backend.get_runtime("torch-distributed") + + # Mock initializer containers to complete successfully + original_create = container_backend._adapter.create_and_start_container + + def mock_create_with_status(*args, **kwargs): + container_id = original_create(*args, **kwargs) + # If it's an initializer container, mark it as completed + if "initializer" in kwargs.get("name", ""): + container_backend._adapter.set_container_status(container_id, "exited", 0) + return container_id + + container_backend._adapter.create_and_start_container = mock_create_with_status + + job_name = container_backend.train( + runtime=runtime, trainer=trainer, initializer=test_case.config.get("initializer") + ) + + assert test_case.expected_status == SUCCESS + assert job_name is not None + + # Check that expected number of containers were created + assert ( + len(container_backend._adapter.containers_created) + == test_case.config["expected_containers"] + ) + + # Check that initializer containers have correct labels + initializer_containers = [ + c + for c in container_backend._adapter.containers_created + if "initializer" in c["labels"].get(f"{container_backend.label_prefix}/step", "") + ] + + if "expected_initializer_type" in test_case.config: + expected_type = test_case.config["expected_initializer_type"] + assert any(expected_type in c["name"] for c in initializer_containers) + + # Check that initializer containers have correct environment variables + for container in initializer_containers: + assert "STORAGE_URI" in container["environment"] + assert "OUTPUT_PATH" in container["environment"] + + # Verify OUTPUT_PATH is correct based on initializer type + if "dataset-initializer" in container["name"]: + assert container["environment"]["OUTPUT_PATH"] == constants.DATASET_PATH + elif "model-initializer" in container["name"]: + assert container["environment"]["OUTPUT_PATH"] == constants.MODEL_PATH + + # Verify the job can be retrieved and has correct steps + job = container_backend.get_job(job_name) + assert job.name == job_name + + # Check that initializer steps are included + step_names = [step.name for step in job.steps] + if test_case.config.get("initializer") and test_case.config["initializer"].dataset: + assert "dataset-initializer" in step_names + if test_case.config.get("initializer") and test_case.config["initializer"].model: + assert "model-initializer" in step_names + + # Check that num_nodes only counts training nodes, not initializers + assert job.num_nodes == test_case.config.get("num_nodes", 1) + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="get logs from dataset initializer", + expected_status=SUCCESS, + config={ + "step": "dataset-initializer", + "expected_log_count": 1, + }, + ), + TestCase( + name="get logs from model initializer", + expected_status=SUCCESS, + config={ + "step": "model-initializer", + "expected_log_count": 1, + }, + ), + TestCase( + name="get logs from training node excludes initializers", + expected_status=SUCCESS, + config={ + "step": constants.NODE + "-0", + "expected_log_count": 1, + "should_exclude_initializers": True, + }, + ), + ], +) +def test_get_logs_with_initializers(container_backend, test_case): + """Test getting logs from initializer and training containers.""" + print("Executing test:", test_case.name) + try: + trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1) + runtime = container_backend.get_runtime("torch-distributed") + + initializer = types.Initializer( + dataset=types.HuggingFaceDatasetInitializer(storage_uri="hf://user/dataset"), + model=types.HuggingFaceModelInitializer(storage_uri="hf://user/model"), + ) + + # Mock initializer containers to complete successfully + original_create = container_backend._adapter.create_and_start_container + + def mock_create_with_status(*args, **kwargs): + container_id = original_create(*args, **kwargs) + if "initializer" in kwargs.get("name", ""): + container_backend._adapter.set_container_status(container_id, "exited", 0) + return container_id + + container_backend._adapter.create_and_start_container = mock_create_with_status + + job_name = container_backend.train( + runtime=runtime, trainer=trainer, initializer=initializer + ) + + # Get logs for the specified step + logs = list(container_backend.get_job_logs(job_name, step=test_case.config["step"])) + + assert test_case.expected_status == SUCCESS + assert len(logs) >= test_case.config["expected_log_count"] + + # If step is node-0, ensure initializer logs are not included + if test_case.config.get("should_exclude_initializers"): + log_str = "".join(logs) + # The logs should be from training containers only + assert "Complete log from container-" in log_str + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="initializer fails with non-zero exit code", + expected_status=FAILED, + config={ + "initializer": types.Initializer( + dataset=types.HuggingFaceDatasetInitializer( + storage_uri="hf://user/invalid-dataset" + ) + ), + "initializer_exit_code": 1, + }, + expected_error=RuntimeError, + ), + TestCase( + name="initializer timeout", + expected_status=FAILED, + config={ + "initializer": types.Initializer( + model=types.S3ModelInitializer(storage_uri="s3://bucket/model") + ), + "initializer_timeout": True, + }, + expected_error=TimeoutError, + ), + ], +) +def test_initializer_failures(container_backend, test_case): + """Test handling of initializer failures.""" + print("Executing test:", test_case.name) + try: + trainer = types.CustomTrainer(func=simple_train_func, num_nodes=1) + runtime = container_backend.get_runtime("torch-distributed") + + # Mock initializer to fail + original_create = container_backend._adapter.create_and_start_container + + def mock_create_with_failure(*args, **kwargs): + container_id = original_create(*args, **kwargs) + if ( + "initializer" in kwargs.get("name", "") + and "initializer_exit_code" in test_case.config + ): + container_backend._adapter.set_container_status( + container_id, "exited", test_case.config["initializer_exit_code"] + ) + # For timeout test, keep status as running + return container_id + + container_backend._adapter.create_and_start_container = mock_create_with_failure + + # For timeout test, patch the timeout value + if test_case.config.get("initializer_timeout"): + with patch( + "kubeflow.trainer.backends.container.backend." + "ContainerBackend._run_single_initializer" + ) as mock_run: + mock_run.side_effect = TimeoutError("Initializer timeout") + container_backend.train( + runtime=runtime, + trainer=trainer, + initializer=test_case.config["initializer"], + ) + else: + container_backend.train( + runtime=runtime, trainer=trainer, initializer=test_case.config["initializer"] + ) + + except Exception as e: + assert type(e) is test_case.expected_error + # Verify cleanup happened (containers and network should be cleaned up) + # This is tested by checking that the error was raised + # before training containers were created + print("test execution complete") diff --git a/kubeflow/trainer/backends/container/types.py b/kubeflow/trainer/backends/container/types.py index f30025cb9..ee6e10820 100644 --- a/kubeflow/trainer/backends/container/types.py +++ b/kubeflow/trainer/backends/container/types.py @@ -65,3 +65,11 @@ class ContainerBackendConfig(BaseModel): default_factory=TrainingRuntimeSource, description="Configuration for training runtime sources", ) + initializer_image: str = Field( + default="kubeflow/training-operator:latest", + description="Container image for dataset and model initializers", + ) + initializer_timeout: int = Field( + default=600, + description="Timeout in seconds for initializer containers (default 10 minutes)", + ) diff --git a/kubeflow/trainer/backends/container/utils.py b/kubeflow/trainer/backends/container/utils.py index 0370eca28..42c05231b 100644 --- a/kubeflow/trainer/backends/container/utils.py +++ b/kubeflow/trainer/backends/container/utils.py @@ -211,3 +211,111 @@ def aggregate_container_statuses(adapter, containers: list[dict]) -> str: """ statuses = [get_container_status(adapter, c["id"]) for c in containers] return aggregate_status_from_containers(statuses) + + +def build_initializer_command(initializer: types.BaseInitializer, init_type: str) -> list[str]: + """ + Build the command for an initializer container. + + Args: + initializer: Dataset or model initializer configuration. + init_type: Type of initializer ("dataset" or "model"). + + Returns: + Command list for the initializer container. + + Raises: + ValueError: If the initializer type is not supported. + """ + # Use the training-operator initializer script + # The initializer script is expected to be available in the image + if isinstance(initializer, (types.S3DatasetInitializer, types.S3ModelInitializer)): + python_cmd = "python -m kubeflow.storage_initializer.s3 " + elif isinstance( + initializer, (types.HuggingFaceDatasetInitializer, types.HuggingFaceModelInitializer) + ): + python_cmd = "python -m kubeflow.storage_initializer.hugging_face " + elif isinstance(initializer, types.DataCacheInitializer): + python_cmd = "python -m kubeflow.storage_initializer.datacache " + else: + raise ValueError( + f"Unsupported initializer type: {type(initializer).__name__}. " + "Supported types: HuggingFaceDatasetInitializer, HuggingFaceModelInitializer, " + "S3DatasetInitializer, S3ModelInitializer, DataCacheInitializer" + ) + + return ["bash", "-c", python_cmd] + + +def build_initializer_env(initializer: types.BaseInitializer, init_type: str) -> dict[str, str]: + """ + Build environment variables for an initializer container. + + Args: + initializer: Dataset or model initializer configuration. + init_type: Type of initializer ("dataset" or "model"). + + Returns: + Dictionary of environment variables. + """ + env = { + "STORAGE_URI": initializer.storage_uri, + } + + # Set the output path based on initializer type + if init_type == "dataset": + env["OUTPUT_PATH"] = constants.DATASET_PATH + else: # model + env["OUTPUT_PATH"] = constants.MODEL_PATH + + # Add optional fields based on initializer type + if isinstance( + initializer, (types.HuggingFaceDatasetInitializer, types.HuggingFaceModelInitializer) + ): + if initializer.access_token: + env["ACCESS_TOKEN"] = initializer.access_token + if hasattr(initializer, "ignore_patterns") and initializer.ignore_patterns: + env["IGNORE_PATTERNS"] = ",".join(initializer.ignore_patterns) + + elif isinstance(initializer, (types.S3DatasetInitializer, types.S3ModelInitializer)): + if initializer.endpoint: + env["ENDPOINT"] = initializer.endpoint + if initializer.access_key_id: + env["ACCESS_KEY_ID"] = initializer.access_key_id + if initializer.secret_access_key: + env["SECRET_ACCESS_KEY"] = initializer.secret_access_key + if initializer.region: + env["REGION"] = initializer.region + if initializer.role_arn: + env["ROLE_ARN"] = initializer.role_arn + if hasattr(initializer, "ignore_patterns") and initializer.ignore_patterns: + env["IGNORE_PATTERNS"] = ",".join(initializer.ignore_patterns) + + elif isinstance(initializer, types.DataCacheInitializer): + env["CLUSTER_SIZE"] = str(initializer.num_data_nodes + 1) + env["METADATA_LOC"] = initializer.metadata_loc + if initializer.head_cpu: + env["HEAD_CPU"] = initializer.head_cpu + if initializer.head_mem: + env["HEAD_MEM"] = initializer.head_mem + if initializer.worker_cpu: + env["WORKER_CPU"] = initializer.worker_cpu + if initializer.worker_mem: + env["WORKER_MEM"] = initializer.worker_mem + if initializer.iam_role: + env["IAM_ROLE"] = initializer.iam_role + + return env + + +def get_initializer_image(config) -> str: + """ + Get the container image for initializers from backend config. + + Args: + config: ContainerBackendConfig with initializer_image setting. + + Returns: + Container image name for initializers. + """ + return config.initializer_image