Skip to content

Commit 8392acd

Browse files
Add tests for cloud I/O changes (#1257)
1 parent d377992 commit 8392acd

File tree

6 files changed

+230
-57
lines changed

6 files changed

+230
-57
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ test = [
167167
"pytest-cov",
168168
"pytest-loguru",
169169
"scikit-learn",
170+
"s3fs", # added for testing cloud fs
170171
]
171172

172173
[tool.uv]

tests/stages/deduplication/semantic/test_pairwise_io.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# ruff: noqa: E402
1616
from pathlib import Path
17+
from unittest.mock import Mock
1718

1819
import pytest
1920

@@ -28,6 +29,21 @@
2829
class TestClusterWiseFilePartitioningStage:
2930
"""Test cases for ClusterWiseFilePartitioningStage."""
3031

32+
def test_setup(self):
33+
# Test fs and path_normalizer are set correctly
34+
stage = ClusterWiseFilePartitioningStage("s3://test-bucket/test-path")
35+
stage.setup()
36+
assert stage.fs is not None
37+
assert stage.path_normalizer is not None
38+
assert stage.path_normalizer("test-bucket/test-path") == "s3://test-bucket/test-path"
39+
40+
# Test for local filesystem
41+
stage = ClusterWiseFilePartitioningStage("/test/path")
42+
stage.setup()
43+
assert stage.fs is not None
44+
assert stage.path_normalizer is not None
45+
assert stage.path_normalizer("/test/path") == "/test/path"
46+
3147
def test_process_finds_all_centroid_files(self, tmp_path: Path):
3248
"""Test that process method finds all files in centroid directories."""
3349

@@ -58,9 +74,22 @@ def test_process_finds_all_centroid_files(self, tmp_path: Path):
5874
stage = ClusterWiseFilePartitioningStage(str(tmp_path))
5975
stage.setup()
6076

77+
# Mock path_normalizer to track calls and verify it's used correctly
78+
# For local filesystem, path_normalizer is lambda x: x, so mock should return input
79+
mock_path_normalizer = Mock(side_effect=lambda x: x)
80+
stage.path_normalizer = mock_path_normalizer
81+
6182
empty_task = _EmptyTask(task_id="test", dataset_name="test", data=None)
6283
result = stage.process(empty_task)
6384

85+
# Verify path_normalizer was called exactly 3 times (once per centroid directory)
86+
assert mock_path_normalizer.call_count == 3
87+
88+
# Verify it was called with centroid directory paths
89+
# fs.ls() returns entries that contain "centroid="
90+
call_args = [call[0][0] for call in mock_path_normalizer.call_args_list]
91+
assert all("centroid=" in str(arg) for arg in call_args)
92+
6493
# Should create 3 FileGroupTasks for 3 centroids
6594
assert len(result) == 3
6695
assert all(isinstance(task, FileGroupTask) for task in result)

tests/stages/text/io/reader/test_parquet.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from pathlib import Path
16+
from unittest.mock import patch
1617

1718
import pandas as pd
1819
import pyarrow as pa
@@ -77,13 +78,29 @@ def test_parquet_reader_stage_pandas_reads_and_concatenates(sample_parquet_files
7778
task = _make_file_group_task(sample_parquet_files[:2])
7879
stage = ParquetReaderStage(fields=None)
7980

80-
out = stage.process(task)
81-
assert isinstance(out, DocumentBatch)
81+
# Track calls to pd.read_parquet and pd.concat using mock.patch with wraps
82+
with (
83+
patch(
84+
"nemo_curator.stages.text.io.reader.parquet.pd.read_parquet", wraps=pd.read_parquet
85+
) as mock_read_parquet,
86+
patch("nemo_curator.stages.text.io.reader.parquet.pd.concat", wraps=pd.concat) as mock_concat,
87+
):
88+
out = stage.process(task)
89+
assert isinstance(out, DocumentBatch)
90+
91+
df = out.to_pandas()
92+
assert isinstance(df, pd.DataFrame)
93+
assert len(df) == 4 # 2 files * 2 records each = 4 records
94+
assert {"text", "category", "score"}.issubset(set(df.columns))
95+
96+
# Verify pd.read_parquet was called once per file
97+
assert mock_read_parquet.call_count == 2
98+
assert mock_read_parquet.call_args_list[0][0][0] == sample_parquet_files[0]
99+
assert mock_read_parquet.call_args_list[1][0][0] == sample_parquet_files[1]
82100

83-
df = out.to_pandas()
84-
assert isinstance(df, pd.DataFrame)
85-
assert len(df) == 4 # 2 files * 2 records each = 4 records
86-
assert {"text", "category", "score"}.issubset(set(df.columns))
101+
# Verify pd.concat was called once with ignore_index=True
102+
assert mock_concat.call_count == 1
103+
assert mock_concat.call_args[1].get("ignore_index") is True
87104

88105

89106
class TestParquetReaderStorageOptionsAndColumns:
@@ -150,11 +167,29 @@ def test_parquet_reader_stage_pyarrow_reads_and_concatenates(tmp_path: Path):
150167
task = _make_file_group_task([str(f1), str(f2)])
151168
stage = ParquetReaderStage(read_kwargs={"engine": "pyarrow"}, fields=None)
152169

153-
out = stage.process(task)
154-
table = out.to_pyarrow()
155-
assert isinstance(table, pa.Table)
156-
assert table.num_rows == 3
157-
assert {"text", "category", "score"}.issubset(set(table.column_names))
170+
# Track calls to pd.read_parquet and pd.concat using mock.patch with wraps
171+
with (
172+
patch(
173+
"nemo_curator.stages.text.io.reader.parquet.pd.read_parquet", wraps=pd.read_parquet
174+
) as mock_read_parquet,
175+
patch("nemo_curator.stages.text.io.reader.parquet.pd.concat", wraps=pd.concat) as mock_concat,
176+
):
177+
out = stage.process(task)
178+
table = out.to_pyarrow()
179+
assert isinstance(table, pa.Table)
180+
assert table.num_rows == 3
181+
assert {"text", "category", "score"}.issubset(set(table.column_names))
182+
183+
# Verify pd.read_parquet was called once per file
184+
assert mock_read_parquet.call_count == 2
185+
assert mock_read_parquet.call_args_list[0][0][0] == str(f1)
186+
assert mock_read_parquet.call_args_list[1][0][0] == str(f2)
187+
# Verify engine was passed correctly
188+
assert mock_read_parquet.call_args_list[0][1].get("engine") == "pyarrow"
189+
190+
# Verify pd.concat was called once with ignore_index=True
191+
assert mock_concat.call_count == 1
192+
assert mock_concat.call_args[1].get("ignore_index") is True
158193

159194

160195
def test_parquet_reader_stage_pyarrow_errors_when_some_columns_missing(tmp_path: Path):

tests/stages/text/io/writer/test_jsonl.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import nemo_curator.stages.text.io.writer.utils as writer_utils
2424
from nemo_curator.stages.text.io.writer import JsonlWriter
25+
from nemo_curator.stages.text.io.writer import base as writer_base
2526
from nemo_curator.tasks import DocumentBatch
2627

2728

@@ -266,3 +267,26 @@ def test_jsonl_writer_overwrites_existing_file(
266267
pd.testing.assert_frame_equal(
267268
pd.read_json(result1.data[0], lines=True), pd.read_json(result2.data[0], lines=True)
268269
)
270+
271+
@pytest.mark.parametrize(
272+
"path",
273+
[
274+
"s3://test-bucket/output",
275+
"/local/path",
276+
],
277+
)
278+
def test_jsonl_writer_write_data_path_protocol_handling(self, pandas_document_batch: DocumentBatch, path: str):
279+
"""Test that write_data is called with correct protocol handling for cloud and local paths."""
280+
with mock.patch.object(writer_base, "check_output_mode", return_value=None):
281+
writer = JsonlWriter(path=path)
282+
writer.setup()
283+
284+
with (
285+
mock.patch.object(writer.fs, "exists", return_value=False),
286+
mock.patch.object(writer, "write_data") as mock_write_data,
287+
):
288+
writer.process(pandas_document_batch)
289+
290+
mock_write_data.assert_called_once()
291+
file_path = mock_write_data.call_args[0][1]
292+
assert file_path.startswith(path), f"Path should start with {path}"

tests/stages/text/io/writer/test_parquet.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222

2323
from nemo_curator.stages.text.io.writer import ParquetWriter
24+
from nemo_curator.stages.text.io.writer import base as writer_base
2425
from nemo_curator.stages.text.io.writer import utils as writer_utils
2526
from nemo_curator.tasks import DocumentBatch
2627

@@ -251,3 +252,26 @@ def test_jsonl_writer_overwrites_existing_file(
251252
)
252253

253254
pd.testing.assert_frame_equal(pd.read_parquet(result1.data[0]), pd.read_parquet(result2.data[0]))
255+
256+
@pytest.mark.parametrize(
257+
"path",
258+
[
259+
"s3://test-bucket/output",
260+
"/local/path",
261+
],
262+
)
263+
def test_parquet_writer_write_data_path_protocol_handling(self, pandas_document_batch: DocumentBatch, path: str):
264+
"""Test that write_data is called with correct protocol handling for cloud and local paths."""
265+
with mock.patch.object(writer_base, "check_output_mode", return_value=None):
266+
writer = ParquetWriter(path=path)
267+
writer.setup()
268+
269+
with (
270+
mock.patch.object(writer.fs, "exists", return_value=False),
271+
mock.patch.object(writer, "write_data") as mock_write_data,
272+
):
273+
writer.process(pandas_document_batch)
274+
275+
mock_write_data.assert_called_once()
276+
file_path = mock_write_data.call_args[0][1]
277+
assert file_path.startswith(path), f"Path should start with {path}"

0 commit comments

Comments
 (0)