Skip to content

Commit c4e5c0d

Browse files
committed
Adding ShieldGemma 2 notebook to Responsible AI Toolkit docs
1 parent 7868875 commit c4e5c0d

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"id": "cLCmbOz_5tWH"
7+
},
8+
"source": [
9+
"##### Copyright 2025 Google LLC"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {
16+
"cellView": "form",
17+
"id": "vdPaBz5y5LHW"
18+
},
19+
"outputs": [],
20+
"source": [
21+
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
22+
"# you may not use this file except in compliance with the License.\n",
23+
"# You may obtain a copy of the License at\n",
24+
"#\n",
25+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
26+
"#\n",
27+
"# Unless required by applicable law or agreed to in writing, software\n",
28+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
29+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
30+
"# See the License for the specific language governing permissions and\n",
31+
"# limitations under the License."
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"metadata": {
37+
"id": "3Zd1278P5wt_"
38+
},
39+
"source": [
40+
"# Evaluating content safety with ShieldGemma 2 and Hugging Face Transformers"
41+
]
42+
},
43+
{
44+
"cell_type": "markdown",
45+
"metadata": {
46+
"id": "4IlgEYUj7xdW"
47+
},
48+
"source": [
49+
"The **ShieldGemma 2** model is trained to detect key harms detailed in the [model card](https://ai.google.dev/gemma/docs/shieldgemma/model_card_2). This guide demonstrates how to use Hugging Face Transformers to build robust data and models.\n",
50+
"\n",
51+
"Note that `ShieldGemma 2` is trained to classify only one harm type at a time, so you will need to make a separate call to `ShieldGemma 2` for each harm type you want to check against. You may have additional that you can use model tuning techniques on `ShieldGemma 2`."
52+
]
53+
},
54+
{
55+
"cell_type": "markdown",
56+
"metadata": {
57+
"id": "RhlnMQoK9fZG"
58+
},
59+
"source": [
60+
"# Supported safety checks\n",
61+
"\n",
62+
"**ShieldGemma2** is a model trained on Gemma 3's 4B IT checkpoint and is trained to detect and predict violations of key harm types listed below:\n",
63+
"\n",
64+
"* **Dangerous Content**: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).\n",
65+
"\n",
66+
"* **Sexually Explicit**: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault).\n",
67+
"\n",
68+
"* **Violence/Gore**: The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death).\n",
69+
"\n",
70+
"This serves as a foundation, but users can provide customized safety policies as input to the model, allowing for fine-grained control and specific use-case requirements."
71+
]
72+
},
73+
{
74+
"cell_type": "markdown",
75+
"metadata": {
76+
"id": "t3aq-ToeAmRM"
77+
},
78+
"source": [
79+
"# Supported Use Case\n",
80+
"\n",
81+
"**We recommend using `ShieldGemma 2` as an input filter to vision language models or as an output filter of image generation systems or both.** ShieldGemma 2 offers the following key advantages:\n",
82+
"\n",
83+
"* **Policy-Aware Classification**: ShieldGemma 2 accepts both a user-defined safety policy and an image as input, providing classifications for both real and generated images, tailored to the specific policy guidelines.\n",
84+
"* **Probability-Based Output and Thresholding**: ShieldGemma 2 outputs a probability score for its predictions, allowing downstream users to flexibly tune the classification threshold based on their specific use cases and risk tolerance. This enables a more nuanced and adaptable approach to safety classification.\n",
85+
"\n",
86+
"The input/output format are as follows:\n",
87+
"* **Input**: Image + Prompt Instruction with policy definition\n",
88+
"* **Output**: Probability of 'Yes'/'No' tokens, 'Yes' meaning that the image violated the specific policy. The higher the score, the higher the model's confidence that the image violates the specified policy."
89+
]
90+
},
91+
{
92+
"cell_type": "markdown",
93+
"metadata": {
94+
"id": "0WhRozADVJos"
95+
},
96+
"source": [
97+
"# Usage example"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {
104+
"id": "K_XERopLUZhk"
105+
},
106+
"outputs": [],
107+
"source": [
108+
"! pip install -q 'transformers>=4.50.0'"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": null,
114+
"metadata": {
115+
"id": "Qg-Hy0ffbwvE"
116+
},
117+
"outputs": [],
118+
"source": [
119+
"! huggingface-cli login"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"metadata": {
126+
"id": "40Rm46Xt7wqW"
127+
},
128+
"outputs": [],
129+
"source": [
130+
"from transformers import AutoProcessor, AutoModelForImageClassification\n",
131+
"import torch\n",
132+
"\n",
133+
"model_id = \"google/shieldgemma-2-4b-it\"\n",
134+
"\n",
135+
"processor = AutoProcessor.from_pretrained(model_id)\n",
136+
"model = AutoModelForImageClassification.from_pretrained(model_id)\n",
137+
"model.to(torch.device(\"cuda\"))"
138+
]
139+
},
140+
{
141+
"cell_type": "code",
142+
"execution_count": null,
143+
"metadata": {},
144+
"outputs": [],
145+
"source": [
146+
"from PIL import Image\n",
147+
"import requests\n",
148+
"\n",
149+
"# The image included in this Colab is benign and will results in the prediction\n",
150+
"# of a `No` token for all policies, meanign the image does not violate any\n",
151+
"# content policies. Change this URL or otherwise update this code to use an\n",
152+
"# image that may be violative.\n",
153+
"url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg\"\n",
154+
"image = Image.open(requests.get(url, stream=True).raw)"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": null,
160+
"metadata": {
161+
"id": "AK1PrHnYz4fv"
162+
},
163+
"outputs": [],
164+
"source": [
165+
"inputs = processor(images=[image], return_tensors=\"pt\").to(torch.device(\"cuda\"))\n",
166+
"\n",
167+
"with torch.no_grad():\n",
168+
" scores = model(**inputs)\n",
169+
"\n",
170+
"# `scores` is a `ShieldGemma2ImageClassifierOutputWithNoAttention` instance\n",
171+
"# continaing the logits and probabilities associated with the model predicting\n",
172+
"# the `Yes` or `No` token as the response to the prompt batch, captured in the\n",
173+
"# following properties.\n",
174+
"#\n",
175+
"# * `logits` (`torch.Tensor` of shape `(batch_size, 2)`): The first position\n",
176+
"# along dim=1 is the logits for the `Yes` token and the second position\n",
177+
"# along dim=1 is the logits for the `No` token.\n",
178+
"# * `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`): The first\n",
179+
"# position along dim=1 is the probability of predicting the `Yes` token\n",
180+
"# and the second position along dim=1 is the probability of predicting the\n",
181+
"# `No` token.\n",
182+
"#\n",
183+
"# When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to\n",
184+
"# `len(images) * len(policies)`, and the order within the batch will be\n",
185+
"# img1_policy1, ... img1_policyN, ... imgM_policyN.\n",
186+
"print(scores.logits)\n",
187+
"print(scores.probabilities)\n",
188+
"\n",
189+
"# ShieldGemma prompts are constructed such that predicting the `Yes` token means\n",
190+
"# the content does violate the policy. If you are only interested in the\n",
191+
"# violative condition, use to extract that slice from the output tensors.\n",
192+
"p_violated = scores.probabilities[:, 0]\n",
193+
"print(p_violated)\n"
194+
]
195+
}
196+
],
197+
"metadata": {
198+
"accelerator": "GPU",
199+
"colab": {
200+
"gpuType": "A100",
201+
"machine_shape": "hm",
202+
"provenance": []
203+
},
204+
"kernelspec": {
205+
"display_name": "Python 3",
206+
"name": "python3"
207+
},
208+
"language_info": {
209+
"name": "python"
210+
}
211+
},
212+
"nbformat": 4,
213+
"nbformat_minor": 0
214+
}

0 commit comments

Comments
 (0)