@@ -670,6 +670,7 @@ def _optimize_for_jumpstart(
670
670
output_path : Optional [str ] = None ,
671
671
tags : Optional [Tags ] = None ,
672
672
job_name : Optional [str ] = None ,
673
+ instance_type : Optional [str ] = None ,
673
674
accept_eula : Optional [bool ] = None ,
674
675
quantization_config : Optional [Dict ] = None ,
675
676
compilation_config : Optional [Dict ] = None ,
@@ -685,6 +686,7 @@ def _optimize_for_jumpstart(
685
686
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
686
687
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
687
688
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
689
+ instance_type (str): Target deployment instance type that the model is optimized for.
688
690
accept_eula (bool): For models that require a Model Access Config, specify True or
689
691
False to indicate whether model terms of use have been accepted.
690
692
The `accept_eula` value must be explicitly defined as `True` in order to
@@ -711,12 +713,12 @@ def _optimize_for_jumpstart(
711
713
)
712
714
713
715
is_compilation = (not quantization_config ) and (
714
- (compilation_config is not None ) or _is_inferentia_or_trainium (self . instance_type )
716
+ (compilation_config is not None ) or _is_inferentia_or_trainium (instance_type )
715
717
)
716
718
717
719
pysdk_model_env_vars = dict ()
718
720
if is_compilation :
719
- pysdk_model_env_vars = self ._get_neuron_model_env_vars (self . instance_type )
721
+ pysdk_model_env_vars = self ._get_neuron_model_env_vars (instance_type )
720
722
721
723
optimization_config , override_env = _extract_optimization_config_and_env (
722
724
quantization_config , compilation_config
@@ -752,9 +754,7 @@ def _optimize_for_jumpstart(
752
754
if self .pysdk_model .deployment_config
753
755
else None
754
756
)
755
- self .instance_type = (
756
- self .instance_type or deployment_config_instance_type or _get_nb_instance ()
757
- )
757
+ self .instance_type = instance_type or deployment_config_instance_type or _get_nb_instance ()
758
758
759
759
create_optimization_job_args = {
760
760
"OptimizationJobName" : job_name ,
0 commit comments