Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions qlib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,33 @@ class QSettings(BaseSettings):

class Config:
def __init__(self, default_conf):
self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflicts with __getattr__
self.__dict__["_default_config"] = copy.deepcopy(default_conf)
self.reset()

def validate(self):
errors = []

if not self.get("provider_uri"):
errors.append(
"provider_uri must be set (e.g. ~/.qlib/qlib_data or a valid path)"
)

if not self.get("region"):
errors.append(
"region must be specified (e.g. 'cn', 'us')"
)

if errors:
raise ValueError(
"Invalid Qlib configuration:\n- " + "\n- ".join(errors)
)

def __getitem__(self, key):
return self.__dict__["_config"][key]

def __getattr__(self, attr):
if attr in self.__dict__["_config"]:
return self.__dict__["_config"][attr]

raise AttributeError(f"No such `{attr}` in self._config")

def get(self, key, default=None):
Expand Down Expand Up @@ -109,14 +126,20 @@ def set_conf_from_C(self, config_c):

@staticmethod
def register_from_C(config, skip_register=True):
from .utils import set_log_with_config # pylint: disable=C0415
from .utils import set_log_with_config

if C.registered and skip_register:
return


C.set_conf_from_C(config)

if not skip_register:
C.validate()

if C.logging_config:
set_log_with_config(C.logging_config)

C.register()


Expand Down Expand Up @@ -523,4 +546,4 @@ def registered(self):


# global config
C = QlibConfig(_default_config)
C = QlibConfig(_default_config)
17 changes: 17 additions & 0 deletions qlib/tests/test_config_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from qlib.config import Config


def test_missing_provider_uri_raises():
default_conf = {
"provider_uri": None,
"region": "us",
}

cfg = Config(default_conf)

with pytest.raises(ValueError) as exc:
cfg.validate()

assert "provider_uri must be set" in str(exc.value)