diff --git a/scripts/preprocess.py b/scripts/preprocess.py index 90b834b6..e8ec8103 100644 --- a/scripts/preprocess.py +++ b/scripts/preprocess.py @@ -43,35 +43,46 @@ print ' Test size: %d' % test_size # Choose the datatype based on the vocabulary size - dtype = np.uint8 - if len(token_to_idx) > 255: + if len(token_to_idx) > 4294967295: + dtype = np.uint64 + elif len(token_to_idx) > 65535: dtype = np.uint32 + elif len(token_to_idx) > 255: + dtype = np.uint16 + else: + dtype = np.uint8 + if not args.quiet: print 'Using dtype ', dtype - # Just load data into memory ... we'll have to do something more clever - # for huge datasets but this should be fine for now - train = np.zeros(train_size, dtype=dtype) - val = np.zeros(val_size, dtype=dtype) - test = np.zeros(test_size, dtype=dtype) - splits = [train, val, test] - - # Go through the file again and write data to numpy arrays - split_idx, cur_idx = 0, 0 + # Create, fill, and store each dataset, + # one at a time to save memory with codecs.open(args.input_txt, 'r', args.encoding) as f: - for line in f: - for char in line: - splits[split_idx][cur_idx] = token_to_idx[char] - cur_idx += 1 - if cur_idx == splits[split_idx].size: - split_idx += 1 - cur_idx = 0 - - # Write data to HDF5 file - with h5py.File(args.output_h5, 'w') as f: - f.create_dataset('train', data=train) - f.create_dataset('val', data=val) - f.create_dataset('test', data=test) + with h5py.File(args.output_h5, 'w') as h: + def fill_and_store(arr_size, set_name): + """Create a one-dimensional numpy array + of the given size, fill it, + and write the result to h under the given name. + + Leaves the source file advanced as far + as it had to go to fill the array. + + If the remaining part of the file is shorter + than arr_size, the remainder of the array is + filled with zeroes. + """ + arr = np.zeros(arr_size, dtype=dtype) + for idx in xrange(arr_size): + char = f.read(1) + if not char: + break + arr[idx] = token_to_idx[char] + + h.create_dataset(set_name, data=arr) + + fill_and_store(train_size, 'train') + fill_and_store(val_size, 'val') + fill_and_store(test_size, 'test') # For 'bytes' encoding, replace non-ascii characters so the json dump # doesn't crash