|
| 1 | +import os |
| 2 | +import json |
| 3 | +import yaml |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +def load_workspace_config(workspace_path): |
| 7 | + """ |
| 8 | + Unified configuration loader for QuantPits workspaces. |
| 9 | + Merges model_config.json, prod_config.json, and strategy_config.yaml. |
| 10 | + |
| 11 | + Returns a unified dict with all necessary parameters. |
| 12 | + """ |
| 13 | + workspace_path = Path(workspace_path) |
| 14 | + config_dir = workspace_path / "config" |
| 15 | + |
| 16 | + # Files |
| 17 | + model_cfg_path = config_dir / "model_config.json" |
| 18 | + prod_cfg_path = config_dir / "prod_config.json" |
| 19 | + strat_cfg_path = config_dir / "strategy_config.yaml" |
| 20 | + |
| 21 | + config = {} |
| 22 | + |
| 23 | + # 1. Load Model Config (Base environment properties) |
| 24 | + if model_cfg_path.exists(): |
| 25 | + with open(model_cfg_path, 'r') as f: |
| 26 | + config.update(json.load(f)) |
| 27 | + |
| 28 | + # 2. Load Strategy Config (Single Source of Truth for strategy params) |
| 29 | + if strat_cfg_path.exists(): |
| 30 | + with open(strat_cfg_path, 'r') as f: |
| 31 | + strat_data = yaml.safe_load(f) |
| 32 | + if strat_data: |
| 33 | + config['strategy'] = strat_data.get('strategy', {}) |
| 34 | + config['backtest'] = strat_data.get('backtest', {}) |
| 35 | + |
| 36 | + # Promote core strategy params to top-level for convenience/compatibility |
| 37 | + strat_params = config['strategy'].get('params', {}) |
| 38 | + config['topk'] = strat_params.get('topk', config.get('TopK')) |
| 39 | + config['n_drop'] = strat_params.get('n_drop', config.get('DropN')) |
| 40 | + config['buy_suggestion_factor'] = strat_params.get('buy_suggestion_factor', config.get('buy_suggestion_factor')) |
| 41 | + |
| 42 | + # Compatibility mapping (Upper case versions if they don't exist) |
| 43 | + if 'TopK' not in config: config['TopK'] = config['topk'] |
| 44 | + if 'DropN' not in config: config['DropN'] = config['n_drop'] |
| 45 | + |
| 46 | + # 3. Load Prod Config (Current state - handles cash/holding) |
| 47 | + if prod_cfg_path.exists(): |
| 48 | + with open(prod_cfg_path, 'r') as f: |
| 49 | + prod_data = json.load(f) |
| 50 | + # We only want State fields from prod_config, others should come from model/strategy |
| 51 | + state_fields = [ |
| 52 | + 'current_date', 'last_processed_date', 'initial_cash', |
| 53 | + 'current_full_cash', 'initial_holding', 'current_cash', |
| 54 | + 'current_holding', 'model', 'experiment_name', |
| 55 | + 'current_train_record_id', 'current_pred_record_id' |
| 56 | + ] |
| 57 | + for field in state_fields: |
| 58 | + if field in prod_data: |
| 59 | + config[field] = prod_data[field] |
| 60 | + |
| 61 | + # Sanity checks / Cross-file consistency (Optional but recommended) |
| 62 | + # If market/benchmark exist in both, we prefer model_config but can log warnings if they mismatch |
| 63 | + |
| 64 | + return config |
| 65 | + |
| 66 | +if __name__ == "__main__": |
| 67 | + # Test loading |
| 68 | + import sys |
| 69 | + if len(sys.argv) > 1: |
| 70 | + c = load_workspace_config(sys.argv[1]) |
| 71 | + print(json.dumps(c, indent=2)) |
0 commit comments