Skip to content

Commit 6fe9dbf

Browse files
authored
Merge pull request #2 from MigoXLab/feat/custom_api
Feat/custom api
2 parents 9cd6277 + 82c49d9 commit 6fe9dbf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+10646
-2182
lines changed

backend/api/api_task.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_tasks_status_svc,
2828
get_tasks_svc,
2929
stop_task_svc,
30+
test_api_endpoint_svc,
3031
)
3132

3233
# Create an API router for task-related endpoints
@@ -175,3 +176,18 @@ async def compare_performance(request: Request, comparison_request: ComparisonRe
175176
ComparisonResponse: A response object containing comparison metrics.
176177
"""
177178
return await compare_performance_svc(request, comparison_request)
179+
180+
181+
@router.post("/test", response_model=Dict[str, Any])
182+
async def test_api_endpoint(request: Request, task_create: TaskCreateReq):
183+
"""
184+
Test the API endpoint with the provided configuration.
185+
186+
Args:
187+
request (Request): The incoming request.
188+
task_create (TaskCreateReq): The data for testing the API endpoint.
189+
190+
Returns:
191+
Dict[str, Any]: A response containing the test result.
192+
"""
193+
return await test_api_endpoint_svc(request, task_create)

backend/db/mysql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from typing import AsyncGenerator
77

8-
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
9-
from sqlalchemy.orm import declarative_base, sessionmaker
8+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
9+
from sqlalchemy.orm import declarative_base
1010
from sqlalchemy.pool import AsyncAdaptedQueuePool
1111

1212
from config.db_config import get_settings
@@ -28,7 +28,7 @@
2828
)
2929

3030
# Create a factory for asynchronous database sessions
31-
async_session_factory = sessionmaker(
31+
async_session_factory = async_sessionmaker(
3232
engine,
3333
class_=AsyncSession,
3434
expire_on_commit=False,

backend/model/task.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ class HeaderItem(BaseModel):
8888
fixed: bool = True
8989

9090

91+
class CookieItem(BaseModel):
92+
"""
93+
Represents a single HTTP cookie item for a request.
94+
95+
Attributes:
96+
key: The cookie name.
97+
value: The cookie value.
98+
"""
99+
100+
key: str
101+
value: str
102+
103+
91104
class CertConfig(BaseModel):
92105
"""
93106
Configuration for SSL/TLS certificates.
@@ -123,19 +136,24 @@ class TaskCreateReq(BaseModel):
123136
api_path: str = Field(
124137
default="/v1/chat/completions", description="API path to test"
125138
)
126-
model: str = Field(..., description="Name of the model to test")
139+
model: Optional[str] = Field(default="", description="Name of the model to test")
127140
duration: int = Field(
128141
default=300, ge=1, description="Duration of the test in seconds"
129142
)
130143
concurrent_users: int = Field(..., ge=1, description="Number of concurrent users")
131144
spawn_rate: int = Field(ge=1, description="Number of users to spawn per second")
132-
chat_type: int = Field(ge=0, description="Type of chat interaction")
145+
chat_type: Optional[int] = Field(
146+
default=0, ge=0, description="Type of chat interaction"
147+
)
133148
stream_mode: bool = Field(
134149
default=True, description="Whether to use streaming response"
135150
)
136151
headers: List[HeaderItem] = Field(
137152
default_factory=list, description="List of request headers"
138153
)
154+
cookies: List[CookieItem] = Field(
155+
default_factory=list, description="List of request cookies"
156+
)
139157
cert_config: Optional[CertConfig] = Field(
140158
default=None, description="Certificate configuration"
141159
)
@@ -145,6 +163,12 @@ class TaskCreateReq(BaseModel):
145163
user_prompt: Optional[str] = Field(
146164
default="", description="User prompt for the model"
147165
)
166+
request_payload: Optional[str] = Field(
167+
default="", description="Custom request payload for non-chat APIs (JSON string)"
168+
)
169+
field_mapping: Optional[Dict[str, str]] = Field(
170+
default=None, description="Field mapping configuration for custom APIs"
171+
)
148172

149173

150174
class TaskResultItem(BaseModel):
@@ -289,20 +313,23 @@ class Task(Base):
289313
name = Column(String(255), nullable=False)
290314
status = Column(String(32), nullable=False)
291315
target_host = Column(String(255), nullable=False)
292-
model = Column(String(100), nullable=False)
316+
model = Column(String(100), nullable=True)
293317
system_prompt = Column(Text, nullable=True)
294318
user_prompt = Column(Text, nullable=True)
295319
stream_mode = Column(String(20), nullable=False)
296320
concurrent_users = Column(Integer, nullable=False)
297321
spawn_rate = Column(Integer, nullable=False)
298322
duration = Column(Integer, nullable=False)
299-
chat_type = Column(Integer, nullable=False)
323+
chat_type = Column(Integer, nullable=True)
300324
log_file = Column(Text, nullable=True)
301325
result_file = Column(Text, nullable=True)
302326
cert_file = Column(String(255), nullable=True)
303327
key_file = Column(String(255), nullable=True)
304328
headers = Column(Text, nullable=True)
305-
# api_path = Column(String(255), nullable=True)
329+
cookies = Column(Text, nullable=True)
330+
api_path = Column(String(255), nullable=True)
331+
request_payload = Column(Text, nullable=True)
332+
field_mapping = Column(Text, nullable=True)
306333
error_message = Column(Text, nullable=True)
307334
created_at = Column(DateTime, server_default=func.now())
308335
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
@@ -326,10 +353,10 @@ class TaskResult(Base):
326353
p90_latency = Column(Float, nullable=False)
327354
rps = Column(Float, nullable=False)
328355
avg_content_length = Column(Float, nullable=False)
329-
total_tps = Column(Float, nullable=False)
330-
completion_tps = Column(Float, nullable=False)
331-
avg_total_tokens_per_req = Column(Float, nullable=False)
332-
avg_completion_tokens_per_req = Column(Float, nullable=False)
356+
total_tps = Column(Float, nullable=True, default=0.0)
357+
completion_tps = Column(Float, nullable=True, default=0.0)
358+
avg_total_tokens_per_req = Column(Float, nullable=True, default=0.0)
359+
avg_completion_tokens_per_req = Column(Float, nullable=True, default=0.0)
333360
created_at = Column(DateTime, server_default=func.now())
334361
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now())
335362

@@ -349,8 +376,18 @@ def to_task_result_item(self) -> TaskResultItem:
349376
rps=self.rps,
350377
avg_content_length=self.avg_content_length,
351378
created_at=self.created_at.isoformat() if self.created_at else "",
352-
total_tps=self.total_tps,
353-
completion_tps=self.completion_tps,
354-
avg_total_tokens_per_req=self.avg_total_tokens_per_req,
355-
avg_completion_tokens_per_req=self.avg_completion_tokens_per_req,
379+
total_tps=self.total_tps if self.total_tps is not None else 0.0,
380+
completion_tps=(
381+
self.completion_tps if self.completion_tps is not None else 0.0
382+
),
383+
avg_total_tokens_per_req=(
384+
self.avg_total_tokens_per_req
385+
if self.avg_total_tokens_per_req is not None
386+
else 0.0
387+
),
388+
avg_completion_tokens_per_req=(
389+
self.avg_completion_tokens_per_req
390+
if self.avg_completion_tokens_per_req is not None
391+
else 0.0
392+
),
356393
)

backend/mypy.ini

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
[mypy]
2+
python_version = 3.11
3+
warn_return_any = False
4+
warn_unused_configs = True
5+
show_error_codes = True
6+
7+
# ignore SQLAlchemy errors
8+
disallow_untyped_defs = False
9+
disallow_incomplete_defs = False
10+
check_untyped_defs = False
11+
12+
# ignore SQLAlchemy module import errors
13+
[mypy-sqlalchemy.*]
14+
ignore_missing_imports = True
15+
ignore_errors = True
16+
17+
# ignore specific SQLAlchemy type errors
18+
[mypy-model.*]
19+
ignore_errors = True
20+
21+
# ignore other dependencies
22+
[mypy-fastapi.*]
23+
ignore_missing_imports = True
24+
25+
[mypy-pydantic.*]
26+
ignore_missing_imports = True
27+
28+
[mypy-uvicorn.*]
29+
ignore_missing_imports = True
30+
31+
[mypy-starlette.*]
32+
ignore_missing_imports = True

backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,4 @@ markers = [
103103

104104
[tool.bandit]
105105
exclude_dirs = ["tests"]
106-
skips = ["B101", "B601"]
106+
skips = ["B101", "B601", "B501"]

backend/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ pydantic_settings
88
aiomysql
99
greenlet
1010
werkzeug
11-
python-multipart
11+
python-multipart
12+
httpx

backend/service/log_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_last_n_lines(file_path: str, n: int = 100) -> str:
3030
# Move to the end of the file
3131
f.seek(0, os.SEEK_END)
3232
block_size = 1024
33-
lines_found = deque()
33+
lines_found: deque[str] = deque()
3434

3535
while f.tell() > 0 and len(lines_found) <= n:
3636
# Calculate the position and size of the next block to read
@@ -59,7 +59,7 @@ def get_last_n_lines(file_path: str, n: int = 100) -> str:
5959

6060
return "\n".join(list(lines_found)[-n:])
6161
except Exception as e:
62-
logger.error(f"Failed to read log file:- {str(e)}")
62+
logger.error(f"Failed to read log file: {str(e)}")
6363
return ""
6464

6565

@@ -110,12 +110,12 @@ async def get_service_log_svc(service_name: str, offset: int, tail: int):
110110
log_file_path = os.path.join(LOG_DIR, f"{service_name}.log")
111111

112112
if not os.path.exists(log_file_path):
113-
logger.warning("Log file not found.")
113+
logger.warning(f"Log file not found: {log_file_path}")
114114
return JSONResponse(
115115
status_code=404,
116116
content={
117117
"status": "error",
118-
"error": f"Log file for service '{service_name}' not found at {log_file_path}",
118+
"error": f"Log file for service '{service_name}' not found",
119119
},
120120
)
121121
try:

0 commit comments

Comments
 (0)