1+ import argparse
2+ import numpy
3+ import os
4+ import pickle
5+ import torch
6+ import tqdm
7+
8+ from typing import Literal , Union
9+
10+ # Local imports
11+ from parameters_only import get_params_NN
12+ from hybrid_1 import get_hybrid_1_NN
13+ from hybrid_2 import get_hybrid_2_NN
14+ from black_box import get_bb_NN
15+ from utils import epoch
16+
17+ def train (* ,
18+ L : int = 30 ,
19+ n : int = 1 ,
20+ key : Literal ['params' , 'hybrid_1' , 'hybrid_2' , 'bb' ] = 'params' ,
21+ seed : int = None ,
22+ dt : Union [float , torch .Tensor ] = 0.2 ,
23+ training_data : str ,
24+ recursive : bool = False ,
25+ out_dir : str = None ,
26+ N_epochs : int = None
27+ ):
28+ # Set the seed, if passed
29+ if seed is not None :
30+ numpy .random .seed (seed )
31+ torch .manual_seed (seed )
32+
33+ # Load the training data and set the initial condition
34+ Y = torch .load (training_data , weights_only = True )
35+ y0 = Y [:, 0 , :]
36+
37+ # Set up a dictionary with all the required data
38+ data = {
39+ 'Y_target' : Y [:n ],
40+ 'X_input' : Y [:n , :(L + 1 )].flatten (start_dim = 1 ),
41+ 'loss' : []
42+ }
43+
44+ # Add dataset identifier if sweeping over more than one training dataset
45+ if n > 1 :
46+ if recursive :
47+ data ['Y_input' ] = Y [:n , :L ]
48+ data ['z' ] = Y [:n , :3 , :2 ].flatten (start_dim = 1 )
49+ else :
50+ data ['Y_input' ] = torch .cat ([
51+ Y [:n , :L , :], Y [:n , :3 , :2 ].flatten (start_dim = 1 )[:, None , :].repeat (1 , L , 1 )
52+ ], dim = 2 )
53+ else :
54+ data ['Y_input' ] = Y [:n , :(L + 1 )]
55+
56+ # Get the neural network
57+ if key == 'params' :
58+ data ['NN' ] = get_params_NN (input_size = data ['X_input' ].shape [1 ], z = 6 if n > 1 else 0 )
59+ elif key == 'hybrid_1' :
60+ data ['NN' ] = get_hybrid_1_NN (input_size = data ['X_input' ].shape [1 ], z = 6 if n > 1 else 0 )
61+ elif key == 'hybrid_2' :
62+ data ['NN' ] = get_hybrid_2_NN (z = 6 if n > 1 else 0 )
63+ elif key == 'bb' :
64+ data ['NN' ] = get_bb_NN (z = 6 if n > 1 else 0 )
65+
66+ # Save the trained network and loss evolution to a folder
67+ if out_dir is not None :
68+ path_name = f"{ key } __n_{ n } __L_{ L } "
69+ if recursive :
70+ path_name += "__recursive"
71+ if seed is not None :
72+ path_name += f"__seed_{ seed } "
73+ path_name = os .path .expanduser (os .path .join (out_dir , path_name ))
74+ os .makedirs (path_name , exist_ok = True )
75+ if key != 'hybrid_1' :
76+ data ['NN' ].load_state_dict (torch .load (f"{ path_name } /NN.pt" , weights_only = True ))
77+ data ['NN' ].eval ()
78+ else :
79+ data ['NN' ]['const_params' ].load_state_dict (torch .load (f"{ path_name } /NN_const_params.pt" , weights_only = True ))
80+ data ['NN' ]['time_dep_params' ].load_state_dict (torch .load (f"{ path_name } /NN_time_dep_params.pt" , weights_only = True ))
81+ data ['NN' ]['const_params' ].eval ()
82+ data ['NN' ]['time_dep_params' ].eval ()
83+ with open (f"{ path_name } /loss.pickle" , "rb" ) as f :
84+ loss = pickle .load (f )
85+
86+ # Train for N_epochs
87+ if N_epochs is None :
88+ N_epochs = 10000 if key == 'params' else 20000
89+
90+ N_epochs -= len (loss )
91+ print (f'Remaining: { N_epochs } ' )
92+ for i in tqdm .tqdm (range (N_epochs )):
93+ epoch (key = key , NN = data ['NN' ], X_input = data ['X_input' ], Y_target = data ['Y_target' ][:, :L ], Y_input = data ['Y_input' ],
94+ dt = dt , y0 = y0 , t_span = (0 , (L - 1 )* dt ), recursive = recursive , z = data .get ('z' , None ), loss_array = data ['loss' ])
95+
96+ # Store the results every 100 epochs
97+ if ((i > 0 and i % 100 == 0 ) or i == N_epochs - 1 ) and out_dir is not None :
98+ if key != 'hybrid_1' :
99+ torch .save (data ['NN' ].state_dict (), f"{ path_name } /NN.pt" )
100+ else :
101+ torch .save (data ['NN' ]['const_params' ].state_dict (), f"{ path_name } /NN_const_params.pt" )
102+ torch .save (data ['NN' ]['time_dep_params' ].state_dict (), f"{ path_name } /NN_time_dep_params.pt" )
103+ with open (f"{ path_name } /loss.pickle" , "wb" ) as f :
104+ pickle .dump (data ['loss' ], f )
105+
106+ if __name__ == "__main__" :
107+ parser = argparse .ArgumentParser ()
108+ parser .add_argument ("--L" , type = int , default = 30 , help = "Length of training time series" )
109+ parser .add_argument ("--n" , type = int , default = 1 , help = "Number of training datasets to use" )
110+ parser .add_argument ("--key" , type = str , default = 'params' , help = "Model to use" )
111+ parser .add_argument ("--training_data" , type = str , help = "Path to training data" )
112+ parser .add_argument ("--recursive" , action = "store_true" , help = "Whether to generate predictions recursively" )
113+ parser .add_argument ("--seed" , type = int , default = None , help = "Set the seed" )
114+ parser .add_argument ("--N_epochs" , type = int , default = None , help = "Number of training epochs" )
115+ parser .add_argument ("--out_dir" , type = str , help = "Output directory" )
116+ args = parser .parse_args ()
117+
118+ train (key = args .key , n = args .n , L = args .L , training_data = args .training_data , recursive = args .recursive , seed = args .seed ,
119+ N_epochs = args .N_epochs , out_dir = args .out_dir )
0 commit comments