forked from microsoft/qlib
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpytorch_sfm.py
More file actions
486 lines (386 loc) · 15.5 KB
/
pytorch_sfm.py
File metadata and controls
486 lines (386 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
class SFM_Model(nn.Module):
def __init__(
self,
d_feat=6,
output_dim=1,
freq_dim=10,
hidden_size=64,
dropout_W=0.0,
dropout_U=0.0,
device="cpu",
):
super().__init__()
self.input_dim = d_feat
self.output_dim = output_dim
self.freq_dim = freq_dim
self.hidden_dim = hidden_size
self.device = device
self.W_i = nn.Parameter(init.xavier_uniform_(torch.empty((self.input_dim, self.hidden_dim))))
self.U_i = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
self.b_i = nn.Parameter(torch.zeros(self.hidden_dim))
self.W_ste = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))
self.U_ste = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
self.b_ste = nn.Parameter(torch.ones(self.hidden_dim))
self.W_fre = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.freq_dim)))
self.U_fre = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.freq_dim)))
self.b_fre = nn.Parameter(torch.ones(self.freq_dim))
self.W_c = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))
self.U_c = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
self.b_c = nn.Parameter(torch.zeros(self.hidden_dim))
self.W_o = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)))
self.U_o = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)))
self.b_o = nn.Parameter(torch.zeros(self.hidden_dim))
self.U_a = nn.Parameter(init.orthogonal_(torch.empty(self.freq_dim, 1)))
self.b_a = nn.Parameter(torch.zeros(self.hidden_dim))
self.W_p = nn.Parameter(init.xavier_uniform_(torch.empty(self.hidden_dim, self.output_dim)))
self.b_p = nn.Parameter(torch.zeros(self.output_dim))
self.activation = nn.Tanh()
self.inner_activation = nn.Hardsigmoid()
self.dropout_W, self.dropout_U = (dropout_W, dropout_U)
self.fc_out = nn.Linear(self.output_dim, 1)
self.states = []
def forward(self, input):
input = input.reshape(len(input), self.input_dim, -1) # [N, F, T]
input = input.permute(0, 2, 1) # [N, T, F]
time_step = input.shape[1]
for ts in range(time_step):
x = input[:, ts, :]
if len(self.states) == 0: # hasn't initialized yet
self.init_states(x)
self.get_constants(x)
p_tm1 = self.states[0]
h_tm1 = self.states[1]
S_re_tm1 = self.states[2]
S_im_tm1 = self.states[3]
time_tm1 = self.states[4]
B_U = self.states[5]
B_W = self.states[6]
frequency = self.states[7]
x_i = torch.matmul(x * B_W[0], self.W_i) + self.b_i
x_ste = torch.matmul(x * B_W[0], self.W_ste) + self.b_ste
x_fre = torch.matmul(x * B_W[0], self.W_fre) + self.b_fre
x_c = torch.matmul(x * B_W[0], self.W_c) + self.b_c
x_o = torch.matmul(x * B_W[0], self.W_o) + self.b_o
i = self.inner_activation(x_i + torch.matmul(h_tm1 * B_U[0], self.U_i))
ste = self.inner_activation(x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste))
fre = self.inner_activation(x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre))
ste = torch.reshape(ste, (-1, self.hidden_dim, 1))
fre = torch.reshape(fre, (-1, 1, self.freq_dim))
f = ste * fre
c = i * self.activation(x_c + torch.matmul(h_tm1 * B_U[0], self.U_c))
time = time_tm1 + 1
omega = torch.tensor(2 * np.pi) * time * frequency
re = torch.cos(omega)
im = torch.sin(omega)
c = torch.reshape(c, (-1, self.hidden_dim, 1))
S_re = f * S_re_tm1 + c * re
S_im = f * S_im_tm1 + c * im
A = torch.square(S_re) + torch.square(S_im)
A = torch.reshape(A, (-1, self.freq_dim)).float()
A_a = torch.matmul(A * B_U[0], self.U_a)
A_a = torch.reshape(A_a, (-1, self.hidden_dim))
a = self.activation(A_a + self.b_a)
o = self.inner_activation(x_o + torch.matmul(h_tm1 * B_U[0], self.U_o))
h = o * a
p = torch.matmul(h, self.W_p) + self.b_p
self.states = [p, h, S_re, S_im, time, None, None, None]
self.states = []
return self.fc_out(p).squeeze()
def init_states(self, x):
reducer_f = torch.zeros((self.hidden_dim, self.freq_dim)).to(self.device)
reducer_p = torch.zeros((self.hidden_dim, self.output_dim)).to(self.device)
init_state_h = torch.zeros(self.hidden_dim).to(self.device)
init_state_p = torch.matmul(init_state_h, reducer_p)
init_state = torch.zeros_like(init_state_h).to(self.device)
init_freq = torch.matmul(init_state_h, reducer_f)
init_state = torch.reshape(init_state, (-1, self.hidden_dim, 1))
init_freq = torch.reshape(init_freq, (-1, 1, self.freq_dim))
init_state_S_re = init_state * init_freq
init_state_S_im = init_state * init_freq
init_state_time = torch.tensor(0).to(self.device)
self.states = [
init_state_p,
init_state_h,
init_state_S_re,
init_state_S_im,
init_state_time,
None,
None,
None,
]
def get_constants(self, x):
constants = []
constants.append([torch.tensor(1.0).to(self.device) for _ in range(6)])
constants.append([torch.tensor(1.0).to(self.device) for _ in range(7)])
array = np.array([float(ii) / self.freq_dim for ii in range(self.freq_dim)])
constants.append(torch.tensor(array).to(self.device))
self.states[5:] = constants
class SFM(Model):
"""SFM Model
Parameters
----------
input_dim : int
input dimension
output_dim : int
output dimension
lr : float
learning rate
optimizer : str
optimizer name
GPU : int
the GPU ID used for training
"""
def __init__(
self,
d_feat=6,
hidden_size=64,
output_dim=1,
freq_dim=10,
dropout_W=0.0,
dropout_U=0.0,
n_epochs=200,
lr=0.001,
metric="",
batch_size=2000,
early_stop=20,
eval_steps=5,
loss="mse",
optimizer="gd",
GPU=0,
seed=None,
**kwargs
):
# Set logger.
self.logger = get_module_logger("SFM")
self.logger.info("SFM pytorch version...")
# set hyper-parameters.
self.d_feat = d_feat
self.hidden_size = hidden_size
self.output_dim = output_dim
self.freq_dim = freq_dim
self.dropout_W = dropout_W
self.dropout_U = dropout_U
self.n_epochs = n_epochs
self.lr = lr
self.metric = metric
self.batch_size = batch_size
self.early_stop = early_stop
self.eval_steps = eval_steps
self.optimizer = optimizer.lower()
self.loss = loss
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.seed = seed
self.logger.info(
"SFM parameters setting:"
"\nd_feat : {}"
"\nhidden_size : {}"
"\noutput_size : {}"
"\nfrequency_dimension : {}"
"\ndropout_W: {}"
"\ndropout_U: {}"
"\nn_epochs : {}"
"\nlr : {}"
"\nmetric : {}"
"\nbatch_size : {}"
"\nearly_stop : {}"
"\neval_steps : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
"\nseed : {}".format(
d_feat,
hidden_size,
output_dim,
freq_dim,
dropout_W,
dropout_U,
n_epochs,
lr,
metric,
batch_size,
early_stop,
eval_steps,
optimizer.lower(),
loss,
self.device,
self.use_gpu,
seed,
)
)
if self.seed is not None:
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.sfm_model = SFM_Model(
d_feat=self.d_feat,
output_dim=self.output_dim,
hidden_size=self.hidden_size,
freq_dim=self.freq_dim,
dropout_W=self.dropout_W,
dropout_U=self.dropout_U,
device=self.device,
)
self.logger.info("model:\n{:}".format(self.sfm_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.sfm_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)
elif optimizer.lower() == "gd":
self.train_optimizer = optim.SGD(self.sfm_model.parameters(), lr=self.lr)
else:
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
self.fitted = False
self.sfm_model.to(self.device)
@property
def use_gpu(self):
return self.device != torch.device("cpu")
def test_epoch(self, data_x, data_y):
# prepare training data
x_values = data_x.values
y_values = np.squeeze(data_y.values)
self.sfm_model.eval()
scores = []
losses = []
indices = np.arange(len(x_values))
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.sfm_model(feature)
loss = self.loss_fn(pred, label)
losses.append(loss.item())
score = self.metric_fn(pred, label)
scores.append(score.item())
return np.mean(losses), np.mean(scores)
def train_epoch(self, x_train, y_train):
x_train_values = x_train.values
y_train_values = np.squeeze(y_train.values)
self.sfm_model.train()
indices = np.arange(len(x_train_values))
np.random.shuffle(indices)
for i in range(len(indices))[:: self.batch_size]:
if len(indices) - i < self.batch_size:
break
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
pred = self.sfm_model(feature)
loss = self.loss_fn(pred, label)
self.train_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.sfm_model.parameters(), 3.0)
self.train_optimizer.step()
def fit(
self,
dataset: DatasetH,
evals_result=dict(),
save_path=None,
):
df_train, df_valid = dataset.prepare(
["train", "valid"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
if df_train.empty or df_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_valid["feature"], df_valid["label"]
save_path = get_or_create_path(save_path)
stop_steps = 0
train_loss = 0
best_score = -np.inf
best_epoch = 0
evals_result["train"] = []
evals_result["valid"] = []
# train
self.logger.info("training...")
self.fitted = True
for step in range(self.n_epochs):
self.logger.info("Epoch%d:", step)
self.logger.info("training...")
self.train_epoch(x_train, y_train)
self.logger.info("evaluating...")
train_loss, train_score = self.test_epoch(x_train, y_train)
val_loss, val_score = self.test_epoch(x_valid, y_valid)
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
evals_result["train"].append(train_score)
evals_result["valid"].append(val_score)
if val_score > best_score:
best_score = val_score
stop_steps = 0
best_epoch = step
best_param = copy.deepcopy(self.sfm_model.state_dict())
else:
stop_steps += 1
if stop_steps >= self.early_stop:
self.logger.info("early stop")
break
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
self.sfm_model.load_state_dict(best_param)
torch.save(best_param, save_path)
if self.device != "cpu":
torch.cuda.empty_cache()
def mse(self, pred, label):
loss = (pred - label) ** 2
return torch.mean(loss)
def loss_fn(self, pred, label):
mask = ~torch.isnan(label)
if self.loss == "mse":
return self.mse(pred[mask], label[mask])
raise ValueError("unknown loss `%s`" % self.loss)
def metric_fn(self, pred, label):
mask = torch.isfinite(label)
if self.metric == "" or self.metric == "loss":
return -self.loss_fn(pred[mask], label[mask])
raise ValueError("unknown metric `%s`" % self.metric)
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.sfm_model.eval()
x_values = x_test.values
sample_num = x_values.shape[0]
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.sfm_model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count