forked from RUCAIBox/FMLP-Rec
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
116 lines (95 loc) · 4.91 KB
/
main.py
File metadata and controls
116 lines (95 loc) · 4.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# @Time : 2022/2/13
# @Author : Hui Yu
# @Email : ishyu@outlook.com
import os
import torch
import argparse
import numpy as np
from models import FMLPRecModel
from trainers import FMLPRecTrainer
from utils import EarlyStopping, check_path, set_seed, get_local_time, get_seq_dic, get_dataloder, get_rating_matrix
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default="./data/", type=str)
parser.add_argument("--output_dir", default="output/", type=str)
parser.add_argument("--data_name", default="Beauty", type=str)
parser.add_argument("--do_eval", action="store_true")
parser.add_argument("--load_model", default=None, type=str)
# model args
parser.add_argument("--model_name", default="FMLPRec", type=str)
parser.add_argument("--hidden_size", default=64, type=int, help="hidden size of model")
parser.add_argument("--num_hidden_layers", default=2, type=int, help="number of filter-enhanced blocks")
parser.add_argument("--num_attention_heads", default=2, type=int)
parser.add_argument("--hidden_act", default="gelu", type=str) # gelu relu
parser.add_argument("--attention_probs_dropout_prob", default=0.5, type=float)
parser.add_argument("--hidden_dropout_prob", default=0.5, type=float)
parser.add_argument("--initializer_range", default=0.02, type=float)
parser.add_argument("--max_seq_length", default=50, type=int)
parser.add_argument("--no_filters", action="store_true", help="if no filters, filter layers transform to self-attention")
# train args
parser.add_argument("--lr", default=0.001, type=float, help="learning rate of adam")
parser.add_argument("--batch_size", default=256, type=int, help="number of batch_size")
parser.add_argument("--epochs", default=200, type=int, help="number of epochs")
parser.add_argument("--no_cuda", action="store_true")
parser.add_argument("--log_freq", default=1, type=int, help="per epoch print res")
parser.add_argument("--full_sort", action="store_true")
parser.add_argument("--patience", default=10, type=int, help="how long to wait after last time validation loss improved")
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--weight_decay", default=0.0, type=float, help="weight_decay of adam")
parser.add_argument("--adam_beta1", default=0.9, type=float, help="adam first beta value")
parser.add_argument("--adam_beta2", default=0.999, type=float, help="adam second beta value")
parser.add_argument("--gpu_id", default="0", type=str, help="gpu_id")
parser.add_argument("--variance", default=5, type=float)
args = parser.parse_args()
set_seed(args.seed)
check_path(args.output_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
args.cuda_condition = torch.cuda.is_available() and not args.no_cuda
seq_dic, max_item = get_seq_dic(args)
args.item_size = max_item + 1
# save model args
cur_time = get_local_time()
if args.no_filters:
args.model_name = "SASRec"
args_str = f'{args.model_name}-{args.data_name}-{cur_time}'
args.log_file = os.path.join(args.output_dir, args_str + '.txt')
print(str(args))
with open(args.log_file, 'a') as f:
f.write(str(args) + '\n')
# save model
args.checkpoint_path = os.path.join(args.output_dir, args_str + '.pt')
train_dataloader, eval_dataloader, test_dataloader = get_dataloder(args,seq_dic)
model = FMLPRecModel(args=args)
trainer = FMLPRecTrainer(model, train_dataloader, eval_dataloader,
test_dataloader, args)
if args.full_sort:
args.valid_rating_matrix, args.test_rating_matrix = get_rating_matrix(args.data_name, seq_dic, max_item)
if args.do_eval:
if args.load_model is None:
print(f"No model input!")
exit(0)
else:
args.checkpoint_path = os.path.join(args.output_dir, args.load_model + '.pt')
trainer.load(args.checkpoint_path)
print(f"Load model from {args.checkpoint_path} for test!")
scores, result_info = trainer.test(0, full_sort=args.full_sort)
else:
early_stopping = EarlyStopping(args.checkpoint_path, patience=args.patience, verbose=True)
for epoch in range(args.epochs):
trainer.train(epoch)
scores, _ = trainer.valid(epoch, full_sort=args.full_sort)
# evaluate on MRR
early_stopping(np.array(scores[-1:]), trainer.model)
if early_stopping.early_stop:
print("Early stopping")
break
print("---------------Sample 99 results---------------")
# load the best model
trainer.model.load_state_dict(torch.load(args.checkpoint_path))
scores, result_info = trainer.test(0, full_sort=args.full_sort)
print(args_str)
print(result_info)
with open(args.log_file, 'a') as f:
f.write(args_str + '\n')
f.write(result_info + '\n')
main()