Skip to content

Commit 8afe7b6

Browse files
committed
Added saving the model and metrics locally into corresponding folders
1 parent 93c1013 commit 8afe7b6

File tree

2 files changed

+208
-11
lines changed

2 files changed

+208
-11
lines changed

main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4-
import wandb
54
from torch.utils.data import DataLoader
65
from torchvision import transforms
76
from tqdm import tqdm
87

8+
import wandb
99
from CollaborativeCoding import (
1010
MetricWrapper,
1111
createfolders,
@@ -132,6 +132,7 @@ def main():
132132
wandb.init(
133133
entity="ColabCode",
134134
project=args.run_name,
135+
dir=args.resultfolder,
135136
tags=[args.modelname, args.dataset],
136137
config=args,
137138
)
@@ -178,6 +179,9 @@ def main():
178179
train_metrics.resetmetric()
179180
val_metrics.resetmetric()
180181

182+
if args.savemodel:
183+
th.save(model, args.modelfolder / f"{args.modelname}_run:{args.run_name}.pth")
184+
181185
testloss = []
182186
model.eval()
183187
with th.no_grad():

0 commit comments

Comments
 (0)