Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 63 additions & 11 deletions qlib/contrib/online/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pathlib
import pandas as pd
import shutil
import os
from ruamel.yaml import YAML
from ...backtest.account import Account
from .user import User
Expand Down Expand Up @@ -43,6 +44,55 @@ def __init__(self, user_data_path, save_report=True):
self.users = {}
self.user_record = None

@staticmethod
def _validate_user_id(user_id):
"""
Validate user_id to prevent path traversal / absolute paths.

user_id is used as a directory name under self.data_path and must not
contain path separators, parent traversal, or drive/UNC prefixes.
"""
if not isinstance(user_id, str):
raise TypeError("user_id must be a string")
if user_id == "" or "\x00" in user_id:
raise ValueError("Invalid user_id")

# Forbid any multi-part paths (e.g. "../x", "a/b", "C:\\x", "\\\\server\\share").
parts = pathlib.PurePath(user_id).parts
if len(parts) != 1 or parts[0] in (".", ".."):
raise ValueError("Invalid user_id")

# Extra guard: explicit separators (platform-specific and cross-platform)
if "/" in user_id or "\\" in user_id:
raise ValueError("Invalid user_id")

return user_id

def _user_path(self, user_id):
"""
Return the resolved user directory path under self.data_path.

Ensures the resulting path is contained within self.data_path even if
symlinks are involved.
"""
user_id = self._validate_user_id(user_id)
base = self.data_path.resolve(strict=False)
candidate = (self.data_path / user_id).resolve(strict=False)

# Ensure candidate is inside base (and on same drive on Windows).
try:
common = os.path.commonpath([str(base), str(candidate)])
except ValueError:
# Different drives or invalid paths
raise ValueError("Invalid user_id") from None

if common != str(base):
raise ValueError("Invalid user_id")
if candidate == base:
raise ValueError("Invalid user_id")

return candidate

def load_users(self):
"""
load all users' data into manager
Expand All @@ -60,9 +110,10 @@ def load_user(self, user_id):
:return
user : User()
"""
account_path = self.data_path / user_id
strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id)
model_file = self.data_path / user_id / "model_{}.pickle".format(user_id)
user_path = self._user_path(user_id)
account_path = user_path
strategy_file = user_path / "strategy_{}.pickle".format(user_id)
model_file = user_path / "model_{}.pickle".format(user_id)
cur_user_list = list(self.users)
if user_id in cur_user_list:
raise ValueError("User {} has been loaded".format(user_id))
Expand All @@ -82,14 +133,15 @@ def save_user_data(self, user_id):
"""
if not user_id in self.users:
raise ValueError("Cannot find user {}".format(user_id))
self.users[user_id].account.save_account(self.data_path / user_id)
user_path = self._user_path(user_id)
self.users[user_id].account.save_account(user_path)
save_instance(
self.users[user_id].strategy,
self.data_path / user_id / "strategy_{}.pickle".format(user_id),
user_path / "strategy_{}.pickle".format(user_id),
)
save_instance(
self.users[user_id].model,
self.data_path / user_id / "model_{}.pickle".format(user_id),
user_path / "model_{}.pickle".format(user_id),
)

def add_user(self, user_id, config_file, add_date):
Expand All @@ -105,7 +157,7 @@ def add_user(self, user_id, config_file, add_date):
config_file = pathlib.Path(config_file)
if not config_file.exists():
raise ValueError("Cannot find config file {}".format(config_file))
user_path = self.data_path / user_id
user_path = self._user_path(user_id)
if user_path.exists():
raise ValueError("User data for {} already exists".format(user_id))

Expand All @@ -125,9 +177,9 @@ def add_user(self, user_id, config_file, add_date):

# save user
user_path.mkdir()
save_instance(model, self.data_path / user_id / "model_{}.pickle".format(user_id))
save_instance(strategy, self.data_path / user_id / "strategy_{}.pickle".format(user_id))
trade_account.save_account(self.data_path / user_id)
save_instance(model, user_path / "model_{}.pickle".format(user_id))
save_instance(strategy, user_path / "strategy_{}.pickle".format(user_id))
trade_account.save_account(user_path)
user_record = pd.read_csv(self.users_file, index_col=0)
user_record.loc[user_id] = [add_date]
user_record.to_csv(self.users_file)
Expand All @@ -139,7 +191,7 @@ def remove_user(self, user_id):
:param
user_id : string
"""
user_path = self.data_path / user_id
user_path = self._user_path(user_id)
if not user_path.exists():
raise ValueError("Cannot find user data {}".format(user_id))
shutil.rmtree(user_path)
Expand Down
Loading