Skip to content

Commit 447eb87

Browse files
updates
1 parent 91b1938 commit 447eb87

File tree

4 files changed

+74
-13
lines changed

4 files changed

+74
-13
lines changed

ML/helper_functions/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
valid_dataloader: DataLoader,
1515
criterion: torch.nn,
1616
optimizer: torch.optim,
17-
lr_schedular: bool = None,
17+
lr_schedular=None,
1818
) -> None:
1919
self.model = model
2020
self.epochs = epochs
@@ -75,4 +75,9 @@ def train(self, run_name: str) -> None:
7575
self.save_model(run_name)
7676

7777
def save_model(self, run_name: str) -> None:
78-
pass
78+
if run_name not in os.listdir("./ML/predictions/"):
79+
os.mkdir(f"./ML/predictions/{run_name}")
80+
torch.save(self.model, f"./ML/predictions/{run_name}/model.pt")
81+
torch.save(self.model, f"./ML/predictions/{run_name}/model.pth")
82+
torch.save(self.model.state_dict(), f"./ML/predictions/{run_name}/model_state_dict.pt")
83+
torch.save(self.model.state_dict(), f"./ML/predictions/{run_name}/model_state_dict.pth")

ML/helper_functions/transformations/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ def transform(self) -> torchtext.transforms.Sequential:
3838
)
3939
return t
4040

41-
def model_transform(self, model) -> torchtext.transforms:
42-
return model.transforms()
41+
def model_transform(self, model=XLMR_BASE_ENCODER) -> torchtext.transforms:
42+
return model.transform()

run.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,33 @@
11
from ML import *
22

3-
lrs = [1e-0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6]
4-
for lr in lrs:
3+
4+
def train(
5+
batch_size: int = 32,
6+
lr: float = 0.01,
7+
test_split: float = 0.25,
8+
optimizer=optim.Adam,
9+
epochs: int = 5,
10+
name: str = "",
11+
lr_schedular=None,
12+
transforms=None,
13+
):
514
train_data_loader, test_data_loader, valid_data_loader = Load_Data(
615
Main_DL,
716
Valid_Loader,
817
[
918
"/media/user/Main/Programmer-RD-AI/Programming/Learning/JS/NLP-Disaster-Tweets/ML/data/train.csv",
10-
32,
11-
Transformer().transform(),
19+
batch_size,
20+
transforms,
1221
],
1322
[
1423
"/media/user/Main/Programmer-RD-AI/Programming/Learning/JS/NLP-Disaster-Tweets/ML/data/test.csv",
1524
1,
1625
],
17-
0.25,
26+
test_split,
1827
42,
1928
).ld()
2029
model = TL().to(device)
21-
optimizer = optim.Adam(model.parameters(), lr=lr)
30+
optimizer = optimizer(model.parameters(), lr=lr)
2231
criterion = nn.CrossEntropyLoss()
2332
config = {
2433
"model": model,
@@ -28,11 +37,58 @@
2837
}
2938
Train(
3039
model,
31-
5,
40+
epochs,
3241
config,
3342
train_data_loader,
3443
test_data_loader,
3544
valid_data_loader,
3645
criterion,
3746
optimizer,
38-
).train(f"{lr}")
47+
).train(f"{name}")
48+
49+
50+
train(
51+
transforms=Transformer().transform(),
52+
batch_size=16,
53+
lr=1e-3,
54+
test_split=0.25,
55+
optimizer=optim.Adam,
56+
lr_schedular=None,
57+
name=f"1e-3",
58+
)
59+
train(
60+
transforms=Transformer().transform(),
61+
batch_size=16,
62+
lr=1e-4,
63+
test_split=0.25,
64+
optimizer=optim.Adam,
65+
lr_schedular=None,
66+
name=f"1e-4",
67+
)
68+
train(
69+
transforms=Transformer().transform(),
70+
batch_size=16,
71+
lr=1e-5,
72+
test_split=0.25,
73+
optimizer=optim.Adam,
74+
lr_schedular=None,
75+
name=f"1e-5",
76+
)
77+
train(
78+
transforms=Transformer().transform(),
79+
batch_size=16,
80+
lr=1e-6,
81+
test_split=0.25,
82+
optimizer=optim.Adam,
83+
lr_schedular=None,
84+
name=f"1e-6",
85+
)
86+
train(
87+
transforms=Transformer().transform(),
88+
batch_size=16,
89+
lr=1e-7,
90+
test_split=0.25,
91+
optimizer=optim.Adam,
92+
lr_schedular=None,
93+
name=f"1e-7",
94+
)

wandb/latest-run

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
run-20230727_223900-960s76vj
1+
run-20230728_100733-nffptso0

0 commit comments

Comments
 (0)