1212import matplotlib .pyplot as plt
1313
1414from sklearn .metrics import mean_squared_error
15+
16+ def annealing_weight (epoch , T_start , T_end , sharpness = 3 ):
17+
18+ if epoch < T_start :
19+ return 0.0
20+ elif epoch > T_end :
21+ return 1.0
22+ else :
23+ # set range [0,1]
24+ x = (epoch - T_start ) / (T_end - T_start )
25+
26+ return float (1 / (1 + np .exp (- sharpness * (x - 0.5 )) * 100 ))
27+
1528class EarlyStopping :
1629 """Early stopping with absolute threshold and patience-based logic."""
1730
@@ -20,6 +33,7 @@ def __init__(
2033 patience : int = 40 ,
2134 min_delta : float = 1.0e-4 ,
2235 model : Union [nn .Module , None ] = None ,
36+ max_epochs : int = 10000
2337 ):
2438 self ._patience = patience
2539 self ._min_delta = min_delta
@@ -29,28 +43,39 @@ def __init__(
2943 self ._stop = False
3044 self ._model_buffer = None
3145 self ._model_script = None
46+ self ._epoch = 0 ,
47+ self ._best_loss_epoch = 0
48+ self ._max_epochs = max_epochs
49+ self ._T_start = 0
3250
33- def __call__ (self , loss : float ) -> bool :
51+ def __call__ (self , loss : float , epoch ) -> bool :
3452 """Check if training should stop."""
35- if loss < self ._best_loss * (1.0 - self ._min_delta ):
36- self ._best_loss = loss
37- self ._counter = 0
38- if self ._model is not None :
39- self .save_model ()
40-
41- else :
42- self ._counter += 1
43- if self ._counter >= self ._patience :
44- self ._stop = True
53+ self ._epoch = epoch
54+ if self ._epoch >= self ._max_epochs :
55+ self ._stop = True
56+ print (f"epoch: { self ._epoch } reached max epochs." )
57+ if self ._epoch >= self ._T_start :
58+ if loss < self ._best_loss * (1.0 - self ._min_delta ):
59+ self ._best_loss = loss
60+ self ._counter = 0
61+ self ._best_loss_epoch = self ._epoch
62+ if self ._model is not None :
63+ self ._save_model ()
64+ else :
65+ self ._counter += 1
66+ if self ._counter > self ._patience :
67+ self ._stop = True
68+
4569 return self ._stop
4670 def reset (self ):
4771 """Reset the early stopping state."""
4872 self ._model .train ()
4973 self ._best_loss = float ("inf" )
5074 self ._counter = 0
5175 self ._stop = False
76+ self ._epoch = 0
5277
53- def save_model (self ):
78+ def _save_model (self ):
5479 self ._model .eval ()
5580 with io .BytesIO () as buffer :
5681
@@ -139,7 +164,7 @@ def train(num_mpi_ranks):
139164 torch .set_default_dtype (torch .float64 )
140165
141166 # Initialize the model
142- model = MLP (num_layers = 3 , layer_width = 50 , input_size = 2 , output_size = 2 , activation_fn = torch .nn .ReLU ()).to (device )
167+ model = MLP (num_layers = 3 , layer_width = 50 , input_size = 2 , output_size = 2 , activation_fn = torch .nn .Tanh ()).to (device )
143168
144169 # Initialize the optimizer
145170 learning_rate = 1e-04
@@ -148,15 +173,20 @@ def train(num_mpi_ranks):
148173 # # L-BFGS optimizer (currently active)
149174 # optimizer = optim.LBFGS(model.parameters(), lr=1.0, max_iter=20, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=100)
150175
176+ epochs = 2000
177+ # Annealing schedule parameters
178+ T_start = 0
179+ T_end = 0.5 * epochs
180+
151181 early_stopper = EarlyStopping (
152- patience = 50 ,
182+ patience = 100 ,
153183 min_delta = 1e-3 ,
154- model = model
184+ model = model ,
185+ max_epochs = epochs
155186 )
156187 # Make sure all datasets are avaialble in the smartredis database.
157188 local_time_index = 1
158189 while True :
159-
160190 print (f"Time step { local_time_index } " )
161191 # Fetch datasets from SmartRedis
162192
@@ -228,10 +258,9 @@ def train(num_mpi_ranks):
228258 loss_func = nn .MSELoss ()
229259
230260 model .train ()
231- epochs = 5000
232261 n_epochs = 0
233262 rmse_loss_val = 1
234-
263+
235264 for epoch in range (epochs ):
236265 # Zero the gradients
237266 optimizer .zero_grad ()
@@ -245,63 +274,32 @@ def train(num_mpi_ranks):
245274
246275 # Annealed weight: start with high physics weight, gradually decrease
247276 # Physics weight increase from 0.01 to 0.1 over training
248- physics_weight = max ( 0.0001 , 0.001 * epoch / epochs + 0.0001 )
277+ physics_weight = annealing_weight ( epoch , T_start , T_end , sharpness = 10 )
249278 data_weight = 1.0
250279
251280 loss_train = data_weight * data_loss + physics_weight * p_loss
252- print (
253- f"[Epoch { epoch } /{ epochs } ] "
254- f"data loss: { data_loss .item ():.6f} , "
255- f"physics loss: { p_loss .item ():.6f} , "
256- f"physics_weight: { physics_weight :.4f} "
257- )
281+ if epoch % 50 == 0 or epoch == epochs - 1 :
282+ print (
283+ f"[Epoch { epoch } /{ epochs } ] "
284+ f"data loss: { data_loss .item ()} , "
285+ f"physics loss: { p_loss .item ()} , "
286+ f"physics_weight: { physics_weight } "
287+ )
258288 # Backward pass and optimization
259289 loss_train .backward ()
260290 optimizer .step ()
261291
262- # for epoch in range(epochs):
263- # # Define closure function for L-BFGS
264- # def closure():
265- # optimizer.zero_grad()
266-
267- # # Forward pass on the training data
268- # displ_pred = model(points_train)
269-
270- # # Compute loss on the training data with annealed weight
271- # data_loss = loss_func(displ_pred, displ_train)
272- # p_loss = pinn_loss(points_train, displ_pred)
273-
274- # # Annealed weight: start with high physics weight, gradually decrease
275- # # Physics weight decreases from 1.0 to 0.01 over training
276- # physics_weight = max(0.01, 1.0 * (1.0 - epoch / epochs))
277- # data_weight = 1.0
278-
279- # loss_train = data_weight * data_loss + physics_weight * p_loss
280- # loss_train.backward()
281- # return loss_train
282-
283- # # L-BFGS optimization step
284- # optimizer.step(closure)
285-
286292 n_epochs = n_epochs + 1
287293 # Forward pass on the validation data, with torch.no_grad() for efficiency
288294 with torch .no_grad ():
289295 displ_pred_val = model (points_val )
290296 mse_loss_val = loss_func (displ_pred_val , displ_val )
291297 rmse_loss_val = torch .sqrt (mse_loss_val )
292- if early_stopper (rmse_loss_val .item ()):
298+ if early_stopper (rmse_loss_val .item (), epoch ):
293299 print (f"Training stopped at epoch { epoch } " )
294- print (f"RMSE { early_stopper ._best_loss } , number of epochs { n_epochs } " )
300+ print (f"RMSE { early_stopper ._best_loss } , the epochs of smallest loss: { early_stopper . _best_loss_epoch } " )
295301 early_stopper .reset ()
296302 break
297-
298- # if epoch % 1000 == 0 or epoch == epochs - 1:
299- # print(f"[Epoch {epoch}]")
300- # print(f" Data Loss : {data_loss.item():.6e}")
301- # print(f" PINN Loss : {p_loss.item():.6e}")
302- # print(f" Physics Weight : {physics_weight:.4f}")
303- # print(f" Data Weight : {data_weight:.4f}")
304- # print(f" Validation RMSE: {rmse_loss_val:.6e}")
305303
306304 # Store the model into SmartRedis
307305 client .set_model ("MLP" , early_stopper ._model_buffer , "TORCH" , "CPU" )
0 commit comments