@@ -117,46 +117,64 @@ def test_retrieve_artifacts(LocalSession, tmpdir):
117117 sagemaker_container .hosts = ['algo-1' , 'algo-2' ] # avoid any randomness
118118 sagemaker_container .container_root = str (tmpdir .mkdir ('container-root' ))
119119
120- volume1 = os .path .join (sagemaker_container .container_root , 'algo-1/output/ ' )
121- volume2 = os .path .join (sagemaker_container .container_root , 'algo-2/output/ ' )
122- os .makedirs (volume1 )
123- os .makedirs (volume2 )
120+ volume1 = os .path .join (sagemaker_container .container_root , 'algo-1' )
121+ volume2 = os .path .join (sagemaker_container .container_root , 'algo-2' )
122+ os .mkdir (volume1 )
123+ os .mkdir (volume2 )
124124
125125 compose_data = {
126126 'services' : {
127127 'algo-1' : {
128- 'volumes' : ['%s:/opt/ml/model' % volume1 ]
128+ 'volumes' : ['%s:/opt/ml/model' % os .path .join (volume1 , 'model' ),
129+ '%s:/opt/ml/output' % os .path .join (volume1 , 'output' )]
129130 },
130131 'algo-2' : {
131- 'volumes' : ['%s:/opt/ml/model' % volume2 ]
132+ 'volumes' : ['%s:/opt/ml/model' % os .path .join (volume2 , 'model' ),
133+ '%s:/opt/ml/output' % os .path .join (volume2 , 'output' )]
132134 }
133135 }
134136 }
135137
136138 dirs1 = ['model' , 'model/data' ]
137139 dirs2 = ['model' , 'model/data' , 'model/tmp' ]
140+ dirs3 = ['output' , 'output/data' ]
141+ dirs4 = ['output' , 'output/data' , 'output/log' ]
138142
139143 files1 = ['model/data/model.json' , 'model/data/variables.csv' ]
140144 files2 = ['model/data/model.json' , 'model/data/variables2.csv' , 'model/tmp/something-else.json' ]
145+ files3 = ['output/data/loss.json' , 'output/data/accuracy.json' ]
146+ files4 = ['output/data/loss.json' , 'output/data/accuracy2.json' , 'output/log/warnings.txt' ]
141147
142148 expected = ['model' , 'model/data/' , 'model/data/model.json' , 'model/data/variables.csv' ,
143- 'model/data/variables2.csv' , 'model/tmp/something-else.json' ]
149+ 'model/data/variables2.csv' , 'model/tmp/something-else.json' , 'output' , 'output/data' , 'output/log' ,
150+ 'output/data/loss.json' , 'output/data/accuracy.json' , 'output/data/accuracy2.json' ,
151+ 'output/log/warnings.txt' ]
144152
145153 for d in dirs1 :
146154 os .mkdir (os .path .join (volume1 , d ))
147155 for d in dirs2 :
148156 os .mkdir (os .path .join (volume2 , d ))
157+ for d in dirs3 :
158+ os .mkdir (os .path .join (volume1 , d ))
159+ for d in dirs4 :
160+ os .mkdir (os .path .join (volume2 , d ))
149161
150162 # create all the files
151163 for f in files1 :
152164 open (os .path .join (volume1 , f ), 'a' ).close ()
153165 for f in files2 :
154166 open (os .path .join (volume2 , f ), 'a' ).close ()
167+ for f in files3 :
168+ open (os .path .join (volume1 , f ), 'a' ).close ()
169+ for f in files4 :
170+ open (os .path .join (volume2 , f ), 'a' ).close ()
155171
156- s3_model_artifacts = sagemaker_container .retrieve_model_artifacts (compose_data )
172+ s3_model_artifacts = sagemaker_container .retrieve_artifacts (compose_data )
173+ s3_artifacts = os .path .dirname (s3_model_artifacts )
157174
158175 for f in expected :
159- assert os .path .exists (os .path .join (s3_model_artifacts , f ))
176+ assert set (os .listdir (s3_artifacts )) == set (['model' , 'output' ])
177+ assert os .path .exists (os .path .join (s3_artifacts , f ))
160178
161179
162180def test_stream_output ():
0 commit comments