3838
3939import six
4040import json
41+ import sys
42+
4143from boto3 .session import Session
4244from botocore .exceptions import ClientError
4345from botocore .exceptions import NoRegionError
4446from botocore .exceptions import NoCredentialsError
4547from botocore .exceptions import EndpointConnectionError
48+ from collections import defaultdict
4649
4750from st2reactor .sensor .base import PollingSensor
4851
@@ -55,19 +58,36 @@ def __init__(self, sensor_service, config=None, poll_interval=5):
5558 def setup (self ):
5659 self ._logger = self ._sensor_service .get_logger (name = self .__class__ .__name__ )
5760
58- self .session = None
59- self .sqs_res = None
61+ self .account_id = None
62+ self .credentials = {}
63+ self .sessions = {}
64+ self .cross_roles = {}
65+ self .sqs_res = defaultdict (dict )
6066
6167 def poll (self ):
6268 # setting SQS ServiceResource object from the parameter of datastore or configuration file
6369 self ._may_setup_sqs ()
6470
6571 for queue in self .input_queues :
66- msgs = self ._receive_messages (queue = self ._get_queue_by_name (queue ),
67- num_messages = self .max_number_of_messages )
72+ account_id , region = self ._get_info (queue )
73+
74+ while True :
75+ try :
76+ msgs = self ._receive_messages (queue = self ._get_queue (queue , account_id , region ),
77+ num_messages = self .max_number_of_messages )
78+ except ClientError as e :
79+ if e .response ['Error' ]['Code' ] == 'ExpiredToken' :
80+ self ._setup_multiaccount_session (account_id )
81+ continue
82+ raise
83+ break
84+
6885 for msg in msgs :
6986 if msg :
70- payload = {"queue" : queue , "body" : json .loads (msg .body )}
87+ payload = {"queue" : queue ,
88+ "account_id" : account_id ,
89+ "region" : region ,
90+ "body" : json .loads (msg .body )}
7191 self ._sensor_service .dispatch (trigger = "aws.sqs_new_message" , payload = payload )
7292 msg .delete ()
7393
@@ -89,7 +109,7 @@ def _get_config_entry(self, key, prefix=None):
89109 ''' Get configuration values either from Datastore or config file. '''
90110 config = self .config
91111 if prefix :
92- config = self ._config .get (prefix , {})
112+ config = self .config .get (prefix , {})
93113
94114 value = self ._sensor_service .get_value ('aws.%s' % (key ), local = False )
95115 if not value :
@@ -101,61 +121,156 @@ def _get_config_entry(self, key, prefix=None):
101121 return value
102122
103123 def _may_setup_sqs (self ):
104- queues = self ._get_config_entry (key = 'input_queues' , prefix = 'sqs_sensor' )
124+ self .access_key_id = self ._get_config_entry ('aws_access_key_id' )
125+ self .secret_access_key = self ._get_config_entry ('aws_secret_access_key' )
126+ self .aws_region = self ._get_config_entry ('region' )
127+ self .max_number_of_messages = self ._get_config_entry ('max_number_of_messages' ,
128+ prefix = 'sqs_other' )
129+
130+ if not self .account_id :
131+ self ._setup_session ()
105132
133+ queues = self ._get_config_entry (key = 'input_queues' , prefix = 'sqs_sensor' )
106134 # XXX: This is a hack as from datastore we can only receive a string while
107135 # from config.yaml we can receive a list
108136 if isinstance (queues , six .string_types ):
109- self .input_queues = [x .strip () for x in queues .split (',' )]
137+ self .input_queues = [six .moves .urllib .parse .urlparse (x .strip ()) for x in
138+ queues .split (',' )]
110139 elif isinstance (queues , list ):
111- self .input_queues = queues
140+ self .input_queues = [ six . moves . urllib . parse . urlparse ( queue ) for queue in queues ]
112141 else :
113142 self .input_queues = []
114143
115- self .aws_access_key = self ._get_config_entry ('aws_access_key_id' )
116- self .aws_secret_key = self ._get_config_entry ('aws_secret_access_key' )
117- self .aws_region = self ._get_config_entry ('region' )
118-
119- self .max_number_of_messages = self ._get_config_entry ('max_number_of_messages' ,
120- prefix = 'sqs_other' )
121-
122144 # checker configuration is update, or not
123- def _is_same_credentials ():
124- c = self .session .get_credentials ()
145+ def _is_same_credentials (session , account_id ):
146+ if not session :
147+ return False
148+
149+ c = session .get_credentials ()
125150 return c is not None and \
126- c .access_key == self .aws_access_key and \
127- c .secret_key == self .aws_secret_key and \
128- self .session .region_name == self .aws_region
151+ c .access_key == self .credentials [account_id ][0 ] and \
152+ c .secret_key == self .credentials [account_id ][1 ] and \
153+ (account_id == self .account_id or c .token == self .credentials [account_id ][2 ])
154+
155+ # Build a map between 'account_id' and its 'role arn' by parsing the matching config entry
156+ self .cross_roles = {
157+ arn .split (':' )[4 ]: arn
158+ for arn in self ._get_config_entry ('roles' , 'sqs_sensor' ) or []
159+ }
160+ required_accounts = {self ._get_info (queue )[0 ] for queue in self .input_queues }
129161
130- if self .session is None or not _is_same_credentials ():
131- self ._setup_sqs ()
162+ for account_id in required_accounts :
163+ if account_id != self .account_id and account_id not in self .cross_roles :
164+ continue
132165
133- def _setup_sqs (self ):
134- ''' Setup Boto3 structures '''
135- self ._logger .debug ('Setting up SQS resources' )
136- self .session = Session (aws_access_key_id = self .aws_access_key ,
137- aws_secret_access_key = self .aws_secret_key ,
138- region_name = self .aws_region )
166+ session = self .sessions .get (account_id )
167+ if not _is_same_credentials (session , account_id ):
168+ if account_id == self .account_id :
169+ self ._setup_session ()
170+ else :
171+ self ._setup_multiaccount_session (account_id )
172+
173+ def _setup_session (self ):
174+ ''' Setup Boto3 session '''
175+ session = Session (aws_access_key_id = self .access_key_id ,
176+ aws_secret_access_key = self .secret_access_key )
177+
178+ if not self .account_id :
179+ self .account_id = session .client ('sts' ).get_caller_identity ().get ('Account' )
180+ self .credentials [self .account_id ] = (self .access_key_id , self .secret_access_key , None )
181+
182+ self .sessions [self .account_id ] = session
183+ self .sqs_res .pop (self .account_id , None )
184+
185+ def _setup_multiaccount_session (self , account_id ):
186+ ''' Assume role and setup session for the cross-account capability'''
187+ try :
188+ assumed_role = self .sessions [self .account_id ].client ('sts' ).assume_role (
189+ RoleArn = self .cross_roles [account_id ],
190+ RoleSessionName = 'StackStormEvents'
191+ )
192+ except ClientError :
193+ self ._logger .error ('Could not assume role on %s' , account_id )
194+ return
195+
196+ self .credentials [account_id ] = (assumed_role ["Credentials" ]["AccessKeyId" ],
197+ assumed_role ["Credentials" ]["SecretAccessKey" ],
198+ assumed_role ["Credentials" ]["SessionToken" ])
199+
200+ session = Session (
201+ aws_access_key_id = self .credentials [account_id ][0 ],
202+ aws_secret_access_key = self .credentials [account_id ][1 ],
203+ aws_session_token = self .credentials [account_id ][2 ]
204+ )
205+ self .sessions [account_id ] = session
206+ self .sqs_res .pop (account_id , None )
207+
208+ def _setup_sqs (self , session , account_id , region ):
209+ ''' Setup SQS resources'''
210+ if region in self .sqs_res [account_id ]:
211+ return self .sqs_res [account_id ][region ]
139212
140213 try :
141- self .sqs_res = self .session .resource ('sqs' )
214+ self .sqs_res [account_id ][region ] = session .resource ('sqs' , region_name = region )
215+ return self .sqs_res [account_id ][region ]
142216 except NoRegionError :
143- self ._logger .warning ("The specified region '%s' is invalid" , self .aws_region )
217+ self ._logger .error ("The specified region '%s' for account %s is invalid." ,
218+ region , account_id )
219+
220+ def _get_info (self , queue ):
221+ ''' Retrieve the account ID and region from the queue URL '''
222+ # Pull default values from previous configuration
223+ account_id = self .account_id
224+ aws_region = self .aws_region
225+
226+ # Netloc will be empty if the queue is just a name
227+ if queue .netloc :
228+ try :
229+ account_id = queue .path .split ('/' )[1 ]
230+ except IndexError as e :
231+ six .reraise (type (e ), type (e )(
232+ "Queue URL must contain the account ID as the first part of the path, "
233+ "eg: https://sqs.<aws_region>.amazonaws.com/<account_id>/<queue_name>" ),
234+ sys .exc_info ()[2 ])
235+ else :
236+ self ._logger .debug ("Using %s as account_id" , account_id )
237+
238+ try :
239+ aws_region = queue .netloc .split ('.' )[1 ]
240+ except IndexError as e :
241+ six .reraise (type (e ), type (e )(
242+ "Queue URL must contain the AWS region, "
243+ "eg: https://sqs.<aws_region>.amazonaws.com/..." ),
244+ sys .exc_info ()[2 ])
245+ else :
246+ self ._logger .debug ("Using %s as the AWS region" , aws_region )
247+
248+ return account_id , aws_region
249+
250+ def _get_queue (self , queue , account_id , region ):
251+ ''' Fetch QUEUE by its name or URL and create new one if queue doesn't exist '''
252+ if account_id not in self .sessions :
253+ self ._logger .error ('Session for account id %s does not exist' , account_id )
254+ return None
255+
256+ sqs_res = self ._setup_sqs (self .sessions [account_id ], account_id , region )
257+ if sqs_res is None :
258+ return None
144259
145- def _get_queue_by_name (self , queueName ):
146- ''' Fetch QUEUE by it's name create new one if queue doesn't exist '''
147260 try :
148- return self .sqs_res .get_queue_by_name (QueueName = queueName )
261+ if queue .netloc :
262+ return sqs_res .Queue (six .moves .urllib .parse .urlunparse (queue ))
263+ return sqs_res .get_queue_by_name (QueueName = queue .path .split ('/' )[- 1 ])
149264 except ClientError as e :
150265 if e .response ['Error' ]['Code' ] == 'AWS.SimpleQueueService.NonExistentQueue' :
151- self ._logger .warning ("SQS Queue: %s doesn't exist, creating it." , queueName )
152- return self . sqs_res .create_queue (QueueName = queueName )
266+ self ._logger .warning ("SQS Queue: %s doesn't exist, creating it." , queue )
267+ return sqs_res .create_queue (QueueName = queue . path . split ( '/' )[ - 1 ] )
153268 elif e .response ['Error' ]['Code' ] == 'InvalidClientTokenId' :
154- self ._logger .warning ("Cloudn 't operate sqs because of invalid credential config" )
269+ self ._logger .warning ("Couldn 't operate sqs because of invalid credential config" )
155270 else :
156271 raise
157272 except NoCredentialsError :
158- self ._logger .warning ("Cloudn 't operate sqs because of invalid credential config" )
273+ self ._logger .warning ("Couldn 't operate sqs because of invalid credential config" )
159274 except EndpointConnectionError as e :
160275 self ._logger .warning (e )
161276
0 commit comments