|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright 2018-present, HKUST-KnowComp. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | +"""Model architecture/optimization options for WRMCQA document reader.""" |
| 8 | + |
| 9 | +import argparse |
| 10 | +import logging |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | +# Index of arguments concerning the core model architecture |
| 15 | +MODEL_ARCHITECTURE = { |
| 16 | + 'model_type', 'embedding_dim', 'char_embedding_dim', 'hidden_size', 'char_hidden_size', |
| 17 | + 'doc_layers', 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge', |
| 18 | + 'use_qemb', 'use_exact_match', 'use_pos', 'use_ner', 'use_lemma', 'use_tf', 'hop' |
| 19 | +} |
| 20 | + |
| 21 | +# Index of arguments concerning the model optimizer/training |
| 22 | +MODEL_OPTIMIZER = { |
| 23 | + 'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay', |
| 24 | + 'rho', 'eps', 'max_len', 'grad_clipping', 'tune_partial', |
| 25 | + 'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb' |
| 26 | +} |
| 27 | + |
| 28 | + |
| 29 | +def str2bool(v): |
| 30 | + return v.lower() in ('yes', 'true', 't', '1', 'y') |
| 31 | + |
| 32 | + |
| 33 | +def add_model_args(parser): |
| 34 | + parser.register('type', 'bool', str2bool) |
| 35 | + |
| 36 | + # Model architecture |
| 37 | + model = parser.add_argument_group('WRMCQA Reader Model Architecture') |
| 38 | + model.add_argument('--model-type', type=str, default='rnn', |
| 39 | + help='Model architecture type: rnn, r_net, mnemonic') |
| 40 | + model.add_argument('--embedding-dim', type=int, default=300, |
| 41 | + help='Embedding size if embedding_file is not given') |
| 42 | + model.add_argument('--char-embedding-dim', type=int, default=50, |
| 43 | + help='Embedding size if char_embedding_file is not given') |
| 44 | + model.add_argument('--hidden-size', type=int, default=100, |
| 45 | + help='Hidden size of RNN units') |
| 46 | + model.add_argument('--char-hidden-size', type=int, default=50, |
| 47 | + help='Hidden size of char RNN units') |
| 48 | + model.add_argument('--doc-layers', type=int, default=3, |
| 49 | + help='Number of encoding layers for document') |
| 50 | + model.add_argument('--question-layers', type=int, default=3, |
| 51 | + help='Number of encoding layers for question') |
| 52 | + model.add_argument('--rnn-type', type=str, default='lstm', |
| 53 | + help='RNN type: LSTM, GRU, or RNN') |
| 54 | + |
| 55 | + # Model specific details |
| 56 | + detail = parser.add_argument_group('WRMCQA Reader Model Details') |
| 57 | + detail.add_argument('--concat-rnn-layers', type='bool', default=True, |
| 58 | + help='Combine hidden states from each encoding layer') |
| 59 | + detail.add_argument('--question-merge', type=str, default='self_attn', |
| 60 | + help='The way of computing the question representation') |
| 61 | + detail.add_argument('--use-qemb', type='bool', default=True, |
| 62 | + help='Whether to use weighted question embeddings') |
| 63 | + detail.add_argument('--use-exact-match', type='bool', default=True, |
| 64 | + help='Whether to use in_question_* features') |
| 65 | + detail.add_argument('--use-pos', type='bool', default=True, |
| 66 | + help='Whether to use pos features') |
| 67 | + detail.add_argument('--use-ner', type='bool', default=True, |
| 68 | + help='Whether to use ner features') |
| 69 | + detail.add_argument('--use-lemma', type='bool', default=True, |
| 70 | + help='Whether to use lemma features') |
| 71 | + detail.add_argument('--use-tf', type='bool', default=True, |
| 72 | + help='Whether to use term frequency features') |
| 73 | + detail.add_argument('--hop', type=int, default=2, |
| 74 | + help='The number of hops for both aligner and the answer pointer in m-reader') |
| 75 | + |
| 76 | + # Optimization details |
| 77 | + optim = parser.add_argument_group('WRMCQA Reader Optimization') |
| 78 | + optim.add_argument('--dropout-emb', type=float, default=0.2, |
| 79 | + help='Dropout rate for word embeddings') |
| 80 | + optim.add_argument('--dropout-rnn', type=float, default=0.2, |
| 81 | + help='Dropout rate for RNN states') |
| 82 | + optim.add_argument('--dropout-rnn-output', type='bool', default=True, |
| 83 | + help='Whether to dropout the RNN output') |
| 84 | + optim.add_argument('--optimizer', type=str, default='adamax', |
| 85 | + help='Optimizer: sgd, adamax, adadelta') |
| 86 | + optim.add_argument('--learning-rate', type=float, default=1.0, |
| 87 | + help='Learning rate for sgd, adadelta') |
| 88 | + optim.add_argument('--grad-clipping', type=float, default=10, |
| 89 | + help='Gradient clipping') |
| 90 | + optim.add_argument('--weight-decay', type=float, default=0, |
| 91 | + help='Weight decay factor') |
| 92 | + optim.add_argument('--momentum', type=float, default=0, |
| 93 | + help='Momentum factor') |
| 94 | + optim.add_argument('--rho', type=float, default=0.95, |
| 95 | + help='Rho for adadelta') |
| 96 | + optim.add_argument('--eps', type=float, default=1e-6, |
| 97 | + help='Eps for adadelta') |
| 98 | + optim.add_argument('--fix-embeddings', type='bool', default=True, |
| 99 | + help='Keep word embeddings fixed (use pretrained)') |
| 100 | + optim.add_argument('--tune-partial', type=int, default=0, |
| 101 | + help='Backprop through only the top N question words') |
| 102 | + optim.add_argument('--rnn-padding', type='bool', default=False, |
| 103 | + help='Explicitly account for padding in RNN encoding') |
| 104 | + optim.add_argument('--max-len', type=int, default=15, |
| 105 | + help='The max span allowed during decoding') |
| 106 | + |
| 107 | + |
| 108 | +def get_model_args(args): |
| 109 | + """Filter args for model ones. |
| 110 | +
|
| 111 | + From a args Namespace, return a new Namespace with *only* the args specific |
| 112 | + to the model architecture or optimization. (i.e. the ones defined here.) |
| 113 | + """ |
| 114 | + global MODEL_ARCHITECTURE, MODEL_OPTIMIZER |
| 115 | + required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER |
| 116 | + arg_values = {k: v for k, v in vars(args).items() if k in required_args} |
| 117 | + return argparse.Namespace(**arg_values) |
| 118 | + |
| 119 | + |
| 120 | +def override_model_args(old_args, new_args): |
| 121 | + """Set args to new parameters. |
| 122 | +
|
| 123 | + Decide which model args to keep and which to override when resolving a set |
| 124 | + of saved args and new args. |
| 125 | +
|
| 126 | + We keep the new optimation, but leave the model architecture alone. |
| 127 | + """ |
| 128 | + global MODEL_OPTIMIZER |
| 129 | + old_args, new_args = vars(old_args), vars(new_args) |
| 130 | + for k in old_args.keys(): |
| 131 | + if k in new_args and old_args[k] != new_args[k]: |
| 132 | + if k in MODEL_OPTIMIZER: |
| 133 | + logger.info('Overriding saved %s: %s --> %s' % |
| 134 | + (k, old_args[k], new_args[k])) |
| 135 | + old_args[k] = new_args[k] |
| 136 | + else: |
| 137 | + logger.info('Keeping saved %s: %s' % (k, old_args[k])) |
| 138 | + return argparse.Namespace(**old_args) |
0 commit comments