Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions scripts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,44 @@
# Choose the datatype based on the vocabulary size
dtype = np.uint8
if len(token_to_idx) > 255:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment #1: use a single level of branching:

if len(..) > 4294967295:
     ....
elif len(..) > 65535:
    ...
else:
   ...

dtype = np.uint32
if len(token_to_idx) > 65535:
if len(token_to_idx) > 4294967295:
dtype = np.uint64
else:
dtype = np.uint32
else:
dtype = np.uint16
if not args.quiet:
print 'Using dtype ', dtype

# Just load data into memory ... we'll have to do something more clever
# for huge datasets but this should be fine for now
train = np.zeros(train_size, dtype=dtype)
val = np.zeros(val_size, dtype=dtype)
test = np.zeros(test_size, dtype=dtype)
splits = [train, val, test]

# Go through the file again and write data to numpy arrays
split_idx, cur_idx = 0, 0
# Create, fill, and store each dataset,
# one at a time to save memory
with codecs.open(args.input_txt, 'r', args.encoding) as f:
for line in f:
for char in line:
splits[split_idx][cur_idx] = token_to_idx[char]
cur_idx += 1
if cur_idx == splits[split_idx].size:
split_idx += 1
cur_idx = 0
with h5py.File(args.output_h5, 'w') as h:
def fill_and_store(arr_size, set_name):
"""Create a one-dimensional numpy array
of the given size, fill it,
and write the result to h under the given name.

Leaves the source file advanced as far
as it had to go to fill the array.

If the remaining part of the file is shorter
than arr_size, the remainder of the array is
filled with zeroes.
"""
arr = np.zeros(arr_size, dtype=dtype)
for idx in xrange(arr_size):
char = f.read(1)
if not char:
break
arr[idx] = token_to_idx[char]

h.create_dataset(set_name, data=arr)

# Write data to HDF5 file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment #2: Indentation doesn't match the rest of the file.

with h5py.File(args.output_h5, 'w') as f:
f.create_dataset('train', data=train)
f.create_dataset('val', data=val)
f.create_dataset('test', data=test)
fill_and_store(train_size, 'train')
fill_and_store(val_size, 'val')
fill_and_store(test_size, 'test')

# For 'bytes' encoding, replace non-ascii characters so the json dump
# doesn't crash
Expand Down