@@ -102,9 +102,13 @@ def temporary_ini_file():
102102 yield str (path )
103103
104104
105- def get_cloudformation_exports (region_name , endpoint_url , role_arn , profile_name ):
105+ def get_cloudformation_exports (
106+ region_name , endpoint_url , role_arn , profile_name , headers
107+ ):
106108 session = create_sdk_session (region_name , profile_name )
107- temp_credentials = get_temporary_credentials (session , role_arn = role_arn )
109+ temp_credentials = get_temporary_credentials (
110+ session , role_arn = role_arn , headers = headers
111+ )
108112 cfn_client = session .client (
109113 "cloudformation" , endpoint_url = endpoint_url , ** temp_credentials
110114 )
@@ -132,13 +136,13 @@ def __retrieve_args(match):
132136
133137
134138def render_template (
135- overrides_string , region_name , endpoint_url , role_arn , profile_name
139+ overrides_string , region_name , endpoint_url , role_arn , profile_name , headers
136140):
137141 regex = r"{{([-A-Za-z0-9:\s]+?)}}"
138142 variables = set (str (match ).strip () for match in re .findall (regex , overrides_string ))
139143 if variables :
140144 exports = get_cloudformation_exports (
141- region_name , endpoint_url , role_arn , profile_name
145+ region_name , endpoint_url , role_arn , profile_name , headers
142146 )
143147 invalid_exports = variables - exports .keys ()
144148 if len (invalid_exports ) > 0 :
@@ -166,15 +170,20 @@ def filter_overrides(overrides, project):
166170 return overrides
167171
168172
169- def get_overrides (root , region_name , endpoint_url , role_arn , profile_name ):
173+ def get_overrides (root , region_name , endpoint_url , role_arn , profile_name , headers ):
170174 if not root :
171175 return empty_override ()
172176
173177 path = root / "overrides.json"
174178 try :
175179 with path .open ("r" , encoding = "utf-8" ) as f :
176180 overrides_raw = render_template (
177- f .read (), region_name , endpoint_url , role_arn , profile_name
181+ f .read (),
182+ region_name ,
183+ endpoint_url ,
184+ role_arn ,
185+ profile_name ,
186+ headers = headers ,
178187 )
179188 except FileNotFoundError :
180189 LOG .debug ("Override file '%s' not found. No overrides will be applied" , path )
@@ -203,15 +212,22 @@ def get_overrides(root, region_name, endpoint_url, role_arn, profile_name):
203212
204213# pylint: disable=R0914
205214# flake8: noqa: C901
206- def get_hook_overrides (root , region_name , endpoint_url , role_arn , profile_name ):
215+ def get_hook_overrides (
216+ root , region_name , endpoint_url , role_arn , profile_name , headers
217+ ):
207218 if not root :
208219 return empty_hook_override ()
209220
210221 path = root / "overrides.json"
211222 try :
212223 with path .open ("r" , encoding = "utf-8" ) as f :
213224 overrides_raw = render_template (
214- f .read (), region_name , endpoint_url , role_arn , profile_name
225+ f .read (),
226+ region_name ,
227+ endpoint_url ,
228+ role_arn ,
229+ profile_name ,
230+ headers = headers ,
215231 )
216232 except FileNotFoundError :
217233 LOG .debug ("Override file '%s' not found. No overrides will be applied" , path )
@@ -258,7 +274,7 @@ def get_hook_overrides(root, region_name, endpoint_url, role_arn, profile_name):
258274
259275
260276# pylint: disable=R0914,too-many-arguments
261- def get_inputs (root , region_name , endpoint_url , value , role_arn , profile_name ):
277+ def get_inputs (root , region_name , endpoint_url , value , role_arn , profile_name , headers ):
262278 inputs = {}
263279 if not root :
264280 return None
@@ -280,7 +296,12 @@ def get_inputs(root, region_name, endpoint_url, value, role_arn, profile_name):
280296 file_path = path / file
281297 with file_path .open ("r" , encoding = "utf-8" ) as f :
282298 overrides_raw = render_template (
283- f .read (), region_name , endpoint_url , role_arn , profile_name
299+ f .read (),
300+ region_name ,
301+ endpoint_url ,
302+ role_arn ,
303+ profile_name ,
304+ headers = headers ,
284305 )
285306 overrides = {}
286307 for pointer , obj in overrides_raw .items ():
@@ -355,6 +376,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
355376 project .type_name ,
356377 args .log_group_name ,
357378 args .log_role_arn ,
379+ headers = {"account_id" : args .source_account , "source_arn" : args .source_arn },
358380 executable_entrypoint = project .executable_entrypoint ,
359381 docker_image = args .docker_image ,
360382 typeconfig = args .typeconfig ,
@@ -378,6 +400,7 @@ def get_contract_plugin_client(args, project, overrides, inputs):
378400 project .type_name ,
379401 args .log_group_name ,
380402 args .log_role_arn ,
403+ headers = {"account_id" : args .source_account , "source_arn" : args .source_arn },
381404 typeconfig = args .typeconfig ,
382405 executable_entrypoint = project .executable_entrypoint ,
383406 docker_image = args .docker_image ,
@@ -402,6 +425,7 @@ def test(args):
402425 args .cloudformation_endpoint_url ,
403426 args .role_arn ,
404427 args .profile ,
428+ headers = {"account_id" : args .source_account , "source_arn" : args .source_arn },
405429 )
406430 else :
407431 overrides = get_overrides (
@@ -410,6 +434,7 @@ def test(args):
410434 args .cloudformation_endpoint_url ,
411435 args .role_arn ,
412436 args .profile ,
437+ headers = {"account_id" : args .source_account , "source_arn" : args .source_arn },
413438 )
414439 filter_overrides (overrides , project )
415440
@@ -422,6 +447,7 @@ def test(args):
422447 index ,
423448 args .role_arn ,
424449 args .profile ,
450+ headers = {"account_id" : args .source_account , "source_arn" : args .source_arn },
425451 )
426452 if not inputs :
427453 break
@@ -509,6 +535,15 @@ def setup_subparser(subparsers, parents):
509535 " '~/.cfn-cli/typeConfiguration.json.'"
510536 ),
511537 )
538+ parser .add_argument (
539+ "--source-account" ,
540+ help = "Source Account key used for Assume Role to Run Contract Tests" ,
541+ )
542+
543+ parser .add_argument (
544+ "--source-arn" ,
545+ help = "Source Type Version Arn key used for Assume Role to Run Contract Tests" ,
546+ )
512547
513548
514549def _sam_arguments (parser ):
0 commit comments