|
105 | 105 | }, |
106 | 106 | "outputs": [], |
107 | 107 | "source": [ |
| 108 | + "# @title install Hugging Face Transformers v4.50+\n", |
108 | 109 | "! pip install -q 'transformers>=4.50.0'" |
109 | 110 | ] |
110 | 111 | }, |
|
116 | 117 | }, |
117 | 118 | "outputs": [], |
118 | 119 | "source": [ |
119 | | - "! huggingface-cli login" |
| 120 | + "# @title Authenticate with Hugging Face Hub\n", |
| 121 | + "# @markdown ShieldGemma is a gated model. To access the weights, you must accept\n", |
| 122 | + "# @markdown the license on Hugging Face Hub under your account and then provide\n", |
| 123 | + "# @markdown an [Access Token](https://huggingface.co/docs/hub/en/security-tokens)\n", |
| 124 | + "# @markdown to authenticate with the Hugging Face Hub API. If using Colab, the\n", |
| 125 | + "# @markdown easiest way to do this is by creating a read-only token specifically\n", |
| 126 | + "# @markdown for Colab and setting this as the value of the `HF_TOKEN` secret;\n", |
| 127 | + "# @markdown this token will then be reusable across all Colab notebooks. Other\n", |
| 128 | + "# @markdown Python notebook platforms may provide a similar mechanism. For those\n", |
| 129 | + "# @markdown that do not, un-comment the lines in this cell to install the\n", |
| 130 | + "# @markdown Hugging Face Hub CLI and log in interactively.\n", |
| 131 | + "# ! pip install -q 'huggingface_hub[cli]'\n", |
| 132 | + "# ! huggingface-cli login" |
120 | 133 | ] |
121 | 134 | }, |
122 | 135 | { |
|
146 | 159 | "from PIL import Image\n", |
147 | 160 | "import requests\n", |
148 | 161 | "\n", |
149 | | - "# The image included in this Colab is benign and will results in the prediction\n", |
150 | | - "# of a `No` token for all policies, meanign the image does not violate any\n", |
151 | | - "# content policies. Change this URL or otherwise update this code to use an\n", |
152 | | - "# image that may be violative.\n", |
| 162 | + "# The image included in this Colab is benign and will not violate any of\n", |
| 163 | + "# ShieldGemma's built-in content policies. Change this URL or otherwise update\n", |
| 164 | + "# this code to use an image that may be violative.\n", |
153 | 165 | "url = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg\"\n", |
154 | 166 | "image = Image.open(requests.get(url, stream=True).raw)" |
155 | 167 | ] |
|
169 | 181 | "\n", |
170 | 182 | "# `scores` is a `ShieldGemma2ImageClassifierOutputWithNoAttention` instance\n", |
171 | 183 | "# continaing the logits and probabilities associated with the model predicting\n", |
172 | | - "# the `Yes` or `No` token as the response to the prompt batch, captured in the\n", |
| 184 | + "# the `Yes` or `No` tokens as the response to the prompt batch, captured in the\n", |
173 | 185 | "# following properties.\n", |
174 | 186 | "#\n", |
175 | 187 | "# * `logits` (`torch.Tensor` of shape `(batch_size, 2)`): The first position\n", |
|
187 | 199 | "print(scores.probabilities)\n", |
188 | 200 | "\n", |
189 | 201 | "# ShieldGemma prompts are constructed such that predicting the `Yes` token means\n", |
190 | | - "# the content does violate the policy. If you are only interested in the\n", |
191 | | - "# violative condition, use to extract that slice from the output tensors.\n", |
| 202 | + "# the content violates the policy. If you are only interested in the violative\n", |
| 203 | + "# condition, you can extract only that slice from the output tensors.\n", |
192 | 204 | "p_violated = scores.probabilities[:, 0]\n", |
193 | 205 | "print(p_violated)\n" |
194 | 206 | ] |
|
0 commit comments