55
66import  arrow 
77from  aws_library .ec2  import  AWSTagKey , EC2InstanceData 
8+ from  aws_library .ec2 ._models  import  AWSTagValue 
89from  fastapi  import  FastAPI 
910from  models_library .users  import  UserID 
1011from  models_library .wallets  import  WalletID 
1112from  pydantic  import  parse_obj_as 
1213from  servicelib .logging_utils  import  log_catch 
13- 
14+ from  servicelib .utils  import  limited_gather 
15+ 
16+ from  ..constants  import  (
17+     DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY ,
18+     DOCKER_STACK_DEPLOY_COMMAND_NAME ,
19+     ROLE_TAG_KEY ,
20+     USER_ID_TAG_KEY ,
21+     WALLET_ID_TAG_KEY ,
22+     WORKER_ROLE_TAG_VALUE ,
23+ )
1424from  ..core .settings  import  get_application_settings 
1525from  ..modules .clusters  import  (
1626    delete_clusters ,
1727    get_all_clusters ,
1828    get_cluster_workers ,
1929    set_instance_heartbeat ,
2030)
31+ from  ..utils .clusters  import  create_deploy_cluster_stack_script 
2132from  ..utils .dask  import  get_scheduler_auth , get_scheduler_url 
22- from  ..utils .ec2  import  HEARTBEAT_TAG_KEY 
33+ from  ..utils .ec2  import  (
34+     HEARTBEAT_TAG_KEY ,
35+     get_cluster_name ,
36+     user_id_from_instance_tags ,
37+     wallet_id_from_instance_tags ,
38+ )
2339from  .dask  import  is_scheduler_busy , ping_scheduler 
40+ from  .ec2  import  get_ec2_client 
41+ from  .ssm  import  get_ssm_client 
2442
2543_logger  =  logging .getLogger (__name__ )
2644
@@ -42,8 +60,8 @@ def _get_instance_last_heartbeat(instance: EC2InstanceData) -> datetime.datetime
4260async  def  _get_all_associated_worker_instances (
4361    app : FastAPI ,
4462    primary_instances : Iterable [EC2InstanceData ],
45- ) ->  list [EC2InstanceData ]:
46-     worker_instances   =  [] 
63+ ) ->  set [EC2InstanceData ]:
64+     worker_instances :  set [ EC2InstanceData ]  =   set () 
4765    for  instance  in  primary_instances :
4866        assert  "user_id"  in  instance .tags   # nosec 
4967        user_id  =  UserID (instance .tags [_USER_ID_TAG_KEY ])
@@ -55,20 +73,20 @@ async def _get_all_associated_worker_instances(
5573            else  None 
5674        )
5775
58-         worker_instances .extend (
76+         worker_instances .update (
5977            await  get_cluster_workers (app , user_id = user_id , wallet_id = wallet_id )
6078        )
6179    return  worker_instances 
6280
6381
6482async  def  _find_terminateable_instances (
6583    app : FastAPI , instances : Iterable [EC2InstanceData ]
66- ) ->  list [EC2InstanceData ]:
84+ ) ->  set [EC2InstanceData ]:
6785    app_settings  =  get_application_settings (app )
6886    assert  app_settings .CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES   # nosec 
6987
7088    # get the corresponding ec2 instance data 
71-     terminateable_instances : list [EC2InstanceData ] =  [] 
89+     terminateable_instances : set [EC2InstanceData ] =  set () 
7290
7391    time_to_wait_before_termination  =  (
7492        app_settings .CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION 
@@ -82,7 +100,7 @@ async def _find_terminateable_instances(
82100            elapsed_time_since_heartbeat  =  arrow .utcnow ().datetime  -  last_heartbeat 
83101            allowed_time_to_wait  =  time_to_wait_before_termination 
84102            if  elapsed_time_since_heartbeat  >=  allowed_time_to_wait :
85-                 terminateable_instances .append (instance )
103+                 terminateable_instances .add (instance )
86104            else :
87105                _logger .info (
88106                    "%s has still %ss before being terminateable" ,
@@ -93,14 +111,14 @@ async def _find_terminateable_instances(
93111            elapsed_time_since_startup  =  arrow .utcnow ().datetime  -  instance .launch_time 
94112            allowed_time_to_wait  =  startup_delay 
95113            if  elapsed_time_since_startup  >=  allowed_time_to_wait :
96-                 terminateable_instances .append (instance )
114+                 terminateable_instances .add (instance )
97115
98116    # get all terminateable instances associated worker instances 
99117    worker_instances  =  await  _get_all_associated_worker_instances (
100118        app , terminateable_instances 
101119    )
102120
103-     return  terminateable_instances   +   worker_instances 
121+     return  terminateable_instances . union ( worker_instances ) 
104122
105123
106124async  def  check_clusters (app : FastAPI ) ->  None :
@@ -112,6 +130,7 @@ async def check_clusters(app: FastAPI) -> None:
112130        if  await  ping_scheduler (get_scheduler_url (instance ), get_scheduler_auth (app ))
113131    }
114132
133+     # set intance heartbeat if scheduler is busy 
115134    for  instance  in  connected_intances :
116135        with  log_catch (_logger , reraise = False ):
117136            # NOTE: some connected instance could in theory break between these 2 calls, therefore this is silenced and will 
@@ -124,6 +143,7 @@ async def check_clusters(app: FastAPI) -> None:
124143                    f"{ instance .id = }   for { instance .tags = }  " ,
125144                )
126145                await  set_instance_heartbeat (app , instance = instance )
146+     # clean any cluster that is not doing anything 
127147    if  terminateable_instances  :=  await  _find_terminateable_instances (
128148        app , connected_intances 
129149    ):
@@ -138,7 +158,7 @@ async def check_clusters(app: FastAPI) -> None:
138158        for  instance  in  disconnected_instances 
139159        if  _get_instance_last_heartbeat (instance ) is  None 
140160    }
141- 
161+      # remove instances that were starting for too long 
142162    if  terminateable_instances  :=  await  _find_terminateable_instances (
143163        app , starting_instances 
144164    ):
@@ -149,7 +169,72 @@ async def check_clusters(app: FastAPI) -> None:
149169        )
150170        await  delete_clusters (app , instances = terminateable_instances )
151171
152-     # the other instances are broken (they were at some point connected but now not anymore) 
172+     # NOTE: transmit command to start docker swarm/stack if needed 
173+     # once the instance is connected to the SSM server, 
174+     # use ssm client to send the command to these instances, 
175+     # we send a command that contain: 
176+     # the docker-compose file in binary, 
177+     # the call to init the docker swarm and the call to deploy the stack 
178+     instances_in_need_of_deployment  =  {
179+         i 
180+         for  i  in  starting_instances  -  terminateable_instances 
181+         if  DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY  not  in   i .tags 
182+     }
183+ 
184+     if  instances_in_need_of_deployment :
185+         app_settings  =  get_application_settings (app )
186+         ssm_client  =  get_ssm_client (app )
187+         ec2_client  =  get_ec2_client (app )
188+         instances_in_need_of_deployment_ssm_connection_state  =  await  limited_gather (
189+             * [
190+                 ssm_client .is_instance_connected_to_ssm_server (i .id )
191+                 for  i  in  instances_in_need_of_deployment 
192+             ],
193+             reraise = False ,
194+             log = _logger ,
195+             limit = 20 ,
196+         )
197+         ec2_connected_to_ssm_server  =  [
198+             i 
199+             for  i , c  in  zip (
200+                 instances_in_need_of_deployment ,
201+                 instances_in_need_of_deployment_ssm_connection_state ,
202+                 strict = True ,
203+             )
204+             if  c  is  True 
205+         ]
206+         started_instances_ready_for_command  =  ec2_connected_to_ssm_server 
207+         if  started_instances_ready_for_command :
208+             # we need to send 1 command per machine here, as the user_id/wallet_id changes 
209+             for  i  in  started_instances_ready_for_command :
210+                 ssm_command  =  await  ssm_client .send_command (
211+                     [i .id ],
212+                     command = create_deploy_cluster_stack_script (
213+                         app_settings ,
214+                         cluster_machines_name_prefix = get_cluster_name (
215+                             app_settings ,
216+                             user_id = user_id_from_instance_tags (i .tags ),
217+                             wallet_id = wallet_id_from_instance_tags (i .tags ),
218+                             is_manager = False ,
219+                         ),
220+                         additional_custom_tags = {
221+                             USER_ID_TAG_KEY : i .tags [USER_ID_TAG_KEY ],
222+                             WALLET_ID_TAG_KEY : i .tags [WALLET_ID_TAG_KEY ],
223+                             ROLE_TAG_KEY : WORKER_ROLE_TAG_VALUE ,
224+                         },
225+                     ),
226+                     command_name = DOCKER_STACK_DEPLOY_COMMAND_NAME ,
227+                 )
228+             await  ec2_client .set_instances_tags (
229+                 started_instances_ready_for_command ,
230+                 tags = {
231+                     DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY : AWSTagValue (
232+                         ssm_command .command_id 
233+                     ),
234+                 },
235+             )
236+ 
237+     # the remaining instances are broken (they were at some point connected but now not anymore) 
153238    broken_instances  =  disconnected_instances  -  starting_instances 
154239    if  terminateable_instances  :=  await  _find_terminateable_instances (
155240        app , broken_instances 
0 commit comments