Skip to content

Commit 24ec556

Browse files
committed
exp: Updated Optuna hyperparameter search
1 parent 75b9066 commit 24ec556

File tree

1 file changed

+99
-32
lines changed

1 file changed

+99
-32
lines changed

experiment/main.py

Lines changed: 99 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from sklearn.model_selection import train_test_split
1818
from sklearn.utils import compute_class_weight
1919
import optuna
20-
from optuna.samplers import TPESampler
21-
from optuna.pruners import HyperbandPruner
20+
from optuna.samplers import TPESampler, GridSampler, RandomSampler
21+
from optuna.pruners import HyperbandPruner, MedianPruner
2222

2323
from experiment.plot import save_plot
2424
from experiment.src.data_loader import read_detected_data, read_metadata, join_label, get_y_labels
@@ -27,9 +27,15 @@
2727
from experiment.src.model_config_preprocess import model_config_preprocess
2828
from experiment.src.prepare_data import prepare_train_data, data_checksum
2929

30+
GPU_SAMPLE_LIMIT = 1024
31+
3032

3133
def objective(trial, train_loader: DataLoader, test_loader: DataLoader, model_inputs_size: List[tuple],
32-
hp: Dict[str, tuple], device: torch.device):
34+
hp: Dict[str, tuple]):
35+
best_val_loss = trial.study.user_attrs["best_val_loss"]
36+
epochs = trial.study.user_attrs["epochs"]
37+
device = trial.study.user_attrs["device"]
38+
best_model_path = trial.study.user_attrs["best_model_path"]
3339
params = {}
3440
for param_name, ((low, high, step), default) in hp.items():
3541
params[param_name] = trial.suggest_float(param_name, low, high, step=step)
@@ -38,29 +44,82 @@ def objective(trial, train_loader: DataLoader, test_loader: DataLoader, model_in
3844
optimizer = optim.Adam(model.parameters(), lr=0.001)
3945
criterion = nn.BCELoss()
4046

41-
model.train()
42-
for _ in range(5):
47+
best_loss = float('inf')
48+
49+
patience_counter = 0
50+
51+
if device == torch.device("cuda") and GPU_SAMPLE_LIMIT < train_loader.batch_size:
52+
accumulation_steps = (train_loader.batch_size + GPU_SAMPLE_LIMIT - 1) // GPU_SAMPLE_LIMIT
53+
else:
54+
accumulation_steps = 1
55+
56+
for epoch in range(epochs):
57+
model.train()
4358
for batch in train_loader:
4459
x_tensors = [x.to(device) for x in batch[:-1]]
4560
y_batch = batch[-1].to(device)
46-
optimizer.zero_grad()
47-
outputs = model(*x_tensors).squeeze()
48-
loss = criterion(outputs, y_batch)
49-
loss.backward()
61+
batch_size = y_batch.shape[0]
62+
sub_batch_size = batch_size // accumulation_steps # sub-batch size
63+
64+
optimizer.zero_grad() # clean up gradients before calculations
65+
66+
for i in range(accumulation_steps):
67+
start = i * sub_batch_size
68+
end = (i + 1) * sub_batch_size if i < accumulation_steps - 1 else batch_size
69+
inputs_sub = [tens[start:end] for tens in x_tensors]
70+
labels_sub = y_batch[start:end]
71+
72+
outputs = model(*inputs_sub).squeeze()
73+
loss = criterion(outputs, labels_sub)
74+
loss = loss / accumulation_steps # normalize losses
75+
76+
loss.backward() # calculate gradients
77+
5078
optimizer.step()
5179

52-
model.eval()
53-
val_loss = 0.0
54-
with torch.no_grad():
55-
for batch in test_loader:
56-
x_tensors = [x.to(device) for x in batch[:-1]]
57-
y_batch = batch[-1].to(device)
58-
outputs = model(*x_tensors).squeeze()
59-
loss = criterion(outputs, y_batch)
60-
val_loss += loss.item()
61-
val_loss /= len(test_loader)
62-
return val_loss
80+
model.eval()
81+
val_loss = 0.0
82+
with torch.no_grad():
83+
for batch in test_loader:
84+
x_tensors = [x.to(device) for x in batch[:-1]]
85+
y_batch = batch[-1].to(device)
86+
87+
batch_size = y_batch.shape[0]
88+
sub_batch_size = batch_size // accumulation_steps
89+
for i in range(accumulation_steps):
90+
start = i * sub_batch_size
91+
end = (i + 1) * sub_batch_size if i < accumulation_steps - 1 else batch_size
92+
inputs_sub = [tens[start:end] for tens in x_tensors]
93+
labels_sub = y_batch[start:end]
94+
95+
outputs = model(*inputs_sub).squeeze()
96+
loss = criterion(outputs, labels_sub)
97+
loss = loss / accumulation_steps
98+
99+
val_loss += loss.item()
100+
val_loss /= len(test_loader)
101+
102+
trial.report(val_loss, epoch)
103+
104+
if val_loss < best_loss:
105+
best_loss = val_loss
106+
patience_counter = 0
107+
if val_loss < best_val_loss:
108+
best_val_loss = val_loss
109+
trial.study.set_user_attr("best_val_loss", best_val_loss)
110+
torch.save(model.state_dict(), best_model_path)
111+
else:
112+
patience_counter += 1
113+
114+
if patience_counter >= 5:
115+
print(f"Early stop on {epoch} - 5 epochs without improvement")
116+
break
117+
118+
if trial.should_prune():
119+
print(f"Early stop on {epoch} - Raise TrialPruned")
120+
raise optuna.TrialPruned()
63121

122+
return best_loss
64123

65124
def evaluate_model(thresholds: dict,
66125
model: nn.Module,
@@ -205,6 +264,7 @@ def main(cred_data_location: str,
205264

206265
x_train = [x_train_line, x_train_variable, x_train_value, x_train_features]
207266
x_test = [x_test_line, x_test_variable, x_test_value, x_test_features]
267+
x_full = [x_full_line, x_full_variable, x_full_value, x_full_features]
208268

209269
print(f"Create pytorch train and test datasets...")
210270
train_dataset = TensorDataset(*[torch.tensor(x, dtype=torch.float32) for x in x_train],
@@ -214,22 +274,31 @@ def main(cred_data_location: str,
214274
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
215275
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=2)
216276

217-
model_inputs_size = [x_full_line.shape, x_full_variable.shape, x_full_value.shape, x_full_features.shape]
277+
inputs_size = [x_full_line.shape, x_full_variable.shape, x_full_value.shape, x_full_features.shape]
218278

219279
if use_tuner:
220280
print(f"Start model train with optimization")
221-
study = optuna.create_study(sampler=TPESampler(), pruner=HyperbandPruner(), direction="minimize")
222-
study.optimize(lambda trial: objective(trial, train_loader, test_loader, model_inputs_size, hp_dict, device),
223-
n_trials=20)
281+
search_space = {} # Only for GridSearch
282+
for param_name, ((low, high, step), default) in hp_dict.items():
283+
search_space[param_name] = list(np.arange(low, high + step, step))
284+
285+
study = optuna.create_study(sampler=GridSampler(search_space), direction="minimize")
286+
study.set_user_attr("best_val_loss", float("inf")) # initialize best value
287+
study.set_user_attr("epochs", epochs) # initialize epochs
288+
study.set_user_attr("device", device)
289+
study.set_user_attr("best_model_path", str(dir_path / f"{current_time}.trials.best_model.pth"))
290+
study.optimize(lambda trial: objective(trial, train_loader, test_loader, inputs_size, hp_dict), n_trials=10)
224291
param_kwargs = study.best_params
225292
print(f"Best hyperparameters: {param_kwargs}")
293+
df_trials = study.trials_dataframe()
294+
df_trials.to_csv(dir_path / f"{current_time}_trials_df.csv", sep=';')
226295
else:
227-
param_kwargs = {k: v[1] for k, v in hp_dict.items()}
296+
param_kwargs = {param_name: default for param_name, ((low, high, step), default) in hp_dict.items()}
228297

229298
print(f"Model will be trained using the following params:{param_kwargs}")
230299

231300
# repeat train step to obtain actual history chart
232-
ml_model = MlModel(*model_inputs_size, param_kwargs).to(device)
301+
ml_model = MlModel(*inputs_size, param_kwargs).to(device)
233302

234303
optimizer = optim.Adam(ml_model.parameters(), lr=0.001)
235304
criterion = nn.BCELoss()
@@ -306,30 +375,28 @@ def main(cred_data_location: str,
306375
ml_model.load_state_dict(torch.load(dir_path / f"{current_time}.best_model.pth"))
307376

308377
print(f"Validate results on the train subset. Size: {len(y_train)} {np.mean(y_train):.4f}")
309-
evaluate_model(thresholds, ml_model, [x_train_line, x_train_variable, x_train_value, x_train_features], y_train,
310-
device, batch_size)
378+
evaluate_model(thresholds, ml_model, x_train, y_train, device, batch_size)
311379
del x_train_line
312380
del x_train_variable
313381
del x_train_value
314382
del x_train_features
315383
del y_train
316384

317385
print(f"Validate results on the test subset. Size: {len(y_test)} {np.mean(y_test):.4f}")
318-
evaluate_model(thresholds, ml_model, [x_test_line, x_test_variable, x_test_value, x_test_features], y_test, device,
319-
batch_size)
386+
evaluate_model(thresholds, ml_model, x_test, y_test, device, batch_size)
320387
del x_test_line
321388
del x_test_variable
322389
del x_test_value
323390
del x_test_features
324391
del y_test
325392

326393
print(f"Validate results on the full set. Size: {len(y_full)} {np.mean(y_full):.4f}")
327-
evaluate_model(thresholds, ml_model, [x_full_line, x_full_variable, x_full_value, x_full_features], y_full, device,
328-
batch_size)
394+
evaluate_model(thresholds, ml_model, x_full, y_full, device, batch_size)
329395
del x_full_line
330396
del x_full_variable
331397
del x_full_value
332398
del x_full_features
399+
del x_full
333400
del y_full
334401

335402
onnx_model_file = pathlib.Path(__file__).parent.parent / "credsweeper" / "ml_model" / "ml_model.onnx"

0 commit comments

Comments
 (0)