@@ -237,10 +237,14 @@ def save(self, buffer: io.BytesIO, image_format: str | None = None, **kwargs) ->
237237 assert dtype == np .uint8
238238
239239
240- def test_write_tar_collision_and_content (tmp_path : pathlib .Path ) -> None :
241- _module , image_writer_stage_cls = _import_writer_with_stubbed_pyarrow ()
240+ def test_write_tar_collision_and_content (monkeypatch : pytest . MonkeyPatch , tmp_path : pathlib .Path ) -> None :
241+ module , image_writer_stage_cls = _import_writer_with_stubbed_pyarrow ()
242242 stage = image_writer_stage_cls (output_dir = str (tmp_path ))
243243
244+ # Mock logger to verify warning
245+ warnings : list [str ] = []
246+ monkeypatch .setattr (module .logger , "warning" , lambda msg : warnings .append (msg ))
247+
244248 base = "images-abc123"
245249 path = stage ._write_tar (base , [("a.jpg" , b"a" ), ("b.jpg" , b"bb" )])
246250 assert pathlib .Path (path ).exists ()
@@ -249,8 +253,18 @@ def test_write_tar_collision_and_content(tmp_path: pathlib.Path) -> None:
249253 names = {m .name for m in tf .getmembers ()}
250254 assert names == {"a.jpg" , "b.jpg" }
251255
252- with pytest .raises (AssertionError , match = "Collision detected" ):
253- stage ._write_tar (base , [("c.jpg" , b"c" )])
256+ # Overwrite existing file
257+ path2 = stage ._write_tar (base , [("c.jpg" , b"c" )])
258+ assert path2 == path
259+ assert pathlib .Path (path ).exists ()
260+
261+ # Verify warning was logged
262+ assert any ("already exists" in w for w in warnings )
263+
264+ # Verify content was overwritten
265+ with tarfile .open (path , "r" ) as tf :
266+ names = {m .name for m in tf .getmembers ()}
267+ assert names == {"c.jpg" }
254268
255269
256270def test_write_parquet_collision_and_path (monkeypatch : pytest .MonkeyPatch , tmp_path : pathlib .Path ) -> None :
@@ -261,6 +275,10 @@ def test_write_parquet_collision_and_path(monkeypatch: pytest.MonkeyPatch, tmp_p
261275
262276 module = importlib .import_module ("nemo_curator.stages.image.io.image_writer" )
263277
278+ # Mock logger to verify warning
279+ warnings : list [str ] = []
280+ monkeypatch .setattr (module .logger , "warning" , lambda msg : warnings .append (msg ))
281+
264282 # Patch module-local pyarrow bindings
265283 stub_pa = types .SimpleNamespace (Table = types .SimpleNamespace (from_pylist = lambda rows : ("T" , rows )))
266284 stub_pq = types .SimpleNamespace (write_table = lambda tbl , p : written .update ({"tbl" : tbl , "path" : p }))
@@ -272,14 +290,21 @@ def test_write_parquet_collision_and_path(monkeypatch: pytest.MonkeyPatch, tmp_p
272290 base = "images-parq"
273291 # Pre-create to trigger collision
274292 (tmp_path / f"{ base } .parquet" ).write_bytes (b"" )
275- with pytest .raises (AssertionError , match = "Collision detected" ):
276- stage ._write_parquet (base , [{"k" : 1 }])
277293
278- # Remove and write successfully
294+ # Should overwrite and warn
295+ out_path = stage ._write_parquet (base , [{"k" : 1 }])
296+
297+ assert any ("already exists" in w for w in warnings )
298+ assert out_path == str (tmp_path / f"{ base } .parquet" )
299+ # Verify write was called
300+ assert written .get ("path" ) == out_path
301+
302+ # Clean up and write again (no collision)
303+ warnings .clear ()
279304 (tmp_path / f"{ base } .parquet" ).unlink ()
280305 out_path = stage ._write_parquet (base , [{"k" : 2 }])
281306 assert out_path == str (tmp_path / f"{ base } .parquet" )
282- assert written . get ( "path" ) == out_path
307+ assert not warnings
283308
284309
285310@pytest .mark .parametrize ("remove" , [True , False ])
0 commit comments