Skip to content

Commit f8beddb

Browse files
feat: add pre-scripts and post-scripts args for start-job command
1 parent 13889de commit f8beddb

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ hyperpod connect-cluster --cluster-name <cluster-name> [--region <region>] [--na
134134
This command submits a new training job to the connected SageMaker HyperPod cluster.
135135
136136
```
137-
hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <kubeflow/PyTorchJob>] [--image <image>] [--command <command>] [--entry-script <script>] [--script-args <arg1 arg2>] [--environment <key=value>] [--pull-policy <Always|IfNotPresent|Never>] [--instance-type <instance-type>] [--node-count <count>] [--tasks-per-node <count>] [--label-selector <key=value>] [--deep-health-check-passed-nodes-only] [--scheduler-type <Kueue SageMaker None>] [--queue-name <queue-name>] [--priority <priority>] [--auto-resume] [--max-retry <count>] [--restart-policy <Always|OnFailure|Never|ExitCode>] [--volumes <volume1,volume2>] [--persistent-volume-claims <claim1:/mount/path,claim2:/mount/path>] [--results-dir <dir>] [--service-account-name <account>]
137+
hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <kubeflow/PyTorchJob>] [--image <image>] [--command <command>] [--entry-script <script>] [--script-args <arg1 arg2>] [--environment <key=value>] [--pull-policy <Always|IfNotPresent|Never>] [--instance-type <instance-type>] [--node-count <count>] [--tasks-per-node <count>] [--label-selector <key=value>] [--deep-health-check-passed-nodes-only] [--scheduler-type <Kueue SageMaker None>] [--queue-name <queue-name>] [--priority <priority>] [--auto-resume] [--max-retry <count>] [--restart-policy <Always|OnFailure|Never|ExitCode>] [--volumes <volume1,volume2>] [--persistent-volume-claims <claim1:/mount/path,claim2:/mount/path>] [--results-dir <dir>] [--service-account-name <account>] [--pre-script <cmd1 cmd2>] [--post-script <cmd1 cmd2>]
138138
```
139139
140140
* `job-name` (string) - Required. The name of the job.
@@ -148,6 +148,8 @@ hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <
148148
* `environment` (dict[string, string]) - Optional. The environment variables (key-value pairs) to set in the containers.
149149
* `node-count` (int) - Required. The number of nodes (instances) to launch the jobs on.
150150
* `instance-type` (string) - Required. The instance type to launch the job on. Note that the instance types you can use are the available instances within your SageMaker quotas for instances prefixed with `ml`.
151+
* `pre-script` (string) - Optional. Commands to run before the job starts. Multiple commands should be separated by semicolons.
152+
* `post-script` (string) - Optional. Commands to run after the job completes. Multiple commands should be separated by semicolons.
151153
* `tasks-per-node` (int) - Optional. The number of devices to use per instance.
152154
* `label-selector` (dict[string, list[string]]) - Optional. A dictionary of labels and their values that will override the predefined node selection rules based on the SageMaker HyperPod `node-health-status` label and values. If users provide this field, the CLI will launch the job with this customized label selection.
153155
* `deep-health-check-passed-nodes-only` (bool) - Optional. If set to `true`, the job will be launched only on nodes that have the `deep-health-check-status` label with the value `passed`.

src/hyperpod_cli/commands/job.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,18 @@ def cancel_job(
430430
help="Optional. Add a temp directory for containers to store data in the hosts."
431431
" <volume_name>:</host/mount/path>:</container/mount/path>,<volume_name>:</host/mount/path1>:</container/mount/path1>",
432432
)
433+
@click.option(
434+
"--pre-script",
435+
type=click.STRING,
436+
required=False,
437+
help="Optional. Commands to run before the job starts. Multiple commands should be separated by semicolons.",
438+
)
439+
@click.option(
440+
"--post-script",
441+
type=click.STRING,
442+
required=False,
443+
help="Optional. Commands to run after the job completes. Multiple commands should be separated by semicolons.",
444+
)
433445
@click.option(
434446
"--recipe",
435447
type=click.STRING,
@@ -548,6 +560,8 @@ def start_job(
548560
service_account_name: Optional[str],
549561
persistent_volume_claims: Optional[str],
550562
volumes: Optional[str],
563+
pre_script: Optional[str],
564+
post_script: Optional[str],
551565
recipe: Optional[str],
552566
override_parameters: Optional[str],
553567
debug: bool,
@@ -793,6 +807,8 @@ def start_job(
793807
label_selector=label_selector,
794808
max_retry=max_retry,
795809
deep_health_check_passed_nodes_only=deep_health_check_passed_nodes_only,
810+
pre_script=pre_script,
811+
post_script=post_script,
796812
)
797813
# TODO: Unblock this after fixing customer using EKS cluster.
798814
console_link = utils.get_cluster_console_url()
@@ -973,7 +989,8 @@ def execute_command(cmd, env=None):
973989
def start_training_job(recipe, override_parameters, job_name, config_file, launcher_config_path=None, launcher_config_file_name=None,
974990
pull_policy=None, restart_policy=None, namespace=None,
975991
service_account_name=None, priority_class_name=None, volumes=None, persistent_volume_claims=None,
976-
auto_resume=None, label_selector=None, max_retry=None, deep_health_check_passed_nodes_only=None):
992+
auto_resume=None, label_selector=None, max_retry=None, deep_health_check_passed_nodes_only=None,
993+
pre_script=None, post_script=None):
977994

978995
logger.info(f"recipe: {recipe}, override_parameters: {override_parameters}, job_name: {job_name}, config_file: {config_file}, launcher_config_path: {launcher_config_path}, launcher_config_file_name: {launcher_config_file_name}")
979996
env = os.environ.copy()
@@ -1035,6 +1052,12 @@ def start_training_job(recipe, override_parameters, job_name, config_file, launc
10351052
cmd.append(f'+cluster.persistent_volume_claims.{idx}.claimName="{claim_name}"')
10361053
cmd.append(f'+cluster.persistent_volume_claims.{idx}.mountPath="{mount_path}"')
10371054

1055+
if pre_script:
1056+
cmd.append(f'+cluster.pre_script="{pre_script}"')
1057+
1058+
if post_script:
1059+
cmd.append(f'+cluster.post_script="{post_script}"')
1060+
10381061
if label_selector:
10391062
cmd.append(f'+cluster.label_selector={label_selector}')
10401063
elif deep_health_check_passed_nodes_only:

0 commit comments

Comments
 (0)