|
1 | 1 | from abc import ABC |
2 | | -from dataclasses import asdict, dataclass, fields |
| 2 | +from dataclasses import asdict, dataclass, fields, is_dataclass |
3 | 3 | from typing import Any, Dict, List, Optional, Union |
4 | 4 |
|
5 | 5 | import networkx as nx |
@@ -87,14 +87,19 @@ def read_from_buffer(self): |
87 | 87 | datasets = [] |
88 | 88 | for buffer in self.input_buffers: |
89 | 89 | exp_list = buffer.read() |
90 | | - if self.original_dataclass is None: |
91 | | - self.original_dataclass = exp_list[0].__class__ |
92 | | - datasets.append(Dataset.from_list([asdict(exp) for exp in exp_list])) |
| 90 | + if len(exp_list) > 0 and is_dataclass(exp_list[0]): |
| 91 | + exp_list = [asdict(exp) for exp in exp_list] |
| 92 | + if self.original_dataclass is None: |
| 93 | + self.original_dataclass = exp_list[0].__class__ |
| 94 | + datasets.append(Dataset.from_list([exp for exp in exp_list])) |
93 | 95 | self.data = concatenate_datasets(datasets) |
94 | 96 | logger.info(f"Read {len(self.data)} samples from input buffers") |
95 | 97 |
|
96 | 98 | def write_to_buffer(self): |
97 | | - exp_list = [dict_to_dataclass(self.original_dataclass, d) for d in self.data.to_list()] |
| 99 | + if self.original_dataclass is not None: |
| 100 | + exp_list = [dict_to_dataclass(self.original_dataclass, d) for d in self.data.to_list()] |
| 101 | + else: |
| 102 | + exp_list = self.data.to_list() |
98 | 103 | self.output_buffer.write(exp_list) |
99 | 104 | logger.info(f"Wrote {len(self.data)} samples to output buffer") |
100 | 105 | self.data = Dataset.from_list([]) |
|
0 commit comments