@@ -139,9 +139,7 @@ def _is_jumpstart_model_id(self) -> bool:
139
139
140
140
def _create_pre_trained_js_model (self ) -> Type [Model ]:
141
141
"""Placeholder docstring"""
142
- pysdk_model = JumpStartModel (
143
- self .model , vpc_config = self .vpc_config , instance_type = self .instance_type
144
- )
142
+ pysdk_model = JumpStartModel (self .model , vpc_config = self .vpc_config )
145
143
pysdk_model .sagemaker_session = self .sagemaker_session
146
144
147
145
self ._original_deploy = pysdk_model .deploy
@@ -696,12 +694,12 @@ def _optimize_for_jumpstart(
696
694
f"Model '{ self .model } ' requires accepting end-user license agreement (EULA)."
697
695
)
698
696
697
+ optimization_env_vars = env_vars
698
+ pysdk_model_env_vars = env_vars
699
+
699
700
if compilation_config :
700
- neuro_model_id = self .pysdk_model .deployment_config .get ("DeploymentArgs" ).get (
701
- "NeuronModelId"
702
- )
703
- self .model = neuro_model_id
704
- self .pysdk_model = self ._create_pre_trained_js_model ()
701
+ neuron_env = self ._get_neuron_model_env_vars (instance_type )
702
+ optimization_env_vars = _update_environment_variables (neuron_env , optimization_env_vars )
705
703
706
704
if speculative_decoding_config :
707
705
self ._set_additional_model_source (speculative_decoding_config )
@@ -714,11 +712,6 @@ def _optimize_for_jumpstart(
714
712
)
715
713
716
714
model_source = _generate_model_source (self .pysdk_model .model_data , accept_eula )
717
- optimization_env_vars = self .pysdk_model .deployment_config .get ("DeploymentArgs" , {}).get (
718
- "Environment"
719
- )
720
- optimization_env_vars = _update_environment_variables (optimization_env_vars , env_vars )
721
- pysdk_model_env_vars = env_vars
722
715
723
716
optimization_config = {}
724
717
if quantization_config :
@@ -874,3 +867,26 @@ def _find_compatible_deployment_config(
874
867
875
868
# fall back to the default jumpstart model deployment config for optimization job
876
869
return self .pysdk_model .deployment_config
870
+
871
+ def _get_neuron_model_env_vars (
872
+ self , instance_type : Optional [str ] = None
873
+ ) -> Optional [Dict [str , Any ]]:
874
+ """Gets Neuron model env vars.
875
+
876
+ Args:
877
+ instance_type (Optional[str]): Instance type.
878
+
879
+ Returns:
880
+ Optional[Dict[str, Any]]: Neuron Model environment variables.
881
+ """
882
+ metadata_config = self .pysdk_model ._metadata_configs .get (self .pysdk_model .config_name )
883
+ resolve_config = metadata_config .resolved_config
884
+ if instance_type not in resolve_config .get ("supported_inference_instance_types" , []):
885
+ neuro_model_id = resolve_config .get ("hosting_neuron_model_id" )
886
+ neuro_model_version = resolve_config .get ("hosting_neuron_model_version" )
887
+ if neuro_model_id :
888
+ job_model = JumpStartModel (
889
+ neuro_model_id , model_version = neuro_model_version , vpc_config = self .vpc_config
890
+ )
891
+ return job_model .env
892
+ return None
0 commit comments