|
7 | 7 | import typing |
8 | 8 | from contextlib import contextmanager |
9 | 9 | from dataclasses import dataclass, field |
10 | | -from functools import partial |
11 | 10 | from typing import Dict, cast |
12 | 11 | from urllib.parse import unquote |
13 | 12 |
|
|
40 | 39 | def noop(): ... |
41 | 40 |
|
42 | 41 |
|
| 42 | +class _FlyteFileDownloader: |
| 43 | + """Downloader for FlyteFile that uses current context when called.""" |
| 44 | + |
| 45 | + def __init__(self, remote_path: str, local_path: str, is_multipart: bool = False): |
| 46 | + self.remote_path = remote_path |
| 47 | + self.local_path = local_path |
| 48 | + self.is_multipart = is_multipart |
| 49 | + |
| 50 | + def __call__(self): |
| 51 | + """Download the file using current context's synced get_data method.""" |
| 52 | + current_ctx = FlyteContextManager.current_context() |
| 53 | + return current_ctx.file_access.get_data( |
| 54 | + remote_path=self.remote_path, local_path=self.local_path, is_multipart=self.is_multipart |
| 55 | + ) |
| 56 | + |
| 57 | + |
43 | 58 | T = typing.TypeVar("T") |
44 | 59 |
|
45 | 60 |
|
@@ -318,11 +333,9 @@ def __init__( |
318 | 333 | if ctx.file_access.is_remote(self.path): |
319 | 334 | self._remote_source = self.path |
320 | 335 | self._local_path = ctx.file_access.get_random_local_path(self._remote_source) |
321 | | - self._downloader = partial( |
322 | | - ctx.file_access.get_data, |
323 | | - ctx=ctx, |
324 | | - remote_path=self._remote_source, # type: ignore |
325 | | - local_path=self._local_path, |
| 336 | + self._downloader = _FlyteFileDownloader( |
| 337 | + remote_path=str(self._remote_source), # type: ignore |
| 338 | + local_path=str(self._local_path), |
326 | 339 | ) |
327 | 340 |
|
328 | 341 | def __fspath__(self): |
@@ -755,7 +768,7 @@ async def async_to_python_value( |
755 | 768 | # For the remote case, return an FlyteFile object that can download |
756 | 769 | local_path = ctx.file_access.get_random_local_path(uri) |
757 | 770 |
|
758 | | - _downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False) |
| 771 | + _downloader = _FlyteFileDownloader(remote_path=uri, local_path=local_path, is_multipart=False) |
759 | 772 |
|
760 | 773 | expected_format = FlyteFilePathTransformer.get_format(expected_python_type) |
761 | 774 | ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader, metadata=metadata) |
|
0 commit comments