@@ -544,7 +544,7 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
544
544
#####################
545
545
546
546
547
- async def test_stream_structured_with_all_typd (allow_model_requests : None ):
547
+ async def test_stream_structured_with_all_type (allow_model_requests : None ):
548
548
class MyTypedDict (TypedDict , total = False ):
549
549
first : str
550
550
second : int
@@ -563,19 +563,19 @@ class MyTypedDict(TypedDict, total=False):
563
563
'", "second": 2' ,
564
564
),
565
565
text_chunk (
566
- '" , "bool_value": true' ,
566
+ ', "bool_value": true' ,
567
567
),
568
568
text_chunk (
569
- '" , "nullable_value": null' ,
569
+ ', "nullable_value": null' ,
570
570
),
571
571
text_chunk (
572
- '" , "array_value": ["A", "B", "C"]' ,
572
+ ', "array_value": ["A", "B", "C"]' ,
573
573
),
574
574
text_chunk (
575
- '" , "dict_value": {"A": "A", "B":"B"}' ,
575
+ ', "dict_value": {"A": "A", "B":"B"}' ,
576
576
),
577
577
text_chunk (
578
- '" , "dict_int_value": {"A": 1, "B":2}' ,
578
+ ', "dict_int_value": {"A": 1, "B":2}' ,
579
579
),
580
580
text_chunk ('}' ),
581
581
chunk ([]),
@@ -721,8 +721,8 @@ class MyTypedDict(TypedDict, total=False):
721
721
{'first' : 'One' },
722
722
{'first' : 'One' },
723
723
{'first' : 'One' },
724
- {'first' : 'One' , 'second' : '' },
725
- {'first' : 'One' , 'second' : '' },
724
+ {'first' : 'One' },
725
+ {'first' : 'One' },
726
726
{'first' : 'One' , 'second' : '' },
727
727
{'first' : 'One' , 'second' : 'T' },
728
728
{'first' : 'One' , 'second' : 'Tw' },
@@ -828,20 +828,21 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
828
828
v = [c async for c in result .stream (debounce_by = None )]
829
829
assert v == snapshot (
830
830
[
831
+ ['' ],
831
832
['f' ],
832
833
['fi' ],
833
834
['fir' ],
834
835
['firs' ],
835
836
['first' ],
836
837
['first' ],
837
838
['first' ],
838
- ['first' ],
839
+ ['first' , '' ],
839
840
['first' , 'O' ],
840
841
['first' , 'On' ],
841
842
['first' , 'One' ],
842
843
['first' , 'One' ],
843
844
['first' , 'One' ],
844
- ['first' , 'One' ],
845
+ ['first' , 'One' , '' ],
845
846
['first' , 'One' , 's' ],
846
847
['first' , 'One' , 'se' ],
847
848
['first' , 'One' , 'sec' ],
@@ -850,7 +851,7 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
850
851
['first' , 'One' , 'second' ],
851
852
['first' , 'One' , 'second' ],
852
853
['first' , 'One' , 'second' ],
853
- ['first' , 'One' , 'second' ],
854
+ ['first' , 'One' , 'second' , '' ],
854
855
['first' , 'One' , 'second' , 'T' ],
855
856
['first' , 'One' , 'second' , 'Tw' ],
856
857
['first' , 'One' , 'second' , 'Two' ],
@@ -869,10 +870,10 @@ async def test_stream_result_type_primitif_array(allow_model_requests: None):
869
870
assert result .usage ().response_tokens == len (stream )
870
871
871
872
872
- async def test_stream_result_type_basemodel (allow_model_requests : None ):
873
+ async def test_stream_result_type_basemodel_with_default_params (allow_model_requests : None ):
873
874
class MyTypedBaseModel (BaseModel ):
874
- first : str = '' # Note: Don't forget to set default values
875
- second : str = ''
875
+ first : str = '' # Note: Default, set value.
876
+ second : str = '' # Note: Default, set value.
876
877
877
878
# Given
878
879
stream = [
@@ -958,6 +959,79 @@ class MyTypedBaseModel(BaseModel):
958
959
assert result .usage ().response_tokens == len (stream )
959
960
960
961
962
+ async def test_stream_result_type_basemodel_with_required_params (allow_model_requests : None ):
963
+ class MyTypedBaseModel (BaseModel ):
964
+ first : str # Note: Required params
965
+ second : str # Note: Required params
966
+
967
+ # Given
968
+ stream = [
969
+ text_chunk ('{' ),
970
+ text_chunk ('"' ),
971
+ text_chunk ('f' ),
972
+ text_chunk ('i' ),
973
+ text_chunk ('r' ),
974
+ text_chunk ('s' ),
975
+ text_chunk ('t' ),
976
+ text_chunk ('"' ),
977
+ text_chunk (':' ),
978
+ text_chunk (' ' ),
979
+ text_chunk ('"' ),
980
+ text_chunk ('O' ),
981
+ text_chunk ('n' ),
982
+ text_chunk ('e' ),
983
+ text_chunk ('"' ),
984
+ text_chunk (',' ),
985
+ text_chunk (' ' ),
986
+ text_chunk ('"' ),
987
+ text_chunk ('s' ),
988
+ text_chunk ('e' ),
989
+ text_chunk ('c' ),
990
+ text_chunk ('o' ),
991
+ text_chunk ('n' ),
992
+ text_chunk ('d' ),
993
+ text_chunk ('"' ),
994
+ text_chunk (':' ),
995
+ text_chunk (' ' ),
996
+ text_chunk ('"' ),
997
+ text_chunk ('T' ),
998
+ text_chunk ('w' ),
999
+ text_chunk ('o' ),
1000
+ text_chunk ('"' ),
1001
+ text_chunk ('}' ),
1002
+ chunk ([]),
1003
+ ]
1004
+
1005
+ mock_client = MockMistralAI .create_stream_mock (stream )
1006
+ model = MistralModel ('mistral-large-latest' , client = mock_client )
1007
+ agent = Agent (model = model , result_type = MyTypedBaseModel )
1008
+
1009
+ # When
1010
+ async with agent .run_stream ('User prompt value' ) as result :
1011
+ # Then
1012
+ assert result .is_structured
1013
+ assert not result .is_complete
1014
+ v = [c async for c in result .stream (debounce_by = None )]
1015
+ assert v == snapshot (
1016
+ [
1017
+ MyTypedBaseModel (first = 'One' , second = '' ),
1018
+ MyTypedBaseModel (first = 'One' , second = 'T' ),
1019
+ MyTypedBaseModel (first = 'One' , second = 'Tw' ),
1020
+ MyTypedBaseModel (first = 'One' , second = 'Two' ),
1021
+ MyTypedBaseModel (first = 'One' , second = 'Two' ),
1022
+ MyTypedBaseModel (first = 'One' , second = 'Two' ),
1023
+ MyTypedBaseModel (first = 'One' , second = 'Two' ),
1024
+ ]
1025
+ )
1026
+ assert result .is_complete
1027
+ assert result .usage ().request_tokens == 34
1028
+ assert result .usage ().response_tokens == 34
1029
+ assert result .usage ().total_tokens == 34
1030
+
1031
+ # double check cost matches stream count
1032
+ assert result .usage ().response_tokens == len (stream )
1033
+
1034
+
961
1035
#####################
962
1036
## Completion Function call
963
1037
#####################
@@ -1693,6 +1767,6 @@ def test_generate_user_output_format_multiple():
1693
1767
),
1694
1768
],
1695
1769
)
1696
- def test_validate_required_json_shema (desc : str , schema : dict [str , Any ], data : dict [str , Any ], expected : bool ) -> None :
1770
+ def test_validate_required_json_schema (desc : str , schema : dict [str , Any ], data : dict [str , Any ], expected : bool ) -> None :
1697
1771
result = MistralStreamStructuredResponse ._validate_required_json_schema (data , schema ) # pyright: ignore[reportPrivateUsage]
1698
1772
assert result == expected , f'{ desc } — expected { expected } , got { result } '
0 commit comments