@@ -225,7 +225,7 @@ def test_single_container_local_mode_s3_data_not_remove_input(modules_sagemaker_
225225 delete_local_path (path )
226226
227227
228- def test_multi_container_local_mode (modules_sagemaker_session ):
228+ def test_multi_container_local_mode_remove_input (modules_sagemaker_session ):
229229 with lock .lock (LOCK_PATH ):
230230 try :
231231 source_code = SourceCode (
@@ -265,6 +265,68 @@ def test_multi_container_local_mode(modules_sagemaker_session):
265265
266266 model_trainer .train ()
267267 assert os .path .exists (os .path .join (CWD , "compressed_artifacts/model.tar.gz" ))
268+
269+ finally :
270+ subprocess .run (["docker" , "compose" , "down" , "-v" ])
271+
272+ assert not os .path .exists (os .path .join (CWD , "shared" ))
273+ assert not os .path .exists (os .path .join (CWD , "input" ))
274+ assert not os .path .exists (os .path .join (CWD , "algo-1" ))
275+ assert not os .path .exists (os .path .join (CWD , "algo-2" ))
276+
277+ directories = [
278+ "compressed_artifacts" ,
279+ "artifacts" ,
280+ "model" ,
281+ "output" ,
282+ ]
283+
284+ for directory in directories :
285+ path = os .path .join (CWD , directory )
286+ delete_local_path (path )
287+
288+
289+ def test_multi_container_local_mode_not_remove_input (modules_sagemaker_session ):
290+ with lock .lock (LOCK_PATH ):
291+ try :
292+ source_code = SourceCode (
293+ source_dir = SOURCE_DIR ,
294+ entry_script = "local_training_script.py" ,
295+ )
296+
297+ distributed = Torchrun (
298+ process_count_per_node = 1 ,
299+ )
300+
301+ compute = Compute (
302+ instance_type = "local_cpu" ,
303+ instance_count = 2 ,
304+ )
305+
306+ train_data = InputData (
307+ channel_name = "train" ,
308+ data_source = os .path .join (SOURCE_DIR , "data/train/" ),
309+ )
310+
311+ test_data = InputData (
312+ channel_name = "test" ,
313+ data_source = os .path .join (SOURCE_DIR , "data/test/" ),
314+ )
315+
316+ model_trainer = ModelTrainer (
317+ training_image = DEFAULT_CPU_IMAGE ,
318+ sagemaker_session = modules_sagemaker_session ,
319+ source_code = source_code ,
320+ distributed = distributed ,
321+ compute = compute ,
322+ input_data_config = [train_data , test_data ],
323+ base_job_name = "local_mode_multi_container" ,
324+ training_mode = Mode .LOCAL_CONTAINER ,
325+ remove_inputs_and_container_artifacts = False ,
326+ )
327+
328+ model_trainer .train ()
329+ assert os .path .exists (os .path .join (CWD , "compressed_artifacts/model.tar.gz" ))
268330 assert os .path .exists (os .path .join (CWD , "algo-1" ))
269331 assert os .path .exists (os .path .join (CWD , "algo-2" ))
270332
@@ -274,7 +336,11 @@ def test_multi_container_local_mode(modules_sagemaker_session):
274336 "compressed_artifacts" ,
275337 "artifacts" ,
276338 "model" ,
339+ "shared" ,
340+ "input" ,
277341 "output" ,
342+ "algo-1" ,
343+ "algo-2" ,
278344 ]
279345
280346 for directory in directories :
0 commit comments