Skip to content

Commit aa449bd

Browse files
authored
[Feature] Add training code for the baseline of VLLN Bench (#198)
* add VL-LN Bench training code * add VL-LN Bench training code * "Remove VLLN trainer; unify training for VLN and IION datasets." * solve the issue from kellyiss and kew6688 * solve the issue from Tai-Wang * (1) Remove `dataset_utils.py`. (2) Add standard docstrings to the main class and key functions. * solve the issue from Tai-Wang * solve the issue from Tai-Wang * refine the docstring
1 parent 0e23d4c commit aa449bd

File tree

4 files changed

+804
-7
lines changed

4 files changed

+804
-7
lines changed

internnav/dataset/internvla_n1_lerobot_dataset.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchcodec.decoders import VideoDecoder
1818
from transformers.image_utils import to_numpy_array
1919

20+
from .vlln_lerobot_dataset import VLLNDataset
2021
from .rope2d import get_rope_index_2, get_rope_index_25
2122

2223
# Define placeholders for dataset paths
@@ -150,6 +151,11 @@ def parse_sampling_rate(dataset_name):
150151
return 1.0
151152

152153

154+
def read_jsonl(path):
155+
with open(path, "r") as f:
156+
return [json.loads(line) for line in f]
157+
158+
153159
def data_list(dataset_names):
154160
config_list = []
155161
for dataset_name in dataset_names:
@@ -180,11 +186,6 @@ def rank0_print(*args):
180186
print(*args)
181187

182188

183-
def read_jsonl(path):
184-
with open(path, "r") as f:
185-
return [json.loads(line) for line in f]
186-
187-
188189
def preprocess_qwen_2_visual(
189190
sources,
190191
tokenizer: transformers.PreTrainedTokenizer,
@@ -1329,11 +1330,50 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
13291330

13301331
return batch
13311332

1333+
class CombinedDataset(Dataset):
1334+
"""
1335+
Combine multiple datasets into a single dataset interface.
1336+
1337+
This class is used to merge different datasets for joint training.
1338+
It concatenates samples from all provided datasets and optionally shuffles
1339+
the global index mapping (without changing the underlying datasets).
1340+
"""
1341+
def __init__(self, datasets, shuffle=False):
1342+
super(CombinedDataset, self).__init__()
1343+
self.datasets = datasets
1344+
self.lengths = [len(dataset) for dataset in datasets]
1345+
self.cum_lengths = np.cumsum(self.lengths)
1346+
self.total_length = self.cum_lengths[-1]
1347+
self.shuffle_enabled = shuffle
1348+
self.indices = np.arange(self.total_length)
1349+
if self.shuffle_enabled:
1350+
self.shuffle()
1351+
1352+
def shuffle(self):
1353+
np.random.shuffle(self.indices)
1354+
1355+
def _map_index(self, idx):
1356+
return self.indices[idx]
1357+
1358+
def __len__(self):
1359+
return self.cum_lengths[-1]
1360+
1361+
def __getitem__(self, i):
1362+
real_idx = self._map_index(i)
1363+
for idx, cum_len in enumerate(self.cum_lengths):
1364+
if real_idx < cum_len:
1365+
return self.datasets[idx][real_idx - cum_len + self.lengths[idx]]
1366+
raise ValueError(f"Index {real_idx} out of bound")
1367+
13321368

13331369
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
13341370
"""Make dataset and collator for supervised fine-tuning."""
1335-
train_dataset = NavPixelGoalDataset(tokenizer=tokenizer, data_args=data_args)
1336-
# train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_args=data_args)
1371+
train_datasets = []
1372+
if data_args.iion_dataset_use:
1373+
train_datasets.append(VLLNDataset(tokenizer=tokenizer, data_args=data_args))
1374+
if data_args.vln_dataset_use:
1375+
train_datasets.append(NavPixelGoalDataset(tokenizer=tokenizer, data_args=data_args))
1376+
train_dataset = CombinedDataset(train_datasets, shuffle=False)
13371377
if data_args.data_flatten:
13381378
data_collator = FlattenedDataCollatorForSupervisedDataset(tokenizer=tokenizer)
13391379
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)

0 commit comments

Comments
 (0)