|
49 | 49 | def noop(): ... |
50 | 50 |
|
51 | 51 |
|
| 52 | +class _FlyteDirectoryDownloader: |
| 53 | + """Downloader for FlyteDirectory that uses current context when called.""" |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, remote_path: str, local_path: str, is_multipart: bool = True, batch_size: typing.Optional[int] = None |
| 57 | + ): |
| 58 | + self.remote_path = remote_path |
| 59 | + self.local_path = local_path |
| 60 | + self.is_multipart = is_multipart |
| 61 | + self.batch_size = batch_size |
| 62 | + |
| 63 | + def __call__(self): |
| 64 | + """Download the directory using current context's synced get_data method.""" |
| 65 | + current_ctx = FlyteContextManager.current_context() |
| 66 | + return current_ctx.file_access.get_data( |
| 67 | + remote_path=self.remote_path, |
| 68 | + local_path=self.local_path, |
| 69 | + is_multipart=self.is_multipart, |
| 70 | + batch_size=self.batch_size, |
| 71 | + ) |
| 72 | + |
| 73 | + |
52 | 74 | @dataclass |
53 | 75 | class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.Generic[T]): |
54 | 76 | path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore |
@@ -409,9 +431,11 @@ def listdir(cls, directory: FlyteDirectory) -> typing.List[typing.Union[FlyteDir |
409 | 431 | paths.append(flyte_file) |
410 | 432 | else: |
411 | 433 | local_folder = file_access.get_random_local_directory() |
412 | | - downloader = partial(file_access.get_data, remote_path, local_folder, is_multipart=True) |
| 434 | + dir_downloader: typing.Callable = _FlyteDirectoryDownloader( |
| 435 | + remote_path=remote_path, local_path=local_folder, is_multipart=True |
| 436 | + ) |
413 | 437 |
|
414 | | - flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=downloader) |
| 438 | + flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=dir_downloader) |
415 | 439 | flyte_directory._remote_source = remote_path |
416 | 440 | paths.append(flyte_directory) |
417 | 441 |
|
@@ -684,7 +708,9 @@ async def async_to_python_value( |
684 | 708 |
|
685 | 709 | batch_size = get_batch_size(expected_python_type) |
686 | 710 |
|
687 | | - _downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size) |
| 711 | + _downloader = _FlyteDirectoryDownloader( |
| 712 | + remote_path=uri, local_path=local_folder, is_multipart=True, batch_size=batch_size |
| 713 | + ) |
688 | 714 |
|
689 | 715 | expected_format = self.get_format(expected_python_type) |
690 | 716 |
|
|
0 commit comments