Skip to content

Commit 55800f6

Browse files
committed
* fix dataset buffer logics and tests
1 parent a1b3b01 commit 55800f6

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

tests/data/core/formatter_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_init(self):
4747
self.assertEqual(formatter.config.solution_key, "solution")
4848
self.assertEqual(formatter.config.chat_template, "User: {}\nAssistant: ")
4949
# test for default configs
50-
self.assertEqual(formatter.config.reward_key, "")
50+
self.assertEqual(formatter.config.reward_key, "reward")
5151
self.assertEqual(formatter.config.chosen_key, "chosen")
5252
self.assertEqual(formatter.config.rejected_key, "rejected")
5353
self.assertEqual(formatter.config.label_key, "")

trinity/data/core/dataset.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC
2-
from dataclasses import asdict, dataclass, fields
2+
from dataclasses import asdict, dataclass, fields, is_dataclass
33
from typing import Any, Dict, List, Optional, Union
44

55
import networkx as nx
@@ -87,14 +87,19 @@ def read_from_buffer(self):
8787
datasets = []
8888
for buffer in self.input_buffers:
8989
exp_list = buffer.read()
90-
if self.original_dataclass is None:
91-
self.original_dataclass = exp_list[0].__class__
92-
datasets.append(Dataset.from_list([asdict(exp) for exp in exp_list]))
90+
if len(exp_list) > 0 and is_dataclass(exp_list[0]):
91+
exp_list = [asdict(exp) for exp in exp_list]
92+
if self.original_dataclass is None:
93+
self.original_dataclass = exp_list[0].__class__
94+
datasets.append(Dataset.from_list([exp for exp in exp_list]))
9395
self.data = concatenate_datasets(datasets)
9496
logger.info(f"Read {len(self.data)} samples from input buffers")
9597

9698
def write_to_buffer(self):
97-
exp_list = [dict_to_dataclass(self.original_dataclass, d) for d in self.data.to_list()]
99+
if self.original_dataclass is not None:
100+
exp_list = [dict_to_dataclass(self.original_dataclass, d) for d in self.data.to_list()]
101+
else:
102+
exp_list = self.data.to_list()
98103
self.output_buffer.write(exp_list)
99104
logger.info(f"Wrote {len(self.data)} samples to output buffer")
100105
self.data = Dataset.from_list([])

0 commit comments

Comments
 (0)