Skip to content

Commit 93becff

Browse files
authored
fix: Identify s3/remote uri path correctly (feast-dev#5076)
Signed-off-by: ntkathole <[email protected]>
1 parent f3a24de commit 93becff

File tree

4 files changed

+46
-18
lines changed

4 files changed

+46
-18
lines changed

sdk/python/feast/infra/offline_stores/dask.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,9 @@ def persist(
100100
# Check if the specified location already exists.
101101
if not allow_overwrite and os.path.exists(storage.file_options.uri):
102102
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)
103-
104-
if not Path(storage.file_options.uri).is_absolute():
105-
absolute_path = Path(self.repo_path) / storage.file_options.uri
106-
else:
107-
absolute_path = Path(storage.file_options.uri)
103+
absolute_path = FileSource.get_uri_for_file_path(
104+
repo_path=self.repo_path, uri=storage.file_options.uri
105+
)
108106

109107
filesystem, path = FileSource.create_filesystem_and_path(
110108
str(absolute_path),

sdk/python/feast/infra/offline_stores/duckdb.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,9 @@ def _write_data_source(
5151

5252
file_options = data_source.file_options
5353

54-
if not Path(file_options.uri).is_absolute():
55-
absolute_path = Path(repo_path) / file_options.uri
56-
else:
57-
absolute_path = Path(file_options.uri)
54+
absolute_path = FileSource.get_uri_for_file_path(
55+
repo_path=repo_path, uri=file_options.uri
56+
)
5857

5958
if (
6059
mode == "overwrite"

sdk/python/feast/infra/offline_stores/file_source.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
2-
from typing import Callable, Dict, Iterable, List, Optional, Tuple
2+
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
3+
from urllib.parse import urlparse
34

45
import pyarrow
56
from packaging import version
@@ -154,17 +155,21 @@ def validate(self, config: RepoConfig):
154155
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
155156
return type_map.pa_to_feast_value_type
156157

158+
@staticmethod
159+
def get_uri_for_file_path(repo_path: Union[Path, str, None], uri: str) -> str:
160+
parsed_uri = urlparse(uri)
161+
if parsed_uri.scheme and parsed_uri.netloc:
162+
return uri # Keep remote URIs as they are
163+
if repo_path is not None and not Path(uri).is_absolute():
164+
return str(Path(repo_path) / uri)
165+
return str(Path(uri))
166+
157167
def get_table_column_names_and_types(
158168
self, config: RepoConfig
159169
) -> Iterable[Tuple[str, str]]:
160-
if (
161-
config.repo_path is not None
162-
and not Path(self.file_options.uri).is_absolute()
163-
):
164-
absolute_path = config.repo_path / self.file_options.uri
165-
else:
166-
absolute_path = Path(self.file_options.uri)
167-
170+
absolute_path = self.get_uri_for_file_path(
171+
repo_path=config.repo_path, uri=self.file_options.uri
172+
)
168173
filesystem, path = FileSource.create_filesystem_and_path(
169174
str(absolute_path), self.file_options.s3_endpoint_override
170175
)

sdk/python/tests/unit/infra/offline_stores/test_offline_store.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TrinoRetrievalJob,
2222
)
2323
from feast.infra.offline_stores.dask import DaskRetrievalJob
24+
from feast.infra.offline_stores.file_source import FileSource
2425
from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata
2526
from feast.infra.offline_stores.redshift import (
2627
RedshiftOfflineStoreConfig,
@@ -246,3 +247,28 @@ def test_to_arrow_timeout(retrieval_job, timeout: Optional[int]):
246247
with patch.object(retrieval_job, "_to_arrow_internal") as mock_to_arrow_internal:
247248
retrieval_job.to_arrow(timeout=timeout)
248249
mock_to_arrow_internal.assert_called_once_with(timeout=timeout)
250+
251+
252+
@pytest.mark.parametrize(
253+
"repo_path, uri, expected",
254+
[
255+
# Remote URI - Should return as-is
256+
(
257+
"/some/repo",
258+
"s3://bucket-name/file.parquet",
259+
"s3://bucket-name/file.parquet",
260+
),
261+
# Absolute Path - Should return as-is
262+
("/some/repo", "/abs/path/file.parquet", "/abs/path/file.parquet"),
263+
# Relative Path with repo_path - Should combine
264+
("/some/repo", "data/output.parquet", "/some/repo/data/output.parquet"),
265+
# Relative Path without repo_path - Should return absolute path
266+
(None, "C:/path/to/file.parquet", "C:/path/to/file.parquet"),
267+
],
268+
ids=["s3_uri", "absolute_path", "relative_path", "windows_path"],
269+
)
270+
def test_get_uri_for_file_path(
271+
repo_path: Optional[str], uri: str, expected: str
272+
) -> None:
273+
result = FileSource.get_uri_for_file_path(repo_path=repo_path, uri=uri)
274+
assert result == expected, f"Expected {expected}, but got {result}"

0 commit comments

Comments
 (0)