@@ -456,6 +456,7 @@ def train( # noqa: C901
456456 enable_sagemaker_metrics = None ,
457457 profiler_rule_configs = None ,
458458 profiler_config = None ,
459+ environment = None ,
459460 ):
460461 """Create an Amazon SageMaker training job.
461462
@@ -522,9 +523,12 @@ def train( # noqa: C901
522523 Series. For more information see:
523524 https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
524525 (default: ``None``).
525- profiler_rule_configs (list[dict]): A list of profiler rule configurations.
526+ profiler_rule_configs (list[dict]): A list of profiler rule
527+ configurations.src/sagemaker/lineage/artifact.py:285
526528 profiler_config (dict): Configuration for how profiling information is emitted
527529 with SageMaker Profiler. (default: ``None``).
530+ environment (dict[str, str]) : Environment variables to be set for
531+ use during training job (default: ``None``)
528532
529533 Returns:
530534 str: ARN of the training job, if it is created.
@@ -556,6 +560,7 @@ def train( # noqa: C901
556560 enable_sagemaker_metrics = enable_sagemaker_metrics ,
557561 profiler_rule_configs = profiler_rule_configs ,
558562 profiler_config = profiler_config ,
563+ environment = environment ,
559564 )
560565 LOGGER .info ("Creating training-job with name: %s" , job_name )
561566 LOGGER .debug ("train request: %s" , json .dumps (train_request , indent = 4 ))
@@ -588,6 +593,7 @@ def _get_train_request( # noqa: C901
588593 enable_sagemaker_metrics = None ,
589594 profiler_rule_configs = None ,
590595 profiler_config = None ,
596+ environment = None ,
591597 ):
592598 """Constructs a request compatible for creating an Amazon SageMaker training job.
593599
@@ -657,6 +663,8 @@ def _get_train_request( # noqa: C901
657663 profiler_rule_configs (list[dict]): A list of profiler rule configurations.
658664 profiler_config(dict): Configuration for how profiling information is emitted with
659665 SageMaker Profiler. (default: ``None``).
666+ environment (dict[str, str]) : Environment variables to be set for
667+ use during training job (default: ``None``)
660668
661669 Returns:
662670 Dict: a training request dict
@@ -699,6 +707,9 @@ def _get_train_request( # noqa: C901
699707 if hyperparameters and len (hyperparameters ) > 0 :
700708 train_request ["HyperParameters" ] = hyperparameters
701709
710+ if environment is not None :
711+ train_request ["Environment" ] = environment
712+
702713 if tags is not None :
703714 train_request ["Tags" ] = tags
704715
0 commit comments