9
9
from datetime import datetime
10
10
from typing import Annotated , Any , Literal , Protocol , Union
11
11
12
+ import pydantic
12
13
import pydantic_core
13
14
from httpx import USE_CLIENT_DEFAULT , AsyncClient as AsyncHTTPClient , Response as HTTPResponse
14
- from pydantic import Discriminator , Field , Tag
15
15
from typing_extensions import NotRequired , TypedDict , TypeGuard , assert_never
16
16
17
- from .. import UnexpectedModelBehavior , _pydantic , _utils , exceptions , result
17
+ from .. import UnexpectedModelBehavior , _utils , exceptions , result
18
18
from ..messages import (
19
19
ArgsDict ,
20
20
Message ,
@@ -386,6 +386,7 @@ def timestamp(self) -> datetime:
386
386
# TypeAdapters take care of validation and serialization
387
387
388
388
389
+ @pydantic .with_config (pydantic .ConfigDict (defer_build = True ))
389
390
class _GeminiRequest (TypedDict ):
390
391
"""Schema for an API request to the Gemini API.
391
392
@@ -457,7 +458,7 @@ class _GeminiTextPart(TypedDict):
457
458
458
459
459
460
class _GeminiFunctionCallPart (TypedDict ):
460
- function_call : Annotated [_GeminiFunctionCall , Field (alias = 'functionCall' )]
461
+ function_call : Annotated [_GeminiFunctionCall , pydantic . Field (alias = 'functionCall' )]
461
462
462
463
463
464
def _function_call_part_from_call (tool : ToolCallPart ) -> _GeminiFunctionCallPart :
@@ -487,7 +488,7 @@ class _GeminiFunctionCall(TypedDict):
487
488
488
489
489
490
class _GeminiFunctionResponsePart (TypedDict ):
490
- function_response : Annotated [_GeminiFunctionResponse , Field (alias = 'functionResponse' )]
491
+ function_response : Annotated [_GeminiFunctionResponse , pydantic . Field (alias = 'functionResponse' )]
491
492
492
493
493
494
def _response_part_from_response (name : str , response : dict [str , Any ]) -> _GeminiFunctionResponsePart :
@@ -517,11 +518,11 @@ def _part_discriminator(v: Any) -> str:
517
518
# TODO discriminator
518
519
_GeminiPartUnion = Annotated [
519
520
Union [
520
- Annotated [_GeminiTextPart , Tag ('text' )],
521
- Annotated [_GeminiFunctionCallPart , Tag ('function_call' )],
522
- Annotated [_GeminiFunctionResponsePart , Tag ('function_response' )],
521
+ Annotated [_GeminiTextPart , pydantic . Tag ('text' )],
522
+ Annotated [_GeminiFunctionCallPart , pydantic . Tag ('function_call' )],
523
+ Annotated [_GeminiFunctionResponsePart , pydantic . Tag ('function_response' )],
523
524
],
524
- Discriminator (_part_discriminator ),
525
+ pydantic . Discriminator (_part_discriminator ),
525
526
]
526
527
527
528
@@ -531,7 +532,7 @@ class _GeminiTextContent(TypedDict):
531
532
532
533
533
534
class _GeminiTools (TypedDict ):
534
- function_declarations : list [Annotated [_GeminiFunction , Field (alias = 'functionDeclarations' )]]
535
+ function_declarations : list [Annotated [_GeminiFunction , pydantic . Field (alias = 'functionDeclarations' )]]
535
536
536
537
537
538
class _GeminiFunction (TypedDict ):
@@ -572,6 +573,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
572
573
allowed_function_names : list [str ]
573
574
574
575
576
+ @pydantic .with_config (pydantic .ConfigDict (defer_build = True ))
575
577
class _GeminiResponse (TypedDict ):
576
578
"""Schema for the response from the Gemini API.
577
579
@@ -581,8 +583,8 @@ class _GeminiResponse(TypedDict):
581
583
582
584
candidates : list [_GeminiCandidates ]
583
585
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
584
- usage_metadata : NotRequired [Annotated [_GeminiUsageMetaData , Field (alias = 'usageMetadata' )]]
585
- prompt_feedback : NotRequired [Annotated [_GeminiPromptFeedback , Field (alias = 'promptFeedback' )]]
586
+ usage_metadata : NotRequired [Annotated [_GeminiUsageMetaData , pydantic . Field (alias = 'usageMetadata' )]]
587
+ prompt_feedback : NotRequired [Annotated [_GeminiPromptFeedback , pydantic . Field (alias = 'promptFeedback' )]]
586
588
587
589
588
590
# TODO: Delete the next three functions once we've reworked streams to be more flexible
@@ -618,14 +620,14 @@ class _GeminiCandidates(TypedDict):
618
620
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
619
621
620
622
content : _GeminiContent
621
- finish_reason : NotRequired [Annotated [Literal ['STOP' , 'MAX_TOKENS' ], Field (alias = 'finishReason' )]]
623
+ finish_reason : NotRequired [Annotated [Literal ['STOP' , 'MAX_TOKENS' ], pydantic . Field (alias = 'finishReason' )]]
622
624
"""
623
625
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
624
626
but let's wait until we see them and know what they mean to add them here.
625
627
"""
626
- avg_log_probs : NotRequired [Annotated [float , Field (alias = 'avgLogProbs' )]]
628
+ avg_log_probs : NotRequired [Annotated [float , pydantic . Field (alias = 'avgLogProbs' )]]
627
629
index : NotRequired [int ]
628
- safety_ratings : NotRequired [Annotated [list [_GeminiSafetyRating ], Field (alias = 'safetyRatings' )]]
630
+ safety_ratings : NotRequired [Annotated [list [_GeminiSafetyRating ], pydantic . Field (alias = 'safetyRatings' )]]
629
631
630
632
631
633
class _GeminiUsageMetaData (TypedDict , total = False ):
@@ -634,10 +636,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
634
636
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
635
637
"""
636
638
637
- prompt_token_count : Annotated [int , Field (alias = 'promptTokenCount' )]
638
- candidates_token_count : NotRequired [Annotated [int , Field (alias = 'candidatesTokenCount' )]]
639
- total_token_count : Annotated [int , Field (alias = 'totalTokenCount' )]
640
- cached_content_token_count : NotRequired [Annotated [int , Field (alias = 'cachedContentTokenCount' )]]
639
+ prompt_token_count : Annotated [int , pydantic . Field (alias = 'promptTokenCount' )]
640
+ candidates_token_count : NotRequired [Annotated [int , pydantic . Field (alias = 'candidatesTokenCount' )]]
641
+ total_token_count : Annotated [int , pydantic . Field (alias = 'totalTokenCount' )]
642
+ cached_content_token_count : NotRequired [Annotated [int , pydantic . Field (alias = 'cachedContentTokenCount' )]]
641
643
642
644
643
645
def _metadata_as_cost (response : _GeminiResponse ) -> result .Cost :
@@ -671,15 +673,15 @@ class _GeminiSafetyRating(TypedDict):
671
673
class _GeminiPromptFeedback (TypedDict ):
672
674
"""See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""
673
675
674
- block_reason : Annotated [str , Field (alias = 'blockReason' )]
675
- safety_ratings : Annotated [list [_GeminiSafetyRating ], Field (alias = 'safetyRatings' )]
676
+ block_reason : Annotated [str , pydantic . Field (alias = 'blockReason' )]
677
+ safety_ratings : Annotated [list [_GeminiSafetyRating ], pydantic . Field (alias = 'safetyRatings' )]
676
678
677
679
678
- _gemini_request_ta = _pydantic . LazyTypeAdapter (_GeminiRequest )
679
- _gemini_response_ta = _pydantic . LazyTypeAdapter (_GeminiResponse )
680
+ _gemini_request_ta = pydantic . TypeAdapter (_GeminiRequest )
681
+ _gemini_response_ta = pydantic . TypeAdapter (_GeminiResponse )
680
682
681
683
# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
682
- _gemini_streamed_response_ta = _pydantic . LazyTypeAdapter (list [_GeminiResponse ])
684
+ _gemini_streamed_response_ta = pydantic . TypeAdapter (list [_GeminiResponse ], config = pydantic . ConfigDict ( defer_build = True ) )
683
685
684
686
685
687
class _GeminiJsonSchema :
0 commit comments