|
17 | 17 | from torchcodec.decoders import VideoDecoder |
18 | 18 | from transformers.image_utils import to_numpy_array |
19 | 19 |
|
| 20 | +from .vlln_lerobot_dataset import VLLNDataset |
20 | 21 | from .rope2d import get_rope_index_2, get_rope_index_25 |
21 | 22 |
|
22 | 23 | # Define placeholders for dataset paths |
@@ -150,6 +151,11 @@ def parse_sampling_rate(dataset_name): |
150 | 151 | return 1.0 |
151 | 152 |
|
152 | 153 |
|
| 154 | +def read_jsonl(path): |
| 155 | + with open(path, "r") as f: |
| 156 | + return [json.loads(line) for line in f] |
| 157 | + |
| 158 | + |
153 | 159 | def data_list(dataset_names): |
154 | 160 | config_list = [] |
155 | 161 | for dataset_name in dataset_names: |
@@ -180,11 +186,6 @@ def rank0_print(*args): |
180 | 186 | print(*args) |
181 | 187 |
|
182 | 188 |
|
183 | | -def read_jsonl(path): |
184 | | - with open(path, "r") as f: |
185 | | - return [json.loads(line) for line in f] |
186 | | - |
187 | | - |
188 | 189 | def preprocess_qwen_2_visual( |
189 | 190 | sources, |
190 | 191 | tokenizer: transformers.PreTrainedTokenizer, |
@@ -1329,11 +1330,50 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
1329 | 1330 |
|
1330 | 1331 | return batch |
1331 | 1332 |
|
| 1333 | +class CombinedDataset(Dataset): |
| 1334 | + """ |
| 1335 | + Combine multiple datasets into a single dataset interface. |
| 1336 | +
|
| 1337 | + This class is used to merge different datasets for joint training. |
| 1338 | + It concatenates samples from all provided datasets and optionally shuffles |
| 1339 | + the global index mapping (without changing the underlying datasets). |
| 1340 | + """ |
| 1341 | + def __init__(self, datasets, shuffle=False): |
| 1342 | + super(CombinedDataset, self).__init__() |
| 1343 | + self.datasets = datasets |
| 1344 | + self.lengths = [len(dataset) for dataset in datasets] |
| 1345 | + self.cum_lengths = np.cumsum(self.lengths) |
| 1346 | + self.total_length = self.cum_lengths[-1] |
| 1347 | + self.shuffle_enabled = shuffle |
| 1348 | + self.indices = np.arange(self.total_length) |
| 1349 | + if self.shuffle_enabled: |
| 1350 | + self.shuffle() |
| 1351 | + |
| 1352 | + def shuffle(self): |
| 1353 | + np.random.shuffle(self.indices) |
| 1354 | + |
| 1355 | + def _map_index(self, idx): |
| 1356 | + return self.indices[idx] |
| 1357 | + |
| 1358 | + def __len__(self): |
| 1359 | + return self.cum_lengths[-1] |
| 1360 | + |
| 1361 | + def __getitem__(self, i): |
| 1362 | + real_idx = self._map_index(i) |
| 1363 | + for idx, cum_len in enumerate(self.cum_lengths): |
| 1364 | + if real_idx < cum_len: |
| 1365 | + return self.datasets[idx][real_idx - cum_len + self.lengths[idx]] |
| 1366 | + raise ValueError(f"Index {real_idx} out of bound") |
| 1367 | + |
1332 | 1368 |
|
1333 | 1369 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: |
1334 | 1370 | """Make dataset and collator for supervised fine-tuning.""" |
1335 | | - train_dataset = NavPixelGoalDataset(tokenizer=tokenizer, data_args=data_args) |
1336 | | - # train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args) |
| 1371 | + train_datasets = [] |
| 1372 | + if data_args.iion_dataset_use: |
| 1373 | + train_datasets.append(VLLNDataset(tokenizer=tokenizer, data_args=data_args)) |
| 1374 | + if data_args.vln_dataset_use: |
| 1375 | + train_datasets.append(NavPixelGoalDataset(tokenizer=tokenizer, data_args=data_args)) |
| 1376 | + train_dataset = CombinedDataset(train_datasets, shuffle=False) |
1337 | 1377 | if data_args.data_flatten: |
1338 | 1378 | data_collator = FlattenedDataCollatorForSupervisedDataset(tokenizer=tokenizer) |
1339 | 1379 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
|
0 commit comments