Skip to content

Commit 6ed3031

Browse files
mollyheamazonrsareddy0329
authored andcommitted
return SDK class in pytorch model.py for v1_0 and v1_1, update pytorch_create function, update unit test (#243)
1 parent 3ee6d51 commit 6ed3031

File tree

5 files changed

+169
-311
lines changed

5 files changed

+169
-311
lines changed

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py

Lines changed: 66 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
HostPath,
1313
PersistentVolumeClaim
1414
)
15+
from sagemaker.hyperpod.training.hyperpod_pytorch_job import HyperPodPytorchJob
1516

1617

1718
class VolumeConfig(BaseModel):
@@ -228,133 +229,82 @@ def validate_label_selector_keys(cls, v):
228229
return v
229230

230231
def to_domain(self) -> Dict:
231-
"""
232-
Convert flat config to domain model (HyperPodPytorchJobSpec)
233-
"""
232+
"""Convert flat config to domain model (HyperPodPytorchJobSpec)"""
234233

235-
# Create container with required fields
236-
container_kwargs = {
237-
"name": "pytorch-job-container",
238-
"image": self.image,
239-
"resources": Resources(
240-
requests={"nvidia.com/gpu": "0"},
241-
limits={"nvidia.com/gpu": "0"},
242-
),
243-
}
244-
245-
# Add optional container fields
246-
if self.command is not None:
247-
container_kwargs["command"] = self.command
248-
if self.args is not None:
249-
container_kwargs["args"] = self.args
250-
if self.pull_policy is not None:
251-
container_kwargs["image_pull_policy"] = self.pull_policy
252-
if self.environment is not None:
253-
container_kwargs["env"] = [
254-
{"name": k, "value": v} for k, v in self.environment.items()
255-
]
256-
257-
if self.volume is not None:
258-
volume_mounts = []
259-
for i, vol in enumerate(self.volume):
260-
volume_mount = {"name": vol.name, "mount_path": vol.mount_path}
261-
volume_mounts.append(volume_mount)
262-
263-
container_kwargs["volume_mounts"] = volume_mounts
264-
265-
266-
# Create container object
267-
try:
268-
container = Containers(**container_kwargs)
269-
except Exception as e:
270-
raise
271-
272-
# Create pod spec kwargs
273-
spec_kwargs = {"containers": list([container])}
234+
# Helper function to build dict with non-None values
235+
def build_dict(**kwargs):
236+
return {k: v for k, v in kwargs.items() if v is not None}
237+
238+
# Build container
239+
container_kwargs = build_dict(
240+
name="pytorch-job-container",
241+
image=self.image,
242+
resources=Resources(requests={"nvidia.com/gpu": "0"}, limits={"nvidia.com/gpu": "0"}),
243+
command=self.command,
244+
args=self.args,
245+
image_pull_policy=self.pull_policy,
246+
env=[{"name": k, "value": v} for k, v in self.environment.items()] if self.environment else None,
247+
volume_mounts=[{"name": vol.name, "mount_path": vol.mount_path} for vol in self.volume] if self.volume else None
248+
)
249+
250+
container = Containers(**container_kwargs)
274251

275-
# Add volumes to pod spec if present
276-
if self.volume is not None:
252+
# Build volumes
253+
volumes = None
254+
if self.volume:
277255
volumes = []
278-
for i, vol in enumerate(self.volume):
256+
for vol in self.volume:
279257
if vol.type == "hostPath":
280-
host_path = HostPath(path=vol.path)
281-
volume_obj = Volumes(name=vol.name, host_path=host_path)
258+
volume_obj = Volumes(name=vol.name, host_path=HostPath(path=vol.path))
282259
elif vol.type == "pvc":
283-
pvc_config = PersistentVolumeClaim(
284-
claim_name=vol.claim_name,
285-
read_only=vol.read_only == "true" if vol.read_only else False
286-
)
287-
volume_obj = Volumes(name=vol.name, persistent_volume_claim=pvc_config)
260+
volume_obj = Volumes(name=vol.name, persistent_volume_claim=PersistentVolumeClaim(
261+
claim_name=vol.claim_name,
262+
read_only=vol.read_only == "true" if vol.read_only else False
263+
))
288264
volumes.append(volume_obj)
289-
290-
spec_kwargs["volumes"] = volumes
291-
292-
# Add node selector if any selector fields are present
293-
node_selector = {}
294-
if self.instance_type is not None:
295-
map = {"node.kubernetes.io/instance-type": self.instance_type}
296-
node_selector.update(map)
297-
if self.label_selector is not None:
298-
node_selector.update(self.label_selector)
299-
if self.deep_health_check_passed_nodes_only:
300-
map = {"deep-health-check-passed": "true"}
301-
node_selector.update(map)
302-
if node_selector:
303-
spec_kwargs.update({"node_selector": node_selector})
304-
305-
# Add other optional pod spec fields
306-
if self.service_account_name is not None:
307-
map = {"service_account_name": self.service_account_name}
308-
spec_kwargs.update(map)
309-
310-
if self.scheduler_type is not None:
311-
map = {"scheduler_name": self.scheduler_type}
312-
spec_kwargs.update(map)
313-
314-
# Build metadata labels only if relevant fields are present
315-
metadata_kwargs = {"name": self.job_name}
316-
if self.namespace is not None:
317-
metadata_kwargs["namespace"] = self.namespace
318-
319-
metadata_labels = {}
320-
if self.queue_name is not None:
321-
metadata_labels["kueue.x-k8s.io/queue-name"] = self.queue_name
322-
if self.priority is not None:
323-
metadata_labels["kueue.x-k8s.io/priority-class"] = self.priority
324-
325-
if metadata_labels:
326-
metadata_kwargs["labels"] = metadata_labels
327265

328-
# Create replica spec with only non-None values
329-
replica_kwargs = {
330-
"name": "pod",
331-
"template": Template(
332-
metadata=Metadata(**metadata_kwargs), spec=Spec(**spec_kwargs)
333-
),
334-
}
266+
# Build node selector
267+
node_selector = build_dict(
268+
**{"node.kubernetes.io/instance-type": self.instance_type} if self.instance_type else {},
269+
**self.label_selector if self.label_selector else {},
270+
**{"deep-health-check-passed": "true"} if self.deep_health_check_passed_nodes_only else {}
271+
)
335272

336-
if self.node_count is not None:
337-
replica_kwargs["replicas"] = self.node_count
273+
# Build spec
274+
spec_kwargs = build_dict(
275+
containers=[container],
276+
volumes=volumes,
277+
node_selector=node_selector if node_selector else None,
278+
service_account_name=self.service_account_name,
279+
scheduler_name=self.scheduler_type
280+
)
338281

339-
replica_spec = ReplicaSpec(**replica_kwargs)
282+
# Build metadata
283+
metadata_labels = build_dict(
284+
**{"kueue.x-k8s.io/queue-name": self.queue_name} if self.queue_name else {},
285+
**{"kueue.x-k8s.io/priority-class": self.priority} if self.priority else {}
286+
)
340287

341-
replica_specs = list([replica_spec])
288+
metadata_kwargs = build_dict(
289+
name=self.job_name,
290+
namespace=self.namespace,
291+
labels=metadata_labels if metadata_labels else None
292+
)
342293

343-
job_kwargs = {"replica_specs": replica_specs}
344-
# Add optional fields only if they exist
345-
if self.tasks_per_node is not None:
346-
job_kwargs["nproc_per_node"] = str(self.tasks_per_node)
294+
# Build replica spec
295+
replica_kwargs = build_dict(
296+
name="pod",
297+
template=Template(metadata=Metadata(**metadata_kwargs), spec=Spec(**spec_kwargs)),
298+
replicas=self.node_count
299+
)
347300

348-
if self.max_retry is not None:
349-
job_kwargs["run_policy"] = RunPolicy(
350-
clean_pod_policy="None", job_max_retry_count=self.max_retry
351-
)
301+
# Build job
302+
job_kwargs = build_dict(
303+
metadata=metadata_kwargs,
304+
replica_specs=[ReplicaSpec(**replica_kwargs)],
305+
nproc_per_node=str(self.tasks_per_node) if self.tasks_per_node else None,
306+
run_policy=RunPolicy(clean_pod_policy="None", job_max_retry_count=self.max_retry) if self.max_retry else None
307+
)
352308

353-
# Create base return dictionary
354-
result = {
355-
"name": self.job_name,
356-
"namespace": self.namespace,
357-
"labels": metadata_labels,
358-
"spec": job_kwargs,
359-
}
309+
result = HyperPodPytorchJob(**job_kwargs)
360310
return result

0 commit comments

Comments
 (0)