3636from .flexible_server_virtual_network import prepare_private_network , prepare_private_dns_zone , prepare_public_network
3737from .validators import pg_arguments_validator , validate_server_name , validate_and_format_restore_point_in_time , \
3838 validate_postgres_replica , validate_georestore_network , pg_byok_validator , validate_migration_runtime_server , \
39- validate_resource_group , check_resource_group , validate_citus_cluster
39+ validate_resource_group , check_resource_group , validate_citus_cluster , cluster_byok_validator
4040
4141logger = get_logger (__name__ )
4242DEFAULT_DB_NAME = 'flexibleserverdb'
@@ -98,10 +98,12 @@ def flexible_server_create(cmd, client,
9898 backup_byok_identity = backup_byok_identity ,
9999 backup_byok_key = backup_byok_key ,
100100 performance_tier = performance_tier ,
101- create_cluster = create_cluster ,
102- cluster_size = cluster_size )
101+ create_cluster = create_cluster )
103102
104- cluster = postgresql_flexibleservers .models .Cluster (cluster_size = cluster_size ) if create_cluster else None
103+ cluster = None
104+ if create_cluster == 'ElasticCluster' :
105+ cluster_size = cluster_size if cluster_size else 2
106+ cluster = postgresql_flexibleservers .models .Cluster (cluster_size = cluster_size )
105107
106108 server_result = firewall_id = None
107109
@@ -163,7 +165,7 @@ def flexible_server_create(cmd, client,
163165 firewall_id = create_firewall_rule (db_context , cmd , resource_group_name , server_name , start_ip , end_ip )
164166
165167 # Create mysql database if it does not exist
166- if database_name is not None or (create_default_db and create_default_db .lower () == 'enabled' ):
168+ if ( database_name is not None or (create_default_db and create_default_db .lower () == 'enabled' ) and create_cluster != 'ElasticCluster ' ):
167169 db_name = database_name if database_name else DEFAULT_DB_NAME
168170 _create_database (db_context , cmd , resource_group_name , server_name , db_name )
169171
@@ -231,9 +233,11 @@ def flexible_server_restore(cmd, client,
231233 logging_name = 'PostgreSQL' , command_group = 'postgres' , server_client = client , location = location )
232234 validate_server_name (db_context , server_name , 'Microsoft.DBforPostgreSQL/flexibleServers' )
233235
236+ instance = client .get (id_parts ['resource_group' ], id_parts ['name' ])
237+
238+ cluster_byok_validator (byok_identity , byok_key , backup_byok_identity , backup_byok_key , geo_redundant_backup , instance )
234239 pg_byok_validator (byok_identity , byok_key , backup_byok_identity , backup_byok_key , geo_redundant_backup )
235240
236- instance = client .get (id_parts ['resource_group' ], id_parts ['name' ])
237241 storage = postgresql_flexibleservers .models .Storage (type = storage_type if instance .storage .type != "PremiumV2_LRS" else None )
238242
239243 parameters = postgresql_flexibleservers .models .Server (
@@ -292,6 +296,7 @@ def flexible_server_update_custom_func(cmd, client, instance,
292296 auto_grow = None ,
293297 performance_tier = None ,
294298 iops = None , throughput = None ,
299+ cluster_size = None ,
295300 yes = False ):
296301
297302 # validator
@@ -809,7 +814,7 @@ def flexible_replica_promote(cmd, client, resource_group_name, server_name, prom
809814
810815
811816def _create_server (db_context , cmd , resource_group_name , server_name , tags , location , sku , administrator_login , administrator_login_password ,
812- storage , backup , network , version , high_availability , availability_zone , identity , data_encryption , auth_config ):
817+ storage , backup , network , version , high_availability , availability_zone , identity , data_encryption , auth_config , cluster ):
813818 validate_resource_group (resource_group_name )
814819
815820 logging_name , server_client = db_context .logging_name , db_context .server_client
@@ -835,6 +840,7 @@ def _create_server(db_context, cmd, resource_group_name, server_name, tags, loca
835840 identity = identity ,
836841 data_encryption = data_encryption ,
837842 auth_config = auth_config ,
843+ cluster = cluster ,
838844 create_mode = "Create" )
839845
840846 return resolve_poller (
@@ -861,8 +867,9 @@ def _create_database(db_context, cmd, resource_group_name, server_name, database
861867 '{} Database Create/Update' .format (logging_name ))
862868
863869
864- def database_create_func (client , resource_group_name , server_name , database_name = None , charset = None , collation = None ):
870+ def database_create_func (cmd , client , resource_group_name , server_name , database_name = None , charset = None , collation = None ):
865871 validate_resource_group (resource_group_name )
872+ validate_citus_cluster (cmd , resource_group_name , server_name )
866873
867874 if charset is None and collation is None :
868875 charset = 'utf8'
0 commit comments