|
48 | 48 | PersistentVolumeClaim, |
49 | 49 | SchedulerType, |
50 | 50 | Volume, |
| 51 | + USER_NAME_LABEL_KEY, |
51 | 52 | ) |
52 | 53 | from hyperpod_cli.clients.kubernetes_client import ( |
53 | 54 | KubernetesClient, |
@@ -579,6 +580,9 @@ def start_job( |
579 | 580 | logger.error("Cannot start Training job due to AWS credentials issue") |
580 | 581 | sys.exit(1) |
581 | 582 |
|
| 583 | + # get caller's user name |
| 584 | + user_name = get_user_name() |
| 585 | + |
582 | 586 | if namespace is None and config_file is None: |
583 | 587 | namespace = _get_auto_fill_namespace_for_create_job() |
584 | 588 |
|
@@ -714,6 +718,10 @@ def start_job( |
714 | 718 | config["cluster"]["cluster_config"].pop("annotations") |
715 | 719 |
|
716 | 720 | custom_labels = {} |
| 721 | + |
| 722 | + # attach user label |
| 723 | + custom_labels[USER_NAME_LABEL_KEY] = user_name |
| 724 | + |
717 | 725 | _override_or_remove( |
718 | 726 | config["cluster"]["cluster_config"], "pullPolicy", pull_policy |
719 | 727 | ) |
@@ -787,8 +795,11 @@ def start_job( |
787 | 795 | cluster_config = config.get("cluster").get("cluster_config") |
788 | 796 | namespace = cluster_config.get("namespace", None) |
789 | 797 | scheduler_type = cluster_config.get("scheduler_type", SchedulerType.get_default().value) |
| 798 | + |
790 | 799 | custom_labels = cluster_config.get("custom_labels", {}) |
791 | 800 | custom_labels = {} if custom_labels is None else custom_labels |
| 801 | + custom_labels[USER_NAME_LABEL_KEY] = user_name |
| 802 | + |
792 | 803 | queue_name = custom_labels.get(KUEUE_QUEUE_NAME_LABEL_KEY, None) |
793 | 804 | # Autofill namespace |
794 | 805 | if namespace is None: |
@@ -1121,4 +1132,14 @@ def start_training_job(recipe, override_parameters, job_name, config_file, launc |
1121 | 1132 | if os.path.exists(file_to_delete): |
1122 | 1133 | os.remove(file_to_delete) |
1123 | 1134 |
|
1124 | | - |
| 1135 | +def get_user_name(): |
| 1136 | + caller_arn = boto3.client("sts").get_caller_identity().get('Arn') |
| 1137 | + if 'user/' in caller_arn: |
| 1138 | + user_name = 'User-' + caller_arn.split('user/')[-1] |
| 1139 | + elif 'assumed-role' in caller_arn: |
| 1140 | + user_name = 'AssumedRole-' + caller_arn.split('assumed-role/')[-1] |
| 1141 | + else: |
| 1142 | + user_name = 'Unknown' |
| 1143 | + |
| 1144 | + # label value does not allow slash |
| 1145 | + return user_name.replace('/', '-') |
0 commit comments