Skip to content

Commit 9462a81

Browse files
davidecarosellifacebook-github-bot
authored andcommitted
Enhanced MMapIndexedDataset: less memory, higher speed (#816)
Summary: I have made an upgrade to my previous implementation of MMapIndexedDataset, now: - It uses up to **4 times less memory and disk space** - Words per second is slightly improved thanks to less memory access Pull Request resolved: #816 Differential Revision: D15899848 Pulled By: myleott fbshipit-source-id: 9ddeb4809729ef69cc6b0867b33ee71184d845e6
1 parent 9c3bb5c commit 9462a81

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

fairseq/data/indexed_dataset.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,16 @@
1515
from . import FairseqDataset
1616

1717

18-
def make_builder(out_file, impl):
18+
def __best_fitting_dtype(vocab_size=None):
19+
if vocab_size is not None and vocab_size < 65500:
20+
return np.uint16
21+
else:
22+
return np.int32
23+
24+
25+
def make_builder(out_file, impl, vocab_size=None):
1926
if impl == 'mmap':
20-
return MMapIndexedDatasetBuilder(out_file)
27+
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
2128
else:
2229
return IndexedDatasetBuilder(out_file)
2330

@@ -63,6 +70,7 @@ def write_longs(f, a):
6370
5: np.int64,
6471
6: np.float,
6572
7: np.double,
73+
8: np.uint16
6674
}
6775

6876

@@ -143,7 +151,7 @@ def size(self, index):
143151
@staticmethod
144152
def exists(path):
145153
return (
146-
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
154+
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
147155
)
148156

149157
@property
@@ -440,11 +448,11 @@ def __len__(self):
440448

441449
def __getitem__(self, i):
442450
ptr, size = self._index[i]
443-
tensor = torch.from_numpy(np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr))
444-
if tensor.dtype == torch.int64:
445-
return tensor
446-
else:
447-
return tensor.long()
451+
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr)
452+
if self._index.dtype != np.int64:
453+
np_array = np_array.astype(np.int64)
454+
455+
return torch.from_numpy(np_array)
448456

449457
@property
450458
def sizes(self):
@@ -457,7 +465,7 @@ def supports_prefetch(self):
457465
@staticmethod
458466
def exists(path):
459467
return (
460-
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
468+
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
461469
)
462470

463471

preprocess.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def merge_result(worker_result):
129129
)
130130
pool.close()
131131

132-
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl)
132+
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"),
133+
impl=args.dataset_impl, vocab_size=len(vocab))
133134
merge_result(
134135
Binarizer.binarize(
135136
input_file, vocab, lambda t: ds.add_item(t),
@@ -231,7 +232,8 @@ def make_all(lang, vocab):
231232

232233

233234
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
234-
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl)
235+
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"),
236+
impl=args.dataset_impl, vocab_size=len(vocab))
235237

236238
def consumer(tensor):
237239
ds.add_item(tensor)

0 commit comments

Comments
 (0)