Skip to content

Commit 3bece5a

Browse files
author
Mohamed Zeidan
committed
CLI exceptions
1 parent 73245b9 commit 3bece5a

File tree

2 files changed

+101
-104
lines changed

2 files changed

+101
-104
lines changed

src/sagemaker/hyperpod/cli/commands/inference.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
registry=JS_REG,
3434
)
3535
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_js_endpoint_cli")
36+
@handle_cli_exceptions()
3637
def js_create(namespace, version, js_endpoint):
3738
"""
3839
Create a jumpstart model endpoint.
3940
"""
40-
4141
js_endpoint.create(namespace=namespace)
4242

4343

@@ -55,11 +55,11 @@ def js_create(namespace, version, js_endpoint):
5555
registry=C_REG,
5656
)
5757
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_custom_endpoint_cli")
58+
@handle_cli_exceptions()
5859
def custom_create(namespace, version, custom_endpoint):
5960
"""
6061
Create a custom model endpoint.
6162
"""
62-
6363
custom_endpoint.create(namespace=namespace)
6464

6565

@@ -85,6 +85,7 @@ def custom_create(namespace, version, custom_endpoint):
8585
help="Optional. The content type of the request to invoke. Default set to 'application/json'",
8686
)
8787
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "invoke_custom_endpoint_cli")
88+
@handle_cli_exceptions()
8889
def custom_invoke(
8990
endpoint_name: str,
9091
body: str,
@@ -138,13 +139,13 @@ def custom_invoke(
138139
help="Optional. The namespace of the jumpstart model endpoint to list. Default set to 'default'",
139140
)
140141
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_js_endpoints_cli")
142+
@handle_cli_exceptions()
141143
def js_list(
142144
namespace: Optional[str],
143145
):
144146
"""
145147
List all Hyperpod Jumpstart model endpoints.
146148
"""
147-
148149
endpoints = HPJumpStartEndpoint.model_construct().list(namespace)
149150
data = [ep.model_dump() for ep in endpoints]
150151

@@ -181,13 +182,13 @@ def js_list(
181182
help="Optional. The namespace of the custom model endpoint to list. Default set to 'default'",
182183
)
183184
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_custom_endpoints_cli")
185+
@handle_cli_exceptions()
184186
def custom_list(
185187
namespace: Optional[str],
186188
):
187189
"""
188190
List all Hyperpod custom model endpoints.
189191
"""
190-
191192
endpoints = HPEndpoint.model_construct().list(namespace)
192193
data = [ep.model_dump() for ep in endpoints]
193194

@@ -729,6 +730,7 @@ def custom_get_logs(
729730
help="Required. The time frame to get logs for.",
730731
)
731732
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_js_operator_logs")
733+
@handle_cli_exceptions()
732734
def js_get_operator_logs(
733735
since_hours: float,
734736
):
@@ -748,6 +750,7 @@ def js_get_operator_logs(
748750
help="Required. The time frame get logs for.",
749751
)
750752
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_custom_operator_logs")
753+
@handle_cli_exceptions()
751754
def custom_get_operator_logs(
752755
since_hours: float,
753756
):

src/sagemaker/hyperpod/cli/commands/training.py

Lines changed: 94 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -19,45 +19,42 @@
1919
registry=SCHEMA_REGISTRY,
2020
)
2121
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_pytorchjob_cli")
22+
@handle_cli_exceptions()
2223
def pytorch_create(version, debug, config):
2324
"""Create a PyTorch job."""
24-
try:
25-
click.echo(f"Using version: {version}")
26-
job_name = config.get("name")
27-
namespace = config.get("namespace")
28-
spec = config.get("spec")
29-
metadata_labels = config.get("labels")
30-
annotations = config.get("annotations")
31-
32-
# Prepare metadata
33-
metadata_kwargs = {"name": job_name}
34-
if namespace:
35-
metadata_kwargs["namespace"] = namespace
36-
if metadata_labels:
37-
metadata_kwargs["labels"] = metadata_labels
38-
if annotations:
39-
metadata_kwargs["annotations"] = annotations
40-
41-
# Prepare job kwargs
42-
job_kwargs = {
43-
"metadata": Metadata(**metadata_kwargs),
44-
"replica_specs": spec.get("replica_specs"),
45-
}
46-
47-
# Add nproc_per_node if present
48-
if "nproc_per_node" in spec:
49-
job_kwargs["nproc_per_node"] = spec.get("nproc_per_node")
50-
51-
# Add run_policy if present
52-
if "run_policy" in spec:
53-
job_kwargs["run_policy"] = spec.get("run_policy")
54-
55-
# Create job
56-
job = HyperPodPytorchJob(**job_kwargs)
57-
job.create(debug=debug)
58-
59-
except Exception as e:
60-
raise click.UsageError(f"Failed to create job: {str(e)}")
25+
click.echo(f"Using version: {version}")
26+
job_name = config.get("name")
27+
namespace = config.get("namespace")
28+
spec = config.get("spec")
29+
metadata_labels = config.get("labels")
30+
annotations = config.get("annotations")
31+
32+
# Prepare metadata
33+
metadata_kwargs = {"name": job_name}
34+
if namespace:
35+
metadata_kwargs["namespace"] = namespace
36+
if metadata_labels:
37+
metadata_kwargs["labels"] = metadata_labels
38+
if annotations:
39+
metadata_kwargs["annotations"] = annotations
40+
41+
# Prepare job kwargs
42+
job_kwargs = {
43+
"metadata": Metadata(**metadata_kwargs),
44+
"replica_specs": spec.get("replica_specs"),
45+
}
46+
47+
# Add nproc_per_node if present
48+
if "nproc_per_node" in spec:
49+
job_kwargs["nproc_per_node"] = spec.get("nproc_per_node")
50+
51+
# Add run_policy if present
52+
if "run_policy" in spec:
53+
job_kwargs["run_policy"] = spec.get("run_policy")
54+
55+
# Create job
56+
job = HyperPodPytorchJob(**job_kwargs)
57+
job.create(debug=debug)
6158

6259

6360
@click.command("hyp-pytorch-job")
@@ -68,74 +65,71 @@ def pytorch_create(version, debug, config):
6865
help="Optional. The namespace to list jobs from. Defaults to 'default' namespace.",
6966
)
7067
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_pytorchjobs_cli")
68+
@handle_cli_exceptions()
7169
def list_jobs(namespace: str):
7270
"""List all HyperPod PyTorch jobs."""
73-
try:
74-
jobs = HyperPodPytorchJob.list(namespace=namespace)
75-
76-
if not jobs:
77-
click.echo("No jobs found.")
78-
return
79-
80-
# Define headers and widths
81-
headers = ["NAME", "NAMESPACE", "STATUS", "AGE"]
82-
widths = [30, 20, 15, 15]
83-
84-
# Print header
85-
header = "".join(f"{h:<{w}}" for h, w in zip(headers, widths))
86-
click.echo("\n" + header)
87-
click.echo("-" * sum(widths))
88-
89-
# Print each job
90-
for job in jobs:
91-
# Get status from conditions
92-
status = "Unknown"
93-
age = "N/A"
71+
jobs = HyperPodPytorchJob.list(namespace=namespace)
72+
73+
if not jobs:
74+
click.echo("No jobs found.")
75+
return
76+
77+
# Define headers and widths
78+
headers = ["NAME", "NAMESPACE", "STATUS", "AGE"]
79+
widths = [30, 20, 15, 15]
80+
81+
# Print header
82+
header = "".join(f"{h:<{w}}" for h, w in zip(headers, widths))
83+
click.echo("\n" + header)
84+
click.echo("-" * sum(widths))
85+
86+
# Print each job
87+
for job in jobs:
88+
# Get status from conditions
89+
status = "Unknown"
90+
age = "N/A"
91+
if job.status and job.status.conditions:
92+
for condition in reversed(job.status.conditions):
93+
if condition.status == "True":
94+
status = condition.type
95+
break
96+
97+
# Calculate age
9498
if job.status and job.status.conditions:
95-
for condition in reversed(job.status.conditions):
96-
if condition.status == "True":
97-
status = condition.type
98-
break
99-
100-
# Calculate age
101-
if job.status and job.status.conditions:
102-
# Find the 'Created' condition to get the start time
103-
created_condition = next(
104-
(c for c in job.status.conditions if c.type == "Created"), None
105-
)
106-
if created_condition and created_condition.lastTransitionTime:
107-
from datetime import datetime, timezone
99+
# Find the 'Created' condition to get the start time
100+
created_condition = next(
101+
(c for c in job.status.conditions if c.type == "Created"), None
102+
)
103+
if created_condition and created_condition.lastTransitionTime:
104+
from datetime import datetime, timezone
108105

109-
start_time = datetime.fromisoformat(
110-
created_condition.lastTransitionTime.replace("Z", "+00:00")
111-
)
112-
now = datetime.now(timezone.utc)
113-
delta = now - start_time
114-
if delta.days > 0:
115-
age = f"{delta.days}d"
106+
start_time = datetime.fromisoformat(
107+
created_condition.lastTransitionTime.replace("Z", "+00:00")
108+
)
109+
now = datetime.now(timezone.utc)
110+
delta = now - start_time
111+
if delta.days > 0:
112+
age = f"{delta.days}d"
113+
else:
114+
hours = delta.seconds // 3600
115+
if hours > 0:
116+
age = f"{hours}h"
116117
else:
117-
hours = delta.seconds // 3600
118-
if hours > 0:
119-
age = f"{hours}h"
120-
else:
121-
minutes = (delta.seconds % 3600) // 60
122-
age = f"{minutes}m"
123-
124-
# Format row
125-
row = "".join(
126-
[
127-
f"{job.metadata.name:<{widths[0]}}",
128-
f"{job.metadata.namespace:<{widths[1]}}",
129-
f"{status:<{widths[2]}}",
130-
f"{age:<{widths[3]}}",
131-
]
132-
)
133-
click.echo(row)
134-
135-
click.echo() # Add empty line at the end
136-
137-
except Exception as e:
138-
raise click.UsageError(f"Failed to list jobs: {str(e)}")
118+
minutes = (delta.seconds % 3600) // 60
119+
age = f"{minutes}m"
120+
121+
# Format row
122+
row = "".join(
123+
[
124+
f"{job.metadata.name:<{widths[0]}}",
125+
f"{job.metadata.namespace:<{widths[1]}}",
126+
f"{status:<{widths[2]}}",
127+
f"{age:<{widths[3]}}",
128+
]
129+
)
130+
click.echo(row)
131+
132+
click.echo() # Add empty line at the end
139133

140134

141135
@click.command("hyp-pytorch-job")

0 commit comments

Comments
 (0)