Skip to content

Commit 7d75dfa

Browse files
committed
Add filters and safety settings to text
1 parent 55d4fce commit 7d75dfa

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

google/generativeai/text.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from google.generativeai.client import get_default_text_client
2525
from google.generativeai.types import text_types
2626
from google.generativeai.types import model_types
27+
from google.generativeai.types import safety_types
2728

2829

2930
def _make_text_prompt(prompt: Union[str, dict[str, str]]) -> glm.TextPrompt:
@@ -44,6 +45,7 @@ def _make_generate_text_request(
4445
max_output_tokens: Optional[int] = None,
4546
top_p: Optional[int] = None,
4647
top_k: Optional[int] = None,
48+
safety_settings: Optional[List[safety_types.SafetySetting]] = None,
4749
stop_sequences: Union[str, Iterable[str]] = None,
4850
) -> glm.GenerateTextRequest:
4951
model = model_types.make_model_name(model)
@@ -61,6 +63,7 @@ def _make_generate_text_request(
6163
max_output_tokens=max_output_tokens,
6264
top_p=top_p,
6365
top_k=top_k,
66+
safety_settings=safety_settings,
6467
stop_sequences=stop_sequences,
6568
)
6669

@@ -74,6 +77,7 @@ def generate_text(
7477
max_output_tokens: Optional[int] = None,
7578
top_p: Optional[float] = None,
7679
top_k: Optional[float] = None,
80+
safety_settings: Optional[Iterable[safety.SafetySetting]] = None,
7781
stop_sequences: Union[str, Iterable[str]] = None,
7882
client: Optional[glm.TextServiceClient] = None,
7983
) -> text_types.Completion:
@@ -103,6 +107,15 @@ def generate_text(
103107
For example, if the sorted probabilities are
104108
`[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample
105109
as `[0.625, 0.25, 0.125, 0, 0, 0].
110+
safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content.
111+
These will be enforced on the `prompt` and
112+
`candidates`. There should not be more than one
113+
setting for each `types.SafetyCategory` type. The API will block any prompts and
114+
responses that fail to meet the thresholds set by these settings. This list
115+
overrides the default settings for each `SafetyCategory` specified in the
116+
safety_settings. If there is no `types.SafetySetting` for a given
117+
`SafetyCategory` provided in the list, the API will use the default safety
118+
setting for that category.
106119
stop_sequences: A set of up to 5 character sequences that will stop output generation.
107120
If specified, the API will stop at the first appearance of a stop
108121
sequence. The stop sequence will not be included as part of the response.
@@ -119,6 +132,7 @@ def generate_text(
119132
max_output_tokens=max_output_tokens,
120133
top_p=top_p,
121134
top_k=top_k,
135+
safety_settings=safety_settings,
122136
stop_sequences=stop_sequences,
123137
)
124138

google/generativeai/types/text_types.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from google.generativeai.types import citation_types
2222

2323

24-
__all__ = ["TextResponse"]
24+
__all__ = ["Completion"]
2525

2626

2727
class TextCompletion(TypedDict, total=False):
@@ -31,15 +31,18 @@ class TextCompletion(TypedDict, total=False):
3131

3232

3333
@dataclasses.dataclass(init=False)
34-
class TextResponse(abc.ABC):
35-
"""A text completion given a prompt from the model.
34+
class Completion(abc.ABC):
35+
"""The result of the `1 given a prompt from the model.
3636
3737
Use `GenerateTextResponse.candidates` to access all the completions generated by the model.
3838
3939
Attributes:
4040
candidates: A list of candidate text completions generated by the model.
41+
result: The output of the first candidate,
42+
filters: Indicates the reasons why content may have been blocked
43+
Either Unspecified, Safety, or Other. See `types.ContentFilter`.
44+
safety_feedback: Indicates which safety settings blocked content in this result.
4145
"""
42-
4346
candidates: List[TextCompletion]
4447
result: Optional[str]
4548
filters: Optional[list[safety_types.ContentFilter]]

0 commit comments

Comments
 (0)