6666 RemoteDebugConfig ,
6767 SessionChainingConfig ,
6868 InputData ,
69+ MetricDefinition ,
6970)
7071
7172from sagemaker .modules .local_core .local_container import _LocalContainer
@@ -119,7 +120,8 @@ class ModelTrainer(BaseModel):
119120 from sagemaker.modules.train import ModelTrainer
120121 from sagemaker.modules.configs import SourceCode, Compute, InputData
121122
122- source_code = SourceCode(source_dir="source", entry_script="train.py")
123+ ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
124+ source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
123125 training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
124126 model_trainer = ModelTrainer(
125127 training_image=training_image,
@@ -238,6 +240,7 @@ class ModelTrainer(BaseModel):
238240 _infra_check_config : Optional [InfraCheckConfig ] = PrivateAttr (default = None )
239241 _session_chaining_config : Optional [SessionChainingConfig ] = PrivateAttr (default = None )
240242 _remote_debug_config : Optional [RemoteDebugConfig ] = PrivateAttr (default = None )
243+ _metric_definitions : Optional [List [MetricDefinition ]] = PrivateAttr (default = None )
241244
242245 _temp_recipe_train_dir : Optional [TemporaryDirectory ] = PrivateAttr (default = None )
243246
@@ -654,6 +657,7 @@ def train(
654657 channel_name = SM_CODE ,
655658 data_source = self .source_code .source_dir ,
656659 key_prefix = input_data_key_prefix ,
660+ ignore_patterns = self .source_code .ignore_patterns ,
657661 )
658662 final_input_data_config .append (source_code_channel )
659663
@@ -675,6 +679,7 @@ def train(
675679 channel_name = SM_DRIVERS ,
676680 data_source = tmp_dir .name ,
677681 key_prefix = input_data_key_prefix ,
682+ ignore_patterns = self .source_code .ignore_patterns ,
678683 )
679684 final_input_data_config .append (sm_drivers_channel )
680685
@@ -693,6 +698,7 @@ def train(
693698 training_image_config = self .training_image_config ,
694699 container_entrypoint = container_entrypoint ,
695700 container_arguments = container_arguments ,
701+ metric_definitions = self ._metric_definitions ,
696702 )
697703
698704 resource_config = self .compute ._to_resource_config ()
@@ -755,7 +761,11 @@ def train(
755761 local_container .train (wait )
756762
757763 def create_input_data_channel (
758- self , channel_name : str , data_source : DataSourceType , key_prefix : Optional [str ] = None
764+ self ,
765+ channel_name : str ,
766+ data_source : DataSourceType ,
767+ key_prefix : Optional [str ] = None ,
768+ ignore_patterns : Optional [List [str ]] = None ,
759769 ) -> Channel :
760770 """Create an input data channel for the training job.
761771
@@ -771,6 +781,10 @@ def create_input_data_channel(
771781
772782 If specified, local data will be uploaded to:
773783 ``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
784+ ignore_patterns: (Optional[List[str]]) :
785+ The ignore patterns to ignore specific files/folders when uploading to S3.
786+ If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store',
787+ '.cache', '.ipynb_checkpoints'].
774788 """
775789 channel = None
776790 if isinstance (data_source , str ):
@@ -810,11 +824,28 @@ def create_input_data_channel(
810824 )
811825 if self .sagemaker_session .default_bucket_prefix :
812826 key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
813- s3_uri = self .sagemaker_session .upload_data (
814- path = data_source ,
815- bucket = self .sagemaker_session .default_bucket (),
816- key_prefix = key_prefix ,
817- )
827+ if ignore_patterns and _is_valid_path (data_source , path_type = "Directory" ):
828+ tmp_dir = TemporaryDirectory ()
829+ copied_path = os .path .join (
830+ tmp_dir .name , os .path .basename (os .path .normpath (data_source ))
831+ )
832+ shutil .copytree (
833+ data_source ,
834+ copied_path ,
835+ dirs_exist_ok = True ,
836+ ignore = shutil .ignore_patterns (* ignore_patterns ),
837+ )
838+ s3_uri = self .sagemaker_session .upload_data (
839+ path = copied_path ,
840+ bucket = self .sagemaker_session .default_bucket (),
841+ key_prefix = key_prefix ,
842+ )
843+ else :
844+ s3_uri = self .sagemaker_session .upload_data (
845+ path = data_source ,
846+ bucket = self .sagemaker_session .default_bucket (),
847+ key_prefix = key_prefix ,
848+ )
818849 channel = Channel (
819850 channel_name = channel_name ,
820851 data_source = DataSource (
@@ -861,7 +892,9 @@ def _get_input_data_config(
861892 channels .append (input_data )
862893 elif isinstance (input_data , InputData ):
863894 channel = self .create_input_data_channel (
864- input_data .channel_name , input_data .data_source , key_prefix = key_prefix
895+ input_data .channel_name ,
896+ input_data .data_source ,
897+ key_prefix = key_prefix ,
865898 )
866899 channels .append (channel )
867900 else :
@@ -1260,3 +1293,33 @@ def with_checkpoint_config(
12601293 """
12611294 self .checkpoint_config = checkpoint_config or configs .CheckpointConfig ()
12621295 return self
1296+
1297+ def with_metric_definitions (
1298+ self , metric_definitions : List [MetricDefinition ]
1299+ ) -> "ModelTrainer" : # noqa: D412
1300+ """Set the metric definitions for the training job.
1301+
1302+ Example:
1303+
1304+ .. code:: python
1305+
1306+ from sagemaker.modules.train import ModelTrainer
1307+ from sagemaker.modules.configs import MetricDefinition
1308+
1309+ metric_definitions = [
1310+ MetricDefinition(
1311+ name="loss",
1312+ regex="Loss: (.*?)",
1313+ )
1314+ ]
1315+
1316+ model_trainer = ModelTrainer(
1317+ ...
1318+ ).with_metric_definitions(metric_definitions)
1319+
1320+ Args:
1321+ metric_definitions (List[MetricDefinition]):
1322+ The metric definitions for the training job.
1323+ """
1324+ self ._metric_definitions = metric_definitions
1325+ return self
0 commit comments