Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.terminal.activateEnvironment": false
}
2 changes: 1 addition & 1 deletion agents/rag/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"version": "0.2.0",
"configurations": [
{
"name": "agent-form",
"name": "agent-rag",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/src/rag/agent.py",
Expand Down
3 changes: 1 addition & 2 deletions agents/rag/src/rag/tools/files/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from typing import List, Literal

from agentstack_sdk.platform import File
from beeai_framework.emitter import Emitter
from beeai_framework.tools import (
JSONToolOutput,
Tool,
ToolRunOptions,
)
from agentstack_sdk.platform import File
from pydantic import BaseModel, Field, create_model

from rag.tools.files.utils import File, format_size
Expand Down Expand Up @@ -116,7 +116,6 @@ async def _run(self, input: FileReadInputBase, options, context) -> FileReaderTo
# pull the first (only) MessagePart from the async-generator
async with file.load_text_content() as loaded_file:
content = loaded_file.text
content_type = loaded_file.content_type

if content is None:
raise ValueError(f"File content is None for {filename}.")
Expand Down
4 changes: 0 additions & 4 deletions agentstack.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
"name": "agent-chat",
"path": "agents/chat"
},
{
"name": "agent-form",
"path": "agents/form"
},
{
"name": "agent-rag",
"path": "agents/rag"
Expand Down
46 changes: 45 additions & 1 deletion apps/agentstack-sdk-py/src/agentstack_sdk/platform/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
from agentstack_sdk.util.file import LoadedFile, LoadedFileWithUri, PlatformFileUrl
from agentstack_sdk.util.utils import filter_dict

ExtractionFormatLiteral = typing.Literal["markdown", "vendor_specific_json"]


class ExtractedFileInfo(pydantic.BaseModel):
"""Information about an extracted file."""

file_id: str
format: ExtractionFormatLiteral | None


class Extraction(pydantic.BaseModel):
id: str
file_id: str
extracted_file_id: str | None = None
extracted_files: list[ExtractedFileInfo] = pydantic.Field(default_factory=list)
status: typing.Literal["pending", "in_progress", "completed", "failed", "cancelled"] = "pending"
job_id: str | None = None
error_message: str | None = None
Expand Down Expand Up @@ -152,9 +161,43 @@ async def load_text_content(
await response.aread()
yield LoadedFileWithUri(response=response, content_type=file.content_type, filename=file.filename)

@asynccontextmanager
async def load_json_content(
self: File | str,
*,
stream: bool = False,
client: PlatformClient | None = None,
context_id: str | None | Literal["auto"] = "auto",
) -> AsyncIterator[LoadedFile]:
# `self` has a weird type so that you can call both `instance.load_json_content()` to create an extraction for an instance, or `File.load_json_content("123")`
file_id = self if isinstance(self, str) else self.id
async with client or get_platform_client() as platform_client:
context_id = platform_client.context_id if context_id == "auto" else context_id

file = await File.get(file_id, client=client, context_id=context_id) if isinstance(self, str) else self
extraction = await file.get_extraction(client=client, context_id=context_id)

for extracted_file_info in extraction.extracted_files:
if extracted_file_info.format != "vendor_specific_json":
continue
extracted_json_file_id = extracted_file_info.file_id
async with platform_client.stream(
"GET",
url=f"/api/v1/files/{extracted_json_file_id}/content",
params=context_id and {"context_id": context_id},
) as response:
response.raise_for_status()
if not stream:
await response.aread()
yield LoadedFileWithUri(response=response, content_type=file.content_type, filename=file.filename)
return

raise ValueError("No extracted JSON content available for this file.")

async def create_extraction(
self: File | str,
*,
formats: list[ExtractionFormatLiteral] | None = None,
client: PlatformClient | None = None,
context_id: str | None | Literal["auto"] = "auto",
) -> Extraction:
Expand All @@ -167,6 +210,7 @@ async def create_extraction(
await platform_client.post(
url=f"/api/v1/files/{file_id}/extraction",
params=context_id and {"context_id": context_id},
json={"settings": {"formats": formats}} if formats else None,
)
)
.raise_for_status()
Expand Down
1 change: 1 addition & 0 deletions apps/agentstack-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"opentelemetry-instrumentation-httpx>=0.59b0",
"opentelemetry-instrumentation-fastapi>=0.59b0",
"limits[async-redis]>=5.3.0",
"ijson>=3.4.0.post0",
]

[dependency-groups]
Expand Down
47 changes: 42 additions & 5 deletions apps/agentstack-server/src/agentstack_server/api/routes/files.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

import logging
from contextlib import AsyncExitStack
from typing import Annotated
Expand All @@ -14,9 +15,17 @@
RequiresContextPermissions,
)
from agentstack_server.api.schema.common import EntityModel
from agentstack_server.api.schema.files import FileListQuery
from agentstack_server.api.schema.files import FileListQuery, TextExtractionRequest
from agentstack_server.domain.models.common import PaginatedResult
from agentstack_server.domain.models.file import AsyncFile, ExtractionStatus, File, TextExtraction
from agentstack_server.domain.models.file import (
AsyncFile,
Backend,
ExtractionFormat,
ExtractionStatus,
File,
TextExtraction,
TextExtractionSettings,
)
from agentstack_server.domain.models.permissions import AuthorizedUser
from agentstack_server.service_layer.services.files import FileService

Expand Down Expand Up @@ -92,12 +101,32 @@ async def get_text_file_content(
user: Annotated[AuthorizedUser, Depends(RequiresContextPermissions(files={"read"}))],
) -> StreamingResponse:
extraction = await file_service.get_extraction(file_id=file_id, user=user.user, context_id=user.context_id)
if not extraction.status == ExtractionStatus.COMPLETED or not extraction.extracted_file_id:
if not extraction.status == ExtractionStatus.COMPLETED or not extraction.extracted_files:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Extraction is not completed (status {extraction.status})",
)
return await _stream_file(file_service=file_service, user=user, file_id=extraction.extracted_file_id)

if extraction.extraction_metadata is not None and extraction.extraction_metadata.backend == Backend.IN_PLACE:
# Fallback to the original file for in-place extraction
original_file_id = extraction.find_file_by_format(format=None)
if not original_file_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Original file not found in extraction results",
)
file_to_stream_id = original_file_id
else:
# Find the markdown file from extracted files
markdown_file_id = extraction.find_file_by_format(format=ExtractionFormat.MARKDOWN)
if not markdown_file_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Markdown file not found in extraction results",
)
file_to_stream_id = markdown_file_id

return await _stream_file(file_service=file_service, user=user, file_id=file_to_stream_id)


@router.delete("/{file_id}", status_code=fastapi.status.HTTP_204_NO_CONTENT)
Expand All @@ -114,6 +143,7 @@ async def create_text_extraction(
file_id: UUID,
file_service: FileServiceDependency,
user: Annotated[AuthorizedUser, Depends(RequiresContextPermissions(files={"write", "extract"}))],
request: TextExtractionRequest | None = None,
) -> EntityModel[TextExtraction]:
"""Create or return text extraction for a file.

Expand All @@ -122,8 +152,15 @@ async def create_text_extraction(
- If extraction is pending/in-progress, returns current status
- If no extraction exists, creates a new one
"""
if request is None:
request = TextExtractionRequest()

settings = request.settings if request.settings is not None else TextExtractionSettings()

return EntityModel(
await file_service.create_extraction(file_id=file_id, user=user.user, context_id=user.context_id)
await file_service.create_extraction(
file_id=file_id, user=user.user, context_id=user.context_id, settings=settings
)
)


Expand Down
10 changes: 10 additions & 0 deletions apps/agentstack-server/src/agentstack_server/api/schema/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, Field

from agentstack_server.api.schema.common import PaginationQuery
from agentstack_server.domain.models.file import TextExtractionSettings


class FileResponse(BaseModel):
Expand Down Expand Up @@ -36,3 +37,12 @@ class FileListQuery(PaginationQuery):
description="Case-insensitive partial match search on filename (e.g., 'doc' matches 'my_document.pdf')",
)
order_by: str = Field(default_factory=lambda: "created_at", pattern="^created_at|filename|file_size_bytes$")


class TextExtractionRequest(BaseModel):
"""Request schema for text extraction."""

settings: TextExtractionSettings | None = Field(
default=None,
description="Additional options for text extraction",
)
89 changes: 80 additions & 9 deletions apps/agentstack-server/src/agentstack_server/domain/models/file.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterator, Awaitable, Callable
from enum import StrEnum
from typing import Self
from uuid import UUID, uuid4

from pydantic import AwareDatetime, BaseModel, Field
Expand All @@ -23,8 +24,24 @@ class ExtractionStatus(StrEnum):
CANCELLED = "cancelled"


class ExtractionFormat(StrEnum):
MARKDOWN = "markdown"
VENDOR_SPECIFIC_JSON = "vendor_specific_json"


class TextExtractionSettings(BaseModel):
formats: list[ExtractionFormat] = Field(
default_factory=lambda: [ExtractionFormat.MARKDOWN, ExtractionFormat.VENDOR_SPECIFIC_JSON]
)


class Backend(StrEnum):
IN_PLACE = "in-place"


class ExtractionMetadata(BaseModel, extra="allow"):
backend: str
backend: str | None = None
settings: TextExtractionSettings | None = None


class FileMetadata(BaseModel, extra="allow"):
Expand All @@ -39,6 +56,36 @@ class AsyncFile(BaseModel):
read: Callable[[int], Awaitable[bytes]]
size: int | None = None

@classmethod
def from_async_iterator(cls, iterator: AsyncIterator[bytes], filename: str, content_type: str) -> Self:
buffer = b""

async def read(size: int = 8192) -> bytes:
nonlocal buffer
while len(buffer) < size:
try:
buffer += await anext(iterator)
except StopAsyncIteration:
break

result = buffer[:size]
buffer = buffer[size:]
return result

return cls(filename=filename, content_type=content_type, read=read)

@classmethod
def from_bytes(cls, content: bytes, filename: str, content_type: str) -> Self:
pos = 0

async def read(size: int = 8192) -> bytes:
nonlocal pos
result = content[pos : pos + size]
pos += len(result)
return result

return cls(filename=filename, content_type=content_type, read=read, size=len(content))


class File(BaseModel):
id: UUID = Field(default_factory=uuid4)
Expand All @@ -52,10 +99,17 @@ class File(BaseModel):
context_id: UUID | None = None


class ExtractedFileInfo(BaseModel):
"""Information about an extracted file."""

file_id: UUID
format: ExtractionFormat | None = None


class TextExtraction(BaseModel):
id: UUID = Field(default_factory=uuid4)
file_id: UUID
extracted_file_id: UUID | None = None
extracted_files: list[ExtractedFileInfo] = Field(default_factory=list)
status: ExtractionStatus = ExtractionStatus.PENDING
job_id: str | None = None
error_message: str | None = None
Expand All @@ -64,20 +118,30 @@ class TextExtraction(BaseModel):
finished_at: AwareDatetime | None = None
created_at: AwareDatetime = Field(default_factory=utc_now)

def set_started(self, job_id: str) -> None:
"""Mark extraction as started with job ID."""
def set_started(self, job_id: str, backend: str) -> None:
"""Mark extraction as started with job ID and backend name."""
self.status = ExtractionStatus.IN_PROGRESS
self.job_id = job_id
self.started_at = utc_now()
self.error_message = None

def set_completed(self, extracted_file_id: UUID, metadata: ExtractionMetadata | None = None) -> None:
"""Mark extraction as completed with extracted file ID."""
# Create extraction_metadata if it doesn't exist
if self.extraction_metadata is None:
self.extraction_metadata = ExtractionMetadata()

# Set the backend name
self.extraction_metadata.backend = backend

def set_completed(
self, extracted_files: list[ExtractedFileInfo], metadata: ExtractionMetadata | None = None
) -> None:
"""Mark extraction as completed with extracted files and their formats."""
self.status = ExtractionStatus.COMPLETED
self.extracted_file_id = extracted_file_id
self.extracted_files = extracted_files
self.finished_at = utc_now()
self.extraction_metadata = metadata
self.error_message = None
if metadata is not None:
self.extraction_metadata = metadata

def set_failed(self, error_message: str) -> None:
"""Mark extraction as failed with error message."""
Expand All @@ -97,3 +161,10 @@ def reset_for_retry(self) -> None:
self.started_at = None
self.finished_at = None
self.job_id = None

def find_file_by_format(self, format: ExtractionFormat | None) -> UUID | None:
"""Find an extracted file by format from the extracted files list."""
for extracted_file_info in self.extracted_files:
if extracted_file_info.format == format:
return extracted_file_info.file_id
return None
Loading