-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_multiconfig.py
More file actions
executable file
·124 lines (101 loc) · 3.37 KB
/
train_multiconfig.py
File metadata and controls
executable file
·124 lines (101 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import wandb
from typing import Dict
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from dataset import SubDialogueDataModule
from model import GiBERTino
from utils.constants import RELATIONS
def run_experiment(dataset: str, graph_type: str, params: Dict):
log_steps = {"STAC": 5, "MOLWENI": 15, "MINECRAFT": 50, "BALANCED": 50}[dataset]
experiment_name = f"giBERTino-improved-{params['gnn_model']}-{dataset}"
data_root = f"data/{dataset}/{graph_type}-graphs"
log_dir = f"./lightning_logs/baselines/{experiment_name}"
tokenizer = "Alibaba-NLP/gte-modernbert-base"
bert_model = "Alibaba-NLP/gte-modernbert-base"
# TODO: explore with dialogpt
print(f"\n{'#' * 50}")
print(f"Starting {experiment_name}")
data_module = SubDialogueDataModule(
root=data_root,
batch_size=params["batch_size"],
num_workers=0,
)
model = GiBERTino(
in_channels=770,
gnn_model=params["gnn_model"],
hidden_channels=params["hidden_channels"],
num_layers=params["num_layers"],
num_relations=len(RELATIONS['UNIFIED']),
lr=params["lr"],
dataset_name=dataset, # noqa
tokenizer=tokenizer,
bert_model=bert_model
)
wandb.finish()
logger = WandbLogger(
name=experiment_name,
save_dir="lightning_logs/improved",
project="giBERTino",
log_model=False,
)
# Initialize callbadocks
early_stop = EarlyStopping(
monitor="val_loss",
min_delta=0.0,
patience=10,
verbose=True,
mode="min",
strict=True,
check_finite=True,
log_rank_zero_only=False,
)
checkpoint_callback = ModelCheckpoint(
filename=f"{experiment_name}-{{epoch:02d}}-{{val_loss:.2f}}",
monitor="val_loss",
save_top_k=1,
save_weights_only=False,
mode="min",
auto_insert_metric_name=True,
save_on_train_epoch_end=True,
enable_version_counter=True,
)
trainer = Trainer(
precision="16-mixed",
accelerator="auto",
gpus=[0, 1],
max_epochs=30,
logger=logger,
callbacks=[early_stop, checkpoint_callback],
log_every_n_steps=log_steps,
default_root_dir=log_dir,
gradient_clip_val=1.0,
gradient_clip_algorithm="norm"
)
try:
trainer.fit(model, datamodule=data_module)
return True
except Exception as e:
print(f"Training failed: {e}")
return False
def main():
configurations = {
"GCN": {"hidden_channels": 128, "num_layers": 64, "lr": 1e-5, "batch_size": 32},
"GraphSAGE": {
"hidden_channels": 512,
"num_layers": 32,
"lr": 1e-5,
"batch_size": 32,
},
}
datasets = [ "STAC", "MOLWENI", "MINECRAFT", "BALANCED"]
graph_types = ["alibaba"]
for model_name, params in configurations.items():
params["gnn_model"] = model_name
for dataset in datasets:
for graph_type in graph_types:
success = run_experiment(dataset, graph_type, params)
status = "SUCCESS" if success else "FAILED"
print(f"\n{status} - {model_name} - {dataset} - {graph_type}\n")
if __name__ == "__main__":
main()