diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index a5eff0b68f..780188f9e5 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -7,7 +7,6 @@ import typing from contextlib import contextmanager from dataclasses import dataclass, field -from functools import partial from typing import Dict, cast from urllib.parse import unquote @@ -40,6 +39,22 @@ def noop(): ... +class _FlyteFileDownloader: + """Downloader for FlyteFile that uses current context when called.""" + + def __init__(self, remote_path: str, local_path: str, is_multipart: bool = False): + self.remote_path = remote_path + self.local_path = local_path + self.is_multipart = is_multipart + + def __call__(self): + """Download the file using current context's synced get_data method.""" + current_ctx = FlyteContextManager.current_context() + return current_ctx.file_access.get_data( + remote_path=self.remote_path, local_path=self.local_path, is_multipart=self.is_multipart + ) + + T = typing.TypeVar("T") @@ -318,11 +333,9 @@ def __init__( if ctx.file_access.is_remote(self.path): self._remote_source = self.path self._local_path = ctx.file_access.get_random_local_path(self._remote_source) - self._downloader = partial( - ctx.file_access.get_data, - ctx=ctx, - remote_path=self._remote_source, # type: ignore - local_path=self._local_path, + self._downloader = _FlyteFileDownloader( + remote_path=str(self._remote_source), # type: ignore + local_path=str(self._local_path), ) def __fspath__(self): @@ -755,7 +768,7 @@ async def async_to_python_value( # For the remote case, return an FlyteFile object that can download local_path = ctx.file_access.get_random_local_path(uri) - _downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False) + _downloader = _FlyteFileDownloader(remote_path=uri, local_path=local_path, is_multipart=False) expected_format = FlyteFilePathTransformer.get_format(expected_python_type) ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader, metadata=metadata)