File tree Expand file tree Collapse file tree 3 files changed +25
-15
lines changed Expand file tree Collapse file tree 3 files changed +25
-15
lines changed Original file line number Diff line number Diff line change @@ -372,10 +372,18 @@ def _get_json_file(
372372 object and None when reading from the local file system.
373373 """
374374 if self ._is_local_metadata_mode ():
375- file_content , etag = self ._get_json_file_from_local_override (key , filetype ), None
376- else :
377- file_content , etag = self ._get_json_file_and_etag_from_s3 (key )
378- return file_content , etag
375+ if filetype in {
376+ JumpStartS3FileType .OPEN_WEIGHT_MANIFEST ,
377+ JumpStartS3FileType .OPEN_WEIGHT_SPECS ,
378+ }:
379+ return self ._get_json_file_from_local_override (key , filetype ), None
380+ else :
381+ JUMPSTART_LOGGER .warning (
382+ "Local metadata mode is enabled, but the file type %s is not supported "
383+ "for local override. Falling back to s3." ,
384+ filetype ,
385+ )
386+ return self ._get_json_file_and_etag_from_s3 (key )
379387
380388 def _get_json_md5_hash (self , key : str ):
381389 """Retrieves md5 object hash for s3 objects, using `s3.head_object`.
Original file line number Diff line number Diff line change @@ -632,13 +632,7 @@ def _add_model_reference_arn_to_kwargs(
632632
633633def _add_model_uri_to_kwargs (kwargs : JumpStartEstimatorInitKwargs ) -> JumpStartEstimatorInitKwargs :
634634 """Sets model uri in kwargs based on default or override, returns full kwargs."""
635- # hub_arn is by default None unless the user specifies the hub_name
636- # If no hub_name is specified, it is assumed the public hub
637- is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs .hub_arn if kwargs .hub_arn else False
638- if (
639- _model_supports_training_model_uri (** get_model_info_default_kwargs (kwargs ))
640- or is_private_hub
641- ):
635+ if _model_supports_training_model_uri (** get_model_info_default_kwargs (kwargs )):
642636 default_model_uri = model_uris .retrieve (
643637 model_scope = JumpStartScriptScope .TRAINING ,
644638 instance_type = kwargs .instance_type ,
Original file line number Diff line number Diff line change @@ -1940,12 +1940,20 @@ def use_inference_script_uri(self) -> bool:
19401940
19411941 def use_training_model_artifact (self ) -> bool :
19421942 """Returns True if the model should use a model uri when kicking off training job."""
1943- # gated model never use training model artifact
1944- if self .gated_bucket :
1943+ # old models with this environment variable present don't use model channel
1944+ if any (
1945+ self .training_instance_type_variants .get_instance_specific_gated_model_key_env_var_value (
1946+ instance_type
1947+ )
1948+ for instance_type in self .supported_training_instance_types
1949+ ):
1950+ return False
1951+
1952+ # even older models with training model package artifact uris present also don't use model channel
1953+ if len (self .training_model_package_artifact_uris or {}) > 0 :
19451954 return False
19461955
1947- # otherwise, return true is a training model package is not set
1948- return len (self .training_model_package_artifact_uris or {}) == 0
1956+ return getattr (self , "training_artifact_key" , None ) is not None
19491957
19501958 def is_gated_model (self ) -> bool :
19511959 """Returns True if the model has a EULA key or the model bucket is gated."""
You can’t perform that action at this time.
0 commit comments