Skip to content

Commit 35e2bda

Browse files
committed
Escape Telegram Markdown characters in error messages and enhance test coverage.
- Add `escape_markdown` utility function to sanitize dynamic content in Telegram messages. - Update `ErrorReporter` to escape special characters in error messages, URLs, user agents, and additional context fields. - Refactor and expand test cases for `escape_markdown` function under various scenarios, including real-world error messages. - Adjust test mocks in `kis.py` to reflect changes in method naming for better clarity (`get_holdings` -> `get_holdings_by_user`).
1 parent 3f89dbc commit 35e2bda

File tree

4 files changed

+106
-17
lines changed

4 files changed

+106
-17
lines changed

app/monitoring/error_reporter.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import hashlib
1212
import logging
13+
import re
1314
import traceback
1415
from typing import Dict, Optional
1516

@@ -22,6 +23,17 @@
2223
logger = logging.getLogger(__name__)
2324

2425

26+
def escape_markdown(text: str) -> str:
27+
"""Escape Telegram Markdown special characters.
28+
29+
Telegram Markdown v1 reserves these characters: _ * ` [
30+
"""
31+
# 백틱(`) 안의 텍스트는 그대로 두고, 나머지만 이스케이프
32+
# 간단하게 모든 특수문자 이스케이프
33+
escape_chars = r'_*`['
34+
return re.sub(f'([{re.escape(escape_chars)}])', r'\\\1', text)
35+
36+
2537
class ErrorReporter:
2638
"""
2739
Singleton error reporter with Telegram integration and Redis-based deduplication.
@@ -177,13 +189,16 @@ def _format_error_message(
177189
"""
178190
timestamp = format_datetime()
179191

192+
# Escape Markdown special characters in dynamic content
193+
safe_error_message = escape_markdown(error_message)
194+
180195
# Build message parts
181196
parts = [
182197
"🚨 *Error Alert*",
183198
f"🕒 {timestamp}",
184199
"",
185200
f"*Type:* `{error_type}`",
186-
f"*Message:* {error_message}",
201+
f"*Message:* {safe_error_message}",
187202
]
188203

189204
# Add request info if available
@@ -193,26 +208,29 @@ def _format_error_message(
193208
if "method" in request_info:
194209
parts.append(f" • Method: `{request_info['method']}`")
195210
if "url" in request_info:
196-
parts.append(f" • URL: `{request_info['url']}`")
211+
safe_url = escape_markdown(str(request_info['url']))
212+
parts.append(f" • URL: {safe_url}")
197213
if "client" in request_info:
198214
parts.append(f" • Client: `{request_info['client']}`")
199215
if "user_agent" in request_info:
200216
user_agent = request_info["user_agent"][:100] # Truncate
201-
parts.append(f" • User-Agent: `{user_agent}`")
217+
safe_ua = escape_markdown(user_agent)
218+
parts.append(f" • User-Agent: {safe_ua}")
202219

203220
# Add additional context if available
204221
if additional_context:
205222
parts.append("")
206223
parts.append("*Additional Context:*")
207224
for key, value in additional_context.items():
208225
# Format specific keys nicely
226+
safe_value = escape_markdown(str(value))
209227
if key == "request_id":
210228
parts.append(f" • Request ID: `{value}`")
211229
elif key == "duration_ms":
212230
parts.append(f" • Duration: `{value:.2f}ms`")
213231
else:
214-
# Generic key-value pair
215-
parts.append(f" • {key}: `{value}`")
232+
# Generic key-value pair - escape value
233+
parts.append(f" • {key}: {safe_value}")
216234

217235
# Add stack trace (truncated if too long)
218236
parts.append("")

app/tasks/kis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ async def _run() -> dict:
438438
manual_service = ManualHoldingsService(db)
439439
# USER_ID는 현재 1로 고정 (추후 다중 사용자 지원 시 변경 필요)
440440
user_id = 1
441-
manual_holdings = await manual_service.get_holdings(user_id=user_id, market_type=MarketType.KR)
441+
manual_holdings = await manual_service.get_holdings_by_user(user_id=user_id, market_type=MarketType.KR)
442442

443443
# 3. 수동 잔고 종목을 한투 형식으로 변환하여 병합
444444
for holding in manual_holdings:

tests/test_kis_tasks.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class MockManualService:
4242
def __init__(self, db):
4343
pass
4444

45-
async def get_holdings(self, user_id, market_type):
45+
async def get_holdings_by_user(self, user_id, market_type):
4646
return [] # No manual holdings
4747

4848
buy_calls = []
@@ -194,7 +194,7 @@ class MockManualService:
194194
def __init__(self, db):
195195
pass
196196

197-
async def get_holdings(self, user_id, market_type):
197+
async def get_holdings_by_user(self, user_id, market_type):
198198
return [] # No manual holdings
199199

200200
with patch('app.core.db.AsyncSessionLocal') as mock_session_cls, \
@@ -281,7 +281,7 @@ class MockManualService:
281281
def __init__(self, db):
282282
pass
283283

284-
async def get_holdings(self, user_id, market_type):
284+
async def get_holdings_by_user(self, user_id, market_type):
285285
return []
286286

287287
sell_calls = []
@@ -369,7 +369,7 @@ class MockManualService:
369369
def __init__(self, db):
370370
pass
371371

372-
async def get_holdings(self, user_id, market_type):
372+
async def get_holdings_by_user(self, user_id, market_type):
373373
return []
374374

375375
async def fake_buy(*_, **__):
@@ -462,7 +462,7 @@ class MockManualService:
462462
def __init__(self, db):
463463
pass
464464

465-
async def get_holdings(self, user_id, market_type):
465+
async def get_holdings_by_user(self, user_id, market_type):
466466
return []
467467

468468
sell_calls: List[Dict[str, Any]] = []
@@ -635,7 +635,7 @@ class MockManualService:
635635
def __init__(self, db):
636636
pass
637637

638-
async def get_holdings(self, user_id, market_type):
638+
async def get_holdings_by_user(self, user_id, market_type):
639639
return []
640640

641641
error_reports = []
@@ -723,7 +723,7 @@ class MockManualService:
723723
def __init__(self, db):
724724
pass
725725

726-
async def get_holdings(self, user_id, market_type):
726+
async def get_holdings_by_user(self, user_id, market_type):
727727
return []
728728

729729
error_reports = []
@@ -1247,7 +1247,7 @@ class MockManualService:
12471247
def __init__(self, db):
12481248
pass
12491249

1250-
async def get_holdings(self, user_id, market_type):
1250+
async def get_holdings_by_user(self, user_id, market_type):
12511251
return []
12521252

12531253
async def fake_buy(*_, **__):
@@ -1339,7 +1339,7 @@ class MockManualService:
13391339
def __init__(self, db):
13401340
pass
13411341

1342-
async def get_holdings(self, user_id, market_type):
1342+
async def get_holdings_by_user(self, user_id, market_type):
13431343
return []
13441344

13451345
async def fake_buy(*_, **__):
@@ -1432,7 +1432,7 @@ class MockManualService:
14321432
def __init__(self, db):
14331433
pass
14341434

1435-
async def get_holdings(self, user_id, market_type):
1435+
async def get_holdings_by_user(self, user_id, market_type):
14361436
return []
14371437

14381438
async def fake_buy(*_, **__):

tests/test_monitoring.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastapi import HTTPException
99

1010
from app.middleware.monitoring import MonitoringMiddleware
11-
from app.monitoring.error_reporter import ErrorReporter
11+
from app.monitoring.error_reporter import ErrorReporter, escape_markdown
1212

1313

1414
@pytest.fixture
@@ -132,3 +132,74 @@ def __str__(self):
132132
assert reporter.last_context["status_code"] == 500
133133
assert reporter.last_context["request_id"] == "req-1"
134134
assert isinstance(reporter.last_context["duration_ms"], float)
135+
136+
137+
class TestEscapeMarkdown:
138+
"""escape_markdown 함수 테스트."""
139+
140+
def test_escape_underscore(self):
141+
"""언더스코어(_)가 이스케이프되어야 함."""
142+
text = "process_kis_domestic_sell_orders"
143+
result = escape_markdown(text)
144+
assert result == r"process\_kis\_domestic\_sell\_orders"
145+
146+
def test_escape_asterisk(self):
147+
"""별표(*)가 이스케이프되어야 함."""
148+
text = "**bold** and *italic*"
149+
result = escape_markdown(text)
150+
assert r"\*\*bold\*\*" in result
151+
assert r"\*italic\*" in result
152+
153+
def test_escape_backtick(self):
154+
"""백틱(`)이 이스케이프되어야 함."""
155+
text = "use `code` here"
156+
result = escape_markdown(text)
157+
assert r"\`code\`" in result
158+
159+
def test_escape_square_bracket(self):
160+
"""대괄호([)가 이스케이프되어야 함."""
161+
text = "see [link] here"
162+
result = escape_markdown(text)
163+
assert r"\[link]" in result
164+
165+
def test_no_escape_for_normal_text(self):
166+
"""특수문자가 없는 일반 텍스트는 변경되지 않아야 함."""
167+
text = "APBK0400 주문 가능한 수량을 초과했습니다."
168+
result = escape_markdown(text)
169+
assert result == text
170+
171+
def test_complex_error_message(self):
172+
"""실제 에러 메시지 시나리오."""
173+
text = "File '/app/services/kis_trading_service.py', line 343"
174+
result = escape_markdown(text)
175+
# 언더스코어만 이스케이프됨
176+
assert r"kis\_trading\_service" in result
177+
178+
179+
class TestFormatErrorMessageWithEscape:
180+
"""에러 메시지 포맷팅 시 이스케이프 테스트."""
181+
182+
def test_format_message_escapes_error_message(self, error_reporter):
183+
"""에러 메시지의 특수문자가 이스케이프되어야 함."""
184+
message = error_reporter._format_error_message(
185+
error_type="RuntimeError",
186+
error_message="Error in func_name with _underscore",
187+
stack_trace="simple trace",
188+
)
189+
# 에러 메시지 부분에서 언더스코어가 이스케이프되어야 함
190+
assert r"func\_name" in message
191+
assert r"\_underscore" in message
192+
193+
def test_format_message_escapes_additional_context(self, error_reporter):
194+
"""추가 컨텍스트의 특수문자가 이스케이프되어야 함."""
195+
message = error_reporter._format_error_message(
196+
error_type="RuntimeError",
197+
error_message="Error",
198+
stack_trace="trace",
199+
additional_context={
200+
"task_name": "kis.run_per_domestic_stock_automation",
201+
"stock": "삼성전자우 (005935)"
202+
}
203+
)
204+
# task_name의 언더스코어가 이스케이프되어야 함
205+
assert r"run\_per\_domestic\_stock\_automation" in message

0 commit comments

Comments
 (0)