Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/attribution/attribution_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from typing import Annotated, Self, cast

from flask import current_app
from pydantic import AfterValidator, Field
from rank_bm25 import BM25Okapi # type: ignore
from werkzeug import exceptions
Expand All @@ -23,9 +24,8 @@
from src.attribution.infini_gram_api_client.models.available_infini_gram_index_id import (
AvailableInfiniGramIndexId,
)
from src.attribution.infini_gram_api_client.models.http_validation_error import (
HTTPValidationError,
)
from src.attribution.infini_gram_api_client.models.problem import Problem
from src.attribution.infini_gram_api_client.models.request_validation_error import RequestValidationError
from src.config.get_config import cfg
from src.util.pii_regex import does_contain_pii

Expand Down Expand Up @@ -196,10 +196,25 @@ def get_attribution(
msg = f"Something went wrong when calling the infini-gram API: {e.status_code} {e.content.decode()}"
raise exceptions.BadGateway(msg) from e

if isinstance(attribution_response, HTTPValidationError):
if isinstance(attribution_response, RequestValidationError):
current_app.logger.error(
"Validation error from infini-gram %s, errors %s",
attribution_response.title,
str(attribution_response.errors),
)
# validation error handling
raise exceptions.InternalServerError(
description=f"infini-gram API reported a validation error: {attribution_response.detail}\nThis is likely an error in olmo-api."
description=f"infini-gram API reported a validation error: {attribution_response.title}\nThis is likely an error in olmo-api."
)

if isinstance(attribution_response, Problem):
current_app.logger.error(
"Problem from infini-gram %s, detail %s",
attribution_response.title,
str(attribution_response.detail),
)
raise exceptions.InternalServerError(
description=f"infini-gram API reported an error: {attribution_response.title}"
)

if attribution_response is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from http import HTTPStatus
from typing import Any
from typing import Any, Optional, Union

import httpx

from src.attribution.infini_gram_api_client import errors
from src.attribution.infini_gram_api_client.client import AuthenticatedClient, Client
from src.attribution.infini_gram_api_client.models.available_infini_gram_index_id import AvailableInfiniGramIndexId
from src.attribution.infini_gram_api_client.models.http_validation_error import HTTPValidationError
from src.attribution.infini_gram_api_client.models.infini_gram_count_response import InfiniGramCountResponse
from src.attribution.infini_gram_api_client.types import UNSET, Response
from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.available_infini_gram_index_id import AvailableInfiniGramIndexId
from ...models.infini_gram_count_response import InfiniGramCountResponse
from ...models.request_validation_error import RequestValidationError
from ...types import UNSET, Response


def _get_kwargs(
Expand All @@ -32,22 +32,25 @@ def _get_kwargs(


def _parse_response(
*, client: AuthenticatedClient | Client, response: httpx.Response
) -> HTTPValidationError | InfiniGramCountResponse | None:
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[Union[InfiniGramCountResponse, RequestValidationError]]:
if response.status_code == 200:
return InfiniGramCountResponse.from_dict(response.json())
response_200 = InfiniGramCountResponse.from_dict(response.json())

return response_200
if response.status_code == 422:
return HTTPValidationError.from_dict(response.json())
response_422 = RequestValidationError.from_dict(response.json())

return response_422
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
return None
else:
return None


def _build_response(
*, client: AuthenticatedClient | Client, response: httpx.Response
) -> Response[HTTPValidationError | InfiniGramCountResponse]:
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Response[Union[InfiniGramCountResponse, RequestValidationError]]:
return Response(
status_code=HTTPStatus(response.status_code),
content=response.content,
Expand All @@ -59,9 +62,9 @@ def _build_response(
def sync_detailed(
index: AvailableInfiniGramIndexId,
*,
client: AuthenticatedClient | Client,
client: Union[AuthenticatedClient, Client],
query: str,
) -> Response[HTTPValidationError | InfiniGramCountResponse]:
) -> Response[Union[InfiniGramCountResponse, RequestValidationError]]:
"""Count

Args:
Expand All @@ -73,7 +76,7 @@ def sync_detailed(
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[Union[HTTPValidationError, InfiniGramCountResponse]]
Response[Union[InfiniGramCountResponse, RequestValidationError]]
"""

kwargs = _get_kwargs(
Expand All @@ -91,9 +94,9 @@ def sync_detailed(
def sync(
index: AvailableInfiniGramIndexId,
*,
client: AuthenticatedClient | Client,
client: Union[AuthenticatedClient, Client],
query: str,
) -> HTTPValidationError | InfiniGramCountResponse | None:
) -> Optional[Union[InfiniGramCountResponse, RequestValidationError]]:
"""Count

Args:
Expand All @@ -105,7 +108,7 @@ def sync(
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Union[HTTPValidationError, InfiniGramCountResponse]
Union[InfiniGramCountResponse, RequestValidationError]
"""

return sync_detailed(
Expand All @@ -118,9 +121,9 @@ def sync(
async def asyncio_detailed(
index: AvailableInfiniGramIndexId,
*,
client: AuthenticatedClient | Client,
client: Union[AuthenticatedClient, Client],
query: str,
) -> Response[HTTPValidationError | InfiniGramCountResponse]:
) -> Response[Union[InfiniGramCountResponse, RequestValidationError]]:
"""Count

Args:
Expand All @@ -132,7 +135,7 @@ async def asyncio_detailed(
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Response[Union[HTTPValidationError, InfiniGramCountResponse]]
Response[Union[InfiniGramCountResponse, RequestValidationError]]
"""

kwargs = _get_kwargs(
Expand All @@ -148,9 +151,9 @@ async def asyncio_detailed(
async def asyncio(
index: AvailableInfiniGramIndexId,
*,
client: AuthenticatedClient | Client,
client: Union[AuthenticatedClient, Client],
query: str,
) -> HTTPValidationError | InfiniGramCountResponse | None:
) -> Optional[Union[InfiniGramCountResponse, RequestValidationError]]:
"""Count

Args:
Expand All @@ -162,7 +165,7 @@ async def asyncio(
httpx.TimeoutException: If the request takes longer than Client.timeout.

Returns:
Union[HTTPValidationError, InfiniGramCountResponse]
Union[InfiniGramCountResponse, RequestValidationError]
"""

return (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from http import HTTPStatus
from typing import Any
from typing import Any, Optional, Union

import httpx

from src.attribution.infini_gram_api_client import errors
from src.attribution.infini_gram_api_client.client import AuthenticatedClient, Client
from src.attribution.infini_gram_api_client.models.available_infini_gram_index_id import AvailableInfiniGramIndexId
from src.attribution.infini_gram_api_client.types import Response
from ... import errors
from ...client import AuthenticatedClient, Client
from ...models.available_infini_gram_index_id import AvailableInfiniGramIndexId
from ...types import Response


def _get_kwargs() -> dict[str, Any]:
Expand All @@ -19,8 +19,8 @@ def _get_kwargs() -> dict[str, Any]:


def _parse_response(
*, client: AuthenticatedClient | Client, response: httpx.Response
) -> list[AvailableInfiniGramIndexId] | None:
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Optional[list[AvailableInfiniGramIndexId]]:
if response.status_code == 200:
response_200 = []
_response_200 = response.json()
Expand All @@ -32,11 +32,12 @@ def _parse_response(
return response_200
if client.raise_on_unexpected_status:
raise errors.UnexpectedStatus(response.status_code, response.content)
return None
else:
return None


def _build_response(
*, client: AuthenticatedClient | Client, response: httpx.Response
*, client: Union[AuthenticatedClient, Client], response: httpx.Response
) -> Response[list[AvailableInfiniGramIndexId]]:
return Response(
status_code=HTTPStatus(response.status_code),
Expand All @@ -48,7 +49,7 @@ def _build_response(

def sync_detailed(
*,
client: AuthenticatedClient | Client,
client: Union[AuthenticatedClient, Client],
) -> Response[list[AvailableInfiniGramIndexId]]:
"""Get Available Indexes

Expand All @@ -71,8 +72,8 @@ def sync_detailed(

def sync(
*,
client: AuthenticatedClient | Client,
) -> list[AvailableInfiniGramIndexId] | None:
client: Union[AuthenticatedClient, Client],
) -> Optional[list[AvailableInfiniGramIndexId]]:
"""Get Available Indexes

Raises:
Expand All @@ -90,7 +91,7 @@ def sync(

async def asyncio_detailed(
*,
client: AuthenticatedClient | Client,
client: Union[AuthenticatedClient, Client],
) -> Response[list[AvailableInfiniGramIndexId]]:
"""Get Available Indexes

Expand All @@ -111,8 +112,8 @@ async def asyncio_detailed(

async def asyncio(
*,
client: AuthenticatedClient | Client,
) -> list[AvailableInfiniGramIndexId] | None:
client: Union[AuthenticatedClient, Client],
) -> Optional[list[AvailableInfiniGramIndexId]]:
"""Get Available Indexes

Raises:
Expand Down
Loading