|
1 | | -from typing import List |
| 1 | +from typing import List, Annotated |
2 | 2 | from uuid import UUID |
3 | 3 |
|
4 | | -from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status |
| 4 | +from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, status, Request |
5 | 5 | from fastapi.responses import JSONResponse |
6 | 6 | from fastapi_pagination import Page |
7 | 7 | from fastapi_pagination import Params as PaginationParams |
|
18 | 18 | from src.db.models.scenario import Scenario |
19 | 19 | from src.db.models.scenario_feature import ScenarioFeature |
20 | 20 | from src.db.session import AsyncSession |
21 | | -from src.deps.auth import auth_z |
| 21 | +from src.deps.auth import auth_z, auth_z_lite |
22 | 22 | from src.endpoints.deps import get_db, get_scenario, get_user_id |
23 | 23 | from src.schemas.common import OrderEnum |
24 | 24 | from src.schemas.error import HTTPErrorHandler |
@@ -517,9 +517,9 @@ async def get_chart_data( |
517 | 517 | summary="Get aggregated statistics for a column", |
518 | 518 | response_model=dict, |
519 | 519 | status_code=200, |
520 | | - dependencies=[Depends(auth_z)], |
521 | 520 | ) |
522 | 521 | async def get_statistic_aggregation( |
| 522 | + request: Request, |
523 | 523 | async_session: AsyncSession = Depends(get_db), |
524 | 524 | project_id: UUID4 = Path( |
525 | 525 | ..., |
@@ -567,6 +567,18 @@ async def get_statistic_aggregation( |
567 | 567 | ): |
568 | 568 | """Get aggregated statistics for a numeric column based on the supplied group-by column and CQL-filter.""" |
569 | 569 |
|
| 570 | + # Check authorization status |
| 571 | + try: |
| 572 | + await auth_z_lite(request, async_session) |
| 573 | + except HTTPException as e: |
| 574 | + # Check publication status if unauthorized |
| 575 | + public_project = await crud_project.get_public_project( |
| 576 | + async_session=async_session, |
| 577 | + project_id=str(project_id), |
| 578 | + ) |
| 579 | + if not public_project: |
| 580 | + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") |
| 581 | + |
570 | 582 | # Ensure an operation or expression is specified |
571 | 583 | if operation is None and expression is None: |
572 | 584 | raise HTTPException( |
@@ -618,9 +630,9 @@ async def get_statistic_aggregation( |
618 | 630 | summary="Get histogram statistics for a column", |
619 | 631 | response_model=dict, |
620 | 632 | status_code=200, |
621 | | - dependencies=[Depends(auth_z)], |
622 | 633 | ) |
623 | 634 | async def get_statistic_histogram( |
| 635 | + request: Request, |
624 | 636 | async_session: AsyncSession = Depends(get_db), |
625 | 637 | project_id: UUID4 = Path( |
626 | 638 | ..., |
@@ -655,6 +667,18 @@ async def get_statistic_histogram( |
655 | 667 | ): |
656 | 668 | """Get histogram statistics for a numeric column based on the specified number of bins and CQL-filter.""" |
657 | 669 |
|
| 670 | + # Check authorization status |
| 671 | + try: |
| 672 | + await auth_z_lite(request, async_session) |
| 673 | + except HTTPException as e: |
| 674 | + # Check publication status if unauthorized |
| 675 | + public_project = await crud_project.get_public_project( |
| 676 | + async_session=async_session, |
| 677 | + project_id=str(project_id), |
| 678 | + ) |
| 679 | + if not public_project: |
| 680 | + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") |
| 681 | + |
658 | 682 | # Ensure the number of bins is not excessively large |
659 | 683 | if num_bins > 100: |
660 | 684 | raise HTTPException( |
|
0 commit comments