Skip to content

Commit 0dbb6b6

Browse files
committed
feat(trainer): address reviewer feedback for initializer support
- Make initializer image configurable via ContainerBackendConfig - Make initializer timeout configurable (default 600 seconds) - Implement wait API in adapters instead of polling - Clean up successful initializer containers after completion - Clean up network on initializer failure - Raise ValueError for unsupported initializer types (no datacache fallback) All tests passing (173/173). Addresses all feedback from PR #188. Signed-off-by: HKanoje <[email protected]>
1 parent 2387f05 commit 0dbb6b6

File tree

7 files changed

+171
-39
lines changed

7 files changed

+171
-39
lines changed

kubeflow/trainer/backends/container/adapters/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,22 @@ def get_network(self, network_id: str) -> Optional[dict]:
193193
Dictionary with network info including labels, or None if not found
194194
"""
195195
raise NotImplementedError()
196+
197+
@abc.abstractmethod
198+
def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int:
199+
"""
200+
Wait for a container to exit and return its exit code.
201+
202+
This is a blocking call that waits until the container stops.
203+
204+
Args:
205+
container_id: Container ID
206+
timeout: Maximum time to wait in seconds, or None to wait indefinitely
207+
208+
Returns:
209+
Container exit code
210+
211+
Raises:
212+
TimeoutError: If timeout is reached before container exits
213+
"""
214+
raise NotImplementedError()

kubeflow/trainer/backends/container/adapters/docker.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,31 @@ def get_network(self, network_id: str) -> Optional[dict]:
227227
}
228228
except Exception:
229229
return None
230+
231+
def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int:
232+
"""
233+
Wait for a Docker container to exit and return its exit code.
234+
235+
Args:
236+
container_id: Container ID
237+
timeout: Maximum time to wait in seconds, or None to wait indefinitely
238+
239+
Returns:
240+
Container exit code
241+
242+
Raises:
243+
TimeoutError: If timeout is reached before container exits
244+
"""
245+
try:
246+
container = self.get_container(container_id)
247+
result = container.wait(timeout=timeout)
248+
# Docker wait() returns a dict with 'StatusCode' key
249+
if isinstance(result, dict):
250+
return result.get("StatusCode", 0)
251+
return int(result)
252+
except Exception as e:
253+
if "timeout" in str(e).lower():
254+
raise TimeoutError(
255+
f"Container {container_id} did not exit within {timeout} seconds"
256+
) from e
257+
raise

kubeflow/trainer/backends/container/adapters/podman.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,29 @@ def get_network(self, network_id: str) -> Optional[dict]:
254254
}
255255
except Exception:
256256
return None
257+
258+
def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int:
259+
"""
260+
Wait for a Podman container to exit and return its exit code.
261+
262+
Args:
263+
container_id: Container ID
264+
timeout: Maximum time to wait in seconds, or None to wait indefinitely
265+
266+
Returns:
267+
Container exit code
268+
269+
Raises:
270+
TimeoutError: If timeout is reached before container exits
271+
"""
272+
try:
273+
container = self.get_container(container_id)
274+
result = container.wait(timeout=timeout)
275+
# Podman wait() returns exit code directly
276+
return int(result)
277+
except Exception as e:
278+
if "timeout" in str(e).lower():
279+
raise TimeoutError(
280+
f"Container {container_id} did not exit within {timeout} seconds"
281+
) from e
282+
raise

kubeflow/trainer/backends/container/backend.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,17 @@ def train(
274274
# Run initializers if configured
275275
if initializer:
276276
logger.debug("Running initializers")
277-
self._run_initializers(trainjob_name, initializer, workdir, network_id)
278-
logger.debug("Initializers completed successfully")
277+
try:
278+
self._run_initializers(trainjob_name, initializer, workdir, network_id)
279+
logger.debug("Initializers completed successfully")
280+
except Exception as e:
281+
# Clean up network if initializers fail
282+
logger.error(f"Initializer failed, cleaning up network: {e}")
283+
from contextlib import suppress
284+
285+
with suppress(Exception):
286+
self._adapter.delete_network(network_id)
287+
raise
279288

280289
# Generate training script code (inline, not written to disk)
281290
training_script_code = container_utils.get_training_script_code(trainer)
@@ -493,7 +502,7 @@ def _run_initializers(
493502
RuntimeError: If initializer fails to complete successfully.
494503
"""
495504
# Get initializer image
496-
init_image = container_utils.get_initializer_image()
505+
init_image = container_utils.get_initializer_image(self.cfg)
497506

498507
# Pull initializer image if needed
499508
container_utils.maybe_pull_image(self._adapter, init_image, self.cfg.pull_policy)
@@ -586,32 +595,40 @@ def _run_single_initializer(
586595

587596
# Wait for the initializer to complete
588597
try:
589-
import time
598+
# Use the wait API for efficient waiting
599+
exit_code = self._adapter.wait_for_container(
600+
container_id, timeout=self.cfg.initializer_timeout
601+
)
590602

591-
timeout = 600 # 10 minutes timeout for initialization
592-
polling_interval = 2
593-
elapsed = 0
603+
if exit_code == 0:
604+
logger.debug(f"{init_type} initializer completed successfully")
605+
# Clean up the successful container
606+
from contextlib import suppress
607+
608+
with suppress(Exception):
609+
self._adapter.remove_container(container_id, force=True)
610+
return
611+
else:
612+
# Get logs for debugging
613+
logs = list(self._adapter.container_logs(container_id, follow=False))
614+
error_msg = (
615+
f"{init_type} initializer failed with exit code {exit_code}. "
616+
f"Logs: {' '.join(logs[-10:]) if logs else 'No logs available'}"
617+
)
618+
raise RuntimeError(error_msg)
594619

595-
while elapsed < timeout:
596-
status, exit_code = self._adapter.container_status(container_id)
620+
except TimeoutError:
621+
logger.error(
622+
f"{init_type} initializer did not complete within "
623+
f"{self.cfg.initializer_timeout} seconds"
624+
)
625+
# Clean up the timed-out container
626+
from contextlib import suppress
597627

598-
if status == "exited":
599-
if exit_code == 0:
600-
logger.debug(f"{init_type} initializer completed successfully")
601-
return
602-
else:
603-
# Get logs for debugging
604-
logs = list(self._adapter.container_logs(container_id, follow=False))
605-
error_msg = (
606-
f"{init_type} initializer failed with exit code {exit_code}. "
607-
f"Logs: {' '.join(logs[-10:]) if logs else 'No logs available'}"
608-
)
609-
raise RuntimeError(error_msg)
610-
611-
time.sleep(polling_interval)
612-
elapsed += polling_interval
613-
614-
raise TimeoutError(f"{init_type} initializer did not complete within {timeout} seconds")
628+
with suppress(Exception):
629+
self._adapter.stop_container(container_id, timeout=5)
630+
self._adapter.remove_container(container_id, force=True)
631+
raise
615632

616633
except Exception as e:
617634
logger.error(f"Error running {init_type} initializer: {e}")

kubeflow/trainer/backends/container/backend_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,31 @@ def get_network(self, network_id: str) -> Optional[dict]:
197197
}
198198
return None
199199

200+
def wait_for_container(self, container_id: str, timeout: Optional[int] = None) -> int:
201+
"""
202+
Wait for a container to exit and return its exit code.
203+
204+
For testing, immediately returns the container's exit code if it has exited,
205+
or raises TimeoutError if the container is still running.
206+
207+
Args:
208+
container_id: Container ID
209+
timeout: Maximum time to wait in seconds (not used in mock)
210+
211+
Returns:
212+
Container exit code
213+
214+
Raises:
215+
TimeoutError: If container is still running
216+
"""
217+
for container in self.containers_created:
218+
if container["id"] == container_id:
219+
if container["status"] == "exited":
220+
return container.get("exit_code", 0)
221+
# In mock, if not exited, simulate timeout
222+
raise TimeoutError(f"Container {container_id} did not exit within timeout")
223+
raise RuntimeError(f"Container {container_id} not found")
224+
200225

201226
# Fixtures
202227
@pytest.fixture

kubeflow/trainer/backends/container/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,11 @@ class ContainerBackendConfig(BaseModel):
6565
default_factory=TrainingRuntimeSource,
6666
description="Configuration for training runtime sources",
6767
)
68+
initializer_image: str = Field(
69+
default="kubeflow/training-operator:latest",
70+
description="Container image for dataset and model initializers",
71+
)
72+
initializer_timeout: int = Field(
73+
default=600,
74+
description="Timeout in seconds for initializer containers (default 10 minutes)",
75+
)

kubeflow/trainer/backends/container/utils.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -223,18 +223,26 @@ def build_initializer_command(initializer: types.BaseInitializer, init_type: str
223223
224224
Returns:
225225
Command list for the initializer container.
226+
227+
Raises:
228+
ValueError: If the initializer type is not supported.
226229
"""
227230
# Use the training-operator initializer script
228231
# The initializer script is expected to be available in the image
229-
python_cmd = (
230-
"python -m kubeflow.storage_initializer.s3 "
231-
if isinstance(initializer, (types.S3DatasetInitializer, types.S3ModelInitializer))
232-
else "python -m kubeflow.storage_initializer.hugging_face "
233-
if isinstance(
234-
initializer, (types.HuggingFaceDatasetInitializer, types.HuggingFaceModelInitializer)
232+
if isinstance(initializer, (types.S3DatasetInitializer, types.S3ModelInitializer)):
233+
python_cmd = "python -m kubeflow.storage_initializer.s3 "
234+
elif isinstance(
235+
initializer, (types.HuggingFaceDatasetInitializer, types.HuggingFaceModelInitializer)
236+
):
237+
python_cmd = "python -m kubeflow.storage_initializer.hugging_face "
238+
elif isinstance(initializer, types.DataCacheInitializer):
239+
python_cmd = "python -m kubeflow.storage_initializer.datacache "
240+
else:
241+
raise ValueError(
242+
f"Unsupported initializer type: {type(initializer).__name__}. "
243+
"Supported types: HuggingFaceDatasetInitializer, HuggingFaceModelInitializer, "
244+
"S3DatasetInitializer, S3ModelInitializer, DataCacheInitializer"
235245
)
236-
else "python -m kubeflow.storage_initializer.datacache "
237-
)
238246

239247
return ["bash", "-c", python_cmd]
240248

@@ -300,13 +308,14 @@ def build_initializer_env(initializer: types.BaseInitializer, init_type: str) ->
300308
return env
301309

302310

303-
def get_initializer_image() -> str:
311+
def get_initializer_image(config) -> str:
304312
"""
305-
Get the container image for initializers.
313+
Get the container image for initializers from backend config.
314+
315+
Args:
316+
config: ContainerBackendConfig with initializer_image setting.
306317
307318
Returns:
308319
Container image name for initializers.
309320
"""
310-
# Use the training-operator image which contains initializer scripts
311-
# This can be made configurable via backend config in the future
312-
return "kubeflow/training-operator:latest"
321+
return config.initializer_image

0 commit comments

Comments
 (0)