Skip to content

Commit 6e00eed

Browse files
committed
Fix tests
Change-Id: Ifa610965c5d6c38123080a7e16416ac325418285
1 parent e09858b commit 6e00eed

File tree

3 files changed

+42
-23
lines changed

3 files changed

+42
-23
lines changed

google/generativeai/types/generation_types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def _join_code_execution_result(result_1, result_2):
321321

322322

323323
def _join_candidates(candidates: Iterable[protos.Candidate]):
324+
"""Joins stream chunks of a single candidate."""
324325
candidates = tuple(candidates)
325326

326327
index = candidates[0].index # These should all be the same.
@@ -336,6 +337,7 @@ def _join_candidates(candidates: Iterable[protos.Candidate]):
336337

337338

338339
def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]):
340+
"""Joins stream chunks where each chunk is a list of candidate chunks."""
339341
# Assuming that is a candidate ends, it is no longer returned in the list of
340342
# candidates and that's why candidates have an index
341343
candidates = collections.defaultdict(list)
@@ -359,10 +361,16 @@ def _join_prompt_feedbacks(
359361

360362
def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
361363
chunks = tuple(chunks)
364+
if 'usage_metadata' in chunks[-1]:
365+
usage_metadata = chunks[-1].usage_metadata
366+
else:
367+
usage_metadata=None
368+
369+
362370
return protos.GenerateContentResponse(
363371
candidates=_join_candidate_lists(c.candidates for c in chunks),
364372
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
365-
usage_metadata=chunks[-1].usage_metadata,
373+
usage_metadata=usage_metadata,
366374
)
367375

368376

tests/test_generation.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2024 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
import inspect
17+
import json
218
import string
319
import textwrap
420
from typing_extensions import TypedDict
@@ -22,6 +38,7 @@ class Person(TypedDict):
2238

2339

2440
class UnitTests(parameterized.TestCase):
41+
maxDiff = None
2542
@parameterized.named_parameters(
2643
[
2744
"protos.GenerationConfig",
@@ -416,24 +433,16 @@ def test_join_prompt_feedbacks(self):
416433
],
417434
"role": "assistant",
418435
},
419-
"citation_metadata": {"citation_sources": []},
420436
"index": 0,
421-
"finish_reason": 0,
422-
"safety_ratings": [],
423-
"token_count": 0,
424-
"grounding_attributions": [],
437+
"citation_metadata": {},
425438
},
426439
{
427440
"content": {
428441
"parts": [{"text": "Tell me a story about a magic backpack"}],
429442
"role": "assistant",
430443
},
431444
"index": 1,
432-
"citation_metadata": {"citation_sources": []},
433-
"finish_reason": 0,
434-
"safety_ratings": [],
435-
"token_count": 0,
436-
"grounding_attributions": [],
445+
"citation_metadata": {},
437446
},
438447
{
439448
"content": {
@@ -458,17 +467,13 @@ def test_join_prompt_feedbacks(self):
458467
},
459468
]
460469
},
461-
"finish_reason": 0,
462-
"safety_ratings": [],
463-
"token_count": 0,
464-
"grounding_attributions": [],
465470
},
466471
]
467472

468473
def test_join_candidates(self):
469474
candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS]
470475
result = generation_types._join_candidate_lists(candidate_lists)
471-
self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result])
476+
self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r, including_default_value_fields=False) for r in result])
472477

473478
def test_join_chunks(self):
474479
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
@@ -480,6 +485,8 @@ def test_join_chunks(self):
480485
],
481486
)
482487

488+
chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(prompt_token_count=5)
489+
483490
result = generation_types._join_chunks(chunks)
484491

485492
expected = protos.GenerateContentResponse(
@@ -495,10 +502,17 @@ def test_join_chunks(self):
495502
}
496503
],
497504
},
505+
"usage_metadata": {
506+
"prompt_token_count": 5
507+
}
508+
498509
},
499510
)
500511

501-
self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected))
512+
expected = json.dumps(type(expected).to_dict(expected, including_default_value_fields=False), indent=4)
513+
result = json.dumps(type(result).to_dict(result, including_default_value_fields=False), indent=4)
514+
515+
self.assertEqual(expected, result)
502516

503517
def test_generate_content_response_iterator_end_to_end(self):
504518
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]

tests/test_generative_models.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -935,8 +935,7 @@ def test_repr_for_streaming_start_to_finish(self):
935935
"citation_metadata": {}
936936
}
937937
],
938-
"prompt_feedback": {},
939-
"usage_metadata": {}
938+
"prompt_feedback": {}
940939
}),
941940
)"""
942941
)
@@ -964,8 +963,7 @@ def test_repr_for_streaming_start_to_finish(self):
964963
"citation_metadata": {}
965964
}
966965
],
967-
"prompt_feedback": {},
968-
"usage_metadata": {}
966+
"prompt_feedback": {}
969967
}),
970968
)"""
971969
)
@@ -1056,8 +1054,7 @@ def no_throw():
10561054
"citation_metadata": {}
10571055
}
10581056
],
1059-
"prompt_feedback": {},
1060-
"usage_metadata": {}
1057+
"prompt_feedback": {}
10611058
}),
10621059
),
10631060
error=ValueError()"""

0 commit comments

Comments
 (0)