Skip to content

Commit 84c53c2

Browse files
committed
feat:custom dataset
1 parent 59a75b4 commit 84c53c2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3724
-2472
lines changed

backend/config/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

backend/db/mysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sqlalchemy.orm import declarative_base
1010
from sqlalchemy.pool import AsyncAdaptedQueuePool
1111

12-
from config.db_config import get_settings
12+
from db.db_config import get_settings
1313

1414
settings = get_settings()
1515

backend/model/task.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ class TaskCreateReq(BaseModel):
160160
system_prompt: Optional[str] = Field(
161161
default="", description="System prompt for the model"
162162
)
163-
user_prompt: Optional[str] = Field(
164-
default="", description="User prompt for the model"
165-
)
166163
request_payload: Optional[str] = Field(
167164
default="", description="Custom request payload for non-chat APIs (JSON string)"
168165
)
169166
field_mapping: Optional[Dict[str, str]] = Field(
170167
default=None, description="Field mapping configuration for custom APIs"
171168
)
169+
test_data: Optional[str] = Field(
170+
default="", description="Custom test data in JSONL format or file path"
171+
)
172172

173173

174174
class TaskResultItem(BaseModel):
@@ -315,7 +315,6 @@ class Task(Base):
315315
target_host = Column(String(255), nullable=False)
316316
model = Column(String(100), nullable=True)
317317
system_prompt = Column(Text, nullable=True)
318-
user_prompt = Column(Text, nullable=True)
319318
stream_mode = Column(String(20), nullable=False)
320319
concurrent_users = Column(Integer, nullable=False)
321320
spawn_rate = Column(Integer, nullable=False)
@@ -331,6 +330,7 @@ class Task(Base):
331330
request_payload = Column(Text, nullable=True)
332331
field_mapping = Column(Text, nullable=True)
333332
error_message = Column(Text, nullable=True)
333+
test_data = Column(Text, nullable=True)
334334
created_at = Column(DateTime, server_default=func.now())
335335
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
336336

backend/model/upload.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ class UploadFileRsp(BaseModel):
3232
task_id: The identifier for the associated task.
3333
files: A list of `UploadedFileInfo` objects for each uploaded file.
3434
cert_config: An optional dictionary for certificate configuration.
35+
test_data: An optional string for test data file path.
3536
"""
3637

3738
message: str
3839
task_id: str
3940
files: List[UploadedFileInfo]
4041
cert_config: Optional[dict] = None
42+
test_data: Optional[str] = None
4143

4244

4345
class UploadFileReq(BaseModel):

backend/pyproject.toml

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,34 @@ known_third_party = ["fastapi", "pydantic", "sqlalchemy", "loguru", "uvicorn"]
3232

3333
[tool.mypy]
3434
python_version = "3.9"
35+
# completely disable type checking
36+
ignore_errors = true
37+
ignore_missing_imports = true
38+
follow_imports = "skip"
3539
# base settings
3640
warn_return_any = false
37-
warn_unused_configs = true
38-
show_error_codes = true
41+
warn_unused_configs = false
42+
show_error_codes = false
3943

4044
# relax type check requirements
4145
disallow_untyped_defs = false
4246
disallow_incomplete_defs = false
4347
check_untyped_defs = false
4448
disallow_untyped_decorators = false
4549
no_implicit_optional = false
50+
disallow_any_generics = false
51+
disallow_any_unimported = false
52+
disallow_subclassing_any = false
4653

47-
# keep useful warnings
54+
# disable warnings
4855
warn_redundant_casts = false
4956
warn_unused_ignores = false
5057
warn_no_return = false
5158
warn_unreachable = false
5259
strict_equality = false
60+
strict_optional = false
61+
allow_redefinition = true
62+
implicit_reexport = true
5363

5464
# ignore common third-party libraries
5565
[[tool.mypy.overrides]]
@@ -82,6 +92,14 @@ disallow_incomplete_defs = false
8292
check_untyped_defs = false
8393
warn_return_any = false
8494
ignore_errors = true
95+
follow_imports = "skip"
96+
follow_imports_for_stubs = false
97+
98+
# SQLAlchemy specific overrides
99+
[[tool.mypy.overrides]]
100+
module = ["sqlalchemy.*"]
101+
ignore_missing_imports = true
102+
ignore_errors = true
85103

86104
[tool.pytest.ini_options]
87105
testpaths = ["tests"]

backend/service/log_service.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
from starlette.responses import JSONResponse
99

10-
from config.config import LOG_DIR
1110
from model.log import LogContentResponse
12-
from utils.logger import be_logger as logger
11+
from utils.be_config import LOG_DIR
12+
from utils.logger import logger
1313

1414

1515
def get_last_n_lines(file_path: str, n: int = 100) -> str:
@@ -36,15 +36,15 @@ def get_last_n_lines(file_path: str, n: int = 100) -> str:
3636
all_lines = f.readlines()
3737
return "".join(all_lines[-n:]) if all_lines else ""
3838

39-
# For larger files, use a more reliable approach
40-
# Start from end and read backwards in larger chunks
41-
buffer_size = 8192
42-
lines: list[str] = []
39+
# For larger files, use a more efficient approach
40+
# Read from end in chunks and collect lines
41+
lines = list[str]()
4342
buffer = ""
4443
position = file_size
44+
buffer_size = 8192
4545

4646
while position > 0 and len(lines) < n:
47-
# Calculate chunk size to read
47+
# Calculate how much to read
4848
chunk_size = min(buffer_size, position)
4949
position -= chunk_size
5050

@@ -56,26 +56,45 @@ def get_last_n_lines(file_path: str, n: int = 100) -> str:
5656
buffer = chunk + buffer
5757

5858
# Split buffer into lines
59-
lines_in_buffer = buffer.split("\n")
60-
61-
# Keep the first part (might be incomplete line) in buffer
62-
buffer = lines_in_buffer[0]
63-
64-
# Add complete lines to our lines list (in reverse order since we're reading backwards)
65-
for line in reversed(lines_in_buffer[1:]):
59+
split_lines = buffer.split("\n")
60+
61+
# If we haven't reached the beginning, keep the first part as incomplete line
62+
if position > 0:
63+
buffer = split_lines[0]
64+
# Process complete lines (skip the first incomplete one)
65+
complete_lines = split_lines[1:]
66+
else:
67+
# At the beginning of file, all lines are complete
68+
buffer = ""
69+
complete_lines = split_lines
70+
71+
# Add complete lines to the front of our lines list
72+
# (since we're reading backwards)
73+
for line in reversed(complete_lines):
6674
lines.insert(0, line)
6775
if len(lines) >= n:
6876
break
6977

70-
# If we've reached the beginning of file, add the remaining buffer as a line
71-
if position == 0 and buffer:
72-
lines.insert(0, buffer)
78+
# If we have enough lines, break
79+
if len(lines) >= n:
80+
break
7381

74-
# Take last n lines and join them with newlines
82+
# Take last n lines
7583
result_lines = lines[-n:] if len(lines) > n else lines
76-
return "\n".join(result_lines) + (
77-
"\n" if result_lines and not result_lines[-1].endswith("\n") else ""
78-
)
84+
85+
# Join lines and ensure proper ending
86+
if not result_lines:
87+
return ""
88+
89+
result = "\n".join(result_lines)
90+
# Add final newline if the original content had one and result doesn't end with one
91+
if result and not result.endswith("\n"):
92+
# Check if original file ends with newline
93+
f.seek(max(0, file_size - 1))
94+
if f.read(1) == "\n":
95+
result += "\n"
96+
97+
return result
7998

8099
except Exception as e:
81100
logger.error(f"Failed to read log file: {str(e)}")

backend/service/task_service.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from sqlalchemy import func, or_, select, text
1313
from starlette.responses import JSONResponse
1414

15-
from config.config import UPLOAD_FOLDER
1615
from model.task import (
1716
ComparisonMetrics,
1817
ComparisonRequest,
@@ -29,7 +28,8 @@
2928
TaskResultRsp,
3029
TaskStatusRsp,
3130
)
32-
from utils.logger import be_logger as logger
31+
from utils.be_config import UPLOAD_FOLDER
32+
from utils.logger import logger
3333

3434

3535
async def get_tasks_svc(
@@ -53,7 +53,7 @@ async def get_tasks_svc(
5353
Returns:
5454
A `TaskResponse` object containing the list of tasks and pagination details.
5555
"""
56-
task_list = []
56+
task_list: List[Dict] = []
5757
pagination = Pagination()
5858
try:
5959
db = request.state.db
@@ -100,27 +100,67 @@ async def get_tasks_svc(
100100
)
101101

102102
# Format the task data for the response.
103-
task_list = [
104-
{
103+
task_list = []
104+
for task in tasks:
105+
# Convert headers from JSON string back to a list of objects for the frontend.
106+
# headers_list = []
107+
# if task.headers:
108+
# try:
109+
# headers_dict = json.loads(task.headers)
110+
# headers_list = [
111+
# {"key": k, "value": v} for k, v in headers_dict.items()
112+
# ]
113+
# except json.JSONDecodeError:
114+
# logger.warning(
115+
# f"Could not parse headers JSON for task {task.id}: {task.headers}"
116+
# )
117+
118+
# Convert cookies from JSON string back to a list of objects for the frontend.
119+
# cookies_list = []
120+
# if task.cookies:
121+
# try:
122+
# cookies_dict = json.loads(task.cookies)
123+
# cookies_list = [
124+
# {"key": k, "value": v} for k, v in cookies_dict.items()
125+
# ]
126+
# except json.JSONDecodeError:
127+
# logger.warning(
128+
# f"Could not parse cookies JSON for task {task.id}: {task.cookies}"
129+
# )
130+
131+
# Parse field_mapping from JSON string back to dictionary
132+
field_mapping_dict = {}
133+
if task.field_mapping:
134+
try:
135+
field_mapping_dict = json.loads(task.field_mapping)
136+
except json.JSONDecodeError:
137+
logger.warning(
138+
f"Could not parse field_mapping JSON for task {task.id}: {task.field_mapping}"
139+
)
140+
141+
task_data = {
105142
"id": task.id,
106143
"name": task.name,
107144
"status": task.status,
108145
"target_host": task.target_host,
109146
"api_path": task.api_path,
110147
"model": task.model,
111148
"request_payload": task.request_payload,
112-
"field_mapping": task.field_mapping,
149+
"field_mapping": field_mapping_dict,
113150
"concurrent_users": task.concurrent_users,
114151
"duration": task.duration,
115152
"spawn_rate": task.spawn_rate,
116153
"chat_type": task.chat_type,
117154
"stream_mode": str(task.stream_mode).lower() == "true",
118-
"error_message": task.error_message,
155+
"headers": "",
156+
"cookies": "",
157+
"cert_config": "",
158+
"system_prompt": task.system_prompt or "",
159+
"test_data": task.test_data or "",
119160
"created_at": task.created_at.isoformat() if task.created_at else None,
120161
"updated_at": task.updated_at.isoformat() if task.updated_at else None,
121162
}
122-
for task in tasks
123-
]
163+
task_list.append(task_data)
124164
except Exception as e:
125165
logger.error(f"Error getting tasks: {e}", exc_info=True)
126166
return TaskResponse(data=[], pagination=Pagination(), status="error")
@@ -227,17 +267,19 @@ async def create_task_svc(request: Request, body: TaskCreateReq):
227267
# Convert absolute paths to relative paths for cross-service compatibility
228268
if cert_file:
229269
# Convert backend upload path to relative path that st_engine can access
230-
if cert_file.startswith("/app/upload_files/"):
231-
cert_file = cert_file.replace("/app/upload_files/", "")
232-
elif UPLOAD_FOLDER in cert_file:
270+
if cert_file.startswith(UPLOAD_FOLDER + "/"):
233271
cert_file = cert_file.replace(UPLOAD_FOLDER + "/", "")
272+
elif cert_file.startswith("/app/upload_files/"):
273+
# For backward compatibility with existing Docker paths
274+
cert_file = cert_file.replace("/app/upload_files/", "")
234275

235276
if key_file:
236277
# Convert backend upload path to relative path that st_engine can access
237-
if key_file.startswith("/app/upload_files/"):
238-
key_file = key_file.replace("/app/upload_files/", "")
239-
elif UPLOAD_FOLDER in key_file:
278+
if key_file.startswith(UPLOAD_FOLDER + "/"):
240279
key_file = key_file.replace(UPLOAD_FOLDER + "/", "")
280+
elif key_file.startswith("/app/upload_files/"):
281+
# For backward compatibility with existing Docker paths
282+
key_file = key_file.replace("/app/upload_files/", "")
241283

242284
# Convert headers from a list of objects to a dictionary, then to a JSON string.
243285
headers = {
@@ -278,12 +320,12 @@ async def create_task_svc(request: Request, body: TaskCreateReq):
278320
status="created",
279321
error_message="",
280322
system_prompt=body.system_prompt,
281-
user_prompt=body.user_prompt,
282323
cert_file=cert_file,
283324
key_file=key_file,
284325
api_path=body.api_path,
285326
request_payload=body.request_payload,
286327
field_mapping=field_mapping_json,
328+
test_data=body.test_data,
287329
)
288330

289331
db.add(new_task)
@@ -433,11 +475,11 @@ async def get_task_svc(request: Request, task_id: str):
433475
"headers": headers_list,
434476
"cookies": cookies_list,
435477
"cert_config": {"cert_file": task.cert_file, "key_file": task.key_file},
436-
"system_prompt": task.system_prompt,
437-
"user_prompt": task.user_prompt,
478+
"system_prompt": task.system_prompt or "",
438479
"api_path": task.api_path,
439480
"request_payload": task.request_payload,
440481
"field_mapping": field_mapping_dict,
482+
"test_data": task.test_data or "",
441483
"error_message": task.error_message,
442484
"created_at": task.created_at.isoformat() if task.created_at else None,
443485
"updated_at": task.updated_at.isoformat() if task.updated_at else None,
@@ -738,11 +780,12 @@ def _prepare_request_payload(body: TaskCreateReq) -> Dict:
738780
"""Prepare request payload based on API path and configuration."""
739781
if body.api_path == "/v1/chat/completions":
740782
# Use the traditional chat completions format
741-
messages = []
742-
if body.system_prompt:
743-
messages.append({"role": "system", "content": body.system_prompt})
744-
745-
messages.append({"role": "user", "content": body.user_prompt or "Hi"})
783+
messages = [
784+
{
785+
"role": "user",
786+
"content": "Hi",
787+
}
788+
]
746789

747790
return {
748791
"model": body.model,
@@ -931,7 +974,7 @@ async def _handle_streaming_response(response, full_url: str) -> Dict:
931974
}
932975

933976
# For testing purposes, we limit the time and data we collect
934-
max_chunks = 200 # max chunks to collect for testing
977+
max_chunks = 300 # max chunks to collect for testing
935978
max_duration = 15 # max duration to wait for testing
936979

937980
start_time = asyncio.get_event_loop().time()

0 commit comments

Comments
 (0)