2525import subprocess
2626import sys
2727import tempfile
28- from subprocess import Popen
28+ from fcntl import fcntl , F_GETFL , F_SETFL
2929from six .moves .urllib .parse import urlparse
30- from time import sleep
30+ from threading import Thread
3131
3232import yaml
3333
@@ -91,42 +91,7 @@ def train(self, input_data_config, hyperparameters):
9191 os .mkdir (shared_dir )
9292
9393 data_dir = self ._create_tmp_folder ()
94- volumes = []
95-
96- # Set up the channels for the containers. For local data we will
97- # mount the local directory to the container. For S3 Data we will download the S3 data
98- # first.
99- for channel in input_data_config :
100- if channel ['DataSource' ] and 'S3DataSource' in channel ['DataSource' ]:
101- uri = channel ['DataSource' ]['S3DataSource' ]['S3Uri' ]
102- elif channel ['DataSource' ] and 'FileDataSource' in channel ['DataSource' ]:
103- uri = channel ['DataSource' ]['FileDataSource' ]['FileUri' ]
104- else :
105- raise ValueError ('Need channel[\' DataSource\' ] to have [\' S3DataSource\' ] or [\' FileDataSource\' ]' )
106-
107- parsed_uri = urlparse (uri )
108- key = parsed_uri .path .lstrip ('/' )
109-
110- channel_name = channel ['ChannelName' ]
111- channel_dir = os .path .join (data_dir , channel_name )
112- os .mkdir (channel_dir )
113-
114- if parsed_uri .scheme == 's3' :
115- bucket_name = parsed_uri .netloc
116- self ._download_folder (bucket_name , key , channel_dir )
117- elif parsed_uri .scheme == 'file' :
118- path = parsed_uri .path
119- volumes .append (_Volume (path , channel = channel_name ))
120- else :
121- raise ValueError ('Unknown URI scheme {}' .format (parsed_uri .scheme ))
122-
123- # If the training script directory is a local directory, mount it to the container.
124- training_dir = json .loads (hyperparameters [sagemaker .estimator .DIR_PARAM_NAME ])
125- parsed_uri = urlparse (training_dir )
126- if parsed_uri .scheme == 'file' :
127- volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
128- # Also mount a directory that all the containers can access.
129- volumes .append (_Volume (shared_dir , '/opt/ml/shared' ))
94+ volumes = self ._prepare_training_volumes (data_dir , input_data_config , hyperparameters )
13095
13196 # Create the configuration files for each container that we will create
13297 # Each container will map the additional local volumes (if any).
@@ -139,7 +104,15 @@ def train(self, input_data_config, hyperparameters):
139104 compose_command = self ._compose ()
140105
141106 _ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
142- _execute_and_stream_output (compose_command )
107+ process = subprocess .Popen (compose_command , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
108+
109+ try :
110+ _stream_output (process )
111+ except RuntimeError as e :
112+ # _stream_output() doesn't have the command line. We will handle the exception
113+ # which contains the exit code and append the command line to it.
114+ msg = "Failed to run: %s, %s" % (compose_command , e .message )
115+ raise RuntimeError (msg )
143116
144117 s3_artifacts = self .retrieve_artifacts (compose_data )
145118
@@ -196,7 +169,7 @@ def serve(self, primary_container):
196169 additional_volumes = volumes )
197170 compose_command = self ._compose ()
198171 self .container = _HostingContainer (compose_command )
199- self .container .up ()
172+ self .container .start ()
200173
201174 def stop_serving (self ):
202175 """Stop the serving container.
@@ -205,6 +178,7 @@ def stop_serving(self):
205178 """
206179 if self .container :
207180 self .container .down ()
181+ self .container .join ()
208182 self ._cleanup ()
209183 # for serving we can delete everything in the container root.
210184 _delete_tree (self .container_root )
@@ -304,6 +278,47 @@ def _download_folder(self, bucket_name, prefix, target):
304278
305279 obj .download_file (file_path )
306280
281+ def _prepare_training_volumes (self , data_dir , input_data_config , hyperparameters ):
282+ shared_dir = os .path .join (self .container_root , 'shared' )
283+ volumes = []
284+ # Set up the channels for the containers. For local data we will
285+ # mount the local directory to the container. For S3 Data we will download the S3 data
286+ # first.
287+ for channel in input_data_config :
288+ if channel ['DataSource' ] and 'S3DataSource' in channel ['DataSource' ]:
289+ uri = channel ['DataSource' ]['S3DataSource' ]['S3Uri' ]
290+ elif channel ['DataSource' ] and 'FileDataSource' in channel ['DataSource' ]:
291+ uri = channel ['DataSource' ]['FileDataSource' ]['FileUri' ]
292+ else :
293+ raise ValueError ('Need channel[\' DataSource\' ] to have'
294+ ' [\' S3DataSource\' ] or [\' FileDataSource\' ]' )
295+
296+ parsed_uri = urlparse (uri )
297+ key = parsed_uri .path .lstrip ('/' )
298+
299+ channel_name = channel ['ChannelName' ]
300+ channel_dir = os .path .join (data_dir , channel_name )
301+ os .mkdir (channel_dir )
302+
303+ if parsed_uri .scheme == 's3' :
304+ bucket_name = parsed_uri .netloc
305+ self ._download_folder (bucket_name , key , channel_dir )
306+ elif parsed_uri .scheme == 'file' :
307+ path = parsed_uri .path
308+ volumes .append (_Volume (path , channel = channel_name ))
309+ else :
310+ raise ValueError ('Unknown URI scheme {}' .format (parsed_uri .scheme ))
311+
312+ # If the training script directory is a local directory, mount it to the container.
313+ training_dir = json .loads (hyperparameters [sagemaker .estimator .DIR_PARAM_NAME ])
314+ parsed_uri = urlparse (training_dir )
315+ if parsed_uri .scheme == 'file' :
316+ volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
317+ # Also mount a directory that all the containers can access.
318+ volumes .append (_Volume (shared_dir , '/opt/ml/shared' ))
319+
320+ return volumes
321+
307322 def _generate_compose_file (self , command , additional_volumes = None , additional_env_vars = None ):
308323 """Writes a config file describing a training/hosting environment.
309324
@@ -452,15 +467,23 @@ def _cleanup(self):
452467 pass
453468
454469
455- class _HostingContainer (object ):
456- def __init__ (self , command , startup_delay = 5 ):
470+ class _HostingContainer (Thread ):
471+ def __init__ (self , command ):
472+ Thread .__init__ (self )
457473 self .command = command
458- self .startup_delay = startup_delay
459474 self .process = None
460475
461- def up (self ):
462- self .process = Popen (self .command )
463- sleep (self .startup_delay )
476+ def run (self ):
477+ self .process = subprocess .Popen (self .command ,
478+ stdout = subprocess .PIPE ,
479+ stderr = subprocess .PIPE )
480+ try :
481+ _stream_output (self .process )
482+ except RuntimeError as e :
483+ # _stream_output() doesn't have the command line. We will handle the exception
484+ # which contains the exit code and append the command line to it.
485+ msg = "Failed to run: %s, %s" % (self .command , e .message )
486+ raise RuntimeError (msg )
464487
465488 def down (self ):
466489 self .process .terminate ()
@@ -495,26 +518,41 @@ def __init__(self, host_dir, container_dir=None, channel=None):
495518 self .map = '{}:{}' .format (self .host_dir , self .container_dir )
496519
497520
498- def _execute_and_stream_output (cmd ):
499- """Execute a command and stream the output to stdout
521+ def _stream_output (process ):
522+ """Stream the output of a process to stdout
523+
524+ This function takes an existing process that will be polled for output. Both stdout and
525+ stderr will be polled and both will be sent to sys.stdout.
500526
501527 Args:
502- cmd(str or List): either a string or a List (in Popen Format) with the command to execute.
528+ process(subprocess.Popen): a process that has been started with
529+ stdout=PIPE and stderr=PIPE
503530
504- Returns (int): command exit code
531+ Returns (int): process exit code
505532 """
506- if isinstance (cmd , str ):
507- cmd = shlex .split (cmd )
508- process = subprocess .Popen (cmd , stdout = subprocess .PIPE )
509533 exit_code = None
534+
535+ # Get the current flags for the stderr file descriptor
536+ # And add the NONBLOCK flag to allow us to read even if there is no data.
537+ # Since usually stderr will be empty unless there is an error.
538+ flags = fcntl (process .stderr , F_GETFL ) # get current process.stderr flags
539+ fcntl (process .stderr , F_SETFL , flags | os .O_NONBLOCK )
540+
510541 while exit_code is None :
511542 stdout = process .stdout .readline ().decode ("utf-8" )
512543 sys .stdout .write (stdout )
544+ try :
545+ stderr = process .stderr .readline ().decode ("utf-8" )
546+ sys .stdout .write (stderr )
547+ except IOError :
548+ # If there is nothing to read on stderr we will get an IOError
549+ # this is fine.
550+ pass
513551
514552 exit_code = process .poll ()
515553
516554 if exit_code != 0 :
517- raise Exception ( "Failed to run %s, exit code: %s" % ( "," . join ( cmd ), exit_code ) )
555+ raise RuntimeError ( "Process exited with code: %s" % exit_code )
518556
519557 return exit_code
520558
0 commit comments