Skip to content

Commit 4ea9c55

Browse files
committed
Merge branch 'v2-c' into knowledge_workflow
2 parents 3a76770 + ca9fa9d commit 4ea9c55

File tree

11 files changed

+180
-107
lines changed

11 files changed

+180
-107
lines changed

apps/chat/serializers/chat.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,6 @@ def generate_prompt(self, instance: dict):
173173

174174
message = messages[-1]['content']
175175
q = prompt.replace("{userInput}", message)
176-
q = q.replace("{application_name}", application.name)
177-
q = q.replace("{detail}", application.desc)
178176

179177
messages[-1]['content'] = q
180178
SUPPORTED_MODEL_TYPES = ["LLM", "IMAGE"]
@@ -185,13 +183,11 @@ def generate_prompt(self, instance: dict):
185183
if not model_exist:
186184
raise Exception(_("Model does not exists or is not an LLM model"))
187185

188-
system_content = SYSTEM_ROLE.format(application_name=application.name, detail=application.desc)
189-
190186
def process():
191187
model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id,
192188
**application.model_params_setting)
193189
try:
194-
for r in model.stream([SystemMessage(content=system_content),
190+
for r in model.stream([SystemMessage(content=SYSTEM_ROLE),
195191
*[HumanMessage(content=m.get('content')) if m.get(
196192
'role') == 'user' else AIMessage(
197193
content=m.get('content')) for m in messages]]):

apps/chat/template/generate_prompt_system

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,3 @@
6262

6363
输出时不得包含任何解释或附加说明,只能返回符合以上格式的内容。
6464

65-
智能体名称: {application_name}
66-
功能描述: {detail}

apps/common/utils/tool_code.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def exec_code(self, code_str, keywords, function_name=None):
112112
lines = subprocess_result.stdout.splitlines()
113113
result_line = [line for line in lines if line.startswith(_id)]
114114
if not result_line:
115+
maxkb_logger.error("\n".join(lines))
115116
raise Exception("No result found.")
116117
result = json.loads(base64.b64decode(result_line[-1].split(":", 1)[1]).decode())
117118
if result.get('code') == 200:

apps/ops/celery/heartbeat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,20 @@
66
@heartbeat_sent.connect
77
def heartbeat(sender, **kwargs):
88
worker_name = sender.eventer.hostname.split('@')[0]
9-
heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name))
9+
heartbeat_path = Path('/opt/maxkb-app/tmp/worker_heartbeat_{}'.format(worker_name))
1010
heartbeat_path.touch()
1111

1212

1313
@worker_ready.connect
1414
def worker_ready(sender, **kwargs):
1515
worker_name = sender.hostname.split('@')[0]
16-
ready_path = Path('/tmp/worker_ready_{}'.format(worker_name))
16+
ready_path = Path('/opt/maxkb-app/tmp/worker_ready_{}'.format(worker_name))
1717
ready_path.touch()
1818

1919

2020
@worker_shutdown.connect
2121
def worker_shutdown(sender, **kwargs):
2222
worker_name = sender.hostname.split('@')[0]
2323
for signal in ['ready', 'heartbeat']:
24-
path = Path('/tmp/worker_{}_{}'.format(signal, worker_name))
24+
path = Path('/opt/maxkb-app/tmp/worker_{}_{}'.format(signal, worker_name))
2525
path.unlink(missing_ok=True)

apps/oss/serializers/file.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
# coding=utf-8
2+
import base64
3+
import ipaddress
24
import re
5+
import socket
36
import urllib
7+
from urllib.parse import urlparse
48

9+
import requests
510
import uuid_utils.compat as uuid
611
from django.db.models import QuerySet
712
from django.http import HttpResponse
813
from django.utils.translation import gettext_lazy as _
914
from rest_framework import serializers
1015

11-
from common.exception.app_exception import NotFound404
16+
from application.models import Application
17+
from common.exception.app_exception import NotFound404, AppApiException
1218
from knowledge.models import File, FileSourceType
1319
from tools.serializers.tool import UploadedFileField
1420

@@ -158,3 +164,80 @@ def delete(self):
158164
if file is not None:
159165
file.delete()
160166
return True
167+
168+
169+
def get_url_content(url, application_id: str):
170+
application = Application.objects.filter(id=application_id).first()
171+
if application is None:
172+
return AppApiException(500, _('Application does not exist'))
173+
if not application.file_upload_enable:
174+
return AppApiException(500, _('File upload is not enabled'))
175+
file_limit = 50 * 1024 * 1024
176+
if application.file_upload_setting and application.file_upload_setting.file_limit:
177+
file_limit = application.file_upload_setting.file_limit * 1024 * 1024
178+
parsed = validate_url(url)
179+
180+
response = requests.get(
181+
url,
182+
timeout=3,
183+
allow_redirects=False
184+
)
185+
final_host = urlparse(response.url).hostname
186+
if is_private_ip(final_host):
187+
raise ValueError("Blocked unsafe redirect to internal host")
188+
# 判断文件大小
189+
if response.headers.get('Content-Length', 0) > file_limit:
190+
return AppApiException(500, _('File size exceeds limit'))
191+
# 返回状态码 响应内容大小 响应的contenttype 还有字节流
192+
content_type = response.headers.get('Content-Type', '')
193+
# 根据内容类型决定如何处理
194+
if 'text' in content_type or 'json' in content_type:
195+
content = response.text
196+
else:
197+
# 二进制内容使用Base64编码
198+
content = base64.b64encode(response.content).decode('utf-8')
199+
200+
return {
201+
'status_code': response.status_code,
202+
'Content-Length': response.headers.get('Content-Length', 0),
203+
'Content-Type': content_type,
204+
'content': content,
205+
}
206+
207+
208+
def is_private_ip(host: str) -> bool:
209+
"""检测 IP 是否属于内网、环回、云 metadata 的危险地址"""
210+
try:
211+
ip = ipaddress.ip_address(socket.gethostbyname(host))
212+
return (
213+
ip.is_private or
214+
ip.is_loopback or
215+
ip.is_reserved or
216+
ip.is_link_local or
217+
ip.is_multicast
218+
)
219+
except Exception:
220+
return True
221+
222+
223+
def validate_url(url: str):
224+
"""验证 URL 是否安全"""
225+
if not url:
226+
raise ValueError("URL is required")
227+
228+
parsed = urlparse(url)
229+
230+
# 仅允许 http / https
231+
if parsed.scheme not in ("http", "https"):
232+
raise ValueError("Only http and https are allowed")
233+
234+
host = parsed.hostname
235+
# 域名不能为空
236+
if not host:
237+
raise ValueError("Invalid URL")
238+
239+
# 禁止访问内部、保留、环回、云 metadata
240+
if is_private_ip(host):
241+
raise ValueError("Access to internal IP addresses is blocked")
242+
243+
return parsed

apps/oss/views/file.py

Lines changed: 4 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
11
# coding=utf-8
2-
import base64
3-
import ipaddress
4-
import socket
5-
from urllib.parse import urlparse
6-
7-
import requests
82
from django.utils.translation import gettext_lazy as _
93
from drf_spectacular.utils import extend_schema
104
from rest_framework.parsers import MultiPartParser
115
from rest_framework.views import APIView
126
from rest_framework.views import Request
13-
147
from common.auth import TokenAuth
158
from common.log.log import log
169
from common.result import result
1710
from knowledge.api.file import FileUploadAPI, FileGetAPI
18-
from oss.serializers.file import FileSerializer
11+
from oss.serializers.file import FileSerializer, get_url_content
1912

2013

2114
class FileRetrievalView(APIView):
@@ -84,71 +77,7 @@ class GetUrlView(APIView):
8477
operation_id=_('Get url'), # type: ignore
8578
tags=[_('Chat')] # type: ignore
8679
)
87-
def get(self, request: Request):
80+
def get(self, request: Request, application_id: str):
8881
url = request.query_params.get('url')
89-
parsed = validate_url(url)
90-
91-
response = requests.get(
92-
url,
93-
timeout=3,
94-
allow_redirects=False
95-
)
96-
final_host = urlparse(response.url).hostname
97-
if is_private_ip(final_host):
98-
raise ValueError("Blocked unsafe redirect to internal host")
99-
100-
# 返回状态码 响应内容大小 响应的contenttype 还有字节流
101-
content_type = response.headers.get('Content-Type', '')
102-
# 根据内容类型决定如何处理
103-
if 'text' in content_type or 'json' in content_type:
104-
content = response.text
105-
else:
106-
# 二进制内容使用Base64编码
107-
content = base64.b64encode(response.content).decode('utf-8')
108-
109-
return result.success({
110-
'status_code': response.status_code,
111-
'Content-Length': response.headers.get('Content-Length', 0),
112-
'Content-Type': content_type,
113-
'content': content,
114-
})
115-
116-
117-
def is_private_ip(host: str) -> bool:
118-
"""检测 IP 是否属于内网、环回、云 metadata 的危险地址"""
119-
try:
120-
ip = ipaddress.ip_address(socket.gethostbyname(host))
121-
return (
122-
ip.is_private or
123-
ip.is_loopback or
124-
ip.is_reserved or
125-
ip.is_link_local or
126-
ip.is_multicast
127-
)
128-
except Exception:
129-
return True
130-
131-
132-
def validate_url(url: str):
133-
"""验证 URL 是否安全"""
134-
if not url:
135-
raise ValueError("URL is required")
136-
137-
parsed = urlparse(url)
138-
139-
# 仅允许 http / https
140-
if parsed.scheme not in ("http", "https"):
141-
raise ValueError("Only http and https are allowed")
142-
143-
host = parsed.hostname
144-
path = parsed.path
145-
146-
# 域名不能为空
147-
if not host:
148-
raise ValueError("Invalid URL")
149-
150-
# 禁止访问内部、保留、环回、云 metadata
151-
if is_private_ip(host):
152-
raise ValueError("Access to internal IP addresses is blocked")
153-
154-
return parsed
82+
result_data = get_url_content(url, application_id)
83+
return result.success(result_data)

0 commit comments

Comments
 (0)