Skip to content

Commit fa85073

Browse files
authored
overwriting a file with list of weights during upgrade (#1110)
1 parent e37c6c4 commit fa85073

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

platform/services/weights_uploader/app/weights_uploader.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,24 @@
1414
logger = logging.getLogger(__name__)
1515

1616

17-
def upload_file(client: boto3.client, path: str, filename: str) -> None:
18-
"""Uploads file from local_path to s3 if it does not already exist"""
17+
def upload_file(client: boto3.client, path: str, filename: str, overwrite: bool = False) -> None:
18+
"""
19+
Uploads file from local_path to s3. If the "ovewrite" parameter is set to True, skips validation
20+
if the file already exist.
21+
"""
1922
try:
2023
file_path = os.path.join(path, filename)
2124
bucket_name = "pretrainedweights"
2225

23-
try:
24-
client.head_object(Bucket=bucket_name, Key=filename)
25-
logger.info(f"File {filename} already exists on S3. Skipping upload.")
26-
return
27-
except ClientError as err:
28-
if err.response["Error"]["Code"] != "404":
29-
logger.error(f"Error checking existence of {filename} on S3: {err}")
30-
raise err
26+
if not overwrite:
27+
try:
28+
client.head_object(Bucket=bucket_name, Key=filename)
29+
logger.info(f"File {filename} already exists on S3. Skipping upload.")
30+
return
31+
except ClientError as err:
32+
if err.response["Error"]["Code"] != "404":
33+
logger.error(f"Error checking existence of {filename} on S3: {err}")
34+
raise err
3135

3236
# Upload the file if it does not exist
3337
logger.info(f"Uploading file {filename} to S3")
@@ -55,7 +59,7 @@ def main() -> None:
5559
"s3", endpoint_url=s3_host, aws_access_key_id=s3_access_key, aws_secret_access_key=s3_secret_key
5660
)
5761

58-
upload_file(client, config_dir, model_list)
62+
upload_file(client, config_dir, model_list, True)
5963

6064
if os.environ.get("DISABLE_WEIGHT_UPLOADING", None) is not None:
6165
logger.info("Downloading pretrained weights is disabled. Exiting.")

0 commit comments

Comments
 (0)