|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | import enum
|
| 18 | +from collections.abc import Mapping |
| 19 | + |
18 | 20 | from google.ai import generativelanguage as glm
|
19 | 21 | from google.generativeai import docstring_utils
|
20 | 22 | import typing
|
21 |
| -from typing import Iterable, List, TypedDict |
| 23 | +from typing import Iterable, Dict, Iterable, List, TypedDict, Union |
22 | 24 |
|
23 | 25 | __all__ = [
|
24 | 26 | "HarmCategory",
|
|
37 | 39 | HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold
|
38 | 40 | BlockedReason = glm.ContentFilter.BlockedReason
|
39 | 41 |
|
| 42 | +HarmCategoryOptions = Union[str, int, HarmCategory] |
| 43 | + |
| 44 | +_HARM_CATEGORIES: Dict[HarmCategoryOptions, HarmCategory] = { |
| 45 | + HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmCategory.HARM_CATEGORY_UNSPECIFIED, |
| 46 | + 0: HarmCategory.HARM_CATEGORY_UNSPECIFIED, |
| 47 | + HarmCategory.HARM_CATEGORY_DEROGATORY: HarmCategory.HARM_CATEGORY_DEROGATORY, |
| 48 | + 1: HarmCategory.HARM_CATEGORY_DEROGATORY, |
| 49 | + "derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY, |
| 50 | + HarmCategory.HARM_CATEGORY_TOXICITY: HarmCategory.HARM_CATEGORY_TOXICITY, |
| 51 | + 2: HarmCategory.HARM_CATEGORY_TOXICITY, |
| 52 | + "toxicity": HarmCategory.HARM_CATEGORY_TOXICITY, |
| 53 | + "toxic": HarmCategory.HARM_CATEGORY_TOXICITY, |
| 54 | + HarmCategory.HARM_CATEGORY_VIOLENCE: HarmCategory.HARM_CATEGORY_VIOLENCE, |
| 55 | + 3: HarmCategory.HARM_CATEGORY_VIOLENCE, |
| 56 | + "violence": HarmCategory.HARM_CATEGORY_VIOLENCE, |
| 57 | + "violent": HarmCategory.HARM_CATEGORY_VIOLENCE, |
| 58 | + HarmCategory.HARM_CATEGORY_SEXUAL: HarmCategory.HARM_CATEGORY_SEXUAL, |
| 59 | + 4: HarmCategory.HARM_CATEGORY_SEXUAL, |
| 60 | + "sexual": HarmCategory.HARM_CATEGORY_SEXUAL, |
| 61 | + "sex": HarmCategory.HARM_CATEGORY_SEXUAL, |
| 62 | + HarmCategory.HARM_CATEGORY_MEDICAL: HarmCategory.HARM_CATEGORY_MEDICAL, |
| 63 | + 5: HarmCategory.HARM_CATEGORY_MEDICAL, |
| 64 | + "medical": HarmCategory.HARM_CATEGORY_MEDICAL, |
| 65 | + "med": HarmCategory.HARM_CATEGORY_MEDICAL, |
| 66 | + HarmCategory.HARM_CATEGORY_DANGEROUS: HarmCategory.HARM_CATEGORY_DANGEROUS, |
| 67 | + 6: HarmCategory.HARM_CATEGORY_DANGEROUS, |
| 68 | + "danger": HarmCategory.HARM_CATEGORY_DANGEROUS, |
| 69 | + "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, |
| 70 | +} |
| 71 | + |
| 72 | + |
| 73 | +def to_harm_category(x: HarmCategoryOptions) -> HarmCategory: |
| 74 | + if isinstance(x, str): |
| 75 | + x = x.lower() |
| 76 | + return _HARM_CATEGORIES[x] |
| 77 | + |
| 78 | + |
| 79 | +HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] |
| 80 | + |
| 81 | +_BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = { |
| 82 | + HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, |
| 83 | + 0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, |
| 84 | + HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, |
| 85 | + 1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, |
| 86 | + "low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, |
| 87 | + HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
| 88 | + 2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
| 89 | + "medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
| 90 | + "med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
| 91 | + HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH, |
| 92 | + 3: HarmBlockThreshold.BLOCK_ONLY_HIGH, |
| 93 | + "high": HarmBlockThreshold.BLOCK_ONLY_HIGH, |
| 94 | + HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE, |
| 95 | + 4: HarmBlockThreshold.BLOCK_NONE, |
| 96 | + "block_none": HarmBlockThreshold.BLOCK_NONE, |
| 97 | +} |
| 98 | + |
| 99 | + |
| 100 | +def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmCategory: |
| 101 | + if isinstance(x, str): |
| 102 | + x = x.lower() |
| 103 | + return _BLOCK_THRESHOLDS[x] |
| 104 | + |
40 | 105 |
|
41 | 106 | class ContentFilterDict(TypedDict):
|
42 | 107 | reason: BlockedReason
|
@@ -83,6 +148,35 @@ class SafetySettingDict(TypedDict):
|
83 | 148 | __doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__)
|
84 | 149 |
|
85 | 150 |
|
| 151 | +class LooseSafetySettingDict(TypedDict): |
| 152 | + category: HarmCategoryOptions |
| 153 | + threshold: HarmBlockThresholdOptions |
| 154 | + |
| 155 | + |
| 156 | +EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] |
| 157 | + |
| 158 | +SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] |
| 159 | + |
| 160 | + |
| 161 | +def normalize_safety_settings( |
| 162 | + settings: SafetySettingOptions, |
| 163 | +) -> list[SafetySettingDict] | None: |
| 164 | + if settings is None: |
| 165 | + return None |
| 166 | + if isinstance(settings, Mapping): |
| 167 | + return [ |
| 168 | + {"category": to_harm_category(key), "threshold": to_block_threshold(value)} |
| 169 | + for key, value in settings.items() |
| 170 | + ] |
| 171 | + return [ |
| 172 | + { |
| 173 | + "category": to_harm_category(d["category"]), |
| 174 | + "threshold": to_block_threshold(d["threshold"]), |
| 175 | + } |
| 176 | + for d in settings |
| 177 | + ] |
| 178 | + |
| 179 | + |
86 | 180 | def convert_setting_to_enum(setting: dict) -> SafetySettingDict:
|
87 | 181 | return {
|
88 | 182 | "category": HarmCategory(setting["category"]),
|
|
0 commit comments