Skip to content

Commit 907e5a2

Browse files
committed
format
Change-Id: I15fd5701dd5c4200461a32c968fa19e375403a7e
1 parent 6e00eed commit 907e5a2

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

google/generativeai/types/generation_types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,10 @@ def _join_prompt_feedbacks(
361361

362362
def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]):
363363
chunks = tuple(chunks)
364-
if 'usage_metadata' in chunks[-1]:
364+
if "usage_metadata" in chunks[-1]:
365365
usage_metadata = chunks[-1].usage_metadata
366366
else:
367-
usage_metadata=None
368-
367+
usage_metadata = None
369368

370369
return protos.GenerateContentResponse(
371370
candidates=_join_candidate_lists(c.candidates for c in chunks),

google/generativeai/types/model_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def idecode_time(parent: dict["str", Any], name: str):
143143

144144
def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel:
145145
if isinstance(tuned_model, protos.TunedModel):
146-
tuned_model = type(tuned_model).to_dict(tuned_model, including_default_value_fields=False) # pytype: disable=attribute-error
146+
tuned_model = type(tuned_model).to_dict(
147+
tuned_model, including_default_value_fields=False
148+
) # pytype: disable=attribute-error
147149
tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None))
148150

149151
base_model = tuned_model.pop("base_model", None)

tests/test_generation.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class Person(TypedDict):
3939

4040
class UnitTests(parameterized.TestCase):
4141
maxDiff = None
42+
4243
@parameterized.named_parameters(
4344
[
4445
"protos.GenerationConfig",
@@ -473,7 +474,10 @@ def test_join_prompt_feedbacks(self):
473474
def test_join_candidates(self):
474475
candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS]
475476
result = generation_types._join_candidate_lists(candidate_lists)
476-
self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r, including_default_value_fields=False) for r in result])
477+
self.assertEqual(
478+
self.MERGED_CANDIDATES,
479+
[type(r).to_dict(r, including_default_value_fields=False) for r in result],
480+
)
477481

478482
def test_join_chunks(self):
479483
chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS]
@@ -485,7 +489,9 @@ def test_join_chunks(self):
485489
],
486490
)
487491

488-
chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(prompt_token_count=5)
492+
chunks[-1].usage_metadata = protos.GenerateContentResponse.UsageMetadata(
493+
prompt_token_count=5
494+
)
489495

490496
result = generation_types._join_chunks(chunks)
491497

@@ -502,15 +508,16 @@ def test_join_chunks(self):
502508
}
503509
],
504510
},
505-
"usage_metadata": {
506-
"prompt_token_count": 5
507-
}
508-
511+
"usage_metadata": {"prompt_token_count": 5},
509512
},
510513
)
511514

512-
expected = json.dumps(type(expected).to_dict(expected, including_default_value_fields=False), indent=4)
513-
result = json.dumps(type(result).to_dict(result, including_default_value_fields=False), indent=4)
515+
expected = json.dumps(
516+
type(expected).to_dict(expected, including_default_value_fields=False), indent=4
517+
)
518+
result = json.dumps(
519+
type(result).to_dict(result, including_default_value_fields=False), indent=4
520+
)
514521

515522
self.assertEqual(expected, result)
516523

0 commit comments

Comments
 (0)