|
| 1 | +from boto3 import session |
| 2 | +from botocore.exceptions import ClientError |
| 3 | +from typing import Dict, List |
| 4 | +import os |
| 5 | +import logging |
| 6 | +import uuid |
| 7 | +import time |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | +ENDPOINT_ENV = "aws_endpoint_url" |
| 12 | +INVALIDATION_BATCH_DEFAULT = 3000 |
| 13 | +INVALIDATION_BATCH_WILDCARD = 15 |
| 14 | + |
| 15 | +INVALIDATION_STATUS_COMPLETED = "Completed" |
| 16 | +INVALIDATION_STATUS_INPROGRESS = "InProgress" |
| 17 | + |
| 18 | +DEFAULT_BUCKET_TO_DOMAIN = { |
| 19 | + "prod-ga": "maven.repository.redhat.com", |
| 20 | + "prod-maven-ga": "maven.repository.redhat.com", |
| 21 | + "prod-ea": "maven.repository.redhat.com", |
| 22 | + "prod-maven-ea": "maven.repository.redhat.com", |
| 23 | + "stage-ga": "maven.stage.repository.redhat.com", |
| 24 | + "stage-maven-ga": "maven.stage.repository.redhat.com", |
| 25 | + "stage-ea": "maven.stage.repository.redhat.com", |
| 26 | + "stage-maven-ea": "maven.stage.repository.redhat.com", |
| 27 | + "prod-npm": "npm.registry.redhat.com", |
| 28 | + "prod-npm-npmjs": "npm.registry.redhat.com", |
| 29 | + "stage-npm": "npm.stage.registry.redhat.com", |
| 30 | + "stage-npm-npmjs": "npm.stage.registry.redhat.com" |
| 31 | +} |
| 32 | + |
| 33 | + |
| 34 | +class CFClient(object): |
| 35 | + """The CFClient is a wrapper of the original boto3 clouldfrong client, |
| 36 | + which will provide CloudFront functions to be used in the charon. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + aws_profile=None, |
| 42 | + extra_conf=None |
| 43 | + ) -> None: |
| 44 | + self.__client = self.__init_aws_client(aws_profile, extra_conf) |
| 45 | + |
| 46 | + def __init_aws_client( |
| 47 | + self, aws_profile=None, extra_conf=None |
| 48 | + ): |
| 49 | + if aws_profile: |
| 50 | + logger.debug("[CloudFront] Using aws profile: %s", aws_profile) |
| 51 | + cf_session = session.Session(profile_name=aws_profile) |
| 52 | + else: |
| 53 | + cf_session = session.Session() |
| 54 | + endpoint_url = self.__get_endpoint(extra_conf) |
| 55 | + return cf_session.client( |
| 56 | + 'cloudfront', |
| 57 | + endpoint_url=endpoint_url |
| 58 | + ) |
| 59 | + |
| 60 | + def __get_endpoint(self, extra_conf) -> str: |
| 61 | + endpoint_url = os.getenv(ENDPOINT_ENV) |
| 62 | + if not endpoint_url or not endpoint_url.strip(): |
| 63 | + if isinstance(extra_conf, Dict): |
| 64 | + endpoint_url = extra_conf.get(ENDPOINT_ENV, None) |
| 65 | + if endpoint_url: |
| 66 | + logger.info( |
| 67 | + "[CloudFront] Using endpoint url for aws CF client: %s", |
| 68 | + endpoint_url |
| 69 | + ) |
| 70 | + else: |
| 71 | + logger.debug("[CloudFront] No user-specified endpoint url is used.") |
| 72 | + return endpoint_url |
| 73 | + |
| 74 | + def invalidate_paths( |
| 75 | + self, distr_id: str, paths: List[str], |
| 76 | + batch_size=INVALIDATION_BATCH_DEFAULT |
| 77 | + ) -> List[Dict[str, str]]: |
| 78 | + """Send a invalidating requests for the paths in distribution to CloudFront. |
| 79 | + This will invalidate the paths in the distribution to enforce the refreshment |
| 80 | + from backend S3 bucket for these paths. For details see: |
| 81 | + https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/Invalidation.html |
| 82 | + * The distr_id is the id for the distribution. This id can be get through |
| 83 | + get_dist_id_by_domain(domain) function |
| 84 | + * Can specify the invalidating paths through paths param. |
| 85 | + * Batch size is the number of paths to be invalidated in one request. |
| 86 | + The default value is 3000 which is the maximum number in official doc: |
| 87 | + https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/Invalidation.html#InvalidationLimits |
| 88 | + """ |
| 89 | + INPRO_W_SECS = 5 |
| 90 | + NEXT_W_SECS = 1 |
| 91 | + real_paths = [paths] |
| 92 | + # Split paths into batches by batch_size |
| 93 | + if batch_size: |
| 94 | + real_paths = [paths[i:i + batch_size] for i in range(0, len(paths), batch_size)] |
| 95 | + total_time_approx = len(real_paths) * (INPRO_W_SECS * 2 + NEXT_W_SECS) |
| 96 | + logger.info("There will be %d invalidating requests in total," |
| 97 | + " will take more than %d seconds", |
| 98 | + len(real_paths), total_time_approx) |
| 99 | + results = [] |
| 100 | + current_invalidation = {} |
| 101 | + processed_count = 0 |
| 102 | + for batch_paths in real_paths: |
| 103 | + while (current_invalidation and |
| 104 | + INVALIDATION_STATUS_INPROGRESS == current_invalidation.get('Status', '')): |
| 105 | + time.sleep(INPRO_W_SECS) |
| 106 | + try: |
| 107 | + result = self.check_invalidation(distr_id, current_invalidation.get('Id')) |
| 108 | + if result: |
| 109 | + current_invalidation = { |
| 110 | + 'Id': result.get('Id', None), |
| 111 | + 'Status': result.get('Status', None) |
| 112 | + } |
| 113 | + logger.debug("Check invalidation: %s", current_invalidation) |
| 114 | + except Exception as err: |
| 115 | + logger.warning( |
| 116 | + "[CloudFront] Error occurred while checking invalidation status during" |
| 117 | + " creating invalidation, invalidation: %s, error: %s", |
| 118 | + current_invalidation, err |
| 119 | + ) |
| 120 | + break |
| 121 | + if current_invalidation: |
| 122 | + results.append(current_invalidation) |
| 123 | + processed_count += 1 |
| 124 | + if processed_count % 10 == 0: |
| 125 | + logger.info( |
| 126 | + "[CloudFront] ######### %d/%d requests finished", |
| 127 | + processed_count, len(real_paths)) |
| 128 | + # To avoid conflict rushing request, we can wait 1s here |
| 129 | + # for next invalidation request sending. |
| 130 | + time.sleep(NEXT_W_SECS) |
| 131 | + caller_ref = str(uuid.uuid4()) |
| 132 | + logger.debug( |
| 133 | + "Processing invalidation for batch with ref %s, size: %s", |
| 134 | + caller_ref, len(batch_paths) |
| 135 | + ) |
| 136 | + try: |
| 137 | + response = self.__client.create_invalidation( |
| 138 | + DistributionId=distr_id, |
| 139 | + InvalidationBatch={ |
| 140 | + 'CallerReference': caller_ref, |
| 141 | + 'Paths': { |
| 142 | + 'Quantity': len(batch_paths), |
| 143 | + 'Items': batch_paths |
| 144 | + } |
| 145 | + } |
| 146 | + ) |
| 147 | + if response: |
| 148 | + invalidation = response.get('Invalidation', {}) |
| 149 | + current_invalidation = { |
| 150 | + 'Id': invalidation.get('Id', None), |
| 151 | + 'Status': invalidation.get('Status', None) |
| 152 | + } |
| 153 | + except Exception as err: |
| 154 | + logger.error( |
| 155 | + "[CloudFront] Error occurred while creating invalidation" |
| 156 | + " for paths %s, error: %s", batch_paths, err |
| 157 | + ) |
| 158 | + if current_invalidation: |
| 159 | + results.append(current_invalidation) |
| 160 | + return results |
| 161 | + |
| 162 | + def check_invalidation(self, distr_id: str, invalidation_id: str) -> dict: |
| 163 | + try: |
| 164 | + response = self.__client.get_invalidation( |
| 165 | + DistributionId=distr_id, |
| 166 | + Id=invalidation_id |
| 167 | + ) |
| 168 | + if response: |
| 169 | + invalidation = response.get('Invalidation', {}) |
| 170 | + return { |
| 171 | + 'Id': invalidation.get('Id', None), |
| 172 | + 'CreateTime': str(invalidation.get('CreateTime', None)), |
| 173 | + 'Status': invalidation.get('Status', None) |
| 174 | + } |
| 175 | + except Exception as err: |
| 176 | + logger.error( |
| 177 | + "[CloudFront] Error occurred while check invalidation of id %s, " |
| 178 | + "error: %s", invalidation_id, err |
| 179 | + ) |
| 180 | + |
| 181 | + def get_dist_id_by_domain(self, domain: str) -> str: |
| 182 | + """Get distribution id by a domain name. The id can be used to send invalidating |
| 183 | + request through #invalidate_paths function |
| 184 | + * Domain are Ronda domains, like "maven.repository.redhat.com" |
| 185 | + or "npm.registry.redhat.com" |
| 186 | + """ |
| 187 | + try: |
| 188 | + response = self.__client.list_distributions() |
| 189 | + if response: |
| 190 | + dist_list_items = response.get("DistributionList", {}).get("Items", []) |
| 191 | + for distr in dist_list_items: |
| 192 | + aliases_items = distr.get('Aliases', {}).get('Items', []) |
| 193 | + if aliases_items and domain in aliases_items: |
| 194 | + return distr['Id'] |
| 195 | + logger.error("[CloudFront]: Distribution not found for domain %s", domain) |
| 196 | + except ClientError as err: |
| 197 | + logger.error( |
| 198 | + "[CloudFront]: Error occurred while get distribution for domain %s: %s", |
| 199 | + domain, err |
| 200 | + ) |
| 201 | + return None |
| 202 | + |
| 203 | + def get_domain_by_bucket(self, bucket: str) -> str: |
| 204 | + return DEFAULT_BUCKET_TO_DOMAIN.get(bucket, None) |
0 commit comments