Skip to content

Commit 4b8d70d

Browse files
[feat] fix a bug and adapt general_nn for use with rdagent_qlib (#1928)
* update qlib general_nn for rdagent_qlib * fix install lightgbm error * fix install lightgbm error & format with black --------- Co-authored-by: Linlang <[email protected]>
1 parent a2996f7 commit 4b8d70d

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

.github/workflows/test_qlib_from_pip.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ jobs:
4646
python -m pip install pyqlib
4747
python -m pip install "joblib<=1.4.2"
4848
49+
# install.sh file contents from: https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh
50+
# brew_install.sh file contents from: https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh
4951
- name: Install Lightgbm for MacOS
5052
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
5153
run: |
52-
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
54+
/bin/bash -c "$(curl -fsSL https://github.com/SunsetWolf/qlib_dataset/releases/download/maocs_lightgbm/install.sh)"
55+
/bin/bash -c "$(curl -fsSL https://github.com/SunsetWolf/qlib_dataset/releases/download/maocs_lightgbm/brew_install.sh)"
5356
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
5457
# FIX MacOS error: Segmentation fault
5558
# reference: https://github.com/microsoft/LightGBM/issues/4229

.github/workflows/test_qlib_from_source.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,13 @@ jobs:
8383
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
8484
python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl
8585
86+
# install.sh file contents from: https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh
87+
# brew_install.sh file contents from: https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh
8688
- name: Install Lightgbm for MacOS
8789
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
8890
run: |
89-
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
91+
/bin/bash -c "$(curl -fsSL https://github.com/SunsetWolf/qlib_dataset/releases/download/maocs_lightgbm/install.sh)"
92+
/bin/bash -c "$(curl -fsSL https://github.com/SunsetWolf/qlib_dataset/releases/download/maocs_lightgbm/brew_install.sh)"
9093
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
9194
# FIX MacOS error: Segmentation fault
9295
# reference: https://github.com/microsoft/LightGBM/issues/4229

.github/workflows/test_qlib_from_source_slow.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,13 @@ jobs:
3737
run: |
3838
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
3939
40+
# install.sh file contents from: https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh
41+
# brew_install.sh file contents from: https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh
4042
- name: Install Lightgbm for MacOS
4143
if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-15' }}
4244
run: |
43-
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)"
45+
/bin/bash -c "$(curl -fsSL https://github.com/SunsetWolf/qlib_dataset/releases/download/maocs_lightgbm/install.sh)"
46+
/bin/bash -c "$(curl -fsSL https://github.com/SunsetWolf/qlib_dataset/releases/download/maocs_lightgbm/brew_install.sh)"
4447
HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm
4548
# FIX MacOS error: Segmentation fault
4649
# reference: https://github.com/microsoft/LightGBM/issues/4229

qlib/contrib/model/pytorch_general_nn.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
import torch.optim as optim
16+
from torch.optim.lr_scheduler import ReduceLROnPlateau
1617

1718
from qlib.data.dataset.weight import Reweighter
1819

@@ -136,6 +137,10 @@ def __init__(
136137
else:
137138
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
138139

140+
# === ReduceLROnPlateau learning rate scheduler ===
141+
self.lr_scheduler = ReduceLROnPlateau(
142+
self.train_optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6, threshold=1e-5
143+
)
139144
self.fitted = False
140145
self.dnn_model.to(self.device)
141146

@@ -154,15 +159,15 @@ def loss_fn(self, pred, label, weight=None):
154159
weight = torch.ones_like(label)
155160

156161
if self.loss == "mse":
157-
return self.mse(pred[mask], label[mask], weight[mask])
162+
return self.mse(pred[mask], label[mask].view(-1, 1), weight[mask])
158163

159164
raise ValueError("unknown loss `%s`" % self.loss)
160165

161166
def metric_fn(self, pred, label):
162167
mask = torch.isfinite(label)
163168

164169
if self.metric in ("", "loss"):
165-
return -self.loss_fn(pred[mask], label[mask])
170+
return self.loss_fn(pred[mask], label[mask])
166171

167172
raise ValueError("unknown metric `%s`" % self.metric)
168173

@@ -238,6 +243,8 @@ def fit(
238243

239244
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
240245
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
246+
self.logger.info(f"Train samples: {len(dl_train)}")
247+
self.logger.info(f"Valid samples: {len(dl_valid)}")
241248
if dl_train.empty or dl_valid.empty:
242249
raise ValueError("Empty data from dataset, please check your dataset config.")
243250

@@ -279,7 +286,7 @@ def fit(
279286

280287
stop_steps = 0
281288
train_loss = 0
282-
best_score = -np.inf
289+
best_score = np.inf
283290
best_epoch = 0
284291
evals_result["train"] = []
285292
evals_result["valid"] = []
@@ -295,13 +302,18 @@ def fit(
295302
self.logger.info("evaluating...")
296303
train_loss, train_score = self.test_epoch(train_loader)
297304
val_loss, val_score = self.test_epoch(valid_loader)
298-
self.logger.info("train %.6f, valid %.6f" % (train_score, val_score))
305+
self.logger.info("Epoch%d: train %.6f, valid %.6f" % (step, train_score, val_score))
299306
evals_result["train"].append(train_score)
300307
evals_result["valid"].append(val_score)
301308

309+
# current_lr = self.train_optimizer.param_groups[0]["lr"]
310+
# self.logger.info("Current learning rate: %.6e" % current_lr)
311+
312+
self.lr_scheduler.step(val_score)
313+
302314
if step == 0:
303315
best_param = copy.deepcopy(self.dnn_model.state_dict())
304-
if val_score > best_score:
316+
if val_score < best_score:
305317
best_score = val_score
306318
stop_steps = 0
307319
best_epoch = step
@@ -312,7 +324,7 @@ def fit(
312324
self.logger.info("early stop")
313325
break
314326

315-
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
327+
self.logger.info("best score: %.6lf @ %d epoch" % (best_score, best_epoch))
316328
self.dnn_model.load_state_dict(best_param)
317329
torch.save(best_param, save_path)
318330

@@ -329,6 +341,7 @@ def predict(
329341
raise ValueError("model is not fitted yet!")
330342

331343
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
344+
self.logger.info(f"Test samples: {len(dl_test)}")
332345

333346
if isinstance(dataset, TSDatasetH):
334347
dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader

0 commit comments

Comments
 (0)