55
66import arrow
77from aws_library .ec2 import AWSTagKey , EC2InstanceData
8+ from aws_library .ec2 ._models import AWSTagValue
89from fastapi import FastAPI
910from models_library .users import UserID
1011from models_library .wallets import WalletID
1112from pydantic import parse_obj_as
1213from servicelib .logging_utils import log_catch
13-
14+ from servicelib .utils import limited_gather
15+
16+ from ..constants import (
17+ DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY ,
18+ DOCKER_STACK_DEPLOY_COMMAND_NAME ,
19+ ROLE_TAG_KEY ,
20+ USER_ID_TAG_KEY ,
21+ WALLET_ID_TAG_KEY ,
22+ WORKER_ROLE_TAG_VALUE ,
23+ )
1424from ..core .settings import get_application_settings
1525from ..modules .clusters import (
1626 delete_clusters ,
1727 get_all_clusters ,
1828 get_cluster_workers ,
1929 set_instance_heartbeat ,
2030)
31+ from ..utils .clusters import create_deploy_cluster_stack_script
2132from ..utils .dask import get_scheduler_auth , get_scheduler_url
22- from ..utils .ec2 import HEARTBEAT_TAG_KEY
33+ from ..utils .ec2 import (
34+ HEARTBEAT_TAG_KEY ,
35+ get_cluster_name ,
36+ user_id_from_instance_tags ,
37+ wallet_id_from_instance_tags ,
38+ )
2339from .dask import is_scheduler_busy , ping_scheduler
40+ from .ec2 import get_ec2_client
41+ from .ssm import get_ssm_client
2442
2543_logger = logging .getLogger (__name__ )
2644
@@ -42,8 +60,8 @@ def _get_instance_last_heartbeat(instance: EC2InstanceData) -> datetime.datetime
4260async def _get_all_associated_worker_instances (
4361 app : FastAPI ,
4462 primary_instances : Iterable [EC2InstanceData ],
45- ) -> list [EC2InstanceData ]:
46- worker_instances = []
63+ ) -> set [EC2InstanceData ]:
64+ worker_instances : set [ EC2InstanceData ] = set ()
4765 for instance in primary_instances :
4866 assert "user_id" in instance .tags # nosec
4967 user_id = UserID (instance .tags [_USER_ID_TAG_KEY ])
@@ -55,20 +73,20 @@ async def _get_all_associated_worker_instances(
5573 else None
5674 )
5775
58- worker_instances .extend (
76+ worker_instances .update (
5977 await get_cluster_workers (app , user_id = user_id , wallet_id = wallet_id )
6078 )
6179 return worker_instances
6280
6381
6482async def _find_terminateable_instances (
6583 app : FastAPI , instances : Iterable [EC2InstanceData ]
66- ) -> list [EC2InstanceData ]:
84+ ) -> set [EC2InstanceData ]:
6785 app_settings = get_application_settings (app )
6886 assert app_settings .CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES # nosec
6987
7088 # get the corresponding ec2 instance data
71- terminateable_instances : list [EC2InstanceData ] = []
89+ terminateable_instances : set [EC2InstanceData ] = set ()
7290
7391 time_to_wait_before_termination = (
7492 app_settings .CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION
@@ -82,7 +100,7 @@ async def _find_terminateable_instances(
82100 elapsed_time_since_heartbeat = arrow .utcnow ().datetime - last_heartbeat
83101 allowed_time_to_wait = time_to_wait_before_termination
84102 if elapsed_time_since_heartbeat >= allowed_time_to_wait :
85- terminateable_instances .append (instance )
103+ terminateable_instances .add (instance )
86104 else :
87105 _logger .info (
88106 "%s has still %ss before being terminateable" ,
@@ -93,14 +111,14 @@ async def _find_terminateable_instances(
93111 elapsed_time_since_startup = arrow .utcnow ().datetime - instance .launch_time
94112 allowed_time_to_wait = startup_delay
95113 if elapsed_time_since_startup >= allowed_time_to_wait :
96- terminateable_instances .append (instance )
114+ terminateable_instances .add (instance )
97115
98116 # get all terminateable instances associated worker instances
99117 worker_instances = await _get_all_associated_worker_instances (
100118 app , terminateable_instances
101119 )
102120
103- return terminateable_instances + worker_instances
121+ return terminateable_instances . union ( worker_instances )
104122
105123
106124async def check_clusters (app : FastAPI ) -> None :
@@ -112,6 +130,7 @@ async def check_clusters(app: FastAPI) -> None:
112130 if await ping_scheduler (get_scheduler_url (instance ), get_scheduler_auth (app ))
113131 }
114132
133+ # set intance heartbeat if scheduler is busy
115134 for instance in connected_intances :
116135 with log_catch (_logger , reraise = False ):
117136 # NOTE: some connected instance could in theory break between these 2 calls, therefore this is silenced and will
@@ -124,6 +143,7 @@ async def check_clusters(app: FastAPI) -> None:
124143 f"{ instance .id = } for { instance .tags = } " ,
125144 )
126145 await set_instance_heartbeat (app , instance = instance )
146+ # clean any cluster that is not doing anything
127147 if terminateable_instances := await _find_terminateable_instances (
128148 app , connected_intances
129149 ):
@@ -138,7 +158,7 @@ async def check_clusters(app: FastAPI) -> None:
138158 for instance in disconnected_instances
139159 if _get_instance_last_heartbeat (instance ) is None
140160 }
141-
161+ # remove instances that were starting for too long
142162 if terminateable_instances := await _find_terminateable_instances (
143163 app , starting_instances
144164 ):
@@ -149,7 +169,72 @@ async def check_clusters(app: FastAPI) -> None:
149169 )
150170 await delete_clusters (app , instances = terminateable_instances )
151171
152- # the other instances are broken (they were at some point connected but now not anymore)
172+ # NOTE: transmit command to start docker swarm/stack if needed
173+ # once the instance is connected to the SSM server,
174+ # use ssm client to send the command to these instances,
175+ # we send a command that contain:
176+ # the docker-compose file in binary,
177+ # the call to init the docker swarm and the call to deploy the stack
178+ instances_in_need_of_deployment = {
179+ i
180+ for i in starting_instances - terminateable_instances
181+ if DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY not in i .tags
182+ }
183+
184+ if instances_in_need_of_deployment :
185+ app_settings = get_application_settings (app )
186+ ssm_client = get_ssm_client (app )
187+ ec2_client = get_ec2_client (app )
188+ instances_in_need_of_deployment_ssm_connection_state = await limited_gather (
189+ * [
190+ ssm_client .is_instance_connected_to_ssm_server (i .id )
191+ for i in instances_in_need_of_deployment
192+ ],
193+ reraise = False ,
194+ log = _logger ,
195+ limit = 20 ,
196+ )
197+ ec2_connected_to_ssm_server = [
198+ i
199+ for i , c in zip (
200+ instances_in_need_of_deployment ,
201+ instances_in_need_of_deployment_ssm_connection_state ,
202+ strict = True ,
203+ )
204+ if c is True
205+ ]
206+ started_instances_ready_for_command = ec2_connected_to_ssm_server
207+ if started_instances_ready_for_command :
208+ # we need to send 1 command per machine here, as the user_id/wallet_id changes
209+ for i in started_instances_ready_for_command :
210+ ssm_command = await ssm_client .send_command (
211+ [i .id ],
212+ command = create_deploy_cluster_stack_script (
213+ app_settings ,
214+ cluster_machines_name_prefix = get_cluster_name (
215+ app_settings ,
216+ user_id = user_id_from_instance_tags (i .tags ),
217+ wallet_id = wallet_id_from_instance_tags (i .tags ),
218+ is_manager = False ,
219+ ),
220+ additional_custom_tags = {
221+ USER_ID_TAG_KEY : i .tags [USER_ID_TAG_KEY ],
222+ WALLET_ID_TAG_KEY : i .tags [WALLET_ID_TAG_KEY ],
223+ ROLE_TAG_KEY : WORKER_ROLE_TAG_VALUE ,
224+ },
225+ ),
226+ command_name = DOCKER_STACK_DEPLOY_COMMAND_NAME ,
227+ )
228+ await ec2_client .set_instances_tags (
229+ started_instances_ready_for_command ,
230+ tags = {
231+ DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY : AWSTagValue (
232+ ssm_command .command_id
233+ ),
234+ },
235+ )
236+
237+ # the remaining instances are broken (they were at some point connected but now not anymore)
153238 broken_instances = disconnected_instances - starting_instances
154239 if terminateable_instances := await _find_terminateable_instances (
155240 app , broken_instances
0 commit comments