|
6 | 6 |
|
7 | 7 | from __future__ import annotations as _annotations
|
8 | 8 |
|
| 9 | +import base64 |
9 | 10 | from abc import ABC, abstractmethod
|
10 | 11 | from collections.abc import AsyncIterator, Iterator
|
11 | 12 | from contextlib import asynccontextmanager, contextmanager
|
12 | 13 | from dataclasses import dataclass, field, replace
|
13 | 14 | from datetime import datetime
|
14 | 15 | from functools import cache, cached_property
|
| 16 | +from typing import Generic, TypeVar, overload |
15 | 17 |
|
16 | 18 | import httpx
|
17 |
| -from typing_extensions import Literal, TypeAliasType |
| 19 | +from typing_extensions import Literal, TypeAliasType, TypedDict |
18 | 20 |
|
19 | 21 | from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
|
20 | 22 |
|
21 | 23 | from .._parts_manager import ModelResponsePartsManager
|
22 | 24 | from ..exceptions import UserError
|
23 |
| -from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent |
| 25 | +from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl |
24 | 26 | from ..profiles._json_schema import JsonSchemaTransformer
|
25 | 27 | from ..settings import ModelSettings
|
26 | 28 | from ..tools import ToolDefinition
|
@@ -611,6 +613,91 @@ def _cached_async_http_transport() -> httpx.AsyncHTTPTransport:
|
611 | 613 | return httpx.AsyncHTTPTransport()
|
612 | 614 |
|
613 | 615 |
|
| 616 | +DataT = TypeVar('DataT', str, bytes) |
| 617 | + |
| 618 | + |
| 619 | +class DownloadedItem(TypedDict, Generic[DataT]): |
| 620 | + """The downloaded data and its type.""" |
| 621 | + |
| 622 | + data: DataT |
| 623 | + """The downloaded data.""" |
| 624 | + |
| 625 | + data_type: str |
| 626 | + """The type of data that was downloaded. |
| 627 | +
|
| 628 | + Extracted from header "content-type", but defaults to the media type inferred from the file URL if content-type is "application/octet-stream". |
| 629 | + """ |
| 630 | + |
| 631 | + |
| 632 | +@overload |
| 633 | +async def download_item( |
| 634 | + item: FileUrl, |
| 635 | + data_format: Literal['bytes'], |
| 636 | + type_format: Literal['mime', 'extension'] = 'mime', |
| 637 | +) -> DownloadedItem[bytes]: ... |
| 638 | + |
| 639 | + |
| 640 | +@overload |
| 641 | +async def download_item( |
| 642 | + item: FileUrl, |
| 643 | + data_format: Literal['base64', 'base64_uri', 'text'], |
| 644 | + type_format: Literal['mime', 'extension'] = 'mime', |
| 645 | +) -> DownloadedItem[str]: ... |
| 646 | + |
| 647 | + |
| 648 | +async def download_item( |
| 649 | + item: FileUrl, |
| 650 | + data_format: Literal['bytes', 'base64', 'base64_uri', 'text'] = 'bytes', |
| 651 | + type_format: Literal['mime', 'extension'] = 'mime', |
| 652 | +) -> DownloadedItem[str] | DownloadedItem[bytes]: |
| 653 | + """Download an item by URL and return the content as a bytes object or a (base64-encoded) string. |
| 654 | +
|
| 655 | + Args: |
| 656 | + item: The item to download. |
| 657 | + data_format: The format to return the content in: |
| 658 | + - `bytes`: The raw bytes of the content. |
| 659 | + - `base64`: The base64-encoded content. |
| 660 | + - `base64_uri`: The base64-encoded content as a data URI. |
| 661 | + - `text`: The content as a string. |
| 662 | + type_format: The format to return the media type in: |
| 663 | + - `mime`: The media type as a MIME type. |
| 664 | + - `extension`: The media type as an extension. |
| 665 | +
|
| 666 | + Raises: |
| 667 | + UserError: If the URL points to a YouTube video or its protocol is gs://. |
| 668 | + """ |
| 669 | + if item.url.startswith('gs://'): |
| 670 | + raise UserError('Downloading from protocol "gs://" is not supported.') |
| 671 | + elif isinstance(item, VideoUrl) and item.is_youtube: |
| 672 | + raise UserError('Downloading YouTube videos is not supported.') |
| 673 | + |
| 674 | + client = cached_async_http_client() |
| 675 | + response = await client.get(item.url, follow_redirects=True) |
| 676 | + response.raise_for_status() |
| 677 | + |
| 678 | + if content_type := response.headers.get('content-type'): |
| 679 | + content_type = content_type.split(';')[0] |
| 680 | + if content_type == 'application/octet-stream': |
| 681 | + content_type = None |
| 682 | + |
| 683 | + media_type = content_type or item.media_type |
| 684 | + |
| 685 | + data_type = media_type |
| 686 | + if type_format == 'extension': |
| 687 | + data_type = data_type.split('/')[1] |
| 688 | + |
| 689 | + data = response.content |
| 690 | + if data_format in ('base64', 'base64_uri'): |
| 691 | + data = base64.b64encode(data).decode('utf-8') |
| 692 | + if data_format == 'base64_uri': |
| 693 | + data = f'data:{media_type};base64,{data}' |
| 694 | + return DownloadedItem[str](data=data, data_type=data_type) |
| 695 | + elif data_format == 'text': |
| 696 | + return DownloadedItem[str](data=data.decode('utf-8'), data_type=data_type) |
| 697 | + else: |
| 698 | + return DownloadedItem[bytes](data=data, data_type=data_type) |
| 699 | + |
| 700 | + |
614 | 701 | @cache
|
615 | 702 | def get_user_agent() -> str:
|
616 | 703 | """Get the user agent string for the HTTP client."""
|
|
0 commit comments