diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py index 7475bb6fc52..7107dac570c 100644 --- a/qlib/contrib/online/manager.py +++ b/qlib/contrib/online/manager.py @@ -7,6 +7,7 @@ import pathlib import pandas as pd import shutil +import os from ruamel.yaml import YAML from ...backtest.account import Account from .user import User @@ -43,6 +44,55 @@ def __init__(self, user_data_path, save_report=True): self.users = {} self.user_record = None + @staticmethod + def _validate_user_id(user_id): + """ + Validate user_id to prevent path traversal / absolute paths. + + user_id is used as a directory name under self.data_path and must not + contain path separators, parent traversal, or drive/UNC prefixes. + """ + if not isinstance(user_id, str): + raise TypeError("user_id must be a string") + if user_id == "" or "\x00" in user_id: + raise ValueError("Invalid user_id") + + # Forbid any multi-part paths (e.g. "../x", "a/b", "C:\\x", "\\\\server\\share"). + parts = pathlib.PurePath(user_id).parts + if len(parts) != 1 or parts[0] in (".", ".."): + raise ValueError("Invalid user_id") + + # Extra guard: explicit separators (platform-specific and cross-platform) + if "/" in user_id or "\\" in user_id: + raise ValueError("Invalid user_id") + + return user_id + + def _user_path(self, user_id): + """ + Return the resolved user directory path under self.data_path. + + Ensures the resulting path is contained within self.data_path even if + symlinks are involved. + """ + user_id = self._validate_user_id(user_id) + base = self.data_path.resolve(strict=False) + candidate = (self.data_path / user_id).resolve(strict=False) + + # Ensure candidate is inside base (and on same drive on Windows). + try: + common = os.path.commonpath([str(base), str(candidate)]) + except ValueError: + # Different drives or invalid paths + raise ValueError("Invalid user_id") from None + + if common != str(base): + raise ValueError("Invalid user_id") + if candidate == base: + raise ValueError("Invalid user_id") + + return candidate + def load_users(self): """ load all users' data into manager @@ -60,9 +110,10 @@ def load_user(self, user_id): :return user : User() """ - account_path = self.data_path / user_id - strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id) - model_file = self.data_path / user_id / "model_{}.pickle".format(user_id) + user_path = self._user_path(user_id) + account_path = user_path + strategy_file = user_path / "strategy_{}.pickle".format(user_id) + model_file = user_path / "model_{}.pickle".format(user_id) cur_user_list = list(self.users) if user_id in cur_user_list: raise ValueError("User {} has been loaded".format(user_id)) @@ -82,14 +133,15 @@ def save_user_data(self, user_id): """ if not user_id in self.users: raise ValueError("Cannot find user {}".format(user_id)) - self.users[user_id].account.save_account(self.data_path / user_id) + user_path = self._user_path(user_id) + self.users[user_id].account.save_account(user_path) save_instance( self.users[user_id].strategy, - self.data_path / user_id / "strategy_{}.pickle".format(user_id), + user_path / "strategy_{}.pickle".format(user_id), ) save_instance( self.users[user_id].model, - self.data_path / user_id / "model_{}.pickle".format(user_id), + user_path / "model_{}.pickle".format(user_id), ) def add_user(self, user_id, config_file, add_date): @@ -105,7 +157,7 @@ def add_user(self, user_id, config_file, add_date): config_file = pathlib.Path(config_file) if not config_file.exists(): raise ValueError("Cannot find config file {}".format(config_file)) - user_path = self.data_path / user_id + user_path = self._user_path(user_id) if user_path.exists(): raise ValueError("User data for {} already exists".format(user_id)) @@ -125,9 +177,9 @@ def add_user(self, user_id, config_file, add_date): # save user user_path.mkdir() - save_instance(model, self.data_path / user_id / "model_{}.pickle".format(user_id)) - save_instance(strategy, self.data_path / user_id / "strategy_{}.pickle".format(user_id)) - trade_account.save_account(self.data_path / user_id) + save_instance(model, user_path / "model_{}.pickle".format(user_id)) + save_instance(strategy, user_path / "strategy_{}.pickle".format(user_id)) + trade_account.save_account(user_path) user_record = pd.read_csv(self.users_file, index_col=0) user_record.loc[user_id] = [add_date] user_record.to_csv(self.users_file) @@ -139,7 +191,7 @@ def remove_user(self, user_id): :param user_id : string """ - user_path = self.data_path / user_id + user_path = self._user_path(user_id) if not user_path.exists(): raise ValueError("Cannot find user data {}".format(user_id)) shutil.rmtree(user_path)