1414from __future__ import absolute_import
1515import datetime
1616from difflib import get_close_matches
17- from typing import List , Optional
17+ import os
18+ from typing import List , Optional , Tuple , Union
1819import json
1920import boto3
2021import botocore
2122from packaging .version import Version
2223from packaging .specifiers import SpecifierSet
2324from sagemaker .jumpstart .constants import (
25+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ,
26+ ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ,
2427 JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
2528 JUMPSTART_DEFAULT_REGION_NAME ,
2629)
@@ -90,7 +93,7 @@ def __init__(
9093 self ._s3_cache = LRUCache [JumpStartCachedS3ContentKey , JumpStartCachedS3ContentValue ](
9194 max_cache_items = max_s3_cache_items ,
9295 expiration_horizon = s3_cache_expiration_horizon ,
93- retrieval_function = self ._get_file_from_s3 ,
96+ retrieval_function = self ._retrieval_function ,
9497 )
9598 self ._model_id_semantic_version_manifest_key_cache = LRUCache [
9699 JumpStartVersionedModelId , JumpStartVersionedModelId
@@ -235,7 +238,64 @@ def _get_manifest_key_from_model_id_semantic_version(
235238
236239 raise KeyError (error_msg )
237240
238- def _get_file_from_s3 (
241+ def _get_json_file_and_etag_from_s3 (self , key : str ) -> Tuple [Union [dict , list ], str ]:
242+ """Returns json file from s3, along with its etag."""
243+ response = self ._s3_client .get_object (Bucket = self .s3_bucket_name , Key = key )
244+ return json .loads (response ["Body" ].read ().decode ("utf-8" )), response ["ETag" ]
245+
246+ def _is_local_metadata_mode (self ) -> bool :
247+ """Returns True if the cache should use local metadata mode, based off env variables."""
248+ return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os .environ
249+ and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ])
250+ and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os .environ
251+ and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ]))
252+
253+ def _get_json_file (
254+ self ,
255+ key : str ,
256+ filetype : JumpStartS3FileType
257+ ) -> Tuple [Union [dict , list ], Optional [str ]]:
258+ """Returns json file either from s3 or local file system.
259+
260+ Returns etag along with json object for s3, or just the json
261+ object and None when reading from the local file system.
262+ """
263+ if self ._is_local_metadata_mode ():
264+ file_content , etag = self ._get_json_file_from_local_override (key , filetype ), None
265+ else :
266+ file_content , etag = self ._get_json_file_and_etag_from_s3 (key )
267+ return file_content , etag
268+
269+ def _get_json_md5_hash (self , key : str ):
270+ """Retrieves md5 object hash for s3 objects, using `s3.head_object`.
271+
272+ Raises:
273+ ValueError: if the cache should use local metadata mode.
274+ """
275+ if self ._is_local_metadata_mode ():
276+ raise ValueError ("Cannot get md5 hash of local file." )
277+ return self ._s3_client .head_object (Bucket = self .s3_bucket_name , Key = key )["ETag" ]
278+
279+ def _get_json_file_from_local_override (
280+ self ,
281+ key : str ,
282+ filetype : JumpStartS3FileType
283+ ) -> Union [dict , list ]:
284+ """Reads json file from local filesystem and returns data."""
285+ if filetype == JumpStartS3FileType .MANIFEST :
286+ metadata_local_root = (
287+ os .environ [ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ]
288+ )
289+ elif filetype == JumpStartS3FileType .SPECS :
290+ metadata_local_root = os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ]
291+ else :
292+ raise ValueError (f"Unsupported file type for local override: { filetype } " )
293+ file_path = os .path .join (metadata_local_root , key )
294+ with open (file_path , 'r' ) as f :
295+ data = json .load (f )
296+ return data
297+
298+ def _retrieval_function (
239299 self ,
240300 key : JumpStartCachedS3ContentKey ,
241301 value : Optional [JumpStartCachedS3ContentValue ],
@@ -256,20 +316,17 @@ def _get_file_from_s3(
256316 file_type , s3_key = key .file_type , key .s3_key
257317
258318 if file_type == JumpStartS3FileType .MANIFEST :
259- if value is not None :
260- etag = self ._s3_client . head_object ( Bucket = self . s3_bucket_name , Key = s3_key )[ "ETag" ]
319+ if value is not None and not self . _is_local_metadata_mode () :
320+ etag = self ._get_json_md5_hash ( s3_key )
261321 if etag == value .md5_hash :
262322 return value
263- response = self ._s3_client .get_object (Bucket = self .s3_bucket_name , Key = s3_key )
264- formatted_body = json .loads (response ["Body" ].read ().decode ("utf-8" ))
265- etag = response ["ETag" ]
323+ formatted_body , etag = self ._get_json_file (s3_key , file_type )
266324 return JumpStartCachedS3ContentValue (
267325 formatted_content = utils .get_formatted_manifest (formatted_body ),
268326 md5_hash = etag ,
269327 )
270328 if file_type == JumpStartS3FileType .SPECS :
271- response = self ._s3_client .get_object (Bucket = self .s3_bucket_name , Key = s3_key )
272- formatted_body = json .loads (response ["Body" ].read ().decode ("utf-8" ))
329+ formatted_body , _ = self ._get_json_file (s3_key , file_type )
273330 return JumpStartCachedS3ContentValue (
274331 formatted_content = JumpStartModelSpecs (formatted_body )
275332 )
0 commit comments