Skip to content

Commit 48dbb37

Browse files
sahgerladyuki-97
andauthored
fix: Support datasets saved with save_to_disk in ResponseDataset (#1610)
Signed-off-by: Sahger Lad <[email protected]> Signed-off-by: sahgerlad <[email protected]> Co-authored-by: Yuki Huang <[email protected]>
1 parent 5bf56a9 commit 48dbb37

File tree

3 files changed

+87
-8
lines changed

3 files changed

+87
-8
lines changed

nemo_rl/data/datasets/response_datasets/response_dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ def __init__(
5656
else:
5757
val_ds = None
5858

59-
# format the dataset
60-
train_ds = train_ds.map(
61-
self.add_messages_key, fn_kwargs={"task_name": self.task_name}
62-
)
63-
if val_ds:
59+
# Only apply add_messages_key if 'messages' column doesn't exist
60+
if "messages" not in train_ds.column_names:
61+
train_ds = train_ds.map(
62+
self.add_messages_key, fn_kwargs={"task_name": self.task_name}
63+
)
64+
if val_ds is not None and "messages" not in val_ds.column_names:
6465
val_ds = val_ds.map(
6566
self.add_messages_key, fn_kwargs={"task_name": self.task_name}
6667
)

nemo_rl/data/datasets/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional, Union
1818

1919
import torch
20-
from datasets import DatasetDict, load_dataset
20+
from datasets import DatasetDict, load_dataset, load_from_disk
2121
from PIL import Image
2222
from transformers import AutoProcessor, PreTrainedTokenizerBase
2323

@@ -62,7 +62,7 @@ def pil_to_base64(image: Image.Image, format: str = "PNG") -> str:
6262

6363

6464
def load_dataset_from_path(data_path: str, data_split: Optional[str] = "train"):
65-
"""Load a dataset from a json or huggingface dataset.
65+
"""Load a dataset from a json, huggingface dataset, or Arrow dataset (saved with save_to_disk).
6666
6767
Args:
6868
data_path: The path to the dataset.
@@ -72,7 +72,13 @@ def load_dataset_from_path(data_path: str, data_split: Optional[str] = "train"):
7272
if suffix in [".json", ".jsonl"]:
7373
raw_dataset = load_dataset("json", data_files=data_path)
7474
else:
75-
raw_dataset = load_dataset(data_path)
75+
try:
76+
raw_dataset = load_dataset(data_path)
77+
except ValueError as e:
78+
if "load_from_disk" in str(e):
79+
raw_dataset = load_from_disk(data_path)
80+
else:
81+
raise e
7682

7783
if data_split:
7884
raw_dataset = raw_dataset[data_split]

tests/unit/data/datasets/test_response_dataset.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,75 @@ def test_squad_dataset():
165165
+ " Answer: "
166166
+ example["messages"][2]["content"]
167167
)
168+
169+
170+
def test_load_dataset_saved_with_save_to_disk():
171+
"""Test loading a dataset that was saved using HuggingFace's save_to_disk().
172+
173+
This tests the fix for datasets that already have a 'messages' column,
174+
which should be preserved without applying add_messages_key again.
175+
"""
176+
from datasets import Dataset
177+
178+
# Create a dataset with 'messages' column already present
179+
train_data = [
180+
{
181+
"messages": [
182+
{"role": "user", "content": "What is 2+2?"},
183+
{"role": "assistant", "content": "4"},
184+
]
185+
},
186+
{
187+
"messages": [
188+
{"role": "user", "content": "What is the capital of France?"},
189+
{"role": "assistant", "content": "Paris"},
190+
]
191+
},
192+
]
193+
val_data = [
194+
{
195+
"messages": [
196+
{"role": "user", "content": "What is 3+3?"},
197+
{"role": "assistant", "content": "6"},
198+
]
199+
},
200+
]
201+
202+
with tempfile.TemporaryDirectory() as tmpdir:
203+
# Create HF datasets and save using save_to_disk
204+
train_dataset = Dataset.from_list(train_data)
205+
val_dataset = Dataset.from_list(val_data)
206+
207+
train_path = f"{tmpdir}/train"
208+
val_path = f"{tmpdir}/val"
209+
210+
train_dataset.save_to_disk(train_path)
211+
val_dataset.save_to_disk(val_path)
212+
213+
# Load using load_response_dataset
214+
data_config = {
215+
"dataset_name": "ResponseDataset",
216+
"train_data_path": train_path,
217+
"val_data_path": val_path,
218+
}
219+
dataset = load_response_dataset(data_config)
220+
221+
# Verify the dataset loaded correctly
222+
assert "train" in dataset.formatted_ds
223+
assert "validation" in dataset.formatted_ds
224+
assert len(dataset.formatted_ds["train"]) == 2
225+
assert len(dataset.formatted_ds["validation"]) == 1
226+
227+
# Verify messages are preserved correctly
228+
first_train_example = dataset.formatted_ds["train"][0]
229+
assert "messages" in first_train_example
230+
assert len(first_train_example["messages"]) == 2
231+
assert first_train_example["messages"][0]["role"] == "user"
232+
assert first_train_example["messages"][0]["content"] == "What is 2+2?"
233+
assert first_train_example["messages"][1]["role"] == "assistant"
234+
assert first_train_example["messages"][1]["content"] == "4"
235+
236+
# Verify validation data
237+
first_val_example = dataset.formatted_ds["validation"][0]
238+
assert first_val_example["messages"][0]["content"] == "What is 3+3?"
239+
assert first_val_example["messages"][1]["content"] == "6"

0 commit comments

Comments
 (0)