|
19 | 19 | import itertools |
20 | 20 | import logging |
21 | 21 | import unittest |
22 | | -from collections import defaultdict, Iterable |
| 22 | +from collections import defaultdict |
| 23 | +from collections.abc import Iterable |
23 | 24 | from enum import Enum |
24 | 25 | from functools import partial |
25 | 26 | from typing import Union, Type |
26 | 27 |
|
27 | 28 | import openvino as ov |
28 | 29 | import pytest |
29 | | -import evaluate |
30 | 30 | import numpy as np |
31 | 31 | import torch |
32 | | -from datasets import load_dataset |
33 | 32 | from parameterized import parameterized |
34 | 33 | import nncf |
35 | 34 | from transformers import ( |
36 | 35 | AutoModelForQuestionAnswering, |
37 | | - AutoModelForSequenceClassification, |
38 | 36 | AutoTokenizer, |
39 | 37 | AutoProcessor, |
40 | | - TrainingArguments, |
41 | | - default_data_collator, |
42 | 38 | ) |
43 | 39 | from transformers.testing_utils import slow |
44 | 40 | from transformers.utils.quantization_config import QuantizationMethod |
@@ -116,9 +112,11 @@ class OVQuantizerTest(unittest.TestCase): |
116 | 112 | smooth_quant_alpha=0.95, |
117 | 113 | ), |
118 | 114 | [14, 22, 21] if is_transformers_version("<=", "4.36.0") else [14, 22, 25], |
119 | | - [{"int8": 14}, {"int8": 21}, {"int8": 17}] |
120 | | - if is_transformers_version("<=", "4.36.0") |
121 | | - else [{"int8": 14}, {"int8": 22}, {"int8": 18}], |
| 115 | + ( |
| 116 | + [{"int8": 14}, {"int8": 21}, {"int8": 17}] |
| 117 | + if is_transformers_version("<=", "4.36.0") |
| 118 | + else [{"int8": 14}, {"int8": 22}, {"int8": 18}] |
| 119 | + ), |
122 | 120 | ), |
123 | 121 | ( |
124 | 122 | OVModelForCausalLM, |
@@ -234,6 +232,77 @@ class OVQuantizerTest(unittest.TestCase): |
234 | 232 | {"f8e5m2": 2, "int4": 28}, |
235 | 233 | ], |
236 | 234 | ), |
| 235 | + ( |
| 236 | + OVStableDiffusionPipeline, |
| 237 | + "stable-diffusion", |
| 238 | + dict( |
| 239 | + weight_only=False, |
| 240 | + dataset="conceptual_captions", |
| 241 | + num_samples=1, |
| 242 | + processor=MODEL_NAMES["stable-diffusion"], |
| 243 | + trust_remote_code=True, |
| 244 | + ), |
| 245 | + [ |
| 246 | + 112, |
| 247 | + 0, |
| 248 | + 0, |
| 249 | + 0, |
| 250 | + ], |
| 251 | + [ |
| 252 | + {"int8": 121}, |
| 253 | + {"int8": 42}, |
| 254 | + {"int8": 34}, |
| 255 | + {"int8": 64}, |
| 256 | + ], |
| 257 | + ), |
| 258 | + ( |
| 259 | + OVStableDiffusionXLPipeline, |
| 260 | + "stable-diffusion-xl", |
| 261 | + dict( |
| 262 | + weight_only=False, |
| 263 | + dtype="f8e5m2", |
| 264 | + dataset="laion/220k-GPT4Vision-captions-from-LIVIS", |
| 265 | + num_samples=1, |
| 266 | + processor=MODEL_NAMES["stable-diffusion-xl"], |
| 267 | + trust_remote_code=True, |
| 268 | + ), |
| 269 | + [ |
| 270 | + 174, |
| 271 | + 0, |
| 272 | + 0, |
| 273 | + 0, |
| 274 | + 0, |
| 275 | + ], |
| 276 | + [ |
| 277 | + {"f8e5m2": 183}, |
| 278 | + {"int8": 42}, |
| 279 | + {"int8": 34}, |
| 280 | + {"int8": 64}, |
| 281 | + {"int8": 66}, |
| 282 | + ], |
| 283 | + ), |
| 284 | + ( |
| 285 | + OVLatentConsistencyModelPipeline, |
| 286 | + "latent-consistency", |
| 287 | + OVQuantizationConfig( |
| 288 | + dtype="f8e4m3", |
| 289 | + dataset="laion/filtered-wit", |
| 290 | + num_samples=1, |
| 291 | + trust_remote_code=True, |
| 292 | + ), |
| 293 | + [ |
| 294 | + 79, |
| 295 | + 0, |
| 296 | + 0, |
| 297 | + 0, |
| 298 | + ], |
| 299 | + [ |
| 300 | + {"f8e4m3": 84}, |
| 301 | + {"int8": 42}, |
| 302 | + {"int8": 34}, |
| 303 | + {"int8": 40}, |
| 304 | + ], |
| 305 | + ), |
237 | 306 | ] |
238 | 307 |
|
239 | 308 | @parameterized.expand(SUPPORTED_ARCHITECTURES_TORCH_MODEL) |
@@ -359,6 +428,11 @@ def test_ov_model_static_quantization_with_auto_dataset( |
359 | 428 | tokens = tokenizer("This is a sample input", return_tensors="pt") |
360 | 429 | outputs = ov_model(**tokens) |
361 | 430 | self.assertTrue("logits" in outputs) |
| 431 | + elif any( |
| 432 | + x == model_cls |
| 433 | + for x in (OVStableDiffusionPipeline, OVStableDiffusionXLPipeline, OVLatentConsistencyModelPipeline) |
| 434 | + ): |
| 435 | + submodels = ov_model.ov_submodels.values() |
362 | 436 | else: |
363 | 437 | raise Exception("Unexpected model class.") |
364 | 438 |
|
|
0 commit comments