|
14 | 14 | logger = logging.getLogger(__name__) |
15 | 15 |
|
16 | 16 |
|
17 | | -def upload_file(client: boto3.client, path: str, filename: str) -> None: |
18 | | - """Uploads file from local_path to s3 if it does not already exist""" |
| 17 | +def upload_file(client: boto3.client, path: str, filename: str, overwrite: bool = False) -> None: |
| 18 | + """ |
| 19 | + Uploads file from local_path to s3. If the "ovewrite" parameter is set to True, skips validation |
| 20 | + if the file already exist. |
| 21 | + """ |
19 | 22 | try: |
20 | 23 | file_path = os.path.join(path, filename) |
21 | 24 | bucket_name = "pretrainedweights" |
22 | 25 |
|
23 | | - try: |
24 | | - client.head_object(Bucket=bucket_name, Key=filename) |
25 | | - logger.info(f"File {filename} already exists on S3. Skipping upload.") |
26 | | - return |
27 | | - except ClientError as err: |
28 | | - if err.response["Error"]["Code"] != "404": |
29 | | - logger.error(f"Error checking existence of {filename} on S3: {err}") |
30 | | - raise err |
| 26 | + if not overwrite: |
| 27 | + try: |
| 28 | + client.head_object(Bucket=bucket_name, Key=filename) |
| 29 | + logger.info(f"File {filename} already exists on S3. Skipping upload.") |
| 30 | + return |
| 31 | + except ClientError as err: |
| 32 | + if err.response["Error"]["Code"] != "404": |
| 33 | + logger.error(f"Error checking existence of {filename} on S3: {err}") |
| 34 | + raise err |
31 | 35 |
|
32 | 36 | # Upload the file if it does not exist |
33 | 37 | logger.info(f"Uploading file {filename} to S3") |
@@ -55,7 +59,7 @@ def main() -> None: |
55 | 59 | "s3", endpoint_url=s3_host, aws_access_key_id=s3_access_key, aws_secret_access_key=s3_secret_key |
56 | 60 | ) |
57 | 61 |
|
58 | | - upload_file(client, config_dir, model_list) |
| 62 | + upload_file(client, config_dir, model_list, True) |
59 | 63 |
|
60 | 64 | if os.environ.get("DISABLE_WEIGHT_UPLOADING", None) is not None: |
61 | 65 | logger.info("Downloading pretrained weights is disabled. Exiting.") |
|
0 commit comments