Skip to content

Commit e19b8b0

Browse files
authored
Add dataclass prettyprinting. (#73)
* Add dataclass prettyprinting. * use reprlib.recursive_repr * format * Improve contractions.
1 parent d5326cb commit e19b8b0

File tree

11 files changed

+214
-30
lines changed

11 files changed

+214
-30
lines changed

google/generativeai/discuss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from google.generativeai.client import get_default_discuss_client
2626
from google.generativeai.client import get_default_discuss_async_client
27+
from google.generativeai import string_utils
2728
from google.generativeai.types import discuss_types
2829
from google.generativeai.types import model_types
2930
from google.generativeai.types import safety_types
@@ -445,6 +446,7 @@ async def chat_async(
445446
DATACLASS_KWARGS = {}
446447

447448

449+
@string_utils.prettyprint
448450
@set_doc(discuss_types.ChatResponse.__doc__)
449451
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
450452
class ChatResponse(discuss_types.ChatResponse):

google/generativeai/docstring_utils.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

google/generativeai/operations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Iterator
1919

2020
from google.ai import generativelanguage as glm
21+
2122
from google.generativeai import client as client_lib
2223
from google.generativeai.types import model_types
2324
from google.api_core import operation as operation_lib

google/generativeai/string_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
import pprint
19+
import re
20+
import reprlib
21+
import textwrap
22+
23+
24+
def strip_oneof(docstring):
25+
lines = docstring.splitlines()
26+
lines = [line for line in lines if ".. _oneof:" not in line]
27+
lines = [line for line in lines if "This field is a member of `oneof`_" not in line]
28+
return "\n".join(lines)
29+
30+
31+
def prettyprint(cls):
32+
cls.__str__ = _prettyprint
33+
cls.__repr__ = _prettyprint
34+
return cls
35+
36+
37+
repr = reprlib.Repr()
38+
39+
40+
@reprlib.recursive_repr()
41+
def _prettyprint(self):
42+
"""A dataclass prettyprint function you can use in __str__or __repr__.
43+
44+
Note: You can't set `__str__ = pprint.pformat` because it causes a recursion error.
45+
46+
Mostly identical to pprint but:
47+
48+
* This will contract long lists and dicts (> 10lines) to [...] and {...}.
49+
* This will contract long object reprs to ClassName(...).
50+
"""
51+
fields = []
52+
for f in dataclasses.fields(self):
53+
s = pprint.pformat(getattr(self, f.name))
54+
class_re = r"^(\w+)\(.*\)$"
55+
if s.count("\n") >= 10:
56+
if s.startswith("["):
57+
s = "[...]"
58+
elif s.startswith("{"):
59+
s = "{...}"
60+
elif re.match(class_re, s, flags=re.DOTALL):
61+
s = re.sub(class_re, r"\1(...)", s, flags=re.DOTALL)
62+
else:
63+
s = "..."
64+
else:
65+
width = len(f.name) + 1
66+
s = textwrap.indent(s, " " * width).lstrip(" ")
67+
fields.append(f"{f.name}={s}")
68+
attrs = ",\n".join(fields)
69+
70+
name = self.__class__.__name__
71+
width = len(name) + 1
72+
73+
attrs = textwrap.indent(attrs, " " * width).lstrip(" ")
74+
return f"{name}({attrs})"

google/generativeai/text.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import google.ai.generativelanguage as glm
2222

2323
from google.generativeai.client import get_default_text_client
24+
from google.generativeai import string_utils
2425
from google.generativeai.types import text_types
2526
from google.generativeai.types import model_types
2627
from google.generativeai.types import safety_types
@@ -175,6 +176,7 @@ def generate_text(
175176
return _generate_response(client=client, request=request)
176177

177178

179+
@string_utils.prettyprint
178180
@dataclasses.dataclass(init=False)
179181
class Completion(text_types.Completion):
180182
def __init__(self, **kwargs):

google/generativeai/types/citation_types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from typing import Optional, List
1818

1919
from google.ai import generativelanguage as glm
20-
from google.generativeai import docstring_utils
20+
from google.generativeai import string_utils
21+
2122
from typing import TypedDict
2223

2324
__all__ = [
@@ -32,10 +33,10 @@ class CitationSourceDict(TypedDict):
3233
uri: str | None
3334
license: str | None
3435

35-
__doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__)
36+
__doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__)
3637

3738

3839
class CitationMetadataDict(TypedDict):
3940
citation_sources: List[CitationSourceDict | None]
4041

41-
__doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__)
42+
__doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__)

google/generativeai/types/discuss_types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List
2020

2121
import google.ai.generativelanguage as glm
22+
from google.generativeai import string_utils
23+
2224
from google.generativeai.types import safety_types
2325
from google.generativeai.types import citation_types
2426

@@ -97,6 +99,7 @@ class ResponseDict(TypedDict):
9799
candidates: List[MessageDict]
98100

99101

102+
@string_utils.prettyprint
100103
@dataclasses.dataclass(init=False)
101104
class ChatResponse(abc.ABC):
102105
"""A chat response from the model.

google/generativeai/types/model_types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Any, Iterable, TypedDict, Union
2222

2323
import google.ai.generativelanguage as glm
24+
from google.generativeai import string_utils
2425

2526
__all__ = [
2627
"Model",
@@ -65,6 +66,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState:
6566
return _TUNED_MODEL_STATES[x]
6667

6768

69+
@string_utils.prettyprint
6870
@dataclasses.dataclass
6971
class Model:
7072
"""A dataclass representation of a `glm.Model`.
@@ -152,6 +154,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM
152154
return TunedModel(**tuned_model)
153155

154156

157+
@string_utils.prettyprint
155158
@dataclasses.dataclass
156159
class TunedModel:
157160
"""A dataclass representation of a `glm.TunedModel`."""
@@ -170,6 +173,7 @@ class TunedModel:
170173
tuning_task: TuningTask | None = None
171174

172175

176+
@string_utils.prettyprint
173177
@dataclasses.dataclass
174178
class TuningTask:
175179
start_time: datetime.datetime | None = None
@@ -208,6 +212,7 @@ def encode_tuning_example(example: TuningExampleOptions):
208212
return example
209213

210214

215+
@string_utils.prettyprint
211216
@dataclasses.dataclass
212217
class TuningSnapshot:
213218
step: int
@@ -216,6 +221,7 @@ class TuningSnapshot:
216221
compute_time: datetime.datetime
217222

218223

224+
@string_utils.prettyprint
219225
@dataclasses.dataclass
220226
class Hyperparameters:
221227
epoch_count: int = 0
@@ -246,6 +252,7 @@ def make_model_name(name: AnyModelNameOptions):
246252
TunedModelsIterable = Iterable[TunedModel]
247253

248254

255+
@string_utils.prettyprint
249256
@dataclasses.dataclass
250257
class TokenCount:
251258
"""A dataclass representation of a `glm.TokenCountResponse`.

google/generativeai/types/safety_types.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from collections.abc import Mapping
1818

1919
from google.ai import generativelanguage as glm
20-
from google.generativeai import docstring_utils
20+
from google.generativeai import string_utils
21+
2122
import typing
2223
from typing import Iterable, Dict, Iterable, List, TypedDict, Union
2324

@@ -134,7 +135,7 @@ class ContentFilterDict(TypedDict):
134135
reason: BlockedReason
135136
message: str
136137

137-
__doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__)
138+
__doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__)
138139

139140

140141
def convert_filters_to_enums(
@@ -153,7 +154,7 @@ class SafetyRatingDict(TypedDict):
153154
category: HarmCategory
154155
probability: HarmProbability
155156

156-
__doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__)
157+
__doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__)
157158

158159

159160
def convert_rating_to_enum(rating: dict) -> SafetyRatingDict:
@@ -174,7 +175,7 @@ class SafetySettingDict(TypedDict):
174175
category: HarmCategory
175176
threshold: HarmBlockThreshold
176177

177-
__doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__)
178+
__doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__)
178179

179180

180181
class LooseSafetySettingDict(TypedDict):
@@ -220,7 +221,7 @@ class SafetyFeedbackDict(TypedDict):
220221
rating: SafetyRatingDict
221222
setting: SafetySettingDict
222223

223-
__doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__)
224+
__doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__)
224225

225226

226227
def convert_safety_feedback_to_enums(

google/generativeai/types/text_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
from typing import Any, Dict, List, TypedDict
2020

21+
from google.generativeai import string_utils
2122
from google.generativeai.types import safety_types
2223
from google.generativeai.types import citation_types
2324

@@ -39,6 +40,7 @@ class TextCompletion(TypedDict, total=False):
3940
citation_metadata: citation_types.CitationMetadataDict | None
4041

4142

43+
@string_utils.prettyprint
4244
@dataclasses.dataclass(init=False)
4345
class Completion(abc.ABC):
4446
"""The result returned by `generativeai.generate_text`.

0 commit comments

Comments
 (0)