Skip to content

Commit ae3f6cd

Browse files
committed
temp valiadtion based on new stock + isolated seq
1 parent b9b43e4 commit ae3f6cd

File tree

7 files changed

+43
-2
lines changed

7 files changed

+43
-2
lines changed

hygdra_forecasting/utils/learning_rate_sheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def lrfn(self, epoch):
5151

5252
# return ....
5353
lr_ramp_ep = 3 # int(0.02 * epochs) # 30% of epochs for warm-up
54-
lr_sus_ep = max(0, int(0.3 * self.epochs) - lr_ramp_ep)
54+
lr_sus_ep = max(0, int(0.15 * self.epochs) - lr_ramp_ep)
5555
if epoch < lr_ramp_ep: # Warm-up phase
5656
lr = (self.lr_max - self.lr_start) / lr_ramp_ep * epoch + self.lr_start
5757
elif epoch < lr_ramp_ep + lr_sus_ep: # Sustain phase at max learning rate

train.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from hygdra_forecasting.dataloader.dataloader import StockDataset
88
from torch import device, cuda
99
from torch.utils.data import DataLoader
10+
import numpy as np
1011

1112
if cuda.is_available():
1213
device = device('cuda:0')
@@ -29,11 +30,30 @@
2930
dataset_val = StockDataset(ticker=tickers_val)
3031
dataloader_val = DataLoader(dataset_val, batch_size=256, shuffle=True, num_workers=1)
3132

33+
# temp (non distinct loss and balance) seq val on known stock
34+
lenval = len(dataset_val)
35+
indval = len(dataset) // 2 # Select half the dataset
36+
37+
# Ensure index bounds are valid
38+
if indval > 0:
39+
# Select random indices without replacement
40+
random_indices = np.random.choice(len(dataset), indval, replace=False)
41+
42+
# Move selected data to dataset_val
43+
dataset_val.data = np.concatenate((dataset_val.data, dataset.data[random_indices].copy()), axis=0)
44+
dataset_val.label = np.concatenate((dataset_val.label, dataset.label[random_indices].copy()), axis=0)
45+
46+
# Remove selected indices from dataset
47+
mask = np.ones(len(dataset), dtype=bool)
48+
mask[random_indices] = False
49+
50+
dataset.data = dataset.data[mask]
51+
dataset.label = dataset.label[mask]
52+
3253
# Initialize your model
3354
input_sample, _ = dataset.__getitem__(0)
3455
setup_seed(20) # test liquid ? # check shaping and batch computation based on Dataset
3556
model = ConvCausalLTSM(input_shape=input_sample.shape)
3657
# LtsmAttentionforecastPred(input_shape=input_sample.shape)
37-
# ConvCausalLTSM(input_shape=input_sample.shape)
3858
model = train_model(model, dataloader, dataloader_val, epochs=100, learning_rate=0.01, lrfn=CosineWarmup(0.01, 100).lrfn)
3959

train_graph.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch.utils.data import DataLoader
88
from hygdra_forecasting.utils.learning_rate_sheduler import CosineWarmup
99
from torch import cuda, device, load, nn
10+
import numpy as np
1011

1112
if cuda.is_available():
1213
device = device('cuda:0')
@@ -30,6 +31,26 @@
3031
dataset_val = StockGraphDataset(ticker=tickers_val, indics=TICKERS_ETF)
3132
dataloader_val = DataLoader(dataset_val, batch_size=32, shuffle=True, num_workers=1)
3233

34+
# temp (non distinct loss and balance) seq val on known stock
35+
lenval = len(dataset_val)
36+
indval = len(dataset) // 2 # Select half the dataset
37+
38+
# Ensure index bounds are valid
39+
if indval > 0:
40+
# Select random indices without replacement
41+
random_indices = np.random.choice(len(dataset), indval, replace=False)
42+
43+
# Move selected data to dataset_val
44+
dataset_val.data = np.concatenate((dataset_val.data, dataset.data[random_indices].copy()), axis=0)
45+
dataset_val.label = np.concatenate((dataset_val.label, dataset.label[random_indices].copy()), axis=0)
46+
47+
# Remove selected indices from dataset
48+
mask = np.ones(len(dataset), dtype=bool)
49+
mask[random_indices] = False
50+
51+
dataset.data = dataset.data[mask]
52+
dataset.label = dataset.label[mask]
53+
3354
# Initialize your model
3455
input_sample, _ = dataset.__getitem__(0)
3556
setup_seed(20)

weight/0_weight.pth

998 KB
Binary file not shown.

weight/40_weight.pth

0 Bytes
Binary file not shown.

weight/80_weight.pth

0 Bytes
Binary file not shown.

weight/best_model.pth

998 KB
Binary file not shown.

0 commit comments

Comments
 (0)