Skip to content

Commit aac39bf

Browse files
author
Lily Pan
committed
add gpu profile driver install
1 parent 42cfd44 commit aac39bf

File tree

6 files changed

+106
-0
lines changed

6 files changed

+106
-0
lines changed

src/azure-cli/azure/cli/command_modules/acs/_consts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
CONST_GPU_INSTANCE_PROFILE_MIG4_G = "MIG4g"
5353
CONST_GPU_INSTANCE_PROFILE_MIG7_G = "MIG7g"
5454

55+
# gpu driver install
56+
CONST_GPU_DRIVER_INSTALL = "install"
57+
CONST_GPU_DRIVER_NONE = "none"
58+
5559
# consts for ManagedCluster
5660
# load balancer sku
5761
CONST_LOAD_BALANCER_SKU_BASIC = "basic"

src/azure-cli/azure/cli/command_modules/acs/_help.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,9 @@
16471647
- name: --if-none-match
16481648
type: string
16491649
short-summary: Set to '*' to allow a new agentpool to be created, but to prevent updating an existing agentpool. Other values will be ignored.
1650+
- name: --gpu-driver
1651+
type: string
1652+
short-summary: Whether to install driver for GPU node pool. Possible values are "install" or "none". Default is "install".
16501653
16511654
examples:
16521655
- name: Create a nodepool in an existing AKS cluster with ephemeral os enabled.

src/azure-cli/azure/cli/command_modules/acs/_params.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
CONST_GPU_INSTANCE_PROFILE_MIG1_G, CONST_GPU_INSTANCE_PROFILE_MIG2_G,
1717
CONST_GPU_INSTANCE_PROFILE_MIG3_G, CONST_GPU_INSTANCE_PROFILE_MIG4_G,
1818
CONST_GPU_INSTANCE_PROFILE_MIG7_G, CONST_LOAD_BALANCER_SKU_BASIC,
19+
CONST_GPU_DRIVER_INSTALL, CONST_GPU_DRIVER_NONE,
1920
CONST_LOAD_BALANCER_SKU_STANDARD, CONST_MANAGED_CLUSTER_SKU_TIER_FREE,
2021
CONST_MANAGED_CLUSTER_SKU_TIER_STANDARD, CONST_MANAGED_CLUSTER_SKU_TIER_PREMIUM,
2122
CONST_NETWORK_DATAPLANE_AZURE, CONST_NETWORK_DATAPLANE_CILIUM,
@@ -189,6 +190,11 @@
189190
CONST_GPU_INSTANCE_PROFILE_MIG7_G,
190191
]
191192

193+
gpu_driver_install_modes = [
194+
CONST_GPU_DRIVER_INSTALL,
195+
CONST_GPU_DRIVER_NONE
196+
]
197+
192198
nrg_lockdown_restriction_levels = [
193199
CONST_NRG_LOCKDOWN_RESTRICTION_LEVEL_READONLY,
194200
CONST_NRG_LOCKDOWN_RESTRICTION_LEVEL_UNRESTRICTED,
@@ -799,6 +805,7 @@ def load_arguments(self, _):
799805
c.argument('enable_secure_boot', action='store_true')
800806
c.argument("if_match")
801807
c.argument("if_none_match")
808+
c.argument('gpu_driver', arg_type=get_enum_type(gpu_driver_install_modes))
802809

803810
with self.argument_context('aks nodepool update', resource_type=ResourceType.MGMT_CONTAINERSERVICE, operation_group='agent_pools') as c:
804811
c.argument('enable_cluster_autoscaler', options_list=[

src/azure-cli/azure/cli/command_modules/acs/agentpool_decorator.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,6 +1550,34 @@ def get_if_none_match(self) -> str:
15501550
:return: string
15511551
"""
15521552
return self.raw_param.get("if_none_match")
1553+
1554+
def _get_gpu_driver(self) -> str:
1555+
"""Obtain the value of gpu_driver, default value is CONST_GPU_DRIVER_INSTALL.
1556+
1557+
:return: string
1558+
"""
1559+
# read the original value passed by the command
1560+
gpu_driver = self.raw_param.get("gpu_driver")
1561+
# In create mode, try to read the property value corresponding to the parameter from the `agentpool` object
1562+
if self.decorator_mode == DecoratorMode.CREATE:
1563+
if (
1564+
self.agentpool and
1565+
hasattr(self.agentpool, "gpu_profile") and # backward compatibility
1566+
self.agentpool.gpu_profile and
1567+
self.agentpool.gpu_profile.driver is not None
1568+
):
1569+
gpu_driver = self.agentpool.gpu_profile.driver
1570+
1571+
# this parameter does not need dynamic completion
1572+
# this parameter does not need validation
1573+
return gpu_driver
1574+
1575+
def get_gpu_driver(self) -> Union[str, None]:
1576+
"""Obtain the value of gpu_driver.
1577+
1578+
:return: string or None
1579+
"""
1580+
return self._get_gpu_driver()
15531581

15541582

15551583
class AKSAgentPoolAddDecorator:
@@ -1914,6 +1942,23 @@ def set_up_agentpool_windows_profile(self, agentpool: AgentPool) -> AgentPool:
19141942
)
19151943

19161944
return agentpool
1945+
1946+
def set_up_gpu_profile(self, agentpool: AgentPool) -> AgentPool:
1947+
"""Set up gpu profile for the AgentPool object.
1948+
1949+
:return: the AgentPool object
1950+
"""
1951+
self._ensure_agentpool(agentpool)
1952+
1953+
gpu_driver = self.context.get_gpu_driver()
1954+
1955+
# Construct AgentPoolGPUProfile if one of the fields has been set
1956+
if gpu_driver:
1957+
agentpool.gpu_profile = self.models.GPUProfile( # pylint: disable=no-member
1958+
driver=gpu_driver
1959+
)
1960+
1961+
return agentpool
19171962

19181963
def construct_agentpool_profile_default(self, bypass_restore_defaults: bool = False) -> AgentPool:
19191964
"""The overall controller used to construct the AgentPool profile by default.
@@ -1959,6 +2004,8 @@ def construct_agentpool_profile_default(self, bypass_restore_defaults: bool = Fa
19592004
agentpool = self.set_up_agentpool_security_profile(agentpool)
19602005
# set up message of the day
19612006
agentpool = self.set_up_motd(agentpool)
2007+
# set up gpu profile
2008+
agentpool = self.set_up_gpu_profile(agentpool)
19622009
# restore defaults
19632010
if not bypass_restore_defaults:
19642011
agentpool = self._restore_defaults_in_agentpool(agentpool)

src/azure-cli/azure/cli/command_modules/acs/custom.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,6 +2417,8 @@ def aks_agentpool_add(
24172417
# etag headers
24182418
if_match=None,
24192419
if_none_match=None,
2420+
# gpu driver
2421+
gpu_driver=None,
24202422
):
24212423
# DO NOT MOVE: get all the original parameters and save them as a dictionary
24222424
raw_parameters = locals()

src/azure-cli/azure/cli/command_modules/acs/tests/latest/test_agentpool_decorator.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,24 @@ def common_get_node_public_ip_prefix_id(self):
690690
ctx_1.attach_agentpool(agentpool)
691691
self.assertEqual(ctx_1.get_node_public_ip_prefix_id(), "test_node_public_ip_prefix_id")
692692

693+
def common_get_gpu_driver(self):
694+
# default
695+
ctx_1 = AKSAgentPoolContext(
696+
self.cmd,
697+
AKSAgentPoolParamDict({"gpu_driver": None}),
698+
self.models,
699+
DecoratorMode.CREATE,
700+
self.agentpool_decorator_mode,
701+
)
702+
self.assertEqual(ctx_1.get_gpu_driver(), None)
703+
agentpool = self.create_initialized_agentpool_instance(gpu_driver="install")
704+
ctx_1.attach_agentpool(agentpool)
705+
self.assertEqual(ctx_1.get_gpu_driver(), "install")
706+
707+
agentpool2 = self.create_initialized_agentpool_instance(gpu_driver="none")
708+
ctx_1.attach_agentpool(agentpool2)
709+
self.assertEqual(ctx_1.get_gpu_driver(), "none")
710+
693711
def common_get_node_count_and_enable_cluster_autoscaler_min_max_count(
694712
self,
695713
):
@@ -1788,6 +1806,9 @@ def test_get_if_match(self):
17881806
def test_get_if_none_match(self):
17891807
self.get_if_none_match()
17901808

1809+
def test_get_gpu_driver(self):
1810+
self.test_get_gpu_driver()
1811+
17911812
class AKSAgentPoolContextManagedClusterModeTestCase(AKSAgentPoolContextCommonTestCase):
17921813
def setUp(self):
17931814
self.cli_ctx = MockCLI()
@@ -2431,6 +2452,28 @@ def common_set_up_agentpool_security_profile(self):
24312452
)
24322453
self.assertEqual(dec_agentpool_1, ground_truth_agentpool_1)
24332454

2455+
def common_set_up_gpu_profile(self):
2456+
dec_1 = AKSAgentPoolAddDecorator(
2457+
self.cmd,
2458+
self.client,
2459+
{"gpu_driver": "install"},
2460+
self.resource_type,
2461+
self.agentpool_decorator_mode,
2462+
)
2463+
# fail on passing the wrong agentpool object
2464+
with self.assertRaises(CLIInternalError):
2465+
dec_1.set_up_gpu_profile(None)
2466+
agentpool_1 = self.create_initialized_agentpool_instance(restore_defaults=False)
2467+
dec_1.context.attach_agentpool(agentpool_1)
2468+
dec_agentpool_1 = dec_1.set_up_gpu_profile(agentpool_1)
2469+
dec_agentpool_1 = self._restore_defaults_in_agentpool(dec_agentpool_1)
2470+
ground_truth_agentpool_1 = self.create_initialized_agentpool_instance(
2471+
gpu_profile=self.models.GPUProfile(
2472+
gpu_driver="install",
2473+
)
2474+
)
2475+
self.assertEqual(dec_agentpool_1, ground_truth_agentpool_1)
2476+
24342477
class AKSAgentPoolAddDecoratorStandaloneModeTestCase(AKSAgentPoolAddDecoratorCommonTestCase):
24352478
def setUp(self):
24362479
self.cli_ctx = MockCLI()

0 commit comments

Comments
 (0)