|
20 | 20 | from typing import List, Union |
21 | 21 | from sagemaker import image_uris |
22 | 22 | from sagemaker.inputs import TrainingInput |
23 | | -from sagemaker.s3 import ( |
24 | | - S3Downloader, |
25 | | - S3Uploader, |
26 | | -) |
27 | 23 | from sagemaker.estimator import EstimatorBase |
28 | 24 | from sagemaker.sklearn.estimator import SKLearn |
29 | 25 | from sagemaker.workflow.entities import RequestType |
|
35 | 31 | Step, |
36 | 32 | ConfigurableRetryStep, |
37 | 33 | ) |
| 34 | +from sagemaker.utils import _save_model, download_file_from_url |
38 | 35 | from sagemaker.workflow.retry import RetryPolicy |
39 | 36 |
|
40 | 37 | FRAMEWORK_VERSION = "0.23-1" |
@@ -203,40 +200,36 @@ def _establish_source_dir(self): |
203 | 200 | self._entry_point = self._entry_point_basename |
204 | 201 |
|
205 | 202 | def _inject_repack_script(self): |
206 | | - """Injects the _repack_model.py script where it belongs. |
| 203 | + """Injects the _repack_model.py script into S3 or local source directory. |
207 | 204 |
|
208 | 205 | If the source_dir is an S3 path: |
209 | 206 | 1) downloads the source_dir tar.gz |
210 | | - 2) copies the _repack_model.py script where it belongs |
211 | | - 3) uploads the mutated source_dir |
| 207 | + 2) extracts it |
| 208 | + 3) copies the _repack_model.py script into the extracted directory |
| 209 | + 4) rezips the directory |
| 210 | + 5) overwrites the S3 source_dir with the new tar.gz |
212 | 211 |
|
213 | 212 | If the source_dir is a local path: |
214 | 213 | 1) copies the _repack_model.py script into the source dir |
215 | 214 | """ |
216 | 215 | fname = os.path.join(os.path.dirname(__file__), REPACK_SCRIPT) |
217 | 216 | if self._source_dir.lower().startswith("s3://"): |
218 | 217 | with tempfile.TemporaryDirectory() as tmp: |
219 | | - local_path = os.path.join(tmp, "local.tar.gz") |
220 | | - |
221 | | - S3Downloader.download( |
222 | | - s3_uri=self._source_dir, |
223 | | - local_path=local_path, |
224 | | - sagemaker_session=self.sagemaker_session, |
225 | | - ) |
226 | | - |
227 | | - src_dir = os.path.join(tmp, "src") |
228 | | - with tarfile.open(name=local_path, mode="r:gz") as tf: |
229 | | - tf.extractall(path=src_dir) |
230 | | - |
231 | | - shutil.copy2(fname, os.path.join(src_dir, REPACK_SCRIPT)) |
232 | | - with tarfile.open(name=local_path, mode="w:gz") as tf: |
233 | | - tf.add(src_dir, arcname=".") |
234 | | - |
235 | | - S3Uploader.upload( |
236 | | - local_path=local_path, |
237 | | - desired_s3_uri=self._source_dir, |
238 | | - sagemaker_session=self.sagemaker_session, |
239 | | - ) |
| 218 | + targz_contents_dir = os.path.join(tmp, "extracted") |
| 219 | + |
| 220 | + old_targz_path = os.path.join(tmp, "old.tar.gz") |
| 221 | + download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session) |
| 222 | + |
| 223 | + with tarfile.open(name=old_targz_path, mode="r:gz") as t: |
| 224 | + t.extractall(path=targz_contents_dir) |
| 225 | + |
| 226 | + shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT)) |
| 227 | + |
| 228 | + new_targz_path = os.path.join(tmp, "new.tar.gz") |
| 229 | + with tarfile.open(new_targz_path, mode="w:gz") as t: |
| 230 | + t.add(targz_contents_dir, arcname=os.path.sep) |
| 231 | + |
| 232 | + _save_model(self._source_dir, new_targz_path, self.sagemaker_session, kms_key=None) |
240 | 233 | else: |
241 | 234 | shutil.copy2(fname, os.path.join(self._source_dir, REPACK_SCRIPT)) |
242 | 235 |
|
|
0 commit comments