Skip to content

Commit a06e374

Browse files
machichimaAtharva1723
authored andcommitted
[Fix] Issue when using FlyteFile with Elastic (flyteorg#3313)
Signed-off-by: machichima <[email protected]> Signed-off-by: Atharva <[email protected]>
1 parent 7c8bfc7 commit a06e374

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

flytekit/types/file/file.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import typing
88
from contextlib import contextmanager
99
from dataclasses import dataclass, field
10-
from functools import partial
1110
from typing import Dict, cast
1211
from urllib.parse import unquote
1312

@@ -40,6 +39,22 @@
4039
def noop(): ...
4140

4241

42+
class _FlyteFileDownloader:
43+
"""Downloader for FlyteFile that uses current context when called."""
44+
45+
def __init__(self, remote_path: str, local_path: str, is_multipart: bool = False):
46+
self.remote_path = remote_path
47+
self.local_path = local_path
48+
self.is_multipart = is_multipart
49+
50+
def __call__(self):
51+
"""Download the file using current context's synced get_data method."""
52+
current_ctx = FlyteContextManager.current_context()
53+
return current_ctx.file_access.get_data(
54+
remote_path=self.remote_path, local_path=self.local_path, is_multipart=self.is_multipart
55+
)
56+
57+
4358
T = typing.TypeVar("T")
4459

4560

@@ -318,11 +333,9 @@ def __init__(
318333
if ctx.file_access.is_remote(self.path):
319334
self._remote_source = self.path
320335
self._local_path = ctx.file_access.get_random_local_path(self._remote_source)
321-
self._downloader = partial(
322-
ctx.file_access.get_data,
323-
ctx=ctx,
324-
remote_path=self._remote_source, # type: ignore
325-
local_path=self._local_path,
336+
self._downloader = _FlyteFileDownloader(
337+
remote_path=str(self._remote_source), # type: ignore
338+
local_path=str(self._local_path),
326339
)
327340

328341
def __fspath__(self):
@@ -755,7 +768,7 @@ async def async_to_python_value(
755768
# For the remote case, return an FlyteFile object that can download
756769
local_path = ctx.file_access.get_random_local_path(uri)
757770

758-
_downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False)
771+
_downloader = _FlyteFileDownloader(remote_path=uri, local_path=local_path, is_multipart=False)
759772

760773
expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
761774
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader, metadata=metadata)

0 commit comments

Comments
 (0)