Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, cast
from urllib.parse import unquote

Expand Down Expand Up @@ -40,6 +39,22 @@
def noop(): ...


class _FlyteFileDownloader:
"""Downloader for FlyteFile that uses current context when called."""

def __init__(self, remote_path: str, local_path: str, is_multipart: bool = False):
self.remote_path = remote_path
self.local_path = local_path
self.is_multipart = is_multipart

def __call__(self):
"""Download the file using current context's synced get_data method."""
current_ctx = FlyteContextManager.current_context()
return current_ctx.file_access.get_data(
remote_path=self.remote_path, local_path=self.local_path, is_multipart=self.is_multipart
)


T = typing.TypeVar("T")


Expand Down Expand Up @@ -318,11 +333,9 @@ def __init__(
if ctx.file_access.is_remote(self.path):
self._remote_source = self.path
self._local_path = ctx.file_access.get_random_local_path(self._remote_source)
self._downloader = partial(
ctx.file_access.get_data,
ctx=ctx,
remote_path=self._remote_source, # type: ignore
local_path=self._local_path,
self._downloader = _FlyteFileDownloader(
remote_path=str(self._remote_source), # type: ignore
local_path=str(self._local_path),
)

def __fspath__(self):
Expand Down Expand Up @@ -755,7 +768,7 @@ async def async_to_python_value(
# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

_downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False)
_downloader = _FlyteFileDownloader(remote_path=uri, local_path=local_path, is_multipart=False)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader, metadata=metadata)
Expand Down
Loading