|
22 | 22 |
|
23 | 23 | import copy |
24 | 24 | import importlib.metadata |
| 25 | +import inspect |
25 | 26 | import json |
26 | 27 | import os |
27 | 28 | from dataclasses import dataclass |
28 | 29 | from enum import Enum |
29 | | -from typing import Any, Dict, Union |
| 30 | +from functools import partial |
| 31 | +from typing import Any, Dict, List, Optional, Union |
30 | 32 |
|
31 | 33 | from packaging import version |
32 | 34 |
|
33 | | -from ..utils import is_torch_available, logging |
| 35 | +from ..utils import is_torch_available, is_torchao_available, logging |
34 | 36 |
|
35 | 37 |
|
36 | 38 | if is_torch_available(): |
|
41 | 43 |
|
42 | 44 | class QuantizationMethod(str, Enum): |
43 | 45 | BITS_AND_BYTES = "bitsandbytes" |
| 46 | + TORCHAO = "torchao" |
44 | 47 |
|
45 | 48 |
|
46 | 49 | @dataclass |
@@ -389,3 +392,241 @@ def to_diff_dict(self) -> Dict[str, Any]: |
389 | 392 | serializable_config_dict[key] = value |
390 | 393 |
|
391 | 394 | return serializable_config_dict |
| 395 | + |
| 396 | + |
| 397 | +@dataclass |
| 398 | +class TorchAoConfig(QuantizationConfigMixin): |
| 399 | + """This is a config class for torchao quantization/sparsity techniques. |
| 400 | +
|
| 401 | + Args: |
| 402 | + quant_type (`str`): |
| 403 | + The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`. |
| 404 | + modules_to_not_convert (`list`, *optional*, default to `None`): |
| 405 | + The list of modules to not quantize, useful for quantizing models that explicitly require to have |
| 406 | + some modules left in their original precision. |
| 407 | + kwargs (`Dict[str, Any]`, *optional*): |
| 408 | + The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments |
| 409 | + `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in |
| 410 | + https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques |
| 411 | +
|
| 412 | + Example: |
| 413 | +
|
| 414 | + ```python |
| 415 | + TODO(aryan): update |
| 416 | + quantization_config = TorchAoConfig("int4_weight_only", group_size=32) |
| 417 | + # int4_weight_only quant is only working with *torch.bfloat16* dtype right now |
| 418 | + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) |
| 419 | + ``` |
| 420 | + """ |
| 421 | + |
| 422 | + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs): |
| 423 | + self.quant_method = QuantizationMethod.TORCHAO |
| 424 | + self.quant_type = quant_type |
| 425 | + self.modules_to_not_convert = modules_to_not_convert |
| 426 | + |
| 427 | + # When we load from serialized config, "quant_type_kwargs" will be the key |
| 428 | + if "quant_type_kwargs" in kwargs: |
| 429 | + self.quant_type_kwargs = kwargs["quant_type_kwargs"] |
| 430 | + else: |
| 431 | + self.quant_type_kwargs = kwargs |
| 432 | + |
| 433 | + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() |
| 434 | + if self.quant_type not in _STR_TO_METHOD.keys(): |
| 435 | + raise ValueError( |
| 436 | + f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the " |
| 437 | + f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." |
| 438 | + ) |
| 439 | + |
| 440 | + method = _STR_TO_METHOD[self.quant_type] |
| 441 | + signature = inspect.signature(method) |
| 442 | + all_kwargs = { |
| 443 | + param.name |
| 444 | + for param in signature.parameters.values() |
| 445 | + if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD] |
| 446 | + } |
| 447 | + unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs) |
| 448 | + |
| 449 | + if len(unsupported_kwargs) > 0: |
| 450 | + raise ValueError( |
| 451 | + f"The quantization method \"{method}\" does not supported the following keyword arguments: " |
| 452 | + f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}." |
| 453 | + ) |
| 454 | + |
| 455 | + @classmethod |
| 456 | + def _get_torchao_quant_type_to_method(cls): |
| 457 | + r""" |
| 458 | + Returns supported torchao quantization types with all commonly used notations. |
| 459 | + """ |
| 460 | + |
| 461 | + if is_torchao_available(): |
| 462 | + from torchao.quantization import ( |
| 463 | + int4_weight_only, |
| 464 | + int8_dynamic_activation_int8_weight, |
| 465 | + int8_dynamic_activation_int4_weight, |
| 466 | + int8_weight_only, |
| 467 | + float8_dynamic_activation_float8_weight, |
| 468 | + float8_static_activation_float8_weight, |
| 469 | + float8_weight_only, |
| 470 | + fpx_weight_only, |
| 471 | + uintx_weight_only, |
| 472 | + ) |
| 473 | + # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers |
| 474 | + from torchao.quantization.observer import PerRow, PerTensor |
| 475 | + |
| 476 | + # TODO(aryan): Support autoquant and sparsify |
| 477 | + |
| 478 | + INT4_QUANTIZATION_TYPES = { |
| 479 | + # int4 weight + bfloat16/float16 activation |
| 480 | + "int4": int4_weight_only, |
| 481 | + "int4wo": int4_weight_only, |
| 482 | + "int4_weight_only": int4_weight_only, |
| 483 | + "int4_a16w4": int4_weight_only, |
| 484 | + # int4 weight + int8 activation |
| 485 | + "int4dq": int8_dynamic_activation_int4_weight, |
| 486 | + "int8_dynamic_activation_int4_weight": int8_dynamic_activation_int4_weight, |
| 487 | + "int4_a8w4": int8_dynamic_activation_int4_weight, |
| 488 | + } |
| 489 | + |
| 490 | + INT8_QUANTIZATION_TYPES = { |
| 491 | + # int8 weight + bfloat16/float16 activation |
| 492 | + "int8": int8_weight_only, |
| 493 | + "int8wo": int8_weight_only, |
| 494 | + "int8_weight_only": int8_weight_only, |
| 495 | + "int8_a16w8": int8_weight_only, |
| 496 | + # int8 weight + int8 activation |
| 497 | + "int8dq": int8_dynamic_activation_int8_weight, |
| 498 | + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, |
| 499 | + "int8_a8w8": int8_dynamic_activation_int8_weight, |
| 500 | + } |
| 501 | + |
| 502 | + def generate_float8dq_types(dtype: torch.dtype): |
| 503 | + name = "e5m2" if dtype == torch.float8_e5m2 else "e4m3" |
| 504 | + types = {} |
| 505 | + |
| 506 | + types[f"float8dq_{name}_a8w8"] = partial(float8_dynamic_activation_float8_weight, activation_dtype=dtype, weight_dtype=dtype) |
| 507 | + for activation_granularity_cls in [PerTensor, PerRow]: |
| 508 | + for weight_granularity_cls in [PerTensor, PerRow]: |
| 509 | + activation_name = "t" if activation_granularity_cls is PerTensor else "r" |
| 510 | + weight_name = "t" if weight_granularity_cls is PerTensor else "r" |
| 511 | + # The a{activation_name}w{weight_name} is a made up name for convenience of testing things. |
| 512 | + # It suffixes with for different granularities (activation granularity, weight granularity): |
| 513 | + # - atwt: PerTensor(), PerTensor() |
| 514 | + # - atwr: PerTensor(), PerRow() |
| 515 | + # - arwt: PerRow(), PerTensor() |
| 516 | + # - arwr: PerRow(), PerRow() |
| 517 | + types[f"float8dq_{name}_a{activation_name}w{weight_name}"] = partial( |
| 518 | + float8_dynamic_activation_float8_weight, |
| 519 | + activation_dtype=dtype, |
| 520 | + weight_dtype=dtype, |
| 521 | + granularity=(activation_granularity_cls(), weight_granularity_cls()), |
| 522 | + ) |
| 523 | + types[f"float8dq_{name}_a{activation_name}w{weight_name}_a8w8"] = partial( |
| 524 | + float8_dynamic_activation_float8_weight, |
| 525 | + activation_dtype=dtype, |
| 526 | + weight_dtype=dtype, |
| 527 | + granularity=(activation_granularity_cls(), weight_granularity_cls()), |
| 528 | + ) |
| 529 | + |
| 530 | + return types |
| 531 | + |
| 532 | + def generate_fpx_quantization_types(bits: int): |
| 533 | + types = {} |
| 534 | + |
| 535 | + for ebits in range(1, bits): |
| 536 | + mbits = bits - ebits - 1 |
| 537 | + types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) |
| 538 | + types[f"fp{bits}_e{ebits}m{mbits}_a16w{bits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) |
| 539 | + |
| 540 | + non_sign_bits = bits - 1 |
| 541 | + default_ebits = (non_sign_bits + 1) // 2 |
| 542 | + default_mbits = non_sign_bits - default_ebits |
| 543 | + types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) |
| 544 | + |
| 545 | + return types |
| 546 | + |
| 547 | + # TODO(aryan): handle cuda capability and torch 2.2/2.3 |
| 548 | + FLOATX_QUANTIZATION_TYPES = { |
| 549 | + # float8_e5m2 weight + bfloat16/float16 activation |
| 550 | + "float8": float8_weight_only, |
| 551 | + "float8_weight_only": float8_weight_only, |
| 552 | + "float8wo": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), |
| 553 | + "float8_a16w8": float8_weight_only, |
| 554 | + "float8_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), |
| 555 | + "float8wo_e5m2": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), |
| 556 | + "float8_e5m2_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e5m2), |
| 557 | + # float8_e4m3 weight + bfloat16/float16 activation |
| 558 | + "float8_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), |
| 559 | + "float8wo_e4m3": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), |
| 560 | + "float8wo_e4m3_a16w8": partial(float8_weight_only, weight_dtype=torch.float8_e4m3fn), |
| 561 | + # float8_e5m2 weight + float8 activation (dynamic) |
| 562 | + "float8_dynamic_activation_float8_weight": float8_dynamic_activation_float8_weight, |
| 563 | + "float8dq": float8_dynamic_activation_float8_weight, |
| 564 | + "float8dq_e5m2": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e5m2, weight_dtype=torch.float8_e5m2), |
| 565 | + "float8_a8w8": float8_dynamic_activation_float8_weight, |
| 566 | + **generate_float8dq_types(torch.float8_e5m2), |
| 567 | + # float8_e4m3 weight + float8 activation (dynamic) |
| 568 | + "float8dq_e4m3": partial(float8_dynamic_activation_float8_weight, activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn), |
| 569 | + **generate_float8dq_types(torch.float8_e4m3fn), |
| 570 | + # float8 weight + float8 activation (static) |
| 571 | + "float8_static_activation_float8_weight": float8_static_activation_float8_weight, |
| 572 | + "float8sq": float8_static_activation_float8_weight, |
| 573 | + # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly |
| 574 | + # fpx weight + bfloat16/float16 activation |
| 575 | + **generate_fpx_quantization_types(3), |
| 576 | + **generate_fpx_quantization_types(4), |
| 577 | + **generate_fpx_quantization_types(5), |
| 578 | + **generate_fpx_quantization_types(6), |
| 579 | + **generate_fpx_quantization_types(7), |
| 580 | + **generate_fpx_quantization_types(8), |
| 581 | + } |
| 582 | + |
| 583 | + UINTX_TO_DTYPE = { |
| 584 | + 1: torch.uint1, |
| 585 | + 2: torch.uint2, |
| 586 | + 3: torch.uint3, |
| 587 | + 4: torch.uint4, |
| 588 | + 5: torch.uint5, |
| 589 | + 6: torch.uint6, |
| 590 | + 7: torch.uint7, |
| 591 | + 8: torch.uint8, |
| 592 | + } |
| 593 | + |
| 594 | + def generate_uintx_quantization_types(bits: int): |
| 595 | + types = {} |
| 596 | + types[f"uint{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) |
| 597 | + types[f"uint{bits}wo"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) |
| 598 | + types[f"uint{bits}_a16w{bits}"] = partial(uintx_weight_only, dtype=UINTX_TO_DTYPE[bits]) |
| 599 | + return types |
| 600 | + |
| 601 | + UINTX_QUANTIZATION_DTYPES = { |
| 602 | + "uintx": uintx_weight_only, |
| 603 | + "uintx_weight_only": uintx_weight_only, |
| 604 | + **generate_uintx_quantization_types(1), |
| 605 | + **generate_uintx_quantization_types(2), |
| 606 | + **generate_uintx_quantization_types(3), |
| 607 | + **generate_uintx_quantization_types(4), |
| 608 | + **generate_uintx_quantization_types(5), |
| 609 | + **generate_uintx_quantization_types(6), |
| 610 | + **generate_uintx_quantization_types(7), |
| 611 | + **generate_uintx_quantization_types(8), |
| 612 | + } |
| 613 | + |
| 614 | + QUANTIZATION_TYPES = {} |
| 615 | + QUANTIZATION_TYPES.update(INT4_QUANTIZATION_TYPES) |
| 616 | + QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) |
| 617 | + QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) |
| 618 | + QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) |
| 619 | + |
| 620 | + return QUANTIZATION_TYPES |
| 621 | + else: |
| 622 | + raise ValueError( |
| 623 | + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" |
| 624 | + ) |
| 625 | + |
| 626 | + def get_apply_tensor_subclass(self): |
| 627 | + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() |
| 628 | + return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) |
| 629 | + |
| 630 | + def __repr__(self): |
| 631 | + config_dict = self.to_dict() |
| 632 | + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" |
0 commit comments