-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_sagemaker.py
More file actions
47 lines (41 loc) · 1.81 KB
/
train_sagemaker.py
File metadata and controls
47 lines (41 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
from sagemaker.pytorch import PyTorch
from sagemaker.debugger import TensorBoardOutputConfig
def start_training():
# Retrieve configuration from environment variables
tensorboard_s3_path = os.getenv('TENSORBOARD_S3_PATH', 's3://your-default-bucket/tensorboard') # Added default for robustness
sagemaker_role_arn = os.getenv('SAGEMAKER_ROLE_ARN')
training_data_s3_uri = os.getenv('TRAINING_DATA_S3_URI')
validation_data_s3_uri = os.getenv('VALIDATION_DATA_S3_URI')
test_data_s3_uri = os.getenv('TEST_DATA_S3_URI')
# Basic validation to ensure required variables are set
if not all([sagemaker_role_arn, training_data_s3_uri, validation_data_s3_uri, test_data_s3_uri]):
print("Error: One or more required environment variables are not set.")
print("Please set: SAGEMAKER_ROLE_ARN, TRAINING_DATA_S3_URI, VALIDATION_DATA_S3_URI, TEST_DATA_S3_URI")
return # Or raise an exception
tensorboard_config = TensorBoardOutputConfig(
s3_output_path=tensorboard_s3_path,
container_local_output_path="/opt/ml/output/tensorboard"
)
estimator = PyTorch(
entry_point="train.py",
source_dir="training",
role=sagemaker_role_arn,
framework_version="2.5.1", # Consider making these env vars too if they change often
py_version="py311",
instance_count = 1,
instance_type = "ml.g5.xlarge", # Could be an env var: SAGEMAKER_INSTANCE_TYPE
hyperparameters = {
"batch-size": 32, # Could be env vars
"epochs": 25,
},
tensorboard_config=tensorboard_config
)
#Start training
estimator.fit({
"training": training_data_s3_uri,
"validation": validation_data_s3_uri,
"test": test_data_s3_uri
})
if __name__ == "__main__":
start_training()