|
74 | 74 | get_config_value, |
75 | 75 | name_from_base, |
76 | 76 | ) |
77 | | -from sagemaker.workflow.entities import PipelineVariable |
| 77 | +from sagemaker.workflow import is_pipeline_variable |
78 | 78 |
|
79 | 79 | logger = logging.getLogger(__name__) |
80 | 80 |
|
@@ -600,7 +600,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A |
600 | 600 | current_hyperparameters = hyperparameters |
601 | 601 | if current_hyperparameters is not None: |
602 | 602 | hyperparameters = { |
603 | | - str(k): (v.to_string() if isinstance(v, PipelineVariable) else json.dumps(v)) |
| 603 | + str(k): (v.to_string() if is_pipeline_variable(v) else json.dumps(v)) |
604 | 604 | for (k, v) in current_hyperparameters.items() |
605 | 605 | } |
606 | 606 | return hyperparameters |
@@ -1811,7 +1811,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): |
1811 | 1811 | current_hyperparameters = estimator.hyperparameters() |
1812 | 1812 | if current_hyperparameters is not None: |
1813 | 1813 | hyperparameters = { |
1814 | | - str(k): (v.to_string() if isinstance(v, PipelineVariable) else str(v)) |
| 1814 | + str(k): (v.to_string() if is_pipeline_variable(v) else str(v)) |
1815 | 1815 | for (k, v) in current_hyperparameters.items() |
1816 | 1816 | } |
1817 | 1817 |
|
@@ -1879,7 +1879,9 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args): |
1879 | 1879 | if estimator.use_spot_instances: |
1880 | 1880 | if local_mode: |
1881 | 1881 | raise ValueError("Spot training is not supported in local mode.") |
1882 | | - train_args["use_spot_instances"] = True |
| 1882 | + # estimator.use_spot_instances may be a Pipeline ParameterBoolean object |
| 1883 | + # which is parsed during the Pipeline execution runtime |
| 1884 | + train_args["use_spot_instances"] = estimator.use_spot_instances |
1883 | 1885 |
|
1884 | 1886 | if estimator.checkpoint_s3_uri: |
1885 | 1887 | if local_mode: |
|
0 commit comments