Skip to content

Commit c74b436

Browse files
authored
Add more tests to test_dataset and test_io (#594)
* Add more tests to test_dataset Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * add more read_custom tests Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * ruff Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> --------- Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> Signed-off-by: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com>
1 parent a4ce2de commit c74b436

File tree

3 files changed

+119
-1
lines changed

3 files changed

+119
-1
lines changed

nemo_curator/datasets/doc_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def read_custom( # noqa: PLR0913
187187
and read all files under the directory.
188188
If input_file is a list of strings, we assume each string is a file path.
189189
file_type: The type of the file to read.
190-
read_func_single_partition: A function that reads a single file or a list of files in an single dask partition.
190+
read_func_single_partition: A function that reads a single file or a list of files in an single Dask partition.
191191
The function should take the following arguments:
192192
- files: A list of file paths.
193193
- file_type: The type of the file to read (in case you want to handle different file types differently).
@@ -204,6 +204,7 @@ def read_custom( # noqa: PLR0913
204204
input_meta: A dictionary or a string formatted as a dictionary, which outlines
205205
the field names and their respective data types within the JSONL input file.
206206
"""
207+
207208
if isinstance(input_files, str):
208209
if input_files.endswith(file_type):
209210
files = [input_files]
@@ -218,9 +219,11 @@ def read_custom( # noqa: PLR0913
218219
else:
219220
msg = "input_files must be a string or list"
220221
raise TypeError(msg)
222+
221223
return cls(
222224
read_data(
223225
input_files=files,
226+
file_type=file_type,
224227
backend=backend,
225228
files_per_partition=files_per_partition,
226229
blocksize=None,

tests/test_dataset.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,88 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
from pathlib import Path
17+
1518
import pandas as pd
19+
import pytest
1620

1721
from nemo_curator.datasets import DocumentDataset
22+
from nemo_curator.datasets.doc_dataset import _read_json_or_parquet
1823

1924

2025
def test_to_from_pandas() -> None:
2126
original_df = pd.DataFrame({"first_col": [1, 2, 3], "second_col": ["a", "b", "c"]})
2227
dataset = DocumentDataset.from_pandas(original_df)
2328
converted_df = dataset.to_pandas()
2429
pd.testing.assert_frame_equal(original_df, converted_df)
30+
31+
32+
def test_persist() -> None:
33+
original_df = pd.DataFrame({"first_col": [1, 2, 3], "second_col": ["a", "b", "c"]})
34+
dataset = DocumentDataset.from_pandas(original_df)
35+
dataset.persist()
36+
37+
38+
def test_repartition() -> None:
39+
original_df = pd.DataFrame({"first_col": [1, 2, 3], "second_col": ["a", "b", "c"]})
40+
dataset = DocumentDataset.from_pandas(original_df)
41+
dataset = dataset.repartition(npartitions=3)
42+
assert dataset.df.npartitions == 3 # noqa: PLR2004
43+
44+
45+
def test_head() -> None:
46+
original_df = pd.DataFrame({"first_col": [1, 2, 3], "second_col": ["a", "b", "c"]})
47+
dataset = DocumentDataset.from_pandas(original_df)
48+
expected_df = pd.DataFrame({"first_col": [1, 2], "second_col": ["a", "b"]})
49+
pd.testing.assert_frame_equal(expected_df, dataset.head(2))
50+
51+
52+
def test_read_pickle(tmpdir: Path) -> None:
53+
original_df = pd.DataFrame({"first_col": [1, 2, 3], "second_col": ["a", "b", "c"]})
54+
output_file = str(tmpdir / "output.pkl")
55+
original_df.to_pickle(output_file)
56+
dataset = DocumentDataset.read_pickle(output_file)
57+
pd.testing.assert_frame_equal(original_df, dataset.df.compute())
58+
59+
60+
def test_to_pickle(tmpdir: Path) -> None:
61+
original_df = pd.DataFrame({"first_col": [1, 2, 3], "second_col": ["a", "b", "c"]})
62+
dataset = DocumentDataset.from_pandas(original_df)
63+
64+
output_file = str(tmpdir / "output.pkl")
65+
with pytest.raises(NotImplementedError):
66+
dataset.to_pickle(output_file)
67+
68+
69+
def test_read_json_or_parquet(tmpdir: Path) -> None:
70+
original_df = pd.DataFrame({"first_col": [1, 2, 3], "second_col": ["a", "b", "c"]})
71+
72+
directory_1 = str(tmpdir / "directory_1")
73+
directory_2 = str(tmpdir / "directory_2")
74+
os.makedirs(directory_1, exist_ok=True)
75+
os.makedirs(directory_2, exist_ok=True)
76+
77+
file_1 = directory_1 + "/file_1.jsonl"
78+
file_2 = directory_2 + "/file_2.jsonl"
79+
original_df.to_json(file_1, orient="records", lines=True)
80+
original_df.to_json(file_2, orient="records", lines=True)
81+
82+
# List of directories
83+
data = _read_json_or_parquet(
84+
input_files=[directory_1, directory_2],
85+
file_type="jsonl",
86+
backend="pandas",
87+
files_per_partition=1,
88+
)
89+
assert len(data) == 6 # noqa: PLR2004
90+
91+
file_series = pd.Series([file_1, file_2])
92+
# Non string or list input
93+
with pytest.raises(TypeError):
94+
data = _read_json_or_parquet(
95+
input_files=file_series,
96+
file_type="jsonl",
97+
backend="pandas",
98+
files_per_partition=1,
99+
)

tests/test_io.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def read_npy_file(files: list[str], backend: Literal["cudf", "pandas"], **kwargs
157157
[{**json.loads(pickle.load(open(file, "rb")))} for file in files], # noqa: S301
158158
)
159159

160+
# Directory
160161
dataset = DocumentDataset.read_custom(
161162
input_files=tmp_dir,
162163
file_type="pkl",
@@ -172,6 +173,45 @@ def read_npy_file(files: list[str], backend: Literal["cudf", "pandas"], **kwargs
172173
), # because we sort columns by name
173174
)
174175

176+
def test_read_custom_input_files(self, tmp_path: Path) -> None:
177+
# Prepare files
178+
df = pd.DataFrame({"id": [1, 2, 3], "text": ["a", "b", "c"]})
179+
file_1 = str(tmp_path / "test_file_1.jsonl")
180+
file_2 = str(tmp_path / "test_file_2.jsonl")
181+
df.to_json(file_1, orient="records", lines=True)
182+
df.to_json(file_2, orient="records", lines=True)
183+
184+
def read_jsonl(files: list[str], **kwargs) -> pd.DataFrame: # noqa: ARG001
185+
return pd.concat([pd.read_json(f, lines=True) for f in files], ignore_index=True)
186+
187+
# Single file
188+
dataset = DocumentDataset.read_custom(
189+
input_files=file_1,
190+
file_type="jsonl",
191+
read_func_single_partition=read_jsonl,
192+
files_per_partition=1,
193+
)
194+
assert dataset.df.compute().equals(df)
195+
196+
# List of files
197+
dataset = DocumentDataset.read_custom(
198+
input_files=[file_1, file_2],
199+
file_type="jsonl",
200+
read_func_single_partition=read_jsonl,
201+
files_per_partition=1,
202+
)
203+
assert len(dataset.df) == 6 # noqa: PLR2004
204+
205+
file_series = pd.Series([file_1, file_2])
206+
# Non string or list input
207+
with pytest.raises(TypeError):
208+
dataset = DocumentDataset.read_custom(
209+
input_files=file_series,
210+
file_type="jsonl",
211+
read_func_single_partition=read_jsonl,
212+
files_per_partition=1,
213+
)
214+
175215

176216
class TestWriteWithFilename:
177217
@pytest.mark.parametrize("keep_filename_column", [True, False])

0 commit comments

Comments
 (0)