Skip to content

Commit 5880fab

Browse files
huvunvidiaHuy Vu2
andauthored
Collision warning code (#1325)
Co-authored-by: Huy Vu2 <[email protected]>
1 parent d136838 commit 5880fab

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

nemo_curator/stages/image/io/image_writer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ def _write_tar(self, base_name: str, members: list[tuple[str, bytes]]) -> str:
107107

108108
# Assert to prevent accidental overwrite if a file with the same name already exists
109109
if os.path.exists(tar_path):
110-
err = f"Collision detected: refusing to overwrite existing tar file: {tar_path}"
111-
raise AssertionError(err)
110+
logger.warning(f"File {tar_path} already exists. Overwriting it.")
112111

113112
with open(tar_path, "wb") as fobj, tarfile.open(fileobj=fobj, mode="w") as tf:
114113
for member_name, payload in members:
@@ -130,8 +129,7 @@ def _write_parquet(self, base_name: str, rows: list[dict[str, Any]]) -> str:
130129

131130
# Assert to prevent accidental overwrite if a file with the same name already exists
132131
if os.path.exists(parquet_path):
133-
err = f"Collision detected: refusing to overwrite existing parquet file: {parquet_path}"
134-
raise AssertionError(err)
132+
logger.warning(f"File {parquet_path} already exists. Overwriting it.")
135133

136134
# Convert rows to Arrow Table (assumes uniform keys across rows)
137135
table = pa.Table.from_pylist(rows)

tests/stages/image/io/test_image_writer.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,14 @@ def save(self, buffer: io.BytesIO, image_format: str | None = None, **kwargs) ->
237237
assert dtype == np.uint8
238238

239239

240-
def test_write_tar_collision_and_content(tmp_path: pathlib.Path) -> None:
241-
_module, image_writer_stage_cls = _import_writer_with_stubbed_pyarrow()
240+
def test_write_tar_collision_and_content(monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path) -> None:
241+
module, image_writer_stage_cls = _import_writer_with_stubbed_pyarrow()
242242
stage = image_writer_stage_cls(output_dir=str(tmp_path))
243243

244+
# Mock logger to verify warning
245+
warnings: list[str] = []
246+
monkeypatch.setattr(module.logger, "warning", lambda msg: warnings.append(msg))
247+
244248
base = "images-abc123"
245249
path = stage._write_tar(base, [("a.jpg", b"a"), ("b.jpg", b"bb")])
246250
assert pathlib.Path(path).exists()
@@ -249,8 +253,18 @@ def test_write_tar_collision_and_content(tmp_path: pathlib.Path) -> None:
249253
names = {m.name for m in tf.getmembers()}
250254
assert names == {"a.jpg", "b.jpg"}
251255

252-
with pytest.raises(AssertionError, match="Collision detected"):
253-
stage._write_tar(base, [("c.jpg", b"c")])
256+
# Overwrite existing file
257+
path2 = stage._write_tar(base, [("c.jpg", b"c")])
258+
assert path2 == path
259+
assert pathlib.Path(path).exists()
260+
261+
# Verify warning was logged
262+
assert any("already exists" in w for w in warnings)
263+
264+
# Verify content was overwritten
265+
with tarfile.open(path, "r") as tf:
266+
names = {m.name for m in tf.getmembers()}
267+
assert names == {"c.jpg"}
254268

255269

256270
def test_write_parquet_collision_and_path(monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path) -> None:
@@ -261,6 +275,10 @@ def test_write_parquet_collision_and_path(monkeypatch: pytest.MonkeyPatch, tmp_p
261275

262276
module = importlib.import_module("nemo_curator.stages.image.io.image_writer")
263277

278+
# Mock logger to verify warning
279+
warnings: list[str] = []
280+
monkeypatch.setattr(module.logger, "warning", lambda msg: warnings.append(msg))
281+
264282
# Patch module-local pyarrow bindings
265283
stub_pa = types.SimpleNamespace(Table=types.SimpleNamespace(from_pylist=lambda rows: ("T", rows)))
266284
stub_pq = types.SimpleNamespace(write_table=lambda tbl, p: written.update({"tbl": tbl, "path": p}))
@@ -272,14 +290,21 @@ def test_write_parquet_collision_and_path(monkeypatch: pytest.MonkeyPatch, tmp_p
272290
base = "images-parq"
273291
# Pre-create to trigger collision
274292
(tmp_path / f"{base}.parquet").write_bytes(b"")
275-
with pytest.raises(AssertionError, match="Collision detected"):
276-
stage._write_parquet(base, [{"k": 1}])
277293

278-
# Remove and write successfully
294+
# Should overwrite and warn
295+
out_path = stage._write_parquet(base, [{"k": 1}])
296+
297+
assert any("already exists" in w for w in warnings)
298+
assert out_path == str(tmp_path / f"{base}.parquet")
299+
# Verify write was called
300+
assert written.get("path") == out_path
301+
302+
# Clean up and write again (no collision)
303+
warnings.clear()
279304
(tmp_path / f"{base}.parquet").unlink()
280305
out_path = stage._write_parquet(base, [{"k": 2}])
281306
assert out_path == str(tmp_path / f"{base}.parquet")
282-
assert written.get("path") == out_path
307+
assert not warnings
283308

284309

285310
@pytest.mark.parametrize("remove", [True, False])

0 commit comments

Comments
 (0)