Skip to content

Commit 825586b

Browse files
authored
Fix ZstdError import handling for Python 3.14+ compatibility (Lightning-AI#767)
- Import ZstdError from compression.zstd for Python >= 3.14 - Import ZstdError from zstd for older Python versions - Add test to verify ZstdError is properly caught for corrupted data Fixes Lightning-AI#766
1 parent 0770595 commit 825586b

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/litdata/raw/indexer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,24 +149,28 @@ def _load_index_file(self, index_path: str) -> Optional[list[FileMetadata]]:
149149
"""Loads and decodes an index file."""
150150
if _PYTHON_GREATER_EQUAL_3_14:
151151
from compression import zstd
152+
from compression.zstd import ZstdError
152153
else:
153154
import zstd
155+
from zstd import Error as ZstdError
154156

155157
try:
156158
with open(index_path, "rb") as f:
157159
compressed_data = f.read()
158160
metadata = json.loads(zstd.decompress(compressed_data).decode("utf-8"))
159161
return [FileMetadata.from_dict(file_data) for file_data in metadata["files"]]
160-
except (FileNotFoundError, json.JSONDecodeError, zstd.ZstdError, KeyError) as e:
162+
except (FileNotFoundError, json.JSONDecodeError, ZstdError, KeyError) as e:
161163
logger.warning(f"Failed to load index from local cache at `{index_path}`: {e}. ")
162164
return None
163165

164166
def _save_index_file(self, index_path: str, files: list[FileMetadata], source: str) -> None:
165167
"""Encodes and saves an index file."""
166168
if _PYTHON_GREATER_EQUAL_3_14:
167169
from compression import zstd
170+
from compression.zstd import ZstdError
168171
else:
169172
import zstd
173+
from zstd import Error as ZstdError
170174

171175
try:
172176
metadata = {
@@ -176,7 +180,7 @@ def _save_index_file(self, index_path: str, files: list[FileMetadata], source: s
176180
}
177181
with open(index_path, "wb") as f:
178182
f.write(zstd.compress(json.dumps(metadata).encode("utf-8")))
179-
except (OSError, zstd.ZstdError) as e:
183+
except (OSError, ZstdError) as e:
180184
logger.warning(f"Error caching index to {index_path}: {e}")
181185

182186
def _download_from_cloud(

tests/raw/test_indexer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,26 @@ def test_recompute_index_excludes_index_file(tmp_path):
315315
assert len(files) == 2
316316
for f in files:
317317
assert _INDEX_FILENAME not in f.path
318+
319+
320+
def test_load_index_file_handles_corrupted_zstd(tmp_path):
321+
"""Test that _load_index_file catches ZstdError for corrupted data."""
322+
if _PYTHON_GREATER_EQUAL_3_14:
323+
import compression.zstd as zstd
324+
from compression.zstd import ZstdError
325+
else:
326+
import zstd
327+
from zstd import Error as ZstdError
328+
329+
index_path = tmp_path / _INDEX_FILENAME
330+
331+
with open(index_path, "wb") as f:
332+
f.write(b"dummy data")
333+
334+
# Verify corrupted data raises ZstdError
335+
with pytest.raises(ZstdError):
336+
zstd.decompress(b"dummy data")
337+
338+
# Verify the indexer catches this and returns None
339+
indexer = FileIndexer()
340+
assert indexer._load_index_file(str(index_path)) is None

0 commit comments

Comments
 (0)