Skip to content

Commit 8eb38bc

Browse files
committed
feat: indicator parameterization support for kline chart
1 parent 5818fbe commit 8eb38bc

File tree

13 files changed

+1166
-12
lines changed

13 files changed

+1166
-12
lines changed

backend_api_python/app/routes/indicator.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,48 @@ def delete_indicator():
306306
return jsonify({"code": 0, "msg": str(e), "data": None}), 500
307307

308308

309+
@indicator_bp.route("/getIndicatorParams", methods=["GET"])
310+
@login_required
311+
def get_indicator_params():
312+
"""
313+
获取指标的参数声明
314+
315+
用于前端在策略创建时显示可配置的参数表单。
316+
317+
Query params:
318+
indicator_id: 指标ID
319+
320+
Returns:
321+
params: [
322+
{
323+
"name": "ma_fast",
324+
"type": "int",
325+
"default": 5,
326+
"description": "短期均线周期"
327+
},
328+
...
329+
]
330+
"""
331+
try:
332+
from app.services.indicator_params import get_indicator_params as get_params
333+
334+
indicator_id = request.args.get("indicator_id")
335+
if not indicator_id:
336+
return jsonify({"code": 0, "msg": "indicator_id is required", "data": None}), 400
337+
338+
try:
339+
indicator_id = int(indicator_id)
340+
except ValueError:
341+
return jsonify({"code": 0, "msg": "indicator_id must be an integer", "data": None}), 400
342+
343+
params = get_params(indicator_id)
344+
return jsonify({"code": 1, "msg": "success", "data": params})
345+
346+
except Exception as e:
347+
logger.error(f"get_indicator_params failed: {str(e)}", exc_info=True)
348+
return jsonify({"code": 0, "msg": str(e), "data": None}), 500
349+
350+
309351
@indicator_bp.route("/verifyCode", methods=["POST"])
310352
@login_required
311353
def verify_code():

backend_api_python/app/services/backtest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from app.data_sources import DataSourceFactory
1313
from app.utils.logger import get_logger
14+
from app.services.indicator_params import IndicatorParamsParser, IndicatorCaller
1415

1516
logger = get_logger(__name__)
1617

@@ -1114,6 +1115,21 @@ def _execute_indicator(self, code: str, df: pd.DataFrame, backtest_params: dict
11141115
local_vars['commission'] = backtest_params.get('commission', 0.0002)
11151116
local_vars['trade_direction'] = backtest_params.get('trade_direction', 'both')
11161117

1118+
# === 指标参数支持 ===
1119+
# 从 backtest_params 获取用户设置的指标参数
1120+
user_indicator_params = (backtest_params or {}).get('indicator_params', {})
1121+
# 解析指标代码中声明的参数
1122+
declared_params = IndicatorParamsParser.parse_params(code)
1123+
# 合并参数(用户值优先,否则使用默认值)
1124+
merged_params = IndicatorParamsParser.merge_params(declared_params, user_indicator_params)
1125+
local_vars['params'] = merged_params
1126+
1127+
# === 指标调用器支持 ===
1128+
user_id = (backtest_params or {}).get('user_id', 1)
1129+
indicator_id = (backtest_params or {}).get('indicator_id')
1130+
indicator_caller = IndicatorCaller(user_id, indicator_id)
1131+
local_vars['call_indicator'] = indicator_caller.call_indicator
1132+
11171133
# Add technical indicator functions
11181134
local_vars.update(self._get_indicator_functions())
11191135

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
"""
2+
Indicator Parameters Parser and Helper Functions
3+
4+
支持两个核心功能:
5+
1. 指标参数外部传递 - 解析指标代码中的 @param 声明
6+
2. 指标调用其他指标 - 提供 call_indicator() 函数
7+
8+
参数声明格式:
9+
# @param param_name type default_value 描述
10+
# @param ma_fast int 5 短期均线周期
11+
# @param ma_slow int 20 长期均线周期
12+
# @param threshold float 0.5 阈值
13+
14+
支持的类型:int, float, bool, str
15+
"""
16+
17+
import re
18+
import json
19+
from typing import Dict, Any, List, Optional, Tuple
20+
from app.utils.logger import get_logger
21+
from app.utils.db import get_db_connection
22+
23+
logger = get_logger(__name__)
24+
25+
26+
class IndicatorParamsParser:
27+
"""解析指标代码中的参数声明"""
28+
29+
# 参数声明正则:# @param name type default description
30+
PARAM_PATTERN = re.compile(
31+
r'#\s*@param\s+(\w+)\s+(int|float|bool|str|string)\s+(\S+)\s*(.*)',
32+
re.IGNORECASE
33+
)
34+
35+
@classmethod
36+
def parse_params(cls, indicator_code: str) -> List[Dict[str, Any]]:
37+
"""
38+
解析指标代码中的参数声明
39+
40+
Returns:
41+
List of param definitions:
42+
[
43+
{
44+
"name": "ma_fast",
45+
"type": "int",
46+
"default": 5,
47+
"description": "短期均线周期"
48+
},
49+
...
50+
]
51+
"""
52+
params = []
53+
if not indicator_code:
54+
return params
55+
56+
for line in indicator_code.split('\n'):
57+
line = line.strip()
58+
match = cls.PARAM_PATTERN.match(line)
59+
if match:
60+
name = match.group(1)
61+
param_type = match.group(2).lower()
62+
default_str = match.group(3)
63+
description = match.group(4).strip() if match.group(4) else ''
64+
65+
# 转换默认值类型
66+
default = cls._convert_value(default_str, param_type)
67+
68+
# 规范化类型名
69+
if param_type == 'string':
70+
param_type = 'str'
71+
72+
params.append({
73+
"name": name,
74+
"type": param_type,
75+
"default": default,
76+
"description": description
77+
})
78+
79+
return params
80+
81+
@classmethod
82+
def _convert_value(cls, value_str: str, param_type: str) -> Any:
83+
"""转换字符串值为对应类型"""
84+
try:
85+
param_type = param_type.lower()
86+
if param_type == 'int':
87+
return int(value_str)
88+
elif param_type == 'float':
89+
return float(value_str)
90+
elif param_type == 'bool':
91+
return value_str.lower() in ('true', '1', 'yes', 'on')
92+
else: # str/string
93+
return value_str
94+
except (ValueError, TypeError):
95+
return value_str
96+
97+
@classmethod
98+
def merge_params(cls, declared_params: List[Dict], user_params: Dict[str, Any]) -> Dict[str, Any]:
99+
"""
100+
合并声明的参数和用户提供的参数
101+
102+
Args:
103+
declared_params: 从代码中解析的参数声明
104+
user_params: 用户提供的参数值
105+
106+
Returns:
107+
合并后的参数字典(使用用户值或默认值)
108+
"""
109+
result = {}
110+
for param in declared_params:
111+
name = param['name']
112+
param_type = param['type']
113+
default = param['default']
114+
115+
if name in user_params:
116+
# 用户提供了值,转换为正确类型
117+
result[name] = cls._convert_value(str(user_params[name]), param_type)
118+
else:
119+
# 使用默认值
120+
result[name] = default
121+
122+
return result
123+
124+
125+
class IndicatorCaller:
126+
"""
127+
指标调用器 - 允许一个指标调用另一个指标
128+
129+
使用方式(在指标代码中):
130+
# 按ID调用
131+
rsi_df = call_indicator(5, df)
132+
133+
# 按名称调用(自己的指标)
134+
macd_df = call_indicator('My MACD', df)
135+
"""
136+
137+
# 最大调用深度,防止循环依赖
138+
MAX_CALL_DEPTH = 5
139+
140+
def __init__(self, user_id: int, current_indicator_id: int = None):
141+
self.user_id = user_id
142+
self.current_indicator_id = current_indicator_id
143+
self._call_stack = [] # 调用栈,用于检测循环依赖
144+
145+
def call_indicator(
146+
self,
147+
indicator_ref: Any, # int (ID) 或 str (名称)
148+
df: 'pd.DataFrame',
149+
params: Dict[str, Any] = None,
150+
_depth: int = 0
151+
) -> Optional['pd.DataFrame']:
152+
"""
153+
调用另一个指标并返回结果
154+
155+
Args:
156+
indicator_ref: 指标ID或名称
157+
df: 输入的K线数据
158+
params: 传递给被调用指标的参数
159+
_depth: 内部使用,跟踪调用深度
160+
161+
Returns:
162+
执行后的DataFrame,包含被调用指标计算的列
163+
"""
164+
import pandas as pd
165+
import numpy as np
166+
167+
# 检查调用深度
168+
if _depth >= self.MAX_CALL_DEPTH:
169+
logger.error(f"Indicator call depth exceeded {self.MAX_CALL_DEPTH}")
170+
return df.copy()
171+
172+
# 获取指标代码
173+
indicator_code, indicator_id = self._get_indicator_code(indicator_ref)
174+
if not indicator_code:
175+
logger.warning(f"Indicator not found: {indicator_ref}")
176+
return df.copy()
177+
178+
# 检查循环依赖
179+
if indicator_id in self._call_stack:
180+
logger.error(f"Circular dependency detected: {self._call_stack} -> {indicator_id}")
181+
return df.copy()
182+
183+
self._call_stack.append(indicator_id)
184+
185+
try:
186+
# 解析并合并参数
187+
declared_params = IndicatorParamsParser.parse_params(indicator_code)
188+
merged_params = IndicatorParamsParser.merge_params(declared_params, params or {})
189+
190+
# 准备执行环境
191+
df_copy = df.copy()
192+
local_vars = {
193+
'df': df_copy,
194+
'open': df_copy['open'].astype('float64') if 'open' in df_copy.columns else pd.Series(dtype='float64'),
195+
'high': df_copy['high'].astype('float64') if 'high' in df_copy.columns else pd.Series(dtype='float64'),
196+
'low': df_copy['low'].astype('float64') if 'low' in df_copy.columns else pd.Series(dtype='float64'),
197+
'close': df_copy['close'].astype('float64') if 'close' in df_copy.columns else pd.Series(dtype='float64'),
198+
'volume': df_copy['volume'].astype('float64') if 'volume' in df_copy.columns else pd.Series(dtype='float64'),
199+
'signals': pd.Series(0, index=df_copy.index, dtype='float64'),
200+
'np': np,
201+
'pd': pd,
202+
'params': merged_params,
203+
# 递归调用支持
204+
'call_indicator': lambda ref, d, p=None: self.call_indicator(ref, d, p, _depth + 1)
205+
}
206+
207+
# 安全执行
208+
import builtins
209+
def safe_import(name, *args, **kwargs):
210+
allowed_modules = ['numpy', 'pandas', 'math', 'json', 'time']
211+
if name in allowed_modules or name.split('.')[0] in allowed_modules:
212+
return builtins.__import__(name, *args, **kwargs)
213+
raise ImportError(f"Module not allowed: {name}")
214+
215+
safe_builtins = {k: getattr(builtins, k) for k in dir(builtins)
216+
if not k.startswith('_') and k not in [
217+
'eval', 'exec', 'compile', 'open', 'input',
218+
'help', 'exit', 'quit', '__import__',
219+
'copyright', 'credits', 'license'
220+
]}
221+
safe_builtins['__import__'] = safe_import
222+
223+
exec_env = local_vars.copy()
224+
exec_env['__builtins__'] = safe_builtins
225+
226+
pre_import = "import numpy as np\nimport pandas as pd\n"
227+
exec(pre_import, exec_env)
228+
exec(indicator_code, exec_env)
229+
230+
return exec_env.get('df', df_copy)
231+
232+
except Exception as e:
233+
logger.error(f"Error calling indicator {indicator_ref}: {e}")
234+
return df.copy()
235+
finally:
236+
self._call_stack.pop()
237+
238+
def _get_indicator_code(self, indicator_ref: Any) -> Tuple[Optional[str], Optional[int]]:
239+
"""获取指标代码"""
240+
try:
241+
with get_db_connection() as db:
242+
cursor = db.cursor()
243+
244+
if isinstance(indicator_ref, int):
245+
# 按ID查询
246+
cursor.execute("""
247+
SELECT id, code FROM qd_indicator_codes
248+
WHERE id = %s AND (user_id = %s OR publish_to_community = 1)
249+
""", (indicator_ref, self.user_id))
250+
else:
251+
# 按名称查询(优先自己的指标)
252+
cursor.execute("""
253+
SELECT id, code FROM qd_indicator_codes
254+
WHERE name = %s AND user_id = %s
255+
UNION
256+
SELECT id, code FROM qd_indicator_codes
257+
WHERE name = %s AND publish_to_community = 1
258+
LIMIT 1
259+
""", (str(indicator_ref), self.user_id, str(indicator_ref)))
260+
261+
row = cursor.fetchone()
262+
cursor.close()
263+
264+
if row:
265+
return row['code'], row['id']
266+
return None, None
267+
268+
except Exception as e:
269+
logger.error(f"Error fetching indicator code: {e}")
270+
return None, None
271+
272+
273+
def get_indicator_params(indicator_id: int) -> List[Dict[str, Any]]:
274+
"""
275+
获取指标的参数声明(供API调用)
276+
277+
Args:
278+
indicator_id: 指标ID
279+
280+
Returns:
281+
参数声明列表
282+
"""
283+
try:
284+
with get_db_connection() as db:
285+
cursor = db.cursor()
286+
cursor.execute("SELECT code FROM qd_indicator_codes WHERE id = %s", (indicator_id,))
287+
row = cursor.fetchone()
288+
cursor.close()
289+
290+
if row and row['code']:
291+
return IndicatorParamsParser.parse_params(row['code'])
292+
return []
293+
except Exception as e:
294+
logger.error(f"Error getting indicator params: {e}")
295+
return []

0 commit comments

Comments
 (0)