|
47 | 47 | get_mp_parameters, |
48 | 48 | tar_and_upload_dir, |
49 | 49 | validate_source_dir, |
| 50 | + validate_source_code_input_against_pipeline_variables, |
50 | 51 | ) |
51 | 52 | from sagemaker.inputs import TrainingInput, FileSystemInput |
52 | 53 | from sagemaker.job import _Job |
@@ -140,12 +141,12 @@ def __init__( |
140 | 141 | disable_profiler: bool = False, |
141 | 142 | environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
142 | 143 | max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, |
143 | | - source_dir: Optional[str] = None, |
| 144 | + source_dir: Optional[Union[str, PipelineVariable]] = None, |
144 | 145 | git_config: Optional[Dict[str, str]] = None, |
145 | 146 | hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
146 | 147 | container_log_level: Union[int, PipelineVariable] = logging.INFO, |
147 | 148 | code_location: Optional[str] = None, |
148 | | - entry_point: Optional[str] = None, |
| 149 | + entry_point: Optional[Union[str, PipelineVariable]] = None, |
149 | 150 | dependencies: Optional[List[Union[str]]] = None, |
150 | 151 | instance_groups: Optional[Dict[str, Union[str, int]]] = None, |
151 | 152 | **kwargs, |
@@ -461,6 +462,13 @@ def __init__( |
461 | 462 | "train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs |
462 | 463 | ) |
463 | 464 |
|
| 465 | + validate_source_code_input_against_pipeline_variables( |
| 466 | + entry_point=entry_point, |
| 467 | + source_dir=source_dir, |
| 468 | + git_config=git_config, |
| 469 | + enable_network_isolation=enable_network_isolation, |
| 470 | + ) |
| 471 | + |
464 | 472 | self.role = role |
465 | 473 | self.instance_count = instance_count |
466 | 474 | self.instance_type = instance_type |
@@ -663,7 +671,11 @@ def _prepare_for_training(self, job_name=None): |
663 | 671 | # validate source dir will raise a ValueError if there is something wrong with |
664 | 672 | # the source directory. We are intentionally not handling it because this is a |
665 | 673 | # critical error. |
666 | | - if self.source_dir and not self.source_dir.lower().startswith("s3://"): |
| 674 | + if ( |
| 675 | + self.source_dir |
| 676 | + and not is_pipeline_variable(self.source_dir) |
| 677 | + and not self.source_dir.lower().startswith("s3://") |
| 678 | + ): |
667 | 679 | validate_source_dir(self.entry_point, self.source_dir) |
668 | 680 |
|
669 | 681 | # if we are in local mode with local_code=True. We want the container to just |
@@ -2151,11 +2163,11 @@ def __init__( |
2151 | 2163 | disable_profiler: bool = False, |
2152 | 2164 | environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
2153 | 2165 | max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, |
2154 | | - source_dir: Optional[str] = None, |
| 2166 | + source_dir: Optional[Union[str, PipelineVariable]] = None, |
2155 | 2167 | git_config: Optional[Dict[str, str]] = None, |
2156 | 2168 | container_log_level: Union[int, PipelineVariable] = logging.INFO, |
2157 | 2169 | code_location: Optional[str] = None, |
2158 | | - entry_point: Optional[str] = None, |
| 2170 | + entry_point: Optional[Union[str, PipelineVariable]] = None, |
2159 | 2171 | dependencies: Optional[List[str]] = None, |
2160 | 2172 | instance_groups: Optional[Dict[str, Union[str, int]]] = None, |
2161 | 2173 | **kwargs, |
@@ -2603,8 +2615,8 @@ class Framework(EstimatorBase): |
2603 | 2615 |
|
2604 | 2616 | def __init__( |
2605 | 2617 | self, |
2606 | | - entry_point: str, |
2607 | | - source_dir: Optional[str] = None, |
| 2618 | + entry_point: Union[str, PipelineVariable], |
| 2619 | + source_dir: Optional[Union[str, PipelineVariable]] = None, |
2608 | 2620 | hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
2609 | 2621 | container_log_level: Union[int, PipelineVariable] = logging.INFO, |
2610 | 2622 | code_location: Optional[str] = None, |
@@ -2783,7 +2795,14 @@ def __init__( |
2783 | 2795 | """ |
2784 | 2796 | super(Framework, self).__init__(enable_network_isolation=enable_network_isolation, **kwargs) |
2785 | 2797 | image_uri = renamed_kwargs("image_name", "image_uri", image_uri, kwargs) |
2786 | | - if entry_point.startswith("s3://"): |
| 2798 | + |
| 2799 | + validate_source_code_input_against_pipeline_variables( |
| 2800 | + entry_point=entry_point, |
| 2801 | + source_dir=source_dir, |
| 2802 | + git_config=git_config, |
| 2803 | + enable_network_isolation=enable_network_isolation, |
| 2804 | + ) |
| 2805 | + if not is_pipeline_variable(entry_point) and entry_point.startswith("s3://"): |
2787 | 2806 | raise ValueError( |
2788 | 2807 | "Invalid entry point script: {}. Must be a path to a local file.".format( |
2789 | 2808 | entry_point |
|
0 commit comments