Skip to content

Commit 1b1d883

Browse files
aertoriaMarkDaoust
andauthored
Restrict Harm category to the sublist only Gemini support (#295)
* Restrict Harm category to the sublist only Gemini support * Update text.py * Update safety_types.py * Update safety_types.py * Update safety_types.py * split module Change-Id: Ia94b262d4e27511ca2e4eeb02cb5bd617a772463 * add palm safety Change-Id: Ia1cb199148619ebbc26638d5983b435245904971 * switch imports Change-Id: I2853a88d7acc51a78174c97e30bde8eb24e1d457 --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 55cca2f commit 1b1d883

File tree

10 files changed

+418
-161
lines changed

10 files changed

+418
-161
lines changed

google/generativeai/answer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,7 @@ def _make_generate_answer_request(
206206
contents = content_types.to_contents(contents)
207207

208208
if safety_settings:
209-
safety_settings = safety_types.normalize_safety_settings(
210-
safety_settings, harm_category_set="new"
211-
)
209+
safety_settings = safety_types.normalize_safety_settings(safety_settings)
212210

213211
if inline_passages is not None and semantic_retriever is not None:
214212
raise ValueError(

google/generativeai/discuss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from google.generativeai import string_utils
2828
from google.generativeai.types import discuss_types
2929
from google.generativeai.types import model_types
30-
from google.generativeai.types import safety_types
30+
from google.generativeai.types import palm_safety_types
3131

3232

3333
def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
@@ -521,7 +521,7 @@ def _build_chat_response(
521521
response = type(response).to_dict(response)
522522
response.pop("messages")
523523

524-
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
524+
response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"])
525525

526526
if response["candidates"]:
527527
last = response["candidates"][0]

google/generativeai/generative_models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ def __init__(
7979
if "/" not in model_name:
8080
model_name = "models/" + model_name
8181
self._model_name = model_name
82-
self._safety_settings = safety_types.to_easy_safety_dict(
83-
safety_settings, harm_category_set="new"
84-
)
82+
self._safety_settings = safety_types.to_easy_safety_dict(safety_settings)
8583
self._generation_config = generation_types.to_generation_config_dict(generation_config)
8684
self._tools = content_types.to_function_library(tools)
8785

@@ -149,10 +147,10 @@ def _prepare_request(
149147
merged_gc = self._generation_config.copy()
150148
merged_gc.update(generation_config)
151149

152-
safety_settings = safety_types.to_easy_safety_dict(safety_settings, harm_category_set="new")
150+
safety_settings = safety_types.to_easy_safety_dict(safety_settings)
153151
merged_ss = self._safety_settings.copy()
154152
merged_ss.update(safety_settings)
155-
merged_ss = safety_types.normalize_safety_settings(merged_ss, harm_category_set="new")
153+
merged_ss = safety_types.normalize_safety_settings(merged_ss)
156154

157155
return glm.GenerateContentRequest(
158156
model=self._model_name,

google/generativeai/text.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from google.generativeai.types import text_types
2727
from google.generativeai.types import model_types
2828
from google.generativeai import models
29-
from google.generativeai.types import safety_types
29+
from google.generativeai.types import palm_safety_types
3030

3131
DEFAULT_TEXT_MODEL = "models/text-bison-001"
3232
EMBEDDING_MAX_BATCH_SIZE = 100
@@ -81,7 +81,7 @@ def _make_generate_text_request(
8181
max_output_tokens: int | None = None,
8282
top_p: int | None = None,
8383
top_k: int | None = None,
84-
safety_settings: safety_types.SafetySettingOptions | None = None,
84+
safety_settings: palm_safety_types.SafetySettingOptions | None = None,
8585
stop_sequences: str | Iterable[str] | None = None,
8686
) -> glm.GenerateTextRequest:
8787
"""
@@ -108,9 +108,7 @@ def _make_generate_text_request(
108108
"""
109109
model = model_types.make_model_name(model)
110110
prompt = _make_text_prompt(prompt=prompt)
111-
safety_settings = safety_types.normalize_safety_settings(
112-
safety_settings, harm_category_set="old"
113-
)
111+
safety_settings = palm_safety_types.normalize_safety_settings(safety_settings)
114112
if isinstance(stop_sequences, str):
115113
stop_sequences = [stop_sequences]
116114
if stop_sequences:
@@ -138,7 +136,7 @@ def generate_text(
138136
max_output_tokens: int | None = None,
139137
top_p: float | None = None,
140138
top_k: float | None = None,
141-
safety_settings: safety_types.SafetySettingOptions | None = None,
139+
safety_settings: palm_safety_types.SafetySettingOptions | None = None,
142140
stop_sequences: str | Iterable[str] | None = None,
143141
client: glm.TextServiceClient | None = None,
144142
request_options: dict[str, Any] | None = None,
@@ -240,11 +238,11 @@ def _generate_response(
240238
response = client.generate_text(request, **request_options)
241239
response = type(response).to_dict(response)
242240

243-
response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
244-
response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums(
241+
response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"])
242+
response["safety_feedback"] = palm_safety_types.convert_safety_feedback_to_enums(
245243
response["safety_feedback"]
246244
)
247-
response["candidates"] = safety_types.convert_candidate_enums(response["candidates"])
245+
response["candidates"] = palm_safety_types.convert_candidate_enums(response["candidates"])
248246

249247
return Completion(_client=client, **response)
250248

google/generativeai/types/discuss_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import google.ai.generativelanguage as glm
2323
from google.generativeai import string_utils
2424

25-
from google.generativeai.types import safety_types
25+
from google.generativeai.types import palm_safety_types
2626
from google.generativeai.types import citation_types
2727

2828

@@ -169,7 +169,7 @@ class ChatResponse(abc.ABC):
169169
temperature: Optional[float]
170170
candidate_count: Optional[int]
171171
candidates: List[MessageDict]
172-
filters: List[safety_types.ContentFilterDict]
172+
filters: List[palm_safety_types.ContentFilterDict]
173173
top_p: Optional[float] = None
174174
top_k: Optional[float] = None
175175

0 commit comments

Comments
 (0)