1+ import os
2+ import json
3+ import logging
4+ import time
5+ import urllib
6+ import socket
7+
8+ import botocore
9+ import boto3
10+
11+
12+ logger = logging .getLogger ()
13+ logger .setLevel (logging .INFO )
14+ logging .getLogger ('boto3' ).setLevel (logging .CRITICAL )
15+ logging .getLogger ('botocore' ).setLevel (logging .CRITICAL )
16+
17+
18+ ec2_client = boto3 .client ("ec2" )
19+
20+ LIFECYCLE_KEY = "LifecycleHookName"
21+ ASG_KEY = "AutoScalingGroupName"
22+ EC2_KEY = "EC2InstanceId"
23+
24+ # Checks every CONNECTIVITY_CHECK_INTERVAL seconds, exits after 1 minute
25+ DEFAULT_CONNECTIVITY_CHECK_INTERVAL = "5"
26+
27+ # Which URLs to check for connectivity
28+ DEFAULT_CHECK_URLS = ["https://www.example.com" , "https://www.google.com" ]
29+
30+ # The timeout for the connectivity checks.
31+ REQUEST_TIMEOUT = 5
32+
33+ # Whether or not use IPv6.
34+ DEFAULT_HAS_IPV6 = True
35+
36+
37+ # Overrides socket.getaddrinfo to perform IPv4 lookups
38+ # See https://github.com/chime/terraform-aws-alternat/issues/87
39+ def disable_ipv6 ():
40+ prv_getaddrinfo = socket .getaddrinfo
41+ def getaddrinfo_ipv4 (* args ):
42+ modified_args = (args [0 ], args [1 ], socket .AF_INET ) + args [3 :]
43+ res = prv_getaddrinfo (* modified_args )
44+ return res
45+ socket .getaddrinfo = getaddrinfo_ipv4
46+
47+
48+ def get_az_and_vpc_zone_identifier (auto_scaling_group ):
49+ autoscaling = boto3 .client ("autoscaling" )
50+
51+ try :
52+ asg_objects = autoscaling .describe_auto_scaling_groups (AutoScalingGroupNames = [auto_scaling_group ])
53+ except botocore .exceptions .ClientError as error :
54+ logger .error ("Unable to describe autoscaling groups" )
55+ raise error
56+
57+ if asg_objects ["AutoScalingGroups" ] and len (asg_objects ["AutoScalingGroups" ]) > 0 :
58+ asg = asg_objects ["AutoScalingGroups" ][0 ]
59+ logger .debug ("Auto Scaling Group: %s" , asg )
60+
61+ availability_zone = asg ["AvailabilityZones" ][0 ]
62+ logger .debug ("Availability Zone: %s" , availability_zone )
63+
64+ vpc_zone_identifier = asg ["VPCZoneIdentifier" ]
65+ logger .debug ("VPC zone identifier: %s" , vpc_zone_identifier )
66+
67+ return availability_zone , vpc_zone_identifier
68+
69+ raise MissingVPCZoneIdentifierError (asg_objects )
70+
71+
72+ def get_vpc_id (route_table ):
73+ try :
74+ route_tables = ec2_client .describe_route_tables (RouteTableIds = [route_table ])
75+ except botocore .exceptions .ClientError as error :
76+ logger .error ("Unable to get vpc id" )
77+ raise error
78+ if "RouteTables" in route_tables and len (route_tables ["RouteTables" ]) == 1 :
79+ vpc_id = route_tables ["RouteTables" ][0 ]["VpcId" ]
80+ logger .debug ("VPC ID: %s" , vpc_id )
81+ return vpc_id
82+
83+
84+ def get_nat_gateway_id (vpc_id , subnet_id ):
85+ nat_gateway_id = os .getenv ("NAT_GATEWAY_ID" )
86+ if nat_gateway_id :
87+ logger .info ("Using NAT_GATEWAY_ID env. variable (%s)" , nat_gateway_id )
88+ return nat_gateway_id
89+
90+ try :
91+ nat_gateways = ec2_client .describe_nat_gateways (
92+ Filters = [
93+ {
94+ "Name" : "vpc-id" ,
95+ "Values" : [vpc_id ]
96+ },
97+ {
98+ "Name" : "subnet-id" ,
99+ "Values" : [subnet_id ]
100+ },
101+ ]
102+ )
103+ except botocore .exceptions .ClientError as error :
104+ logger .error ("Unable to describe nat gateway" )
105+ raise error
106+
107+ logger .debug ("NAT Gateways: %s" , nat_gateways )
108+ if len (nat_gateways .get ("NatGateways" )) < 1 :
109+ raise MissingNatGatewayError (nat_gateways )
110+
111+ nat_gateway_id = nat_gateways ['NatGateways' ][0 ]["NatGatewayId" ]
112+ logger .debug ("NAT Gateway ID: %s" , nat_gateway_id )
113+ return nat_gateway_id
114+
115+
116+ def replace_route (route_table_id , nat_gateway_id ):
117+ new_route_table = {
118+ "DestinationCidrBlock" : "0.0.0.0/0" ,
119+ "NatGatewayId" : nat_gateway_id ,
120+ "RouteTableId" : route_table_id
121+ }
122+ try :
123+ logger .info ("Replacing existing route %s for route table %s" , route_table_id , new_route_table )
124+ ec2_client .replace_route (** new_route_table )
125+ except botocore .exceptions .ClientError as error :
126+ logger .error ("Unable to replace route" )
127+ raise error
128+
129+
130+ def check_connection (check_urls ):
131+ """
132+ Checks connectivity to check_urls. If any of them succeed, return success.
133+ If all fail, replaces the route table to point at a standby NAT Gateway and
134+ return failure.
135+ """
136+ for url in check_urls :
137+ try :
138+ req = urllib .request .Request (url )
139+ req .add_header ('User-Agent' , 'alternat/1.0' )
140+ urllib .request .urlopen (req , timeout = REQUEST_TIMEOUT )
141+ logger .debug ("Successfully connected to %s" , url )
142+ return True
143+ except urllib .error .HTTPError as error :
144+ logger .warning ("Response error from %s: %s, treating as success" , url , error )
145+ return True
146+ except urllib .error .URLError as error :
147+ logger .error ("error connecting to %s: %s" , url , error )
148+ except socket .timeout as error :
149+ logger .error ("timeout error connecting to %s: %s" , url , error )
150+
151+ logger .warning ("Failed connectivity tests! Replacing route" )
152+
153+ public_subnet_id = os .getenv ("PUBLIC_SUBNET_ID" )
154+ if not public_subnet_id :
155+ raise MissingEnvironmentVariableError ("PUBLIC_SUBNET_ID" )
156+
157+ route_tables = "ROUTE_TABLE_IDS_CSV" in os .environ and os .getenv ("ROUTE_TABLE_IDS_CSV" ).split ("," )
158+ if not route_tables :
159+ raise MissingEnvironmentVariableError ("ROUTE_TABLE_IDS_CSV" )
160+ vpc_id = get_vpc_id (route_tables [0 ])
161+
162+ nat_gateway_id = get_nat_gateway_id (vpc_id , public_subnet_id )
163+
164+ for rtb in route_tables :
165+ replace_route (rtb , nat_gateway_id )
166+ logger .info ("Route replacement succeeded" )
167+ return False
168+
169+
170+ def connectivity_test_handler (event , context ):
171+ if not isinstance (event , dict ):
172+ logger .error (f"Unknown event: { event } " )
173+ return
174+
175+ if event .get ("source" ) != "aws.events" :
176+ logger .error (f"Unable to handle unknown event type: { json .dumps (event )} " )
177+ raise UnknownEventTypeError
178+
179+ logger .debug ("Starting NAT instance connectivity test" )
180+
181+ check_interval = int (os .getenv ("CONNECTIVITY_CHECK_INTERVAL" , DEFAULT_CONNECTIVITY_CHECK_INTERVAL ))
182+ check_urls = "CHECK_URLS" in os .environ and os .getenv ("CHECK_URLS" ).split ("," ) or DEFAULT_CHECK_URLS
183+
184+ has_ipv6 = get_env_bool ("HAS_IPV6" , DEFAULT_HAS_IPV6 )
185+ if not has_ipv6 :
186+ disable_ipv6 ()
187+
188+ # Run connectivity checks for approximately 1 minute
189+ run = 0
190+ num_runs = 60 / check_interval
191+ while run < num_runs :
192+ if check_connection (check_urls ):
193+ time .sleep (check_interval )
194+ run += 1
195+ else :
196+ break
197+
198+
199+ def get_env_bool (var_name , default_value = False ):
200+ value = os .getenv (var_name , default_value )
201+ true_values = ["t" , "true" , "y" , "yes" , "1" ]
202+ return str (value ).lower () in true_values
203+
204+
205+ def handler (event , _ ):
206+ try :
207+ for record in event ["Records" ]:
208+ message = json .loads (record ["Sns" ]["Message" ])
209+ if LIFECYCLE_KEY in message and ASG_KEY in message :
210+ asg = message [ASG_KEY ]
211+ else :
212+ logger .error ("Failed to find lifecycle message to parse" )
213+ raise LifecycleMessageError
214+ except Exception as error :
215+ logger .error ("Error: %s" , error )
216+ raise error
217+
218+ availability_zone , vpc_zone_identifier = get_az_and_vpc_zone_identifier (asg )
219+ public_subnet_id = vpc_zone_identifier .split ("," )[0 ]
220+ az = availability_zone .upper ().replace ("-" , "_" )
221+ route_tables = az in os .environ and os .getenv (az ).split ("," )
222+ if not route_tables :
223+ raise MissingEnvironmentVariableError
224+ vpc_id = get_vpc_id (route_tables [0 ])
225+
226+ nat_gateway_id = get_nat_gateway_id (vpc_id , public_subnet_id )
227+
228+ for rtb in route_tables :
229+ replace_route (rtb , nat_gateway_id )
230+ logger .info ("Route replacement succeeded" )
231+
232+
233+ class UnknownEventTypeError (Exception ): pass
234+
235+
236+ class MissingVpcConfigError (Exception ): pass
237+
238+
239+ class MissingFunctionSubnetError (Exception ): pass
240+
241+
242+ class MissingAZSubnetError (Exception ): pass
243+
244+
245+ class MissingVPCZoneIdentifierError (Exception ): pass
246+
247+
248+ class MissingVPCandSubnetError (Exception ): pass
249+
250+
251+ class MissingNatGatewayError (Exception ): pass
252+
253+
254+ class MissingRouteTableError (Exception ): pass
255+
256+
257+ class LifecycleMessageError (Exception ): pass
258+
259+
260+ class MissingEnvironmentVariableError (Exception ): pass
0 commit comments