1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2024 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
import inspect
17
+ import json
2
18
import string
3
19
import textwrap
4
20
from typing_extensions import TypedDict
@@ -22,6 +38,7 @@ class Person(TypedDict):
22
38
23
39
24
40
class UnitTests (parameterized .TestCase ):
41
+ maxDiff = None
25
42
@parameterized .named_parameters (
26
43
[
27
44
"protos.GenerationConfig" ,
@@ -416,24 +433,16 @@ def test_join_prompt_feedbacks(self):
416
433
],
417
434
"role" : "assistant" ,
418
435
},
419
- "citation_metadata" : {"citation_sources" : []},
420
436
"index" : 0 ,
421
- "finish_reason" : 0 ,
422
- "safety_ratings" : [],
423
- "token_count" : 0 ,
424
- "grounding_attributions" : [],
437
+ "citation_metadata" : {},
425
438
},
426
439
{
427
440
"content" : {
428
441
"parts" : [{"text" : "Tell me a story about a magic backpack" }],
429
442
"role" : "assistant" ,
430
443
},
431
444
"index" : 1 ,
432
- "citation_metadata" : {"citation_sources" : []},
433
- "finish_reason" : 0 ,
434
- "safety_ratings" : [],
435
- "token_count" : 0 ,
436
- "grounding_attributions" : [],
445
+ "citation_metadata" : {},
437
446
},
438
447
{
439
448
"content" : {
@@ -458,17 +467,13 @@ def test_join_prompt_feedbacks(self):
458
467
},
459
468
]
460
469
},
461
- "finish_reason" : 0 ,
462
- "safety_ratings" : [],
463
- "token_count" : 0 ,
464
- "grounding_attributions" : [],
465
470
},
466
471
]
467
472
468
473
def test_join_candidates (self ):
469
474
candidate_lists = [[protos .Candidate (c ) for c in cl ] for cl in self .CANDIDATE_LISTS ]
470
475
result = generation_types ._join_candidate_lists (candidate_lists )
471
- self .assertEqual (self .MERGED_CANDIDATES , [type (r ).to_dict (r ) for r in result ])
476
+ self .assertEqual (self .MERGED_CANDIDATES , [type (r ).to_dict (r , including_default_value_fields = False ) for r in result ])
472
477
473
478
def test_join_chunks (self ):
474
479
chunks = [protos .GenerateContentResponse (candidates = cl ) for cl in self .CANDIDATE_LISTS ]
@@ -480,6 +485,8 @@ def test_join_chunks(self):
480
485
],
481
486
)
482
487
488
+ chunks [- 1 ].usage_metadata = protos .GenerateContentResponse .UsageMetadata (prompt_token_count = 5 )
489
+
483
490
result = generation_types ._join_chunks (chunks )
484
491
485
492
expected = protos .GenerateContentResponse (
@@ -495,10 +502,17 @@ def test_join_chunks(self):
495
502
}
496
503
],
497
504
},
505
+ "usage_metadata" : {
506
+ "prompt_token_count" : 5
507
+ }
508
+
498
509
},
499
510
)
500
511
501
- self .assertEqual (type (expected ).to_dict (expected ), type (result ).to_dict (expected ))
512
+ expected = json .dumps (type (expected ).to_dict (expected , including_default_value_fields = False ), indent = 4 )
513
+ result = json .dumps (type (result ).to_dict (result , including_default_value_fields = False ), indent = 4 )
514
+
515
+ self .assertEqual (expected , result )
502
516
503
517
def test_generate_content_response_iterator_end_to_end (self ):
504
518
chunks = [protos .GenerateContentResponse (candidates = cl ) for cl in self .CANDIDATE_LISTS ]
0 commit comments