Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions scripts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,46 @@
print ' Test size: %d' % test_size

# Choose the datatype based on the vocabulary size
dtype = np.uint8
if len(token_to_idx) > 255:
if len(token_to_idx) > 4294967295:
dtype = np.uint64
elif len(token_to_idx) > 65535:
dtype = np.uint32
elif len(token_to_idx) > 255:
dtype = np.uint16
else:
dtype = np.uint8

if not args.quiet:
print 'Using dtype ', dtype

# Just load data into memory ... we'll have to do something more clever
# for huge datasets but this should be fine for now
train = np.zeros(train_size, dtype=dtype)
val = np.zeros(val_size, dtype=dtype)
test = np.zeros(test_size, dtype=dtype)
splits = [train, val, test]

# Go through the file again and write data to numpy arrays
split_idx, cur_idx = 0, 0
# Create, fill, and store each dataset,
# one at a time to save memory
with codecs.open(args.input_txt, 'r', args.encoding) as f:
for line in f:
for char in line:
splits[split_idx][cur_idx] = token_to_idx[char]
cur_idx += 1
if cur_idx == splits[split_idx].size:
split_idx += 1
cur_idx = 0

# Write data to HDF5 file
with h5py.File(args.output_h5, 'w') as f:
f.create_dataset('train', data=train)
f.create_dataset('val', data=val)
f.create_dataset('test', data=test)
with h5py.File(args.output_h5, 'w') as h:
def fill_and_store(arr_size, set_name):
"""Create a one-dimensional numpy array
of the given size, fill it,
and write the result to h under the given name.

Leaves the source file advanced as far
as it had to go to fill the array.

If the remaining part of the file is shorter
than arr_size, the remainder of the array is
filled with zeroes.
"""
arr = np.zeros(arr_size, dtype=dtype)
for idx in xrange(arr_size):
char = f.read(1)
if not char:
break
arr[idx] = token_to_idx[char]

h.create_dataset(set_name, data=arr)

fill_and_store(train_size, 'train')
fill_and_store(val_size, 'val')
fill_and_store(test_size, 'test')

# For 'bytes' encoding, replace non-ascii characters so the json dump
# doesn't crash
Expand Down