Skip to content

Commit 5784fec

Browse files
authored
fix: add response models (#118)
### What - Add return types to make fastapi generate correct OpenAPI typings
1 parent 8c5360b commit 5784fec

File tree

1 file changed

+72
-45
lines changed

1 file changed

+72
-45
lines changed

app/api.py

Lines changed: 72 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def create_flag(
286286
flag: FlagCreate,
287287
request: Request,
288288
_: Any = Depends(get_auth_dependency(UserStatus.isLoggedIn)),
289-
):
289+
) -> Flag:
290290
"""Create a flag for a product.
291291
292292
This function is used to create a flag for a product or an image.
@@ -338,25 +338,29 @@ def create_flag(
338338
ticket.save()
339339

340340
device_id = _get_device_id(request)
341-
return model_to_dict(
342-
FlagModel.create(ticket=ticket, device_id=device_id, **flag.model_dump())
343-
)
341+
return FlagModel.create(ticket=ticket, device_id=device_id, **flag.model_dump())
342+
343+
344+
class GetFlagsResponse(BaseModel):
345+
flags: list[Flag]
344346

345347

346348
@api_v1_router.get("/flags")
347-
def get_flags(_: Any = Depends(get_auth_dependency(UserStatus.isModerator))):
349+
def get_flags(
350+
_: Any = Depends(get_auth_dependency(UserStatus.isModerator)),
351+
) -> GetFlagsResponse:
348352
"""Get all flags.
349353
350354
This function is used to get all flags.
351355
"""
352356
with db:
353-
return {"flags": list(FlagModel.select().dicts().iterator())}
357+
return GetFlagsResponse(flags=list(FlagModel.select().dicts()))
354358

355359

356360
@api_v1_router.get("/flags/{flag_id}")
357361
def get_flag(
358362
flag_id: int, _: Any = Depends(get_auth_dependency(UserStatus.isModerator))
359-
):
363+
) -> Flag:
360364
"""Get a flag by ID.
361365
362366
This function is used to get a flag by its ID.
@@ -373,6 +377,13 @@ def _create_ticket(ticket: TicketCreate):
373377
return TicketModel.create(**ticket.model_dump())
374378

375379

380+
class GetTicketsResponse(BaseModel):
381+
"""Response model for get_tickets endpoint."""
382+
383+
tickets: list[Ticket]
384+
max_page: int
385+
386+
376387
@api_v1_router.get("/tickets")
377388
def get_tickets(
378389
status: TicketStatus | None = None,
@@ -381,7 +392,7 @@ def get_tickets(
381392
page: int = 1,
382393
page_size: int = 10,
383394
_: Any = Depends(get_auth_dependency(UserStatus.isModerator)),
384-
):
395+
) -> GetTicketsResponse:
385396
"""Get all tickets.
386397
387398
This function is used to get all tickets with status open.
@@ -404,18 +415,18 @@ def get_tickets(
404415
count = TicketModel.select().where(*where_clause).count()
405416
max_page = count // page_size + int(count % page_size != 0)
406417
if page > max_page:
407-
return {"tickets": [], "max_page": max_page}
408-
return {
409-
"tickets": list(
418+
return GetTicketsResponse(tickets=[], max_page=max_page)
419+
return GetTicketsResponse(
420+
tickets=list(
410421
TicketModel.select()
411422
.where(*where_clause)
412423
.order_by(TicketModel.created_at.desc())
413424
.offset(offset)
414425
.limit(page_size)
415426
.dicts()
416427
),
417-
"max_page": max_page,
418-
}
428+
max_page=max_page,
429+
)
419430

420431

421432
@api_v1_router.get("/tickets/{ticket_id}")
@@ -461,7 +472,7 @@ def update_ticket_status(
461472
ticket_id: int,
462473
status: TicketStatus,
463474
_: Any = Depends(get_auth_dependency(UserStatus.isModerator)),
464-
):
475+
) -> Ticket:
465476
"""Update the status of a ticket by ID.
466477
467478
This function is used to update the status of a ticket by its ID.
@@ -471,36 +482,51 @@ def update_ticket_status(
471482
ticket = TicketModel.get_by_id(ticket_id)
472483
ticket.status = status
473484
ticket.save()
474-
return model_to_dict(ticket)
485+
return ticket
475486
except DoesNotExist:
476487
raise HTTPException(status_code=404, detail="Not found")
477488

478489

490+
class StatsResponse(BaseModel):
491+
"""Response model for get_stats endpoint."""
492+
493+
total_tickets: int = Field(
494+
..., description="Total number of tickets in the database"
495+
)
496+
tickets_by_status: dict = Field(
497+
...,
498+
description="A dictionary with ticket status as keys and the count of tickets as values",
499+
)
500+
tickets_by_flavor: dict = Field(
501+
...,
502+
description="A dictionary with ticket flavor as keys and the count of tickets as values",
503+
)
504+
tickets_by_type: dict = Field(
505+
...,
506+
description="A dictionary with ticket type as keys and the count of tickets as values",
507+
)
508+
n_days: int = Field(
509+
...,
510+
description="The number of days for which the data is fetched",
511+
)
512+
start_date: str = Field(
513+
..., description="The start date of the data range in ISO format"
514+
)
515+
end_date: str = Field(
516+
..., description="The end date of the data range in ISO format"
517+
)
518+
519+
479520
@api_v1_router.get("/stats")
480521
def get_stats(
481522
n_days: int = 31,
482523
_: Any = Depends(get_auth_dependency(UserStatus.isModerator)),
483-
) -> dict:
524+
) -> StatsResponse:
484525
"""Get number of tickets by status for the last n days.
485526
486527
Args:
487528
n_days (int): The number of days from which to fetch ticket data.
488529
Default is 31 days.
489-
490-
Returns:
491-
dict: A dictionary containing the total number of tickets,
492-
tickets by status, tickets by flavor, and tickets by type.
493-
The keys are:
494-
- total_tickets: Total number of tickets.
495-
- tickets_by_status: A dictionary with ticket status as keys
496-
and the count of tickets as values.
497-
- tickets_by_flavor: A dictionary with ticket flavor as keys
498-
and the count of tickets as values.
499-
- tickets_by_type: A dictionary with ticket type as keys
500-
and the count of tickets as values.
501-
- n_days: The number of days for which the data is fetched.
502-
- start_date: The start date of the data range in ISO format.
503-
- end_date: The end date of the data range in ISO format.
504530
"""
505531
with db:
506532
# Return the total number of tickets
@@ -534,25 +560,26 @@ def get_stats(
534560
)
535561

536562
# Prepare the results
537-
result = {
538-
"total_tickets": total_tickets,
539-
"tickets_by_status": {ticket.status: ticket.count for ticket in tickets},
540-
"tickets_by_flavor": {
541-
ticket.flavor: ticket.count for ticket in tickets_by_flavor
542-
},
543-
"tickets_by_type": {ticket.type: ticket.count for ticket in tickets_by_type},
544-
"n_days": n_days,
545-
"start_date": start_date.isoformat(),
546-
"end_date": datetime.now(timezone.utc).isoformat(),
547-
}
548-
563+
result = StatsResponse(
564+
total_tickets=total_tickets,
565+
tickets_by_status={ticket.status: ticket.count for ticket in tickets},
566+
tickets_by_flavor={ticket.flavor: ticket.count for ticket in tickets_by_flavor},
567+
tickets_by_type={ticket.type: ticket.count for ticket in tickets_by_type},
568+
n_days=n_days,
569+
start_date=start_date.isoformat(),
570+
end_date=datetime.now(timezone.utc).isoformat(),
571+
)
549572
return result
550573

551574

575+
class StatusResponse(BaseModel):
576+
status: str = Field(..., description="Health status of the API")
577+
578+
552579
@api_v1_router.get("/status")
553-
def status():
580+
def status() -> StatusResponse:
554581
"""Health check endpoint."""
555-
return {"status": "ok"}
582+
return StatusResponse(status="ok")
556583

557584

558585
# Route only available in dev mode

0 commit comments

Comments
 (0)