Skip to content

Commit d2c26e3

Browse files
committed
fix: storage backend path handling and async streaming blocking
- Fix ObStoreBackend LocalStore path mismatch when base_path is provided (#336) - Fix async streaming blocking event loop in all three backends (ObStore, LocalStore, FSSpec) - Add AsyncChunkedBytesIterator class for mypyc-compatible non-blocking streaming - Fix LocalStore base_path to combine with URI path instead of overriding - Fix FSSpecBackend to auto-derive base_path from file:// URIs - Add comprehensive tests for all fixes
1 parent ac19c46 commit d2c26e3

File tree

9 files changed

+610
-55
lines changed

9 files changed

+610
-55
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ maintainers = [{ name = "Litestar Developers", email = "[email protected]" }]
2424
name = "sqlspec"
2525
readme = "README.md"
2626
requires-python = ">=3.10, <4.0"
27-
version = "0.38.1"
27+
version = "0.38.2"
2828

2929
[project.urls]
3030
Discord = "https://discord.gg/litestar"
@@ -240,7 +240,7 @@ opt_level = "3" # Maximum optimization (0-3)
240240
allow_dirty = true
241241
commit = false
242242
commit_args = "--no-verify"
243-
current_version = "0.38.1"
243+
current_version = "0.38.2"
244244
ignore_missing_files = false
245245
ignore_missing_version = false
246246
message = "chore(release): bump to v{new_version}"

sqlspec/storage/backends/base.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sqlspec.typing import ArrowRecordBatch, ArrowTable
1010

11-
__all__ = ("AsyncArrowBatchIterator", "AsyncBytesIterator", "ObjectStoreBase")
11+
__all__ = ("AsyncArrowBatchIterator", "AsyncBytesIterator", "AsyncChunkedBytesIterator", "ObjectStoreBase")
1212

1313

1414
class AsyncArrowBatchIterator:
@@ -58,6 +58,9 @@ class AsyncBytesIterator:
5858
5959
The class wraps a synchronous iterator and exposes it as an async iterator,
6060
enabling usage with `async for` syntax.
61+
62+
Note: This class blocks the event loop during I/O. For non-blocking streaming,
63+
use AsyncChunkedBytesIterator with pre-loaded data instead.
6164
"""
6265

6366
__slots__ = ("_sync_iter",)
@@ -89,6 +92,65 @@ async def __anext__(self) -> bytes:
8992
raise StopAsyncIteration from None
9093

9194

95+
class AsyncChunkedBytesIterator:
96+
"""Async iterator that yields pre-loaded bytes data in chunks.
97+
98+
This class implements the async iterator protocol without using async generators,
99+
allowing it to be compiled by mypyc (which doesn't support async generators).
100+
101+
Unlike AsyncBytesIterator, this class works with pre-loaded data and yields
102+
control to the event loop between chunks via asyncio.sleep(0), ensuring
103+
the event loop is not blocked during iteration.
104+
105+
Usage pattern:
106+
# Load data in thread pool to avoid blocking
107+
data = await asyncio.to_thread(read_bytes, path)
108+
# Stream chunks without blocking event loop
109+
return AsyncChunkedBytesIterator(data, chunk_size=65536)
110+
"""
111+
112+
__slots__ = ("_chunk_size", "_data", "_offset")
113+
114+
def __init__(self, data: bytes, chunk_size: int = 65536) -> None:
115+
"""Initialize the chunked bytes iterator.
116+
117+
Args:
118+
data: The bytes data to iterate over in chunks.
119+
chunk_size: Size of each chunk to yield (default: 65536 bytes).
120+
"""
121+
self._data = data
122+
self._chunk_size = chunk_size
123+
self._offset = 0
124+
125+
def __aiter__(self) -> "AsyncChunkedBytesIterator":
126+
"""Return self as the async iterator."""
127+
return self
128+
129+
async def __anext__(self) -> bytes:
130+
"""Get the next chunk of bytes asynchronously.
131+
132+
Yields control to the event loop via asyncio.sleep(0) before returning
133+
each chunk, ensuring other tasks can run during iteration.
134+
135+
Returns:
136+
The next chunk of bytes.
137+
138+
Raises:
139+
StopAsyncIteration: When all data has been yielded.
140+
"""
141+
import asyncio
142+
143+
if self._offset >= len(self._data):
144+
raise StopAsyncIteration
145+
146+
# Yield to event loop to allow other tasks to run
147+
await asyncio.sleep(0)
148+
149+
chunk = self._data[self._offset : self._offset + self._chunk_size]
150+
self._offset += self._chunk_size
151+
return chunk
152+
153+
92154
@mypyc_attr(allow_interpreted_subclasses=True)
93155
class ObjectStoreBase(ABC):
94156
"""Base class for storage backends."""

sqlspec/storage/backends/fsspec.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,18 @@ def __init__(self, uri: str, **kwargs: Any) -> None:
7777
uri: Filesystem URI (protocol://path).
7878
**kwargs: Additional fsspec configuration options, including an optional base_path.
7979
80-
For cloud URIs such as S3/GS/Azure, we derive a default base_path from the bucket/path when no explicit base_path is provided.
80+
For cloud URIs (S3/GS/Azure) and file:// URIs, we derive a default base_path from the
81+
URI path when no explicit base_path is provided. When both URI and base_path are provided,
82+
they are combined (base_path is appended to URI-derived path).
83+
84+
Examples:
85+
- FSSpecBackend("s3://bucket/prefix") -> base_path = "bucket/prefix"
86+
- FSSpecBackend("file:///home/user/storage") -> base_path = "/home/user/storage"
87+
- FSSpecBackend("file:///home/user", base_path="subdir") -> base_path = "/home/user/subdir"
8188
"""
8289
ensure_fsspec()
8390

84-
base_path = kwargs.pop("base_path", "")
91+
explicit_base_path = kwargs.pop("base_path", "")
8592

8693
if "://" in uri:
8794
self.protocol = uri.split("://", maxsplit=1)[0]
@@ -93,13 +100,24 @@ def __init__(self, uri: str, **kwargs: Any) -> None:
93100
uri_base_path = parsed.netloc
94101
if parsed.path and parsed.path != "/":
95102
uri_base_path = f"{uri_base_path}{parsed.path}"
96-
if not base_path:
97-
base_path = uri_base_path
103+
# Combine URI path with explicit base_path if both provided
104+
if explicit_base_path:
105+
uri_base_path = f"{uri_base_path.rstrip('/')}/{explicit_base_path.lstrip('/')}"
106+
explicit_base_path = uri_base_path
107+
elif self.protocol == "file":
108+
parsed = urlparse(uri)
109+
if parsed.path and parsed.path != "/":
110+
# For file protocol, keep the path as-is (preserve leading slash for absolute paths)
111+
uri_base_path = parsed.path
112+
# Combine URI path with explicit base_path if both provided
113+
if explicit_base_path:
114+
uri_base_path = f"{uri_base_path.rstrip('/')}/{explicit_base_path.lstrip('/')}"
115+
explicit_base_path = uri_base_path
98116
else:
99117
self.protocol = uri
100118
self._fs_uri = f"{uri}://"
101119

102-
self.base_path = base_path.rstrip("/") if base_path else ""
120+
self.base_path = explicit_base_path.rstrip("/") if explicit_base_path else ""
103121

104122
import fsspec
105123

@@ -453,10 +471,26 @@ async def write_bytes_async(self, path: "str | Path", data: bytes, **kwargs: Any
453471
async def stream_read_async(
454472
self, path: "str | Path", chunk_size: "int | None" = None, **kwargs: Any
455473
) -> AsyncIterator[bytes]:
456-
"""Stream bytes from storage asynchronously."""
457-
from sqlspec.storage.backends.base import AsyncBytesIterator
474+
"""Stream bytes from storage asynchronously.
475+
476+
Uses asyncio.to_thread() to run blocking I/O in a thread pool,
477+
ensuring the event loop is not blocked during read operations.
478+
479+
Args:
480+
path: Path to the file to read.
481+
chunk_size: Size of chunks to yield (default: 65536 bytes).
482+
**kwargs: Additional arguments passed to read_bytes.
483+
484+
Returns:
485+
AsyncIterator yielding chunks of bytes.
486+
"""
487+
import asyncio
488+
489+
from sqlspec.storage.backends.base import AsyncChunkedBytesIterator
458490

459-
return AsyncBytesIterator(self.stream_read(path, chunk_size, **kwargs))
491+
# Pass original path - read_bytes handles path resolution
492+
data = await asyncio.to_thread(self.read_bytes, path, **kwargs)
493+
return AsyncChunkedBytesIterator(data, chunk_size or 65536)
460494

461495
def stream_arrow_async(self, pattern: str, **kwargs: Any) -> AsyncIterator["ArrowRecordBatch"]:
462496
"""Stream Arrow record batches from storage asynchronously.

sqlspec/storage/backends/local.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,14 @@ def __init__(self, uri: str = "", **kwargs: Any) -> None:
7272
7373
Args:
7474
uri: File URI or path (e.g., "file:///path" or "/path")
75-
**kwargs: Additional options (base_path for relative operations)
76-
77-
The URI may be a file:// path (Windows style like file:///C:/path is supported),
78-
and an explicit base_path override will take precedence before we ensure the directory exists.
75+
**kwargs: Additional options including:
76+
- base_path: Subdirectory relative to URI path. If relative, it's combined
77+
with the URI path. If absolute, it takes precedence (backward compatible).
78+
79+
The URI may be a file:// path (Windows style like file:///C:/path is supported).
80+
When both URI and base_path are provided, they are combined:
81+
- file:///home/user/storage + base_path="subdir" -> /home/user/storage/subdir
82+
- file:///home/user/storage + base_path="/other" -> /other (absolute takes precedence)
7983
"""
8084
if uri.startswith("file://"):
8185
parsed = urlparse(uri)
@@ -89,7 +93,9 @@ def __init__(self, uri: str = "", **kwargs: Any) -> None:
8993
self.base_path = Path.cwd()
9094

9195
if "base_path" in kwargs:
92-
self.base_path = Path(kwargs["base_path"]).resolve()
96+
# Combine URI path with base_path (Path division handles absolute paths correctly)
97+
# If base_path is absolute, it takes precedence (backward compatible)
98+
self.base_path = (self.base_path / kwargs["base_path"]).resolve()
9399

94100
if not self.base_path.exists():
95101
self.base_path.mkdir(parents=True, exist_ok=True)
@@ -377,10 +383,27 @@ async def write_text_async(self, path: "str | Path", data: str, encoding: str =
377383
async def stream_read_async(
378384
self, path: "str | Path", chunk_size: "int | None" = None, **kwargs: Any
379385
) -> AsyncIterator[bytes]:
380-
"""Stream bytes from file asynchronously."""
381-
from sqlspec.storage.backends.base import AsyncBytesIterator
386+
"""Stream bytes from file asynchronously.
387+
388+
Uses asyncio.to_thread() to run blocking file I/O in a thread pool,
389+
ensuring the event loop is not blocked during read operations.
382390
383-
return AsyncBytesIterator(self.stream_read(path, chunk_size, **kwargs))
391+
Args:
392+
path: Path to the file to read.
393+
chunk_size: Size of chunks to yield (default: 65536 bytes).
394+
**kwargs: Additional arguments (unused).
395+
396+
Returns:
397+
AsyncIterator yielding chunks of bytes.
398+
"""
399+
import asyncio
400+
401+
from sqlspec.storage.backends.base import AsyncChunkedBytesIterator
402+
403+
resolved = self._resolve_path(path)
404+
# Run blocking I/O in thread pool to avoid blocking event loop
405+
data = await asyncio.to_thread(resolved.read_bytes)
406+
return AsyncChunkedBytesIterator(data, chunk_size or 65536)
384407

385408
async def list_objects_async(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> "list[str]":
386409
"""List objects asynchronously."""

sqlspec/storage/backends/obstore.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,19 @@ def __init__(self, uri: str, **kwargs: Any) -> None:
9090
"""Initialize obstore backend.
9191
9292
Args:
93-
uri: Storage URI (e.g., 's3://bucket', 'file:///path', 'gs://bucket')
94-
**kwargs: Additional options including base_path and obstore configuration
93+
uri: Storage URI. Supported formats:
94+
- file:///absolute/path - Local filesystem
95+
- s3://bucket/prefix - AWS S3
96+
- gs://bucket/prefix - Google Cloud Storage
97+
- az://container/prefix - Azure Blob Storage
98+
- memory:// - In-memory storage (for testing)
99+
**kwargs: Additional options:
100+
- base_path (str): For local files (file://), this is combined with
101+
the URI path to form the storage root. For example:
102+
uri="file:///data" + base_path="uploads" → /data/uploads
103+
If base_path is absolute, it overrides the URI path (backward compat).
104+
For cloud storage, base_path is used as an object key prefix.
105+
- Other obstore configuration options (timeouts, credentials, etc.)
95106
96107
"""
97108
ensure_obstore()
@@ -123,7 +134,9 @@ def __init__(self, uri: str, **kwargs: Any) -> None:
123134
if path_obj.is_file():
124135
path_str = str(path_obj.parent)
125136

126-
local_store_root = self.base_path or path_str
137+
# Combine URI path with base_path for correct storage location
138+
# If base_path is absolute, Path division will use it directly (backward compat)
139+
local_store_root = str(Path(path_str) / self.base_path) if self.base_path else path_str
127140

128141
self._is_local_store = True
129142
self._local_store_root = local_store_root
@@ -228,11 +241,14 @@ def write_text(self, path: "str | Path", data: str, encoding: str = "utf-8", **k
228241

229242
def list_objects(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> "list[str]": # pyright: ignore[reportUnusedParameter]
230243
"""List objects using obstore."""
231-
resolved_prefix = (
232-
resolve_storage_path(prefix, self.base_path, self.protocol, strip_file_scheme=True)
233-
if prefix
234-
else self.base_path or ""
235-
)
244+
# For LocalStore, the base_path is already included in the store root,
245+
# so we use empty prefix when none is given. For cloud stores, use base_path.
246+
if prefix:
247+
resolved_prefix = resolve_storage_path(prefix, self.base_path, self.protocol, strip_file_scheme=True)
248+
elif self._is_local_store:
249+
resolved_prefix = ""
250+
else:
251+
resolved_prefix = self.base_path or ""
236252
items = self.store.list_with_delimiter(resolved_prefix) if not recursive else self.store.list(resolved_prefix)
237253
paths = sorted(item["path"] for batch in items for item in batch)
238254
_log_storage_event(
@@ -629,37 +645,46 @@ async def write_bytes_async(self, path: "str | Path", data: bytes, **kwargs: Any
629645
async def stream_read_async(
630646
self, path: "str | Path", chunk_size: "int | None" = None, **kwargs: Any
631647
) -> AsyncIterator[bytes]:
632-
"""Stream bytes from storage asynchronously."""
648+
"""Stream bytes from storage asynchronously.
649+
650+
Uses asyncio.to_thread() to ensure the event loop is not blocked
651+
during I/O operations with cloud storage backends. This prevents
652+
heartbeat timeouts and allows concurrent async tasks to execute
653+
during large file downloads.
654+
"""
655+
import asyncio
656+
657+
from sqlspec.storage.backends.base import AsyncChunkedBytesIterator
658+
633659
if self._is_local_store:
634660
resolved_path = self._resolve_path_for_local_store(path)
635661
else:
636662
resolved_path = resolve_storage_path(path, self.base_path, self.protocol, strip_file_scheme=True)
637663

638-
result = await self.store.get_async(resolved_path)
639-
stream = result.stream()
664+
# Run blocking I/O in thread pool to avoid blocking event loop
665+
data = await asyncio.to_thread(self.read_bytes, resolved_path)
640666

641-
async def _generator() -> AsyncIterator[bytes]:
642-
async for chunk in stream:
643-
yield bytes(chunk)
644-
645-
_log_storage_event(
646-
"storage.read",
647-
backend_type=self.backend_type,
648-
protocol=self.protocol,
649-
operation="stream_read",
650-
mode="async",
651-
path=resolved_path,
652-
)
667+
_log_storage_event(
668+
"storage.read",
669+
backend_type=self.backend_type,
670+
protocol=self.protocol,
671+
operation="stream_read",
672+
mode="async",
673+
path=resolved_path,
674+
)
653675

654-
return _generator()
676+
return AsyncChunkedBytesIterator(data, chunk_size or 65536)
655677

656678
async def list_objects_async(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> "list[str]": # pyright: ignore[reportUnusedParameter]
657679
"""List objects in storage asynchronously."""
658-
resolved_prefix = (
659-
resolve_storage_path(prefix, self.base_path, self.protocol, strip_file_scheme=True)
660-
if prefix
661-
else self.base_path or ""
662-
)
680+
# For LocalStore, the base_path is already included in the store root,
681+
# so we use empty prefix when none is given. For cloud stores, use base_path.
682+
if prefix:
683+
resolved_prefix = resolve_storage_path(prefix, self.base_path, self.protocol, strip_file_scheme=True)
684+
elif self._is_local_store:
685+
resolved_prefix = ""
686+
else:
687+
resolved_prefix = self.base_path or ""
663688

664689
objects: list[str] = []
665690
async for batch in self.store.list_async(resolved_prefix): # pyright: ignore[reportAttributeAccessIssue]

0 commit comments

Comments
 (0)