@@ -23,35 +23,46 @@ def start(self):
2323 # build the create cluster request
2424 clusters = get_cluster_from_connection_info (self .config ['connectionInfo' ], self .plugin_config ['connectionInfo' ])
2525
26+ is_autopilot = self .config .get ("isAutopilot" , False )
27+ is_regional = is_autopilot or self .config .get ('isRegional' , False )
28+
2629 cluster_builder = clusters .new_cluster_builder ()
2730
2831 cluster_builder .with_name (self .cluster_name )
29- cluster_builder .with_version (self .config .get ("clusterVersion" , "latest" ))
30- cluster_builder .with_initial_node_count (self .config .get ("numNodes" , 3 ))
32+ if is_autopilot :
33+ cluster_builder .with_regional (True , []) # autopilot => regional
34+ cluster_builder .with_autopilot (True , self .config .get ("releaseChannel" , "STABLE" ))
35+ else :
36+ cluster_builder .with_version (self .config .get ("clusterVersion" , "latest" ))
37+ cluster_builder .with_initial_node_count (self .config .get ("numNodes" , 3 ))
3138 cluster_builder .with_network (self .config .get ("inheritFromDSSHost" , True ),
3239 self .config .get ("network" , "" ).strip (),
3340 self .config .get ("subNetwork" , "" ).strip ())
34- cluster_builder .with_vpc_native_settings (self .config .get ("isVpcNative" , None ),
41+ vpc_native = is_autopilot or self .config .get ("isVpcNative" , None )
42+ cluster_builder .with_vpc_native_settings (vpc_native ,
3543 self .config .get ("podIpRange" , "" ),
3644 self .config .get ("svcIpRange" , "" ))
3745 cluster_builder .with_labels (self .config .get ("clusterLabels" , {}))
38- cluster_builder .with_legacy_auth (self .config .get ("legacyAuth" , False ))
39- cluster_builder .with_http_load_balancing (self .config .get ("httpLoadBalancing" , False ))
40- for node_pool in self .config .get ('nodePools' , []):
41- node_pool_builder = cluster_builder .get_node_pool_builder ()
42- node_pool_builder .with_node_count (node_pool .get ('numNodes' , 3 ))
43- node_pool_builder .use_gcr_io (node_pool .get ('useGcrIo' , False ))
44- node_pool_builder .with_oauth_scopes (node_pool .get ('oauthScopes' , None ))
45- node_pool_builder .with_machine_type (node_pool .get ('machineType' , None ))
46- node_pool_builder .with_disk_type (node_pool .get ('diskType' , None ))
47- node_pool_builder .with_disk_size_gb (node_pool .get ('diskSizeGb' , None ))
48- node_pool_builder .with_service_account (node_pool .get ('serviceAccountType' , None ),
49- node_pool .get ('serviceAccount' , None ))
50- node_pool_builder .with_auto_scaling (node_pool .get ('numNodesAutoscaling' , False ), node_pool .get ('minNumNodes' , 2 ), node_pool .get ('maxNumNodes' , 5 ))
51- node_pool_builder .with_gpu (node_pool .get ('withGpu' , False ), node_pool .get ('gpuType' , None ), node_pool .get ('gpuCount' , 1 ))
52- node_pool_builder .with_nodepool_labels (node_pool .get ('nodepoolLabels' , {}))
53- node_pool_builder .with_nodepool_tags (node_pool .get ('networkTags' , []))
54- node_pool_builder .build ()
46+ if not is_autopilot :
47+ cluster_builder .with_legacy_auth (self .config .get ("legacyAuth" , False ))
48+ cluster_builder .with_http_load_balancing (self .config .get ("httpLoadBalancing" , False ))
49+ if is_regional :
50+ cluster_builder .with_regional (True , self .config .get ("locations" , []))
51+ for node_pool in self .config .get ('nodePools' , []):
52+ node_pool_builder = cluster_builder .get_node_pool_builder ()
53+ node_pool_builder .with_node_count (node_pool .get ('numNodes' , 3 ))
54+ node_pool_builder .use_gcr_io (node_pool .get ('useGcrIo' , False ))
55+ node_pool_builder .with_oauth_scopes (node_pool .get ('oauthScopes' , None ))
56+ node_pool_builder .with_machine_type (node_pool .get ('machineType' , None ))
57+ node_pool_builder .with_disk_type (node_pool .get ('diskType' , None ))
58+ node_pool_builder .with_disk_size_gb (node_pool .get ('diskSizeGb' , None ))
59+ node_pool_builder .with_service_account (node_pool .get ('serviceAccountType' , None ),
60+ node_pool .get ('serviceAccount' , None ))
61+ node_pool_builder .with_auto_scaling (node_pool .get ('numNodesAutoscaling' , False ), node_pool .get ('minNumNodes' , 2 ), node_pool .get ('maxNumNodes' , 5 ))
62+ node_pool_builder .with_gpu (node_pool .get ('withGpu' , False ), node_pool .get ('gpuType' , None ), node_pool .get ('gpuCount' , 1 ))
63+ node_pool_builder .with_nodepool_labels (node_pool .get ('nodepoolLabels' , {}))
64+ node_pool_builder .with_nodepool_tags (node_pool .get ('networkTags' , []))
65+ node_pool_builder .build ()
5566 cluster_builder .with_settings_valve (self .config .get ("creationSettingsValve" , None ))
5667
5768 start_op = cluster_builder .build ()
@@ -62,7 +73,7 @@ def start(self):
6273 logging .info ("Cluster started" )
6374
6475 # cluster is ready, fetch its info from GKE
65- cluster = clusters .get_cluster (self .cluster_name )
76+ cluster = clusters .get_cluster (self .cluster_name , 'regional' if is_regional else 'zonal' )
6677 cluster_info = cluster .get_info ()
6778
6879 # build the config file for kubectl
@@ -77,15 +88,18 @@ def start(self):
7788 create_admin_binding (self .config .get ("userName" , None ), kube_config_path )
7889
7990 # Launch NVIDIA driver installer daemonset (will only apply on tainted gpu nodes)
80- create_installer_daemonset (kube_config_path = kube_config_path )
91+ if not is_autopilot : # GPUs are not supported on autopilot (says the GKE doc)
92+ create_installer_daemonset (kube_config_path = kube_config_path )
8193
8294 # collect and prepare the overrides so that DSS can know where and how to use the cluster
8395 overrides = make_overrides (self .config , kube_config , kube_config_path )
8496 return [overrides , {'kube_config_path' :kube_config_path , 'cluster' :cluster_info }]
8597
8698 def stop (self , data ):
8799 clusters = get_cluster_from_connection_info (self .config ['connectionInfo' ], self .plugin_config ['connectionInfo' ])
88- cluster = clusters .get_cluster (self .cluster_name )
100+ is_autopilot = self .config .get ("isAutopilot" , False )
101+ is_regional = is_autopilot or self .config .get ('isRegional' , False )
102+ cluster = clusters .get_cluster (self .cluster_name , 'regional' if is_regional else 'zonal' )
89103 stop_op = cluster .stop ()
90104 logging .info ("Waiting for cluster stop" )
91105 stop_op .wait_done ()
0 commit comments