1313"""Distributed module."""
1414from __future__ import absolute_import
1515
16+ import os
17+
18+ from abc import ABC , abstractmethod
1619from typing import Optional , Dict , Any , List
17- from pydantic import PrivateAttr
1820from sagemaker .modules .utils import safe_serialize
21+ from sagemaker .modules .constants import SM_DRIVERS_LOCAL_PATH
1922from sagemaker .modules .configs import BaseConfig
2023
2124
@@ -73,16 +76,37 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7376 return hyperparameters
7477
7578
76- class DistributedConfig (BaseConfig ):
77- """Base class for distributed training configurations."""
79+ class DistributedConfig (BaseConfig , ABC ):
80+ """Abstract base class for distributed training configurations.
81+
82+ This class defines the interface that all distributed training configurations
83+ must implement. It provides a standardized way to specify driver scripts and
84+ their locations for distributed training jobs.
85+ """
86+
87+ @property
88+ @abstractmethod
89+ def driver_dir (self ) -> str :
90+ """Directory containing the driver script.
91+
92+ This property should return the path to the directory containing
93+ the driver script, relative to the container's working directory.
7894
79- _type : str = PrivateAttr ()
95+ Returns:
96+ str: Path to directory containing the driver script
97+ """
8098
81- def model_dump (self , * args , ** kwargs ):
82- """Dump the model to a dictionary."""
83- result = super ().model_dump (* args , ** kwargs )
84- result ["_type" ] = self ._type
85- return result
99+ @property
100+ @abstractmethod
101+ def driver_script (self ) -> str :
102+ """Name of the driver script.
103+
104+ This property should return the name of the Python script that implements
105+ the distributed training driver logic.
106+
107+ Returns:
108+ str: Name of the driver script file
109+ """
86110
87111
88112class Torchrun (DistributedConfig ):
@@ -99,11 +123,27 @@ class Torchrun(DistributedConfig):
99123 The SageMaker Model Parallelism v2 parameters.
100124 """
101125
102- _type : str = PrivateAttr (default = "torchrun" )
103-
104126 process_count_per_node : Optional [int ] = None
105127 smp : Optional ["SMP" ] = None
106128
129+ @property
130+ def driver_dir (self ) -> str :
131+ """Directory containing the driver script.
132+
133+ Returns:
134+ str: Path to directory containing the driver script
135+ """
136+ return os .path .join (SM_DRIVERS_LOCAL_PATH , "distributed_drivers" )
137+
138+ @property
139+ def driver_script (self ) -> str :
140+ """Name of the driver script.
141+
142+ Returns:
143+ str: Name of the driver script file
144+ """
145+ return "torchrun_driver.py"
146+
107147
108148class MPI (DistributedConfig ):
109149 """MPI.
@@ -119,7 +159,23 @@ class MPI(DistributedConfig):
119159 The custom MPI options to use for the training job.
120160 """
121161
122- _type : str = PrivateAttr (default = "mpi" )
123-
124162 process_count_per_node : Optional [int ] = None
125163 mpi_additional_options : Optional [List [str ]] = None
164+
165+ @property
166+ def driver_dir (self ) -> str :
167+ """Directory containing the driver script.
168+
169+ Returns:
170+ str: Path to directory containing the driver script
171+ """
172+ return os .path .join (SM_DRIVERS_LOCAL_PATH , "distributed_drivers" )
173+
174+ @property
175+ def driver_script (self ) -> str :
176+ """Name of the driver script.
177+
178+ Returns:
179+ str: Name of the driver script
180+ """
181+ return "mpi_driver.py"
0 commit comments