Skip to content

Commit 4f89bbd

Browse files
huvunvidiaHuy Vu2suiyoubi
authored
Nemotron-CC SDG section pipelines (#1268)
* add pipeline * workable all high quality stages, with NIM LLM * fix ruff, lint * fix ruff * fix ruff * fix ruff * update output * workable with 10k generation records (num_input_tasks = 1000) * workable with 100k generation records (num_input_tasks = 10,000) * restore mock data nemotron_cc_sdg_high_quality_example_pipeline.py to num_input_tasks = 100 * Add NemotronCC Low Quality Pipeline (#1273) * Add DocumentJoiner and DocumentSplitter modules with corresponding tests - Implemented DocumentJoiner for merging document segments based on a common ID. - Implemented DocumentSplitter for dividing documents into segments using a specified separator. - Updated __init__.py to include new modules in the text processing pipeline. - Added unit tests for DocumentJoiner and DocumentSplitter to ensure functionality and edge case handling. - Enhanced the preprocessing pipeline to utilize DocumentSplitter and DocumentJoiner for better document management. Signed-off-by: Ao Tang <[email protected]> * enable pre/postprocessing in low_quality_example Signed-off-by: Ao Tang <[email protected]> * Add unit tests for DocumentJoiner and DocumentSplitter - Introduced comprehensive test cases for DocumentJoiner, covering basic joining, custom separators, and handling of segment IDs. - Added tests for DocumentSplitter, including basic splitting, custom text fields, and metadata preservation. - Ensured validation of input parameters and handling of edge cases in both modules. - Removed outdated example script demonstrating usage of DocumentSplitter and DocumentJoiner. Signed-off-by: Ao Tang <[email protected]> * ruff fix Signed-off-by: Ao Tang <[email protected]> * ruff Signed-off-by: Ao Tang <[email protected]> * revert Signed-off-by: Ao Tang <[email protected]> * PR comment resolved Signed-off-by: Ao Tang <[email protected]> * rename Signed-off-by: Ao Tang <[email protected]> * Add original file back Signed-off-by: Ao Tang <[email protected]> * fix Signed-off-by: Ao Tang <[email protected]> * do not use index internal representation Signed-off-by: Ao Tang <[email protected]> * clarify in doc Signed-off-by: Ao Tang <[email protected]> * align with high_quality_example_pipeline Signed-off-by: Ao Tang <[email protected]> --------- Signed-off-by: Ao Tang <[email protected]> * add 'id' to data.parquet and mock data of nemotron_cc_sdg_high_quality_example_pipeline.py * remove tutorials/synthetic/nemotron_cc/nemotron_cc_sdg_low_quality_example.py and nemotron_cc_sdg_high_quality_example.py * fix small errrs * address commments --------- Signed-off-by: Ao Tang <[email protected]> Co-authored-by: Huy Vu2 <[email protected]> Co-authored-by: Ao Tang <[email protected]>
1 parent 33fc84f commit 4f89bbd

File tree

11 files changed

+2463
-571
lines changed

11 files changed

+2463
-571
lines changed

nemo_curator/stages/text/modules/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@
1313
# limitations under the License.
1414

1515
from .add_id import AddId
16+
from .joiner import DocumentJoiner
1617
from .modifier import Modify
1718
from .score_filter import Filter, Score, ScoreFilter
19+
from .splitter import DocumentSplitter
1820

1921
__all__ = [
2022
"AddId",
23+
"DocumentJoiner",
24+
"DocumentSplitter",
2125
"Filter",
2226
"Modify",
2327
"Score",
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
17+
import pandas as pd
18+
19+
from nemo_curator.stages.base import ProcessingStage
20+
from nemo_curator.tasks import DocumentBatch
21+
22+
23+
@dataclass
24+
class DocumentJoiner(ProcessingStage[DocumentBatch, DocumentBatch]):
25+
"""
26+
Joins documents that have a common id back into a single document.
27+
The order of the documents is dictated by an additional segment_id column.
28+
A maximum length can be specified to limit the size of the joined documents.
29+
30+
The joined documents are joined by a separator.
31+
32+
This stage performs the inverse operation of DocumentSplitter, allowing you
33+
to reconstruct documents from their segments.
34+
35+
Important:
36+
This stage assumes that all segments belonging to the same document are
37+
contained within a single DocumentBatch. Segments from the same document
38+
split across multiple batches will NOT be joined together. Ensure your
39+
batching logic keeps all segments of a document together.
40+
41+
Example:
42+
If you have segments with document_id=1, segment_id=[0,1] and text=["Hello", "World"],
43+
they will be joined into a single row with document_id=1 and text="Hello\\n\\nWorld"
44+
(assuming separator="\\n\\n").
45+
46+
Args:
47+
separator (str): The separator to join the documents on.
48+
text_field (str): The name of the column containing the text to join.
49+
Defaults to "text".
50+
segment_id_field (str): The name of the column containing the segment id.
51+
Defaults to "segment_id".
52+
document_id_field (str): The name of the column containing the document id.
53+
Defaults to "id".
54+
drop_segment_id_field (bool): Whether to drop the segment_id_field after joining.
55+
Defaults to True.
56+
max_length (int, optional): The maximum length of the joined documents.
57+
Both max_length and length_field must be specified or neither can be specified.
58+
length_field (str, optional): The name of the column containing the length of the documents.
59+
Both max_length and length_field must be specified or neither can be specified.
60+
"""
61+
62+
separator: str = "\n\n"
63+
text_field: str = "text"
64+
segment_id_field: str = "segment_id"
65+
document_id_field: str = "id"
66+
drop_segment_id_field: bool = True
67+
max_length: int | None = None
68+
length_field: str | None = None
69+
name: str = "document_joiner"
70+
71+
def __post_init__(self):
72+
if self.max_length is not None and self.length_field is None:
73+
msg = "max_length is specified but length_field is not"
74+
raise ValueError(msg)
75+
if self.max_length is None and self.length_field is not None:
76+
msg = "length_field is specified but max_length is not"
77+
raise ValueError(msg)
78+
79+
def inputs(self) -> tuple[list[str], list[str]]:
80+
"""Define stage input requirements."""
81+
required_cols = [self.text_field, self.segment_id_field, self.document_id_field]
82+
if self.length_field is not None:
83+
required_cols.append(self.length_field)
84+
return ["data"], required_cols
85+
86+
def outputs(self) -> tuple[list[str], list[str]]:
87+
"""Define stage output specification."""
88+
output_cols = [self.text_field, self.document_id_field]
89+
if not self.drop_segment_id_field:
90+
output_cols.append(self.segment_id_field)
91+
if self.length_field is not None:
92+
output_cols.append(self.length_field)
93+
return ["data"], output_cols
94+
95+
def _join_segments(self, group: pd.DataFrame) -> pd.DataFrame:
96+
"""Join segments with max_length constraint."""
97+
# Ensure segments are processed in order.
98+
group = group.sort_values(self.segment_id_field)
99+
joined_rows = []
100+
current_seg_id = 0
101+
accumulator_text = None
102+
accumulator_length = 0
103+
accumulator_row = None
104+
105+
for _, row in group.iterrows():
106+
if accumulator_row is None:
107+
# Start a new accumulation with the first segment.
108+
accumulator_text = row[self.text_field]
109+
accumulator_length = row[self.length_field]
110+
accumulator_row = row.copy()
111+
else:
112+
# Calculate what the new length would be if we joined this segment.
113+
proposed_length = accumulator_length + row[self.length_field] + len(self.separator)
114+
if proposed_length <= self.max_length:
115+
accumulator_text = accumulator_text + self.separator + row[self.text_field]
116+
accumulator_length = proposed_length
117+
else:
118+
# Commit the current accumulation as one joined segment.
119+
new_row = accumulator_row.copy()
120+
new_row[self.text_field] = accumulator_text
121+
new_row[self.length_field] = accumulator_length
122+
new_row[self.segment_id_field] = current_seg_id
123+
joined_rows.append(new_row)
124+
current_seg_id += 1
125+
# Start a new accumulation with the current row.
126+
accumulator_text = row[self.text_field]
127+
accumulator_length = row[self.length_field]
128+
accumulator_row = row.copy()
129+
130+
# Commit the last accumulated segment.
131+
if accumulator_row is not None:
132+
new_row = accumulator_row.copy()
133+
new_row[self.text_field] = accumulator_text
134+
new_row[self.length_field] = accumulator_length
135+
new_row[self.segment_id_field] = current_seg_id
136+
joined_rows.append(new_row)
137+
138+
if joined_rows:
139+
return pd.DataFrame(joined_rows)
140+
else:
141+
return pd.DataFrame(columns=group.columns)
142+
143+
def process(self, batch: DocumentBatch) -> DocumentBatch:
144+
"""
145+
Joins the documents back into a single document while preserving all the original fields.
146+
147+
Args:
148+
batch (DocumentBatch): Input batch to process
149+
150+
Returns:
151+
DocumentBatch: Batch with documents joined by document_id
152+
"""
153+
df = batch.to_pandas()
154+
155+
if df.empty:
156+
return batch
157+
158+
if self.max_length is None:
159+
# Sort the segments by the segment_id_field to maintain proper order before aggregating.
160+
df_sorted = df.sort_values(self.segment_id_field)
161+
162+
# Build aggregation functions to preserve all original columns:
163+
# - For self.text_field, join all segments using the separator.
164+
# - For all other columns (except self.document_id_field, which is our grouping key), take the first occurrence.
165+
agg_funcs = {}
166+
for col in df_sorted.columns:
167+
if col == self.text_field:
168+
agg_funcs[col] = lambda texts: self.separator.join(texts.astype(str))
169+
elif col != self.document_id_field:
170+
agg_funcs[col] = "first"
171+
172+
# Group by document_id_field while keeping the key as a column.
173+
joined = df_sorted.groupby(self.document_id_field, as_index=False).agg(agg_funcs)
174+
else:
175+
# Use the more complex joining logic with max_length constraint
176+
joined_groups = []
177+
for _doc_id, group in df.groupby(self.document_id_field):
178+
joined_group = self._join_segments(group)
179+
joined_groups.append(joined_group)
180+
181+
joined = pd.concat(joined_groups, ignore_index=True) if joined_groups else pd.DataFrame(columns=df.columns)
182+
183+
if self.drop_segment_id_field and self.segment_id_field in joined.columns:
184+
joined = joined.drop(columns=self.segment_id_field)
185+
186+
return DocumentBatch(
187+
task_id=f"{batch.task_id}_{self.name}",
188+
dataset_name=batch.dataset_name,
189+
data=joined,
190+
_metadata=batch._metadata,
191+
_stage_perf=batch._stage_perf,
192+
)
193+
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
17+
from nemo_curator.stages.base import ProcessingStage
18+
from nemo_curator.tasks import DocumentBatch
19+
20+
21+
@dataclass
22+
class DocumentSplitter(ProcessingStage[DocumentBatch, DocumentBatch]):
23+
"""
24+
Splits documents into segments based on a separator.
25+
Each segment becomes a new row within the batch with an additional column
26+
indicating the segment id.
27+
28+
To restore the original document, ensure that each document
29+
has a unique id prior to splitting.
30+
31+
Example:
32+
If a document has text="Hello\\n\\nWorld", and separator="\\n\\n",
33+
it will be split into two rows: one with text="Hello" and segment_id=0,
34+
and another with text="World" and segment_id=1.
35+
36+
Args:
37+
separator (str): The separator to split the documents on.
38+
text_field (str): The name of the column containing the text to split.
39+
Defaults to "text".
40+
segment_id_field (str): The name of the column to add to indicate the segment id.
41+
Defaults to "segment_id".
42+
"""
43+
44+
separator: str
45+
text_field: str = "text"
46+
segment_id_field: str = "segment_id"
47+
name: str = "document_splitter"
48+
49+
def inputs(self) -> tuple[list[str], list[str]]:
50+
"""Define stage input requirements."""
51+
return ["data"], [self.text_field]
52+
53+
def outputs(self) -> tuple[list[str], list[str]]:
54+
"""Define stage output specification."""
55+
return ["data"], [self.text_field, self.segment_id_field]
56+
57+
def process(self, batch: DocumentBatch) -> DocumentBatch:
58+
"""
59+
Splits the documents into segments based on the separator and
60+
adds a column indicating the segment id.
61+
62+
Args:
63+
batch (DocumentBatch): Input batch to process
64+
65+
Returns:
66+
DocumentBatch: Batch with documents split into segments
67+
"""
68+
df = batch.to_pandas()
69+
70+
# Split the text field into segments using the separator.
71+
df["_split_text"] = df[self.text_field].str.split(self.separator)
72+
73+
# Explode the list so that each segment becomes a separate row.
74+
# The index is preserved and duplicated for each segment from the same document
75+
df = df.explode("_split_text")
76+
77+
# For each original document (grouped by index level 0), assign a segment id.
78+
# level=0 refers to the (duplicated) index after explode
79+
df[self.segment_id_field] = df.groupby(level=0).cumcount()
80+
81+
# Replace the original text field with the split segment.
82+
df[self.text_field] = df["_split_text"]
83+
84+
# Drop the temporary column and reset index to sequential
85+
df = df.drop(columns=["_split_text"]).reset_index(drop=True)
86+
87+
return DocumentBatch(
88+
task_id=f"{batch.task_id}_{self.name}",
89+
dataset_name=batch.dataset_name,
90+
data=df,
91+
_metadata=batch._metadata,
92+
_stage_perf=batch._stage_perf,
93+
)
94+

0 commit comments

Comments
 (0)