diff --git a/LanguageModel.lua b/LanguageModel.lua index d6248184..b7d0859f 100644 --- a/LanguageModel.lua +++ b/LanguageModel.lua @@ -162,12 +162,28 @@ function LM:sample(kwargs) local verbose = utils.get_kwarg(kwargs, 'verbose', 0) local sample = utils.get_kwarg(kwargs, 'sample', 1) local temperature = utils.get_kwarg(kwargs, 'temperature', 1) + local start_tokens = utils.get_kwarg(kwargs,'start_tokens','') local sampled = torch.LongTensor(1, T) self:resetStates() local scores, first_t - if #start_text > 0 then + if #start_tokens > 0 then + local json_tokens = utils.read_json(start_tokens) + + local num_tokens = table.getn(json_tokens.tokens) + + local tokenTensor = torch.LongTensor(num_tokens) + for i = 1,num_tokens do + tokenTensor[i] = json_tokens.tokens[i] + end + + local x = tokenTensor:view(1,-1) + local T0 = x:size(2) + sampled[{{}, {1, T0}}]:copy(x) + scores = self:forward(x)[{{}, {T0, T0}}] + first_t = T0 + 1 + elseif #start_text > 0 then if verbose > 0 then print('Seeding with: "' .. start_text .. '"') end diff --git a/README.md b/README.md index cb09269b..0e527003 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # torch-rnn -torch-rnn provides high-performance, reusable RNN and LSTM modules for torch7, and uses these modules for character-level +torch-rnn provides high-performance, reusable RNN and LSTM modules for torch7, and uses these modules for character-level and word-level language modeling similar to [char-rnn](https://github.com/karpathy/char-rnn). You can find documentation for the RNN and LSTM modules [here](doc/modules.md); they have no dependencies other than `torch` @@ -92,7 +92,7 @@ Jeff Thompson has written a very detailed installation guide for OSX that you [c To train a model and use it to generate new text, you'll need to follow three simple steps: ## Step 1: Preprocess the data -You can use any text file for training models. Before training, you'll need to preprocess the data using the script +You can use any text file or folder of .txt files for training models. Before training, you'll need to preprocess the data using the script `scripts/preprocess.py`; this will generate an HDF5 file and JSON file containing a preprocessed version of the data. If you have training data stored in `my_data.txt`, you can run the script like this: @@ -104,10 +104,29 @@ python scripts/preprocess.py \ --output_json my_data.json ``` +If you instead have multiple .txt files in the folder `my_data`, you instead run the script like this + +```bash +python scripts/preprocess.py \ + --input_folder my_data +``` + This will produce files `my_data.h5` and `my_data.json` that will be passed to the training script. There are a few more flags you can use to configure preprocessing; [read about them here](doc/flags.md#preprocessing) +### Preprocess Word Tokens +To preprocess the input data with words as tokens, add the flag `--use_words`. + +A large text corpus will contain many rare words, usually typos or unusual names. Adding a token for each of these is not practical and can result in a very large token space. Using the options `--min_occurrences` or `--min_documents` allow specifying how many times or in how many documents a word must occur before being added as a token. Words that fail to meet these criteria are replaced by wildcards, which are randomly distributed to avoid overtraining. + +More information on additional flags is available [here](doc/flags.md#preprocessing) + +### Preprocess Data With Existing Token Schema +If you have an existing token schema (.json file generated by preprocess.py), you can use the script `scripts/tokenize.py` to tokenize a file based on that schema. It accepts input as a text file or folder of text files (similar to the preprocessing script), as well as an argument `--input_json` which specifies the input token schema file. This is useful for transfer learning onto a new dataset. + +To learn more about the tokenizing script [see here](doc/flags.md#tokenizing). + ## Step 2: Train the model After preprocessing the data, you'll need to train the model using the `train.lua` script. This will be the slowest step. You can run the training script like this: @@ -144,6 +163,12 @@ and print the results to the console. By default the sampling script will run in GPU mode using CUDA; to run in CPU-only mode add the flag `-gpu -1` and to run in OpenCL mode add the flag `-gpu_backend opencl`. +To pre-seed the model with text, there are 2 options. If you used character-based preprocessing, use flag `-start_text` and include a quoted string. + +If you used word-based preprocessing, use the Python script `scripts\tokenize.py` to generate a JSON file of tokens and provide it using the flag `-start_tokens`. Since Python was used to parse the input data into tokens, it is best to use Python to so it for seed text as well, as Lua does not have full regex support, hence the extra step. + +To learn more about the tokenizing script [see here](doc/flags.md#tokenizing). + There are more flags you can use to configure sampling; [read about them here](doc/flags.md#sampling). # Benchmarks diff --git a/doc/flags.md b/doc/flags.md index f2652bbf..ddedcc50 100644 --- a/doc/flags.md +++ b/doc/flags.md @@ -3,12 +3,22 @@ Here we'll describe in detail the full set of command line flags available for p # Preprocessing The preprocessing script `scripts/preprocess.py` accepts the following command-line flags: - `--input_txt`: Path to the text file to be used for training. Default is the `tiny-shakespeare.txt` dataset. +- `--input_folder`: Path to a folder containing .txt files to use for training. Overrides the `--input_txt` option - `--output_h5`: Path to the HDF5 file where preprocessed data should be written. - `--output_json`: Path to the JSON file where preprocessed data should be written. - `--val_frac`: What fraction of the data to use as a validation set; default is `0.1`. - `--test_frac`: What fraction of the data to use as a test set; default is `0.1`. - `--quiet`: If you pass this flag then no output will be printed to the console. +- `--use_words`: Passing this flag preprocesses the flag as word tokens rather than characters. Using it activates additional options below (ignored otherwise). +##Preprocessing Word Tokens +- `--case_sensitive`: Makes word tokens case-sensitive. Default is to convert everything to lowercase for words, character tokens are ALWAYS case-sensitive. +- `--min_occurrences`: Minimum number of times a word needs to be seen to be given a token. Default is 20. +- `--min_documents`: Minimum number of documents a word needs to be seen in to be given a token. Default is 1. +- `--use_ascii`: Convert the input files to ASCII by removing all non-ASCII characters. Default is unicode. +- `--wildcard_rate`: Number of wildcards generated as a fraction of ignored words. Ex. `0.01` will generate 1 percent of the number of ignored words as wildcards. Default is `0.01`. +- `--wildcard_max`: If set, the maximum number of wildcards that will be generated. Default is unlimited. +- `--wildcard_min`: Minimum number of wildcards that will be generated. Cannot be less than 1. Default is 10. # Training The training script `train.lua` accepts the following command-line flags: @@ -51,9 +61,22 @@ The training script `train.lua` accepts the following command-line flags: The sampling script `sample.lua` accepts the following command-line flags: - `-checkpoint`: Path to a `.t7` checkpoint file from `train.lua` - `-length`: The length of the generated text, in characters. -- `-start_text`: You can optionally start off the generation process with a string; if this is provided the start text will be processed by the trained network before we start sampling. Without this flag, the first character is chosen randomly. +- `-start_text`: You can optionally start off the generation process with a string; if this is provided the start text will be processed by the trained network before we start sampling. Without this flag or the `-start_tokens` flag, the first character is chosen randomly. +- `-start_tokens`: As an alternative to start_text for word-based tokenizing, accepts a JSON file generated by `scripts/tokenize.py` which contains tokens for start text. Without this flag or the `-start_text` flag, the first character is chosen randomly. - `-sample`: Set this to 1 to sample from the next-character distribution at each timestep; set to 0 to instead just pick the argmax at every timestep. Sampling tends to produce more interesting results. - `-temperature`: Softmax temperature to use when sampling; default is 1. Higher temperatures give noiser samples. Not used when using argmax sampling (`sample` set to 0). - `-gpu`: The ID of the GPU to use (zero-indexed). Default is 0. Set this to -1 to run in CPU-only mode. - `-gpu_backend`: The GPU backend to use; either `cuda` or `opencl`. Default is `cuda`. - `-verbose`: By default just the sampled text is printed to the console. Set this to 1 to also print some diagnostic information. + +#Tokenizing +The tokenizing script `scripts/tokenizeWords.py` accepts the following command-line flags: +- `--input_str`: The string to tokenize as a quoted block, ex. `--input "lorem ipsum"` +- `--input_txt`: Path to the text file to be used for training. Default is the `tiny-shakespeare.txt` dataset. +- `--input_folder`: Path to a folder containing .txt files to use for training. Overrides the `--input_txt` option +- `--input_json: The JSON output from `scripts/preprocessWords.py` to use to tokenize the string +- `--output_json`: Optional - The output JSON file to save the tokenization to. +- `--output_h5`: Optional - The path to the HDF5 file where preprocessed data should be written. +- `--val_frac`: What fraction of the data to use as a validation set; default is `0.1`. +- `--test_frac`: What fraction of the data to use as a test set; default is `0.1`. +- `--quiet`: If you pass this flag then no output will be printed to the console except in case of error. diff --git a/sample.lua b/sample.lua index 4e6ebae0..fdb2792c 100644 --- a/sample.lua +++ b/sample.lua @@ -8,6 +8,7 @@ local cmd = torch.CmdLine() cmd:option('-checkpoint', 'cv/checkpoint_4000.t7') cmd:option('-length', 2000) cmd:option('-start_text', '') +cmd:option('-start_tokens','') cmd:option('-sample', 1) cmd:option('-temperature', 1) cmd:option('-gpu', 0) diff --git a/scripts/preprocess.py b/scripts/preprocess.py index fecd56e5..339d2a46 100644 --- a/scripts/preprocess.py +++ b/scripts/preprocess.py @@ -1,97 +1,194 @@ # -*- coding: utf-8 -*- from __future__ import print_function -import argparse -import json -import os -import six +import argparse, json, os, codecs, h5py, re, string, random, six +from unidecode import unidecode + import numpy as np -import h5py -import codecs +def load_from_files(file_list,use_ascii,encoding): + file_contents = [] + for path in file_list: + with codecs.open(path, 'r', encoding) as infile: + if use_ascii: + file_contents.append(unidecode(infile.read()).encode('ascii', 'ignore')) + else: + file_contents.append(infile.read()) + infile.close() + return file_contents + +def parse_file(file_contents,regex,case_sensitive): + # Split into tokens + if not case_sensitive: + file_contents = file_contents.lower() + if regex != '': + return [item for item in re.split(regex,file_contents,flags=re.UNICODE) if item != ''] + else: + return list(file_contents) + +def compute_frequency(parsed_files): + tokenlist = {} + for item in parsed_files: + item_tokens = set() + for token in item: + if token == '': + continue + if token in tokenlist: + tokenlist[token][0] += 1 + if token not in item_tokens: + item_tokens.add(token) + tokenlist[token][1] += 1 + else: + item_tokens.add(token) + tokenlist[token] = [1,1] + return tokenlist + +def tokenize_data(data_per_file,token_to_idx,wildcard_ids): + unified_idx = [] + wildcard_replace_count = 0 + for item in data_per_file: + for token in item: + if token in token_to_idx: + unified_idx.append(token_to_idx[token]) + else: + if len(wildcard_ids) != 0: + unified_idx.append(random.choice(wildcard_ids)) + wildcard_replace_count += 1 + return unified_idx,wildcard_replace_count + +def build_tokenset(wordlist,min_documents,min_occurrences,min_wildcards,max_wildcards,wildcard_rate): + token_to_idx = {} + wordid = 1 + ignore_counts = set(string.punctuation).union(string.whitespace) # Preserve tokens for all encountered punctuation or whitespace + + total_eliminated = 0 + + for item in wordlist: + if item in ignore_counts or (wordlist[item][0] >= min_occurrences and wordlist[item][1] >= min_documents): + token_to_idx[item] = wordid + wordid += 1 + else: + total_eliminated+=1 + + wildcard_ids = [] + + if total_eliminated > 0: + num_distinct_wild = max(min_wildcards,int(wildcard_rate*total_eliminated)) + if max_wildcards > 0: + num_distinct_wild = min(max_wildcards,num_distinct_wild) + + for wcnum in xrange(num_distinct_wild): + token_to_idx['*/WILDCARD/*{0}'.format(wcnum)] = wordid + wildcard_ids.append(wordid) + wordid += 1 -parser = argparse.ArgumentParser() -parser.add_argument('--input_txt', default='data/tiny-shakespeare.txt') -parser.add_argument('--output_h5', default='data/tiny-shakespeare.h5') -parser.add_argument('--output_json', default='data/tiny-shakespeare.json') -parser.add_argument('--val_frac', type=float, default=0.1) -parser.add_argument('--test_frac', type=float, default=0.1) -parser.add_argument('--quiet', action='store_true') -parser.add_argument('--encoding', default='utf-8') -args = parser.parse_args() + maxtoken = wordid + return token_to_idx,wildcard_ids,maxtoken + + +def save_to_hdf5(data,filename,train_size,val_size,test_size, dtype): + # Split data up into train,val, and test sets. This avoids zeros popping up (might have been the cause of earlier issues) + train = np.array(data[:train_size], dtype=dtype) + val = np.array(data[train_size:train_size+val_size], dtype=dtype) + test = np.array(data[-test_size:], dtype=dtype) + splits = [train, val, test] + # Write data to HDF5 file + with h5py.File(filename, 'w') as f: + f.create_dataset('train', data=train) + f.create_dataset('val', data=val) + f.create_dataset('test', data=test) if __name__ == '__main__': - if args.encoding == 'bytes': args.encoding = None - - # First go the file once to see how big it is and to build the vocab - token_to_idx = {} - total_size = 0 - with codecs.open(args.input_txt, 'r', args.encoding) as f: - for line in f: - total_size += len(line) - for char in line: - if char not in token_to_idx: - token_to_idx[char] = len(token_to_idx) + 1 - - # Now we can figure out the split sizes - val_size = int(args.val_frac * total_size) - test_size = int(args.test_frac * total_size) - train_size = total_size - val_size - test_size - - if not args.quiet: - print('Total vocabulary size: %d' % len(token_to_idx)) - print('Total tokens in file: %d' % total_size) - print(' Training size: %d' % train_size) - print(' Val size: %d' % val_size) - print(' Test size: %d' % test_size) - - # Choose the datatype based on the vocabulary size - dtype = np.uint8 - if len(token_to_idx) > 255: - dtype = np.uint32 - if not args.quiet: - print('Using dtype ', dtype) - - # Just load data into memory ... we'll have to do something more clever - # for huge datasets but this should be fine for now - train = np.zeros(train_size, dtype=dtype) - val = np.zeros(val_size, dtype=dtype) - test = np.zeros(test_size, dtype=dtype) - splits = [train, val, test] - - # Go through the file again and write data to numpy arrays - split_idx, cur_idx = 0, 0 - with codecs.open(args.input_txt, 'r', args.encoding) as f: - for line in f: - for char in line: - splits[split_idx][cur_idx] = token_to_idx[char] - cur_idx += 1 - if cur_idx == splits[split_idx].size: - split_idx += 1 - cur_idx = 0 - - # Write data to HDF5 file - with h5py.File(args.output_h5, 'w') as f: - f.create_dataset('train', data=train) - f.create_dataset('val', data=val) - f.create_dataset('test', data=test) - - # For 'bytes' encoding, replace non-ascii characters so the json dump - # doesn't crash - if args.encoding is None: - new_token_to_idx = {} - for token, idx in six.iteritems(token_to_idx): - if ord(token) > 127: - new_token_to_idx['[%d]' % ord(token)] = idx - else: - new_token_to_idx[token] = idx - token_to_idx = new_token_to_idx - - # Dump a JSON file for the vocab - json_data = { - 'token_to_idx': token_to_idx, - 'idx_to_token': {v: k for k, v in six.iteritems(token_to_idx)}, - } - with open(args.output_json, 'w') as f: - json.dump(json_data, f) + + parser = argparse.ArgumentParser() + parser.add_argument('--input_txt', default='data/tiny-shakespeare.txt') + parser.add_argument('--input_folder', default='') + parser.add_argument('--output_h5', default='data/tiny-shakespeare.h5') + parser.add_argument('--output_json', default='data/tiny-shakespeare.json') + parser.add_argument('--val_frac', type=float, default=0.1) + parser.add_argument('--test_frac', type=float, default=0.1) + parser.add_argument('--quiet', action='store_true') + parser.add_argument('--use_ascii', action='store_true') + parser.add_argument('--encoding', default='utf-8') + + parser.add_argument('--use_words',action='store_true') + parser.add_argument('--case_sensitive', action='store_true') + parser.add_argument('--min_occurrences',type=int,default=20) + parser.add_argument('--min_documents', type=int,default=1) + parser.add_argument('--wildcard_rate',type=float,default=0.01) + parser.add_argument('--wildcard_max',type=int, default=-1) + parser.add_argument('--wildcard_min',type=int,default=10) + args = parser.parse_args() + + if args.encoding == 'bytes': args.encoding = None + + # Build list of files + infiles = [] + if args.input_folder != '': + infiles = [os.path.join(args.input_folder,item) for item in os.listdir(args.input_folder) if item[-4:]=='.txt'] + else: + infiles = [args.input_txt] + + # Sanity check, words can't be in more documents than there are in the corpus + if args.min_documents > len(infiles): + args.min_documents = len(infiles) + + # Regex to split on + regex = '(\W)' if args.use_words else '' + if not args.use_words: + args.case_sensitive = True + args.min_occurrences = 0 + args.min_documents = 0 + + files_parsed = [parse_file(f,regex,args.case_sensitive) for f in load_from_files(infiles,args.use_ascii,args.encoding)] + + wordlist = compute_frequency(files_parsed) + + # Build the final dictionary: word to token number + token_to_idx,wildcard_ids,maxtoken = build_tokenset(wordlist,args.min_documents,args.min_occurrences,args.wildcard_min,args.wildcard_max,args.wildcard_rate) + + # Now we create the final token array + outdata,wildcard_replace_count = tokenize_data(files_parsed,token_to_idx,wildcard_ids) + + total_size = len(outdata) + + # Now we can figure out the split sizes + val_size = int(args.val_frac * total_size) + test_size = int(args.test_frac * total_size) + train_size = total_size - val_size - test_size + + if not args.quiet: + if len(wildcard_ids) > 0: + wildcard_spec = ' ({0} wildcards)'.format(len(wildcard_ids)) + print('Total unique tokens: {0}'.format(len(wordlist))) + else: + wildcard_spec = '' + print('Total vocabulary size: {0}{1}'.format(len(token_to_idx), wildcard_spec)) + print('Total tokens in file: {0}'.format(total_size)) + if len(wildcard_ids) > 0: + print('Total wildcards in file: {0} ({1}%)'.format(wildcard_replace_count,100.0*wildcard_replace_count/total_size)) + print(' Training size: {0}'.format(train_size)) + print(' Val size: {0}'.format(val_size)) + print(' Test size: {0}'.format(test_size)) + + # Choose the datatype based on the vocabulary size + dtype = np.uint8 + if len(token_to_idx) > 255: + dtype = np.uint32 + if not args.quiet: + print('Using dtype {0}'.format(dtype)) + + save_to_hdf5(outdata,args.output_h5,train_size,val_size,test_size,dtype) + + # Dump a JSON file for the vocab + json_data = { + 'token_to_idx': token_to_idx, + 'idx_to_token': {v: k for k, v in token_to_idx.iteritems()}, + 'wildcards':wildcard_ids, + 'tokenize_regex':regex, + 'case_sensitive':args.case_sensitive, + 'use_ascii':args.use_ascii + } + with open(args.output_json, 'w') as f: + json.dump(json_data, f) diff --git a/scripts/preprocessLegacy.py b/scripts/preprocessLegacy.py new file mode 100644 index 00000000..90b834b6 --- /dev/null +++ b/scripts/preprocessLegacy.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- + +import argparse, json, os +import numpy as np +import h5py +import codecs + + +parser = argparse.ArgumentParser() +parser.add_argument('--input_txt', default='data/tiny-shakespeare.txt') +parser.add_argument('--output_h5', default='data/tiny-shakespeare.h5') +parser.add_argument('--output_json', default='data/tiny-shakespeare.json') +parser.add_argument('--val_frac', type=float, default=0.1) +parser.add_argument('--test_frac', type=float, default=0.1) +parser.add_argument('--quiet', action='store_true') +parser.add_argument('--encoding', default='utf-8') +args = parser.parse_args() + + +if __name__ == '__main__': + if args.encoding == 'bytes': args.encoding = None + + # First go the file once to see how big it is and to build the vocab + token_to_idx = {} + total_size = 0 + with codecs.open(args.input_txt, 'r', args.encoding) as f: + for line in f: + total_size += len(line) + for char in line: + if char not in token_to_idx: + token_to_idx[char] = len(token_to_idx) + 1 + + # Now we can figure out the split sizes + val_size = int(args.val_frac * total_size) + test_size = int(args.test_frac * total_size) + train_size = total_size - val_size - test_size + + if not args.quiet: + print 'Total vocabulary size: %d' % len(token_to_idx) + print 'Total tokens in file: %d' % total_size + print ' Training size: %d' % train_size + print ' Val size: %d' % val_size + print ' Test size: %d' % test_size + + # Choose the datatype based on the vocabulary size + dtype = np.uint8 + if len(token_to_idx) > 255: + dtype = np.uint32 + if not args.quiet: + print 'Using dtype ', dtype + + # Just load data into memory ... we'll have to do something more clever + # for huge datasets but this should be fine for now + train = np.zeros(train_size, dtype=dtype) + val = np.zeros(val_size, dtype=dtype) + test = np.zeros(test_size, dtype=dtype) + splits = [train, val, test] + + # Go through the file again and write data to numpy arrays + split_idx, cur_idx = 0, 0 + with codecs.open(args.input_txt, 'r', args.encoding) as f: + for line in f: + for char in line: + splits[split_idx][cur_idx] = token_to_idx[char] + cur_idx += 1 + if cur_idx == splits[split_idx].size: + split_idx += 1 + cur_idx = 0 + + # Write data to HDF5 file + with h5py.File(args.output_h5, 'w') as f: + f.create_dataset('train', data=train) + f.create_dataset('val', data=val) + f.create_dataset('test', data=test) + + # For 'bytes' encoding, replace non-ascii characters so the json dump + # doesn't crash + if args.encoding is None: + new_token_to_idx = {} + for token, idx in token_to_idx.iteritems(): + if ord(token) > 127: + new_token_to_idx['[%d]' % ord(token)] = idx + else: + new_token_to_idx[token] = idx + token_to_idx = new_token_to_idx + + # Dump a JSON file for the vocab + json_data = { + 'token_to_idx': token_to_idx, + 'idx_to_token': {v: k for k, v in token_to_idx.iteritems()}, + } + with open(args.output_json, 'w') as f: + json.dump(json_data, f) diff --git a/scripts/tokenize.py b/scripts/tokenize.py new file mode 100644 index 00000000..8ae38dbf --- /dev/null +++ b/scripts/tokenize.py @@ -0,0 +1,91 @@ +import argparse, json, re, random, h5py +import numpy as np +from preprocess import * + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--input_str',type=str, default='') + parser.add_argument('--input_txt',type=str,default='') + parser.add_argument('--input_folder', default='') + parser.add_argument('--input_json',type=str, default='data/tiny-shakespeare.json') + + parser.add_argument('--output_json',type=str, default='') + parser.add_argument('--output_h5', default='') + parser.add_argument('--val_frac', type=float, default=0.1) + parser.add_argument('--test_frac', type=float, default=0.1) + parser.add_argument('--encoding', default='utf-8') + parser.add_argument('--quiet', action='store_true') + args = parser.parse_args() + + token_to_idx = [] + wildcard_set = [] + regex = '' + case_sensitive = False + use_ascii = False + unified = [] + + if args.output_json == '' and args.output_h5 == '': + print 'No output file specified' + else: + with open(args.input_json) as jsonfile: + json_data = json.load(jsonfile) + token_to_idx = json_data['token_to_idx'] + wildcard_set = json_data['wildcards'] + case_sensitive = json_data['case_sensitive'] + regex = json_data['tokenize_regex'] + use_ascii = json_data['use_ascii'] + + # Build list of files + infiles = [] + file_contents = [] + if args.input_folder != '': + infiles = [os.path.join(args.input_folder,item) for item in os.listdir(args.input_folder) if item[-4:]=='.txt'] + elif args.input_txt != '': + infiles = [args.input_txt] + + if len(infiles) != 0: + file_contents = load_from_files(infiles,use_ascii,args.encoding) + + if args.input_str != '': + file_contents.append(args.input_str) + + files_parsed = [parse_file(f,regex,case_sensitive) for f in file_contents] + + outdata,wildcard_replace_count = tokenize_data(files_parsed,token_to_idx,wildcard_set) + + if args.output_h5 != '': + total_size = len(outdata) + + # Choose the datatype based on the vocabulary size + dtype = np.uint8 + if len(token_to_idx) > 255: + dtype = np.uint32 + if not args.quiet: + print 'Using dtype ', dtype + val_size = int(args.val_frac * total_size) + test_size = int(args.test_frac * total_size) + train_size = total_size - val_size - test_size + save_to_hdf5(outdata,args.output_h5,train_size,val_size,test_size,dtype) + + if not args.quiet: + if len(wildcard_set) > 0: + wildcard_spec = ' ({0} wildcards)'.format(len(wildcard_set)) + else: + wildcard_spec = '' + print 'Total vocabulary size: {0}{1}'.format(len(token_to_idx), wildcard_spec) + print 'Total tokens in file: {0}'.format(total_size) + if len(wildcard_set) > 0: + print 'Total wildcards in file: {0} ({1}%)'.format(wildcard_replace_count,100.0*wildcard_replace_count/total_size) + else: + print 'Total Ignored: {0}'.format(wildcard_replace_count) + print ' Training size: {0}'.format(train_size) + print ' Val size: {0}'.format(val_size) + print ' Test size: {0}'.format(test_size) + + if args.output_json != '': + json_data = {'tokens':outdata} + with open(args.output_json,'w') as jsonfile: + json.dump(json_data,jsonfile) + + \ No newline at end of file