Skip to content

Commit 508cb18

Browse files
committed
Added functionality to create consumption gpu app
1 parent b3f7d8d commit 508cb18

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

src/containerapp/azext_containerapp/_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def load_arguments(self, _):
3333
validator=validate_build_env_vars, is_preview=True)
3434
c.argument('max_inactive_revisions', type=int, help="Max inactive revisions a Container App can have.", is_preview=True)
3535
c.argument('registry_identity', help="The managed identity with which to authenticate to the Azure Container Registry (instead of username/password). Use 'system' for a system-defined identity, Use 'system-environment' for an environment level system-defined identity or a resource id for a user-defined environment/containerapp level identity. The managed identity should have been assigned acrpull permissions on the ACR before deployment (use 'az role assignment create --role acrpull ...').")
36+
c.argument('consumption_gpu_profile', help="Enable and select a consumption GPU workload profile type for your app. If there is no existing consumption GPU workload profile, it will attempt to create one for you.", is_preview=True)
3637

3738
# Springboard
3839
with self.argument_context('containerapp create', arg_group='Service Binding', is_preview=True) as c:

src/containerapp/azext_containerapp/containerapp_decorator.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,17 @@ def set_argument_registry_server(self, registry_server):
664664
def set_argument_no_wait(self, no_wait):
665665
self.set_param("no_wait", no_wait)
666666

667+
def get_argument_enable_consumption_gpu(self):
668+
return self.get_param("consumption_gpu_profile")
669+
670+
def validate_consumption_gpu_profile(self):
671+
if self.get_argument_enable_consumption_gpu() is not None:
672+
if self.get_argument_enable_consumption_gpu().lower() not in ["consumption-gpu-nc8as-t4", "consumption-gpu-nc4as-t4", "consumption-gpu-nc24-a100", "consumption-gpu-nc12-a100", "consumption-gpu-nv6ads-a10"]:
673+
raise InvalidArgumentValueError('Containerapp consumption GPU workload profile must be one of the following: Consumption-GPU-NC8as-T4, Consumption-GPU-NC4as-T4, Consumption-GPU-NC24-A100, Consumption-GPU-NC12-A100, Consumption-GPU-NV6ads-A10.')
674+
else:
675+
return self.get_argument_enable_consumption_gpu().lower()
676+
return None
677+
667678
# not craete role assignment if it's env system msi
668679
def check_create_acrpull_role_assignment(self):
669680
identity = self.get_argument_registry_identity()
@@ -699,6 +710,47 @@ def set_up_system_assigned_identity_as_default_if_using_acr(self):
699710
return
700711
self.set_argument_registry_identity('system')
701712

713+
def set_up_consumption_gpu_wp_payload(self, consumption_gpu_profile_type):
714+
consumption_gpu_profile_type_lower = consumption_gpu_profile_type.lower()
715+
if consumption_gpu_profile_type_lower == "consumption-gpu-nc8as-t4":
716+
payload = {
717+
"workloadProfileType": "Consumption-GPU-NC8as-T4",
718+
"name": "consumption-8core-t4"
719+
}
720+
elif consumption_gpu_profile_type_lower == "consumption-gpu-nc4as-t4":
721+
payload = {
722+
"workloadProfileType": "Consumption-GPU-NC4as-T4",
723+
"name": "consumption-4core-t4"
724+
}
725+
elif consumption_gpu_profile_type_lower == "consumption-gpu-nc24-a100":
726+
payload = {
727+
"workloadProfileType": "Consumption-GPU-NC24-A100",
728+
"name": "consumption-24core-a100"
729+
}
730+
elif consumption_gpu_profile_type_lower == "consumption-gpu-nc12-a100":
731+
payload = {
732+
"workloadProfileType": "Consumption-GPU-NC12-A100",
733+
"name": "consumption-12core-a100"
734+
}
735+
else:
736+
raise ValidationError(f"Invalid consumption GPU profile type: {consumption_gpu_profile_type}.")
737+
return payload
738+
739+
def update_consumption_gpu_wp(self, managed_env_info, consumption_gpu_profile_type):
740+
existing_wp = safe_get(managed_env_info, "properties", "workloadProfiles")
741+
env_name = safe_get(managed_env_info, "name")
742+
if existing_wp is None:
743+
raise ValidationError(f"Existing environment {env_name} cannot enable workload profiles. If you want to use Consumption GPU, please create a new one.")
744+
consumption_gpu_profile = self.set_up_consumption_gpu_wp_payload(consumption_gpu_profile_type)
745+
existing_wp.append(consumption_gpu_profile)
746+
payload = {
747+
"properties": {
748+
"workloadProfiles": existing_wp
749+
}
750+
}
751+
consumption_wp_name = consumption_gpu_profile["name"]
752+
return payload, consumption_wp_name
753+
702754
def parent_construct_payload(self):
703755
# preview logic
704756
self.check_create_acrpull_role_assignment()
@@ -727,6 +779,26 @@ def parent_construct_payload(self):
727779
if not managed_env_info:
728780
raise ValidationError("The environment '{}' does not exist. Specify a valid environment".format(self.get_argument_managed_env()))
729781

782+
consumption_gpu_wp = self.validate_consumption_gpu_profile()
783+
if consumption_gpu_wp is not None:
784+
logger.warning("Enabled preview feature: Consumption GPU workload profile.")
785+
if self.get_argument_workload_profile_name() is not None:
786+
raise ValidationError("Both --consumption-gpu-profile and --workload-profile-name (-w) are specified. Only one can be selected.")
787+
existing_wp = safe_get(managed_env_info, "properties", "workloadProfiles", default=None)
788+
consumption_gpu_wp_name = None
789+
if existing_wp is not None:
790+
for wp in existing_wp:
791+
if wp["workloadProfileType"].lower() == consumption_gpu_wp:
792+
consumption_gpu_wp_name = wp["name"]
793+
break
794+
if consumption_gpu_wp_name is None:
795+
env_client = self.get_environment_client
796+
wp_payload, consumption_gpu_wp_name = self.update_consumption_gpu_wp(managed_env_info, consumption_gpu_wp)
797+
env_client().update(cmd=self.cmd, resource_group_name=managed_env_rg, name=managed_env_name, managed_environment_envelope=wp_payload)
798+
managed_env_info = self.get_environment_client().show(cmd=self.cmd, resource_group_name=managed_env_rg, name=managed_env_name)
799+
800+
self.set_argument_workload_profile_name(consumption_gpu_wp_name)
801+
730802
while not self.get_argument_no_wait() and safe_get(managed_env_info, "properties", "provisioningState", default="").lower() in ["inprogress", "updating"]:
731803
logger.info("Waiting for environment provisioning to finish before creating container app")
732804
time.sleep(5)
@@ -735,7 +807,7 @@ def parent_construct_payload(self):
735807
location = managed_env_info["location"]
736808
_ensure_location_allowed(self.cmd, location, CONTAINER_APPS_RP, "containerApps")
737809

738-
if not self.get_argument_workload_profile_name() and "workloadProfiles" in managed_env_info:
810+
if not self.get_argument_workload_profile_name() and "workloadProfiles" in managed_env_info and consumption_gpu_wp_name is None:
739811
workload_profile_name = get_default_workload_profile_name_from_env(self.cmd, managed_env_info, managed_env_rg)
740812
self.set_argument_workload_profile_name(workload_profile_name)
741813

src/containerapp/azext_containerapp/custom.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,8 @@ def create_containerapp(cmd,
490490
max_inactive_revisions=None,
491491
runtime=None,
492492
enable_java_metrics=None,
493-
enable_java_agent=None):
493+
enable_java_agent=None,
494+
consumption_gpu_profile=None):
494495
raw_parameters = locals()
495496

496497
containerapp_create_decorator = ContainerAppPreviewCreateDecorator(

0 commit comments

Comments
 (0)