Skip to content

Commit 55a6c79

Browse files
committed
update lbfgs optimizer and annealing strategy
1 parent 1272df3 commit 55a6c79

File tree

3 files changed

+710
-59
lines changed

3 files changed

+710
-59
lines changed

run/meshMotion/wingMotion/mesh_trainer_pinn.py

Lines changed: 57 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,19 @@
1212
import matplotlib.pyplot as plt
1313

1414
from sklearn.metrics import mean_squared_error
15+
16+
def annealing_weight(epoch, T_start, T_end, sharpness=3):
17+
18+
if epoch < T_start:
19+
return 0.0
20+
elif epoch > T_end:
21+
return 1.0
22+
else:
23+
# set range [0,1]
24+
x = (epoch - T_start) / (T_end - T_start)
25+
26+
return float(1 / (1 + np.exp(-sharpness * (x - 0.5)) * 100))
27+
1528
class EarlyStopping:
1629
"""Early stopping with absolute threshold and patience-based logic."""
1730

@@ -20,6 +33,7 @@ def __init__(
2033
patience: int = 40,
2134
min_delta: float = 1.0e-4,
2235
model: Union[nn.Module, None] = None,
36+
max_epochs: int = 10000
2337
):
2438
self._patience = patience
2539
self._min_delta = min_delta
@@ -29,28 +43,39 @@ def __init__(
2943
self._stop = False
3044
self._model_buffer = None
3145
self._model_script = None
46+
self._epoch = 0,
47+
self._best_loss_epoch = 0
48+
self._max_epochs = max_epochs
49+
self._T_start = 0
3250

33-
def __call__(self, loss: float) -> bool:
51+
def __call__(self, loss: float, epoch) -> bool:
3452
"""Check if training should stop."""
35-
if loss < self._best_loss * (1.0 - self._min_delta):
36-
self._best_loss = loss
37-
self._counter = 0
38-
if self._model is not None:
39-
self.save_model()
40-
41-
else:
42-
self._counter += 1
43-
if self._counter >= self._patience:
44-
self._stop = True
53+
self._epoch = epoch
54+
if self._epoch >= self._max_epochs:
55+
self._stop = True
56+
print(f"epoch: {self._epoch} reached max epochs.")
57+
if self._epoch >= self._T_start:
58+
if loss < self._best_loss * (1.0 - self._min_delta):
59+
self._best_loss = loss
60+
self._counter = 0
61+
self._best_loss_epoch = self._epoch
62+
if self._model is not None:
63+
self._save_model()
64+
else:
65+
self._counter += 1
66+
if self._counter > self._patience:
67+
self._stop = True
68+
4569
return self._stop
4670
def reset(self):
4771
"""Reset the early stopping state."""
4872
self._model.train()
4973
self._best_loss = float("inf")
5074
self._counter = 0
5175
self._stop = False
76+
self._epoch = 0
5277

53-
def save_model(self):
78+
def _save_model(self):
5479
self._model.eval()
5580
with io.BytesIO() as buffer:
5681

@@ -139,7 +164,7 @@ def train(num_mpi_ranks):
139164
torch.set_default_dtype(torch.float64)
140165

141166
# Initialize the model
142-
model = MLP(num_layers=3, layer_width=50, input_size=2, output_size=2, activation_fn=torch.nn.ReLU()).to(device)
167+
model = MLP(num_layers=3, layer_width=50, input_size=2, output_size=2, activation_fn=torch.nn.Tanh()).to(device)
143168

144169
# Initialize the optimizer
145170
learning_rate = 1e-04
@@ -148,15 +173,20 @@ def train(num_mpi_ranks):
148173
# # L-BFGS optimizer (currently active)
149174
# optimizer = optim.LBFGS(model.parameters(), lr=1.0, max_iter=20, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=100)
150175

176+
epochs = 2000
177+
# Annealing schedule parameters
178+
T_start = 0
179+
T_end = 0.5 * epochs
180+
151181
early_stopper = EarlyStopping(
152-
patience=50,
182+
patience=100,
153183
min_delta=1e-3,
154-
model=model
184+
model=model,
185+
max_epochs=epochs
155186
)
156187
# Make sure all datasets are avaialble in the smartredis database.
157188
local_time_index = 1
158189
while True:
159-
160190
print (f"Time step {local_time_index}")
161191
# Fetch datasets from SmartRedis
162192

@@ -228,10 +258,9 @@ def train(num_mpi_ranks):
228258
loss_func = nn.MSELoss()
229259

230260
model.train()
231-
epochs = 5000
232261
n_epochs = 0
233262
rmse_loss_val = 1
234-
263+
235264
for epoch in range(epochs):
236265
# Zero the gradients
237266
optimizer.zero_grad()
@@ -245,63 +274,32 @@ def train(num_mpi_ranks):
245274

246275
# Annealed weight: start with high physics weight, gradually decrease
247276
# Physics weight increase from 0.01 to 0.1 over training
248-
physics_weight = max(0.0001, 0.001 * epoch / epochs + 0.0001)
277+
physics_weight = annealing_weight(epoch, T_start, T_end, sharpness=10)
249278
data_weight = 1.0
250279

251280
loss_train = data_weight * data_loss + physics_weight * p_loss
252-
print(
253-
f"[Epoch {epoch}/{epochs}] "
254-
f"data loss: {data_loss.item():.6f}, "
255-
f"physics loss: {p_loss.item():.6f}, "
256-
f"physics_weight: {physics_weight:.4f}"
257-
)
281+
if epoch % 50 == 0 or epoch == epochs - 1:
282+
print(
283+
f"[Epoch {epoch}/{epochs}] "
284+
f"data loss: {data_loss.item()}, "
285+
f"physics loss: {p_loss.item()}, "
286+
f"physics_weight: {physics_weight}"
287+
)
258288
# Backward pass and optimization
259289
loss_train.backward()
260290
optimizer.step()
261291

262-
# for epoch in range(epochs):
263-
# # Define closure function for L-BFGS
264-
# def closure():
265-
# optimizer.zero_grad()
266-
267-
# # Forward pass on the training data
268-
# displ_pred = model(points_train)
269-
270-
# # Compute loss on the training data with annealed weight
271-
# data_loss = loss_func(displ_pred, displ_train)
272-
# p_loss = pinn_loss(points_train, displ_pred)
273-
274-
# # Annealed weight: start with high physics weight, gradually decrease
275-
# # Physics weight decreases from 1.0 to 0.01 over training
276-
# physics_weight = max(0.01, 1.0 * (1.0 - epoch / epochs))
277-
# data_weight = 1.0
278-
279-
# loss_train = data_weight * data_loss + physics_weight * p_loss
280-
# loss_train.backward()
281-
# return loss_train
282-
283-
# # L-BFGS optimization step
284-
# optimizer.step(closure)
285-
286292
n_epochs = n_epochs + 1
287293
# Forward pass on the validation data, with torch.no_grad() for efficiency
288294
with torch.no_grad():
289295
displ_pred_val = model(points_val)
290296
mse_loss_val = loss_func(displ_pred_val, displ_val)
291297
rmse_loss_val = torch.sqrt(mse_loss_val)
292-
if early_stopper(rmse_loss_val.item()):
298+
if early_stopper(rmse_loss_val.item(), epoch):
293299
print(f"Training stopped at epoch {epoch}")
294-
print (f"RMSE {early_stopper._best_loss}, number of epochs {n_epochs}")
300+
print (f"RMSE {early_stopper._best_loss}, the epochs of smallest loss: {early_stopper._best_loss_epoch}")
295301
early_stopper.reset()
296302
break
297-
298-
# if epoch % 1000 == 0 or epoch == epochs - 1:
299-
# print(f"[Epoch {epoch}]")
300-
# print(f" Data Loss : {data_loss.item():.6e}")
301-
# print(f" PINN Loss : {p_loss.item():.6e}")
302-
# print(f" Physics Weight : {physics_weight:.4f}")
303-
# print(f" Data Weight : {data_weight:.4f}")
304-
# print(f" Validation RMSE: {rmse_loss_val:.6e}")
305303

306304
# Store the model into SmartRedis
307305
client.set_model("MLP", early_stopper._model_buffer, "TORCH", "CPU")

0 commit comments

Comments
 (0)