Skip to content

Commit c902c7b

Browse files
committed
feat: complete multi-process implementation
1 parent 461280f commit c902c7b

File tree

27 files changed

+916
-636
lines changed

27 files changed

+916
-636
lines changed

backend/db/db_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Config:
3535

3636
env_file = ".env"
3737
case_sensitive = True
38+
extra = "ignore"
3839

3940

4041
@lru_cache()

backend/model/task.py

Lines changed: 161 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
Copyright (c) 2025, All Rights Reserved.
44
"""
55

6+
import re
67
from typing import Dict, List, Optional, Union
78

8-
from pydantic import BaseModel, Field
9+
from pydantic import BaseModel, Field, validator
910
from sqlalchemy import Column, DateTime, Float, Integer, String, Text, func
1011

1112
from db.mysql import Base
@@ -83,8 +84,12 @@ class HeaderItem(BaseModel):
8384
fixed: A boolean indicating if the header is fixed (not used currently).
8485
"""
8586

86-
key: str
87-
value: str
87+
key: str = Field(
88+
..., min_length=1, max_length=100, description="Header name (1-100 chars)"
89+
)
90+
value: str = Field(
91+
..., max_length=1000, description="Header value (max 1000 chars)"
92+
)
8893
fixed: bool = True
8994

9095

@@ -97,8 +102,12 @@ class CookieItem(BaseModel):
97102
value: The cookie value.
98103
"""
99104

100-
key: str
101-
value: str
105+
key: str = Field(
106+
..., min_length=1, max_length=100, description="Cookie name (1-100 chars)"
107+
)
108+
value: str = Field(
109+
..., max_length=1000, description="Cookie value (max 1000 chars)"
110+
)
102111

103112

104113
class CertConfig(BaseModel):
@@ -110,8 +119,12 @@ class CertConfig(BaseModel):
110119
key_file: Path to the SSL private key file.
111120
"""
112121

113-
cert_file: Optional[str] = Field(None, description="Path to the certificate file")
114-
key_file: Optional[str] = Field(None, description="Path to the private key file")
122+
cert_file: Optional[str] = Field(
123+
None, max_length=255, description="Path to the certificate file (max 255 chars)"
124+
)
125+
key_file: Optional[str] = Field(
126+
None, max_length=255, description="Path to the private key file (max 255 chars)"
127+
)
115128

116129

117130
class TaskStopReq(BaseModel):
@@ -130,44 +143,169 @@ class TaskCreateReq(BaseModel):
130143
Request model for creating a new performance testing task.
131144
"""
132145

133-
temp_task_id: str
134-
name: str = Field(..., description="Name of the task")
135-
target_host: str = Field(..., description="Target model API host")
136-
api_path: str = Field(default="/chat/completions", description="API path to test")
137-
model: Optional[str] = Field(default="", description="Name of the model to test")
146+
temp_task_id: str = Field(..., max_length=100, description="Temporary task ID")
147+
name: str = Field(..., min_length=1, max_length=100, description="Name of the task")
148+
target_host: str = Field(
149+
..., min_length=1, max_length=255, description="Target model API host"
150+
)
151+
api_path: str = Field(
152+
default="/chat/completions", max_length=255, description="API path to test"
153+
)
154+
model: Optional[str] = Field(
155+
default="", max_length=255, description="Name of the model to test"
156+
)
138157
duration: int = Field(
139-
default=300, ge=1, description="Duration of the test in seconds"
158+
default=300,
159+
ge=1,
160+
le=172800,
161+
description="Duration of the test in seconds (1-48 hours)",
162+
)
163+
concurrent_users: int = Field(
164+
..., ge=1, le=5000, description="Number of concurrent users (1-5000)"
165+
)
166+
spawn_rate: int = Field(
167+
ge=1, le=100, description="Number of users to spawn per second (1-100)"
140168
)
141-
concurrent_users: int = Field(..., ge=1, description="Number of concurrent users")
142-
spawn_rate: int = Field(ge=1, description="Number of users to spawn per second")
143169
chat_type: Optional[int] = Field(
144-
default=0, ge=0, description="Type of chat interaction"
170+
default=0,
171+
ge=0,
172+
le=1,
173+
description="Type of chat interaction (0=text, 1=multimodal)",
145174
)
146175
stream_mode: bool = Field(
147176
default=True, description="Whether to use streaming response"
148177
)
149178
headers: List[HeaderItem] = Field(
150-
default_factory=list, description="List of request headers"
179+
default_factory=list,
180+
description="List of request headers (max 50)",
151181
)
152182
cookies: List[CookieItem] = Field(
153-
default_factory=list, description="List of request cookies"
183+
default_factory=list,
184+
description="List of request cookies (max 50)",
154185
)
155186
cert_config: Optional[CertConfig] = Field(
156187
default=None, description="Certificate configuration"
157188
)
158-
system_prompt: Optional[str] = Field(
159-
default="", description="System prompt for the model"
160-
)
161189
request_payload: Optional[str] = Field(
162-
default="", description="Custom request payload for non-chat APIs (JSON string)"
190+
default="",
191+
max_length=50000,
192+
description="Custom request payload for non-chat APIs (JSON string, max 50000 chars)",
163193
)
164194
field_mapping: Optional[Dict[str, str]] = Field(
165195
default=None, description="Field mapping configuration for custom APIs"
166196
)
167197
test_data: Optional[str] = Field(
168-
default="", description="Custom test data in JSONL format or file path"
198+
default="",
199+
max_length=1000000,
200+
description="Custom test data in JSONL format or file path (max 1MB)",
169201
)
170202

203+
@validator("name")
204+
def validate_name(cls, v):
205+
if not v or not v.strip():
206+
raise ValueError("Name cannot be empty")
207+
if len(v.strip()) > 100:
208+
raise ValueError("Name length cannot exceed 100 characters")
209+
return v.strip()
210+
211+
@validator("target_host")
212+
def validate_target_host(cls, v):
213+
if not v or not v.strip():
214+
raise ValueError("API address cannot be empty")
215+
v = v.strip()
216+
if len(v) > 255:
217+
raise ValueError("API address length cannot exceed 255 characters")
218+
# 基本URL格式验证
219+
if not (v.startswith("http://") or v.startswith("https://")):
220+
raise ValueError("API address must start with http:// or https://")
221+
return v
222+
223+
@validator("api_path")
224+
def validate_api_path(cls, v):
225+
if not v or not v.strip():
226+
raise ValueError("API path cannot be empty")
227+
v = v.strip()
228+
if len(v) > 255:
229+
raise ValueError("API path length cannot exceed 255 characters")
230+
if not v.startswith("/"):
231+
raise ValueError("API path must start with /")
232+
return v
233+
234+
@validator("model")
235+
def validate_model(cls, v):
236+
if v and len(v.strip()) > 255:
237+
raise ValueError("Model name length cannot exceed 255 characters")
238+
return v.strip() if v else ""
239+
240+
@validator("request_payload")
241+
def validate_request_payload(cls, v, values):
242+
"""Ensure request_payload is never empty - auto-generate if needed"""
243+
# If request_payload is empty, generate default payload
244+
if not v or not v.strip():
245+
model = values.get("model", "your-model-name")
246+
stream_mode = values.get("stream_mode", True)
247+
248+
# Generate default payload for chat/completions API
249+
default_payload = {
250+
"model": model,
251+
"stream": stream_mode,
252+
"messages": [{"role": "user", "content": "Hi"}],
253+
}
254+
255+
import json
256+
257+
return json.dumps(default_payload)
258+
259+
# Validate length
260+
if len(v) > 50000:
261+
raise ValueError("Request payload length cannot exceed 50000 characters")
262+
263+
# Validate JSON format
264+
try:
265+
import json
266+
267+
json.loads(v.strip())
268+
except json.JSONDecodeError:
269+
raise ValueError("Request payload must be a valid JSON format")
270+
271+
return v.strip()
272+
273+
@validator("headers")
274+
def validate_headers(cls, v):
275+
if len(v) > 50:
276+
raise ValueError("Request header count cannot exceed 50")
277+
for header in v:
278+
if not header.key or not header.key.strip():
279+
raise ValueError("Request header name cannot be empty")
280+
if len(header.key.strip()) > 100:
281+
raise ValueError(
282+
"Request header name length cannot exceed 100 characters"
283+
)
284+
if len(header.value) > 1000:
285+
raise ValueError(
286+
"Request header value length cannot exceed 1000 characters"
287+
)
288+
return v
289+
290+
@validator("cookies")
291+
def validate_cookies(cls, v):
292+
if len(v) > 50:
293+
raise ValueError("Cookie count cannot exceed 50")
294+
for cookie in v:
295+
if not cookie.key or not cookie.key.strip():
296+
raise ValueError("Cookie name cannot be empty")
297+
if len(cookie.key.strip()) > 100:
298+
raise ValueError("Cookie name length cannot exceed 100 characters")
299+
if len(cookie.value) > 1000:
300+
raise ValueError("Cookie value length cannot exceed 1000 characters")
301+
return v
302+
303+
@validator("test_data")
304+
def validate_test_data(cls, v):
305+
if v and len(v) > 1000000: # 1MB
306+
raise ValueError("Test data size cannot exceed 1MB")
307+
return v
308+
171309

172310
class TaskResultItem(BaseModel):
173311
"""
@@ -320,7 +458,6 @@ class Task(Base):
320458
status = Column(String(32), nullable=False)
321459
target_host = Column(String(255), nullable=False)
322460
model = Column(String(100), nullable=True)
323-
system_prompt = Column(Text, nullable=True)
324461
stream_mode = Column(String(20), nullable=False)
325462
concurrent_users = Column(Integer, nullable=False)
326463
spawn_rate = Column(Integer, nullable=False)

backend/service/analysis_service.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async def analyze_tasks_svc(
6767
try:
6868
ai_config = await get_ai_service_config_internal_svc(request)
6969
except HTTPException as e:
70-
error_msg = "Failed to get AI service configuration. %s" % str(e)
70+
error_msg = "Failed to get AI service configuration."
7171
logger.error(error_msg, exc_info=True)
7272
return AnalysisResponse(
7373
task_ids=task_ids,
@@ -201,9 +201,8 @@ async def analyze_tasks_svc(
201201

202202
except Exception as e:
203203
# Handle other exceptions - only log error, don't update database
204-
error_message = "Analysis failed for tasks %s: %s" % (
205-
analysis_request.task_ids,
206-
str(e),
204+
error_message = (
205+
f"Analysis failed for tasks {analysis_request.task_ids}: {str(e)}"
207206
)
208207
logger.error(error_message, exc_info=True)
209208
return AnalysisResponse(
@@ -269,7 +268,7 @@ async def get_analysis_svc(request: Request, task_id: str) -> GetAnalysisRespons
269268
)
270269

271270
except Exception as e:
272-
error_msg = "Failed to retrieve analysis for task %s: %s" % (task_id, str(e))
271+
error_msg = f"Failed to retrieve analysis for task {task_id}: {str(e)}"
273272
logger.error(error_msg, exc_info=True)
274273
return GetAnalysisResponse(
275274
data=None,
@@ -323,18 +322,18 @@ async def _call_ai_service(
323322
model_info_str = json.dumps(model_info, ensure_ascii=False, indent=2)
324323
prompt = prompt_template.format(model_info=model_info_str)
325324
except (TypeError, ValueError) as e:
326-
error_msg = "Failed to serialize model_info: %s" % str(e)
325+
error_msg = f"Failed to serialize model_info: {str(e)}"
327326
logger.error(error_msg)
328327
# Try fallback serialization
329328
try:
330329
model_info_str = str(model_info)
331330
prompt = prompt_template.format(model_info=model_info_str)
332331
except Exception as fallback_error:
333-
logger.error("Fallback serialization failed: %s" % str(fallback_error))
334-
raise Exception("Failed to serialize model_info: %s" % str(e))
332+
logger.error(f"Fallback serialization failed: {str(fallback_error)}")
333+
raise Exception(f"Failed to serialize model_info: {str(e)}")
335334
except Exception as format_error:
336-
error_msg = "Failed to format prompt: %s" % str(format_error)
337-
logger.error("Prompt formatting error: %s" % error_msg)
335+
error_msg = f"Failed to format prompt: {str(format_error)}"
336+
logger.error(f"Prompt formatting error: {error_msg}")
338337
raise Exception(error_msg)
339338
else:
340339
error_msg = "model_info is required for task analysis"
@@ -373,22 +372,22 @@ async def _call_ai_service(
373372
raise Exception(error_msg)
374373

375374
except httpx.TimeoutException as e:
376-
error_msg = "AI service request timeout: %s" % str(e)
377-
logger.error("AI service timeout error: %s" % error_msg)
375+
error_msg = f"AI service request timeout: {str(e)}"
376+
logger.error(f"AI service timeout error: {error_msg}")
378377
raise Exception(error_msg)
379378
except httpx.ConnectError as e:
380-
error_msg = "AI service connection error: %s" % str(e)
381-
logger.error("AI service connection error: %s" % error_msg)
379+
error_msg = f"AI service connection error: {str(e)}"
380+
logger.error(f"AI service connection error: {error_msg}")
382381
raise Exception(error_msg)
383382
except httpx.HTTPStatusError as e:
384-
error_msg = "AI service HTTP error: %s - %s" % (e.response.status_code, str(e))
385-
logger.error("AI service HTTP error: %s" % error_msg)
383+
error_msg = f"AI service HTTP error: {e.response.status_code} - {str(e)}"
384+
logger.error(f"AI service HTTP error: {error_msg}")
386385
raise Exception(error_msg)
387386
except httpx.RequestError as e:
388-
error_msg = "AI service request failed: %s" % str(e)
389-
logger.error("AI service request error: %s" % error_msg)
387+
error_msg = f"AI service request failed: {str(e)}"
388+
logger.error(f"AI service request error: {error_msg}")
390389
raise Exception(error_msg)
391390
except Exception as e:
392-
error_msg = "AI service call failed: %s" % str(e)
393-
logger.error("AI service general error: %s" % error_msg)
391+
error_msg = f"AI service call failed: {str(e)}"
392+
logger.error(f"AI service general error: {error_msg}")
394393
raise Exception(error_msg)

backend/service/log_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def get_last_n_lines(file_path: str, n: int = 100) -> str:
9898
return result
9999

100100
except Exception as e:
101-
logger.error("Failed to read log file: %s" % str(e))
101+
logger.error(f"Failed to read log file: {str(e)}")
102102
return ""
103103

104104

@@ -155,7 +155,7 @@ async def get_service_log_svc(service_name: str, offset: int, tail: int):
155155
file_size = os.path.getsize(log_file_path)
156156
return LogContentResponse(content=content, file_size=file_size)
157157
except Exception as e:
158-
logger.error("Failed to read log file %s: %s" % (log_file_path, str(e)))
158+
logger.error(f"Failed to read log file {log_file_path}: {str(e)}")
159159
return ErrorResponse.internal_server_error(ErrorMessages.LOG_FILE_READ_FAILED)
160160

161161

@@ -182,5 +182,5 @@ async def get_task_log_svc(task_id: str, offset: int, tail: int):
182182
file_size = os.path.getsize(log_file_path)
183183
return LogContentResponse(content=content, file_size=file_size)
184184
except Exception as e:
185-
logger.error("Failed to read log file %s: %s" % (log_file_path, str(e)))
185+
logger.error(f"Failed to read log file {log_file_path}: {str(e)}")
186186
return ErrorResponse.internal_server_error(ErrorMessages.LOG_FILE_READ_FAILED)

0 commit comments

Comments
 (0)