2323import torch
2424import torch .nn
2525
26- import lm . data as d
27- from lm . models import LanguageModel
28- from lm . eval import evaluate
26+ import data as d
27+ from models import LanguageModel
28+ from eval import evaluate
2929
3030
3131def get_args ():
@@ -37,7 +37,6 @@ def get_args():
3737 argparser .add_argument ('--batch_size' , type = int , default = 80 )
3838 argparser .add_argument ('--directory' , type = str , required = False , help = 'model directory for checkpoints and config' )
3939 argparser .add_argument ('--hidden' , action = 'store_true' , help = 'returns the hidden states of the whole dataset to perform analysis' )
40- argparser .add_argument ('--prune' , type = float , default = 0.0 )
4140
4241 return argparser .parse_args ()
4342
@@ -85,14 +84,19 @@ def main(args):
8584 model = LanguageModel (** model_args ).to (device )
8685 elif config ['rnn_type' ] == 'egru' :
8786 model = LanguageModel (** model_args ,
88- dampening_factor = config ['damp_factor ' ],
87+ dampening_factor = config ['pseudo_derivative_width ' ],
8988 pseudo_derivative_support = config ['pseudo_derivative_width' ]).to (device )
9089 else :
9190 raise RuntimeError ("Unknown RNN type: %s" % config ['rnn_type' ])
9291
9392 best_model_path = os .path .join (args .directory , 'checkpoints' , f"{ config ['rnn_type' ].upper ()} _best_model.cpt" )
9493 model .load_state_dict (torch .load (best_model_path , map_location = device ))
9594
95+ if model_args ['rnn_type' ] == 'egru' :
96+ hidden_dims = [rnn .hidden_size for rnn in model .rnns ]
97+ else :
98+ hidden_dims = [rnn .module .hidden_size if args .dropout_connect > 0 else rnn .hidden_size for rnn in model .rnns ]
99+
96100 criterion = torch .nn .CrossEntropyLoss ()
97101
98102 if args .hidden :
@@ -104,6 +108,7 @@ def main(args):
104108 bptt = config ['bptt' ],
105109 ntokens = vocab_size ,
106110 device = device ,
111+ hidden_dims = hidden_dims ,
107112 return_hidden = True )
108113 save_file = os .path .join (args .directory , f'hidden_states_{ args .datasplit } .hdf' )
109114 with h5py .File (save_file , 'w' ) as f :
@@ -121,6 +126,7 @@ def main(args):
121126 bptt = config ['bptt' ],
122127 ntokens = vocab_size ,
123128 device = device ,
129+ hidden_dims = hidden_dims ,
124130 return_hidden = False )
125131
126132 test_ppl = math .exp (test_loss )
@@ -131,58 +137,6 @@ def main(args):
131137 print (f'Layerwise activity { test_layerwise_activity_mean .tolist ()} +- { test_layerwise_activity_std .tolist ()} ' )
132138 print ('=' * 89 )
133139
134- if args .prune > 0.0 and args .hidden :
135- print (f"Model Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
136- input_indices = torch .arange (model .rnns [0 ].input_size ).to (device )
137- for i in range (model .nlayers ):
138- if i < model .nlayers - 1 :
139- # get event frequencies
140- hid_dim = all_hiddens [i ].shape [2 ]
141- hid_cells = all_hiddens [i ].reshape (- 1 , hid_dim )
142- seq_len = hid_cells .shape [0 ]
143- spike_frequency = torch .sum (hid_cells != 0 , dim = 0 ) / seq_len
144- print (
145- f"Layer { i + 1 } : "
146- f"less than 1/100: { torch .sum (spike_frequency < 0.01 )} / { spike_frequency .shape } "
147- f"// never: { torch .sum (hid_cells .sum (dim = 0 ) == 0 )} / { spike_frequency .shape } " )
148-
149- # compute remaining indicies from spike frequencies
150- topk = int (model .rnns [i ].hidden_size * (1 - args .prune ))
151- hidden_indices , _ = torch .sort (torch .argsort (spike_frequency , descending = True )[:topk ], descending = False )
152- hidden_indices = hidden_indices .to (device )
153- else :
154- hidden_indices = torch .arange (model .rnns [i ].hidden_size ).to (device )
155- model .rnns [i ].prune_units (input_indices , hidden_indices )
156- input_indices = hidden_indices
157-
158- print (f"Model Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
159-
160- test_loss , test_activity , test_layerwise_activity_mean , test_layerwise_activity_std , centered_cell_states , all_hiddens = \
161- evaluate (model = model ,
162- eval_data = test_data ,
163- criterion = criterion ,
164- batch_size = args .batch_size ,
165- bptt = config ['bptt' ],
166- ntokens = vocab_size ,
167- device = device ,
168- return_hidden = True )
169- for i in range (model .nlayers - 1 ):
170- # get event frequencies
171- hid_dim = all_hiddens [i ].shape [2 ]
172- hid_cells = all_hiddens [i ].reshape (- 1 , hid_dim )
173- seq_len = hid_cells .shape [0 ]
174- spike_frequency = torch .sum (hid_cells != 0 , dim = 0 ) / seq_len
175- print (
176- f"less than 1/100: { torch .sum (spike_frequency < 0.01 )} / { spike_frequency .shape } "
177- f"// never: { torch .sum (hid_cells .sum (dim = 0 ) == 0 )} / { spike_frequency .shape } " )
178- test_ppl = math .exp (test_loss )
179- print ('=' * 89 )
180- print (f'| Inference | test loss { test_loss :5.2f} | '
181- f'test ppl { test_ppl :8.2f} | '
182- f'test mean activity { test_activity } ' )
183- print (f'Layerwise activity { test_layerwise_activity_mean .tolist ()} +- { test_layerwise_activity_std .tolist ()} ' )
184- print ('=' * 89 )
185-
186140
187141if __name__ == "__main__" :
188142 args = get_args ()
0 commit comments