1616
1717from argparse import Namespace
1818from dataclasses import dataclass
19+ from typing import Any
1920from unittest .mock import MagicMock , patch
2021import pytest
2122
22- from xpk .commands .cluster import _install_kueue , _validate_cluster_create_args
23+ from xpk .commands .cluster import _install_kueue , _validate_cluster_create_args , run_gke_cluster_create_command
2324from xpk .core .system_characteristics import SystemCharacteristics , UserFacingNameToSystemCharacteristics
25+ from xpk .core .testing .commands_tester import CommandsTester
2426from xpk .utils .feature_flags import FeatureFlags
2527
2628
@@ -29,10 +31,11 @@ class _Mocks:
2931 common_print_mock : MagicMock
3032 commands_print_mock : MagicMock
3133 commands_get_reservation_deployment_type : MagicMock
34+ commands_tester : CommandsTester
3235
3336
3437@pytest .fixture
35- def mock_common_print_and_exit (mocker ):
38+ def mocks (mocker ) -> _Mocks :
3639 common_print_mock = mocker .patch (
3740 'xpk.commands.common.xpk_print' ,
3841 return_value = None ,
@@ -48,103 +51,137 @@ def mock_common_print_and_exit(mocker):
4851 common_print_mock = common_print_mock ,
4952 commands_get_reservation_deployment_type = commands_get_reservation_deployment_type ,
5053 commands_print_mock = commands_print_mock ,
54+ commands_tester = CommandsTester (
55+ mocker ,
56+ run_command_with_updates_path = (
57+ 'xpk.commands.cluster.run_command_with_updates'
58+ ),
59+ ),
5160 )
5261
5362
54- DEFAULT_TEST_SYSTEM : SystemCharacteristics = (
55- UserFacingNameToSystemCharacteristics ['l4-1' ]
56- )
63+ def construct_args (** kwargs : Any ) -> Namespace :
64+ args_dict = dict (
65+ project = 'project' ,
66+ zone = 'us-central1-a' ,
67+ reservation = '' ,
68+ default_pool_cpu_machine_type = 'test-machine-type' ,
69+ cluster = 'test-cluster' ,
70+ default_pool_cpu_num_nodes = '100' ,
71+ sub_slicing = False ,
72+ gke_version = '' ,
73+ private = False ,
74+ authorized_networks = None ,
75+ enable_pathways = False ,
76+ enable_ray_cluster = False ,
77+ enable_workload_identity = False ,
78+ enable_gcsfuse_csi_driver = False ,
79+ enable_gcpfilestore_csi_driver = False ,
80+ enable_parallelstore_csi_driver = False ,
81+ enable_pd_csi_driver = False ,
82+ enable_lustre_csi_driver = False ,
83+ custom_cluster_arguments = '' ,
84+ num_slices = 1 ,
85+ num_nodes = 1 ,
86+ flex = False ,
87+ memory_limit = '100Gi' ,
88+ cpu_limit = 100 ,
89+ cluster_cpu_machine_type = '' ,
90+ )
91+ args_dict .update (kwargs )
92+ return Namespace (** args_dict )
93+
94+
95+ GPU_TEST_SYSTEM : SystemCharacteristics = UserFacingNameToSystemCharacteristics [
96+ 'l4-1'
97+ ]
5798SUB_SLICING_SYSTEM : SystemCharacteristics = (
5899 UserFacingNameToSystemCharacteristics ['v6e-4x4' ]
59100)
101+ TPU_TEST_SYSTEM : SystemCharacteristics = UserFacingNameToSystemCharacteristics [
102+ 'v6e-4x4'
103+ ]
60104
61105
62106def test_validate_cluster_create_args_for_correct_args_pass (
63- mock_common_print_and_exit : _Mocks ,
107+ mocks : _Mocks ,
64108):
65109 args = Namespace ()
66110
67- _validate_cluster_create_args (args , DEFAULT_TEST_SYSTEM )
111+ _validate_cluster_create_args (args , GPU_TEST_SYSTEM )
68112
69- assert mock_common_print_and_exit .common_print_mock .call_count == 0
113+ assert mocks .common_print_mock .call_count == 0
70114
71115
72116def test_validate_cluster_create_args_for_correct_sub_slicing_args_pass (
73- mock_common_print_and_exit : _Mocks ,
117+ mocks : _Mocks ,
74118):
75119 FeatureFlags .SUB_SLICING_ENABLED = True
76- args = Namespace (
120+ args = construct_args (
77121 sub_slicing = True ,
78122 reservation = 'test-reservation' ,
79- project = 'project' ,
80- zone = 'zone' ,
81123 )
82124
83125 _validate_cluster_create_args (args , SUB_SLICING_SYSTEM )
84126
85- assert mock_common_print_and_exit .common_print_mock .call_count == 0
127+ assert mocks .common_print_mock .call_count == 0
86128
87129
88130def test_validate_cluster_create_args_for_not_supported_system_throws (
89- mock_common_print_and_exit : _Mocks ,
131+ mocks : _Mocks ,
90132):
91133 FeatureFlags .SUB_SLICING_ENABLED = True
92- args = Namespace (
134+ args = construct_args (
93135 sub_slicing = True ,
94136 reservation = 'test-reservation' ,
95- project = 'project' ,
96- zone = 'zone' ,
97137 )
98138
99139 with pytest .raises (SystemExit ):
100- _validate_cluster_create_args (args , DEFAULT_TEST_SYSTEM )
140+ _validate_cluster_create_args (args , GPU_TEST_SYSTEM )
101141
102- assert mock_common_print_and_exit .common_print_mock .call_count == 1
142+ assert mocks .common_print_mock .call_count == 1
103143 assert (
104- mock_common_print_and_exit .common_print_mock .call_args [0 ][0 ]
144+ mocks .common_print_mock .call_args [0 ][0 ]
105145 == 'Error: l4-1 does not support Sub-slicing.'
106146 )
107147
108148
109149def test_validate_cluster_create_args_for_missing_reservation (
110- mock_common_print_and_exit : _Mocks ,
150+ mocks : _Mocks ,
111151):
112152 FeatureFlags .SUB_SLICING_ENABLED = True
113- args = Namespace (
114- sub_slicing = True , project = 'project' , zone = 'zone' , reservation = None
153+ args = construct_args (
154+ sub_slicing = True ,
155+ reservation = None ,
115156 )
116157
117158 with pytest .raises (SystemExit ):
118159 _validate_cluster_create_args (args , SUB_SLICING_SYSTEM )
119160
120- assert mock_common_print_and_exit .commands_print_mock .call_count == 1
161+ assert mocks .commands_print_mock .call_count == 1
121162 assert (
122163 'Validation failed: Sub-slicing cluster creation requires'
123- in mock_common_print_and_exit .commands_print_mock .call_args [0 ][0 ]
164+ in mocks .commands_print_mock .call_args [0 ][0 ]
124165 )
125166
126167
127168def test_validate_cluster_create_args_for_invalid_reservation (
128- mock_common_print_and_exit : _Mocks ,
169+ mocks : _Mocks ,
129170):
130171 FeatureFlags .SUB_SLICING_ENABLED = True
131- args = Namespace (
172+ args = construct_args (
132173 sub_slicing = True ,
133- project = 'project' ,
134- zone = 'zone' ,
135174 reservation = 'test-reservation' ,
136175 )
137- mock_common_print_and_exit .commands_get_reservation_deployment_type .return_value = (
138- 'SPARSE'
139- )
176+ mocks .commands_get_reservation_deployment_type .return_value = 'SPARSE'
140177
141178 with pytest .raises (SystemExit ):
142179 _validate_cluster_create_args (args , SUB_SLICING_SYSTEM )
143180
144- assert mock_common_print_and_exit .commands_print_mock .call_count == 5
181+ assert mocks .commands_print_mock .call_count == 5
145182 assert (
146183 'Refer to the documentation for more information on creating Cluster'
147- in mock_common_print_and_exit .commands_print_mock .call_args [0 ][0 ]
184+ in mocks .commands_print_mock .call_args [0 ][0 ]
148185 )
149186
150187
@@ -155,17 +192,73 @@ def test_install_kueue_returns_kueue_installation_code(
155192 mock_kueue_manager_install .return_value = 17
156193
157194 code = _install_kueue (
158- args = Namespace (
159- num_slices = 1 ,
160- num_nodes = 1 ,
161- flex = False ,
162- memory_limit = '100Gi' ,
163- cpu_limit = 100 ,
164- enable_pathways = False ,
165- sub_slicing = False ,
166- ),
167- system = DEFAULT_TEST_SYSTEM ,
195+ args = construct_args (),
196+ system = GPU_TEST_SYSTEM ,
168197 autoprovisioning_config = None ,
169198 )
170199
171200 assert code == 17
201+
202+
203+ def test_run_gke_cluster_create_command_specifies_custom_cluster_arguments_last (
204+ mocks : _Mocks ,
205+ ):
206+ result = run_gke_cluster_create_command (
207+ args = construct_args (
208+ custom_cluster_arguments = '--enable-autoscaling=False --foo=baz'
209+ ),
210+ gke_control_plane_version = '1.2.3' ,
211+ system = TPU_TEST_SYSTEM ,
212+ )
213+
214+ assert result == 0
215+ mocks .commands_tester .assert_command_run (
216+ 'clusters create' ,
217+ ' --enable-autoscaling' ,
218+ ' --enable-autoscaling=False --foo=baz' ,
219+ )
220+
221+
222+ def test_run_gke_cluster_create_command_without_gke_version_does_not_have_no_autoupgrade_flag (
223+ mocks : _Mocks ,
224+ ):
225+ result = run_gke_cluster_create_command (
226+ args = construct_args (gke_version = '' ),
227+ gke_control_plane_version = '1.2.3' ,
228+ system = TPU_TEST_SYSTEM ,
229+ )
230+
231+ assert result == 0
232+ mocks .commands_tester .assert_command_not_run (
233+ 'clusters create' , ' --no-enable-autoupgrade'
234+ )
235+
236+
237+ def test_run_gke_cluster_create_command_with_gke_version_has_no_autoupgrade_flag (
238+ mocks : _Mocks ,
239+ ):
240+ result = run_gke_cluster_create_command (
241+ args = construct_args (gke_version = '1.2.3' ),
242+ gke_control_plane_version = '1.2.3' ,
243+ system = TPU_TEST_SYSTEM ,
244+ )
245+
246+ assert result == 0
247+ mocks .commands_tester .assert_command_run (
248+ 'clusters create' , ' --no-enable-autoupgrade'
249+ )
250+
251+
252+ def test_run_gke_cluster_create_command_with_gpu_system_has_no_enable_autoupgrade (
253+ mocks : _Mocks ,
254+ ):
255+ result = run_gke_cluster_create_command (
256+ args = construct_args (gke_version = '' ),
257+ gke_control_plane_version = '1.2.3' ,
258+ system = GPU_TEST_SYSTEM ,
259+ )
260+
261+ assert result == 0
262+ mocks .commands_tester .assert_command_run (
263+ 'clusters create' , ' --no-enable-autoupgrade'
264+ )
0 commit comments