|  | 
| 21 | 21 | from ai_edge_torch.generative.examples.gemma import gemma2 | 
| 22 | 22 | from ai_edge_torch.generative.examples.llama import llama | 
| 23 | 23 | from ai_edge_torch.generative.examples.openelm import openelm | 
|  | 24 | +from ai_edge_torch.generative.examples.paligemma import decoder | 
|  | 25 | +from ai_edge_torch.generative.examples.paligemma import decoder2 | 
| 24 | 26 | from ai_edge_torch.generative.examples.paligemma import paligemma | 
| 25 | 27 | from ai_edge_torch.generative.examples.phi import phi2 | 
| 26 | 28 | from ai_edge_torch.generative.examples.phi import phi3 | 
| @@ -171,13 +173,9 @@ def test_amd_llama_135m(self): | 
| 171 | 173 |     pytorch_model = amd_llama_135m.AmdLlama(config).eval() | 
| 172 | 174 |     self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5) | 
| 173 | 175 | 
 | 
| 174 |  | -  @googletest.skipIf( | 
| 175 |  | -      ai_edge_torch.config.in_oss, | 
| 176 |  | -      reason="tests with custom ops are not supported in oss", | 
| 177 |  | -  ) | 
| 178 |  | -  def disabled_test_paligemma(self): | 
| 179 |  | -    config = paligemma.get_fake_model_config() | 
| 180 |  | -    pytorch_model = paligemma.PaliGemma(config).eval() | 
|  | 176 | +  def _test_paligemma_model(self, decoder_class, decoder_config, atol, rtol): | 
|  | 177 | +    config = paligemma.get_fake_model_config(decoder_config) | 
|  | 178 | +    pytorch_model = paligemma.PaliGemma(config, decoder_class).eval() | 
| 181 | 179 | 
 | 
| 182 | 180 |     image_embedding_config = config.image_encoder_config.image_embedding | 
| 183 | 181 |     num_patches = ( | 
| @@ -215,11 +213,32 @@ def disabled_test_paligemma(self): | 
| 215 | 213 |             kv, | 
| 216 | 214 |             pixel_values=pixel_values, | 
| 217 | 215 |             signature_name="prefill_pixel", | 
| 218 |  | -            atol=1e-3, | 
| 219 |  | -            rtol=1e-5, | 
|  | 216 | +            atol=atol, | 
|  | 217 | +            rtol=rtol, | 
| 220 | 218 |         ) | 
| 221 | 219 |     ) | 
| 222 | 220 | 
 | 
|  | 221 | +  @googletest.skipIf( | 
|  | 222 | +      ai_edge_torch.config.in_oss, | 
|  | 223 | +      reason="tests with custom ops are not supported in oss", | 
|  | 224 | +  ) | 
|  | 225 | +  def disabled_test_paligemma1(self): | 
|  | 226 | +    self._test_paligemma_model( | 
|  | 227 | +        decoder.Decoder, decoder.get_fake_decoder_config, atol=1e-3, rtol=1e-5 | 
|  | 228 | +    ) | 
|  | 229 | + | 
|  | 230 | +  @googletest.skipIf( | 
|  | 231 | +      ai_edge_torch.config.in_oss, | 
|  | 232 | +      reason="tests with custom ops are not supported in oss", | 
|  | 233 | +  ) | 
|  | 234 | +  def disabled_test_paligemma2(self): | 
|  | 235 | +    self._test_paligemma_model( | 
|  | 236 | +        decoder2.Decoder2, | 
|  | 237 | +        decoder2.get_fake_decoder2_config, | 
|  | 238 | +        atol=1e-3, | 
|  | 239 | +        rtol=1e-5, | 
|  | 240 | +    ) | 
|  | 241 | + | 
| 223 | 242 |   @googletest.skipIf( | 
| 224 | 243 |       ai_edge_torch.config.in_oss, | 
| 225 | 244 |       reason="tests with custom ops are not supported in oss", | 
|  | 
0 commit comments