Skip to content

Commit 734b473

Browse files
committed
Add a new custom label to store user name
**Description** Introduced new label "sagemaker.user/created-by" with value user/role name from sts.get_caller_identity() response **Testing Done** Tested in beta account and can see the label added
1 parent e8b5b27 commit 734b473

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/hyperpod_cli/commands/job.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
PersistentVolumeClaim,
4949
SchedulerType,
5050
Volume,
51+
USER_NAME_LABEL_KEY,
5152
)
5253
from hyperpod_cli.clients.kubernetes_client import (
5354
KubernetesClient,
@@ -579,6 +580,9 @@ def start_job(
579580
logger.error("Cannot start Training job due to AWS credentials issue")
580581
sys.exit(1)
581582

583+
# get caller's user name
584+
user_name = get_user_name()
585+
582586
if namespace is None and config_file is None:
583587
namespace = _get_auto_fill_namespace_for_create_job()
584588

@@ -714,6 +718,10 @@ def start_job(
714718
config["cluster"]["cluster_config"].pop("annotations")
715719

716720
custom_labels = {}
721+
722+
# attach user label
723+
custom_labels[USER_NAME_LABEL_KEY] = user_name
724+
717725
_override_or_remove(
718726
config["cluster"]["cluster_config"], "pullPolicy", pull_policy
719727
)
@@ -787,8 +795,11 @@ def start_job(
787795
cluster_config = config.get("cluster").get("cluster_config")
788796
namespace = cluster_config.get("namespace", None)
789797
scheduler_type = cluster_config.get("scheduler_type", SchedulerType.get_default().value)
798+
790799
custom_labels = cluster_config.get("custom_labels", {})
791800
custom_labels = {} if custom_labels is None else custom_labels
801+
custom_labels[USER_NAME_LABEL_KEY] = user_name
802+
792803
queue_name = custom_labels.get(KUEUE_QUEUE_NAME_LABEL_KEY, None)
793804
# Autofill namespace
794805
if namespace is None:
@@ -1121,4 +1132,14 @@ def start_training_job(recipe, override_parameters, job_name, config_file, launc
11211132
if os.path.exists(file_to_delete):
11221133
os.remove(file_to_delete)
11231134

1124-
1135+
def get_user_name():
1136+
caller_arn = boto3.client("sts").get_caller_identity().get('Arn')
1137+
if 'user/' in caller_arn:
1138+
user_name = 'User-' + caller_arn.split('user/')[-1]
1139+
elif 'assumed-role' in caller_arn:
1140+
user_name = 'AssumedRole-' + caller_arn.split('assumed-role/')[-1]
1141+
else:
1142+
user_name = 'Unknown'
1143+
1144+
# label value does not allow slash
1145+
return user_name.replace('/', '-')

src/hyperpod_cli/constants/command_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
NVIDIA_GPU_RESOURCE_LIMIT_KEY = "nvidia.com/gpu"
4747
AVAILABLE_ACCELERATOR_DEVICES_KEY = "AvailableAcceleratorDevices"
4848
TOTAL_ACCELERATOR_DEVICES_KEY = "TotalAcceleratorDevices"
49+
USER_NAME_LABEL_KEY = "sagemaker.user/created-by"
4950

5051
class PullPolicy(Enum):
5152
ALWAYS = "Always"

0 commit comments

Comments
 (0)