Skip to content

Commit 4a134f6

Browse files
authored
[Fix] Issue when using FlyteDirectory with Elastic (#3320)
Signed-off-by: machichima <nary12321@gmail.com>
1 parent 12ce1b9 commit 4a134f6

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

flytekit/types/directory/types.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,28 @@
4949
def noop(): ...
5050

5151

52+
class _FlyteDirectoryDownloader:
53+
"""Downloader for FlyteDirectory that uses current context when called."""
54+
55+
def __init__(
56+
self, remote_path: str, local_path: str, is_multipart: bool = True, batch_size: typing.Optional[int] = None
57+
):
58+
self.remote_path = remote_path
59+
self.local_path = local_path
60+
self.is_multipart = is_multipart
61+
self.batch_size = batch_size
62+
63+
def __call__(self):
64+
"""Download the directory using current context's synced get_data method."""
65+
current_ctx = FlyteContextManager.current_context()
66+
return current_ctx.file_access.get_data(
67+
remote_path=self.remote_path,
68+
local_path=self.local_path,
69+
is_multipart=self.is_multipart,
70+
batch_size=self.batch_size,
71+
)
72+
73+
5274
@dataclass
5375
class FlyteDirectory(SerializableType, DataClassJsonMixin, os.PathLike, typing.Generic[T]):
5476
path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore
@@ -409,9 +431,11 @@ def listdir(cls, directory: FlyteDirectory) -> typing.List[typing.Union[FlyteDir
409431
paths.append(flyte_file)
410432
else:
411433
local_folder = file_access.get_random_local_directory()
412-
downloader = partial(file_access.get_data, remote_path, local_folder, is_multipart=True)
434+
dir_downloader: typing.Callable = _FlyteDirectoryDownloader(
435+
remote_path=remote_path, local_path=local_folder, is_multipart=True
436+
)
413437

414-
flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=downloader)
438+
flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=dir_downloader)
415439
flyte_directory._remote_source = remote_path
416440
paths.append(flyte_directory)
417441

@@ -684,7 +708,9 @@ async def async_to_python_value(
684708

685709
batch_size = get_batch_size(expected_python_type)
686710

687-
_downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size)
711+
_downloader = _FlyteDirectoryDownloader(
712+
remote_path=uri, local_path=local_folder, is_multipart=True, batch_size=batch_size
713+
)
688714

689715
expected_format = self.get_format(expected_python_type)
690716

0 commit comments

Comments
 (0)