|  | 
| 17 | 17 | 
 | 
| 18 | 18 | import torch | 
| 19 | 19 | 
 | 
| 20 |  | -from ..utils import get_logger | 
|  | 20 | +from ..utils import get_logger, is_peft_available | 
| 21 | 21 | from .hooks import HookRegistry, ModelHook | 
| 22 | 22 | 
 | 
| 23 | 23 | 
 | 
|  | 24 | +if is_peft_available(): | 
|  | 25 | +    from peft.helpers import disable_lora_input_dtype_casting | 
|  | 26 | +    from peft.tuners.loha.layer import LoHaLayer | 
|  | 27 | +    from peft.tuners.lokr.layer import LoKrLayer | 
|  | 28 | +    from peft.tuners.lora.layer import LoraLayer | 
|  | 29 | +    from peft.tuners.tuners_utils import BaseTunerLayer | 
|  | 30 | + | 
|  | 31 | +    PEFT_ADAPTER_LAYERS_NAMES = () | 
|  | 32 | +    for layer_cls in (LoHaLayer, LoKrLayer, LoraLayer): | 
|  | 33 | +        PEFT_ADAPTER_LAYERS_NAMES += tuple(layer_cls.adapter_layer_names) | 
|  | 34 | + | 
|  | 35 | + | 
| 24 | 36 | logger = get_logger(__name__)  # pylint: disable=invalid-name | 
| 25 | 37 | 
 | 
| 26 | 38 | 
 | 
| 27 | 39 | # fmt: off | 
|  | 40 | +_LAYERWISE_CASTING_HOOK = "layerwise_casting" | 
|  | 41 | +_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable" | 
| 28 | 42 | SUPPORTED_PYTORCH_LAYERS = ( | 
| 29 | 43 |     torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, | 
| 30 | 44 |     torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, | 
| @@ -70,6 +84,32 @@ def post_forward(self, module: torch.nn.Module, output): | 
| 70 | 84 |         return output | 
| 71 | 85 | 
 | 
| 72 | 86 | 
 | 
|  | 87 | +class PeftInputAutocastDisableHook(ModelHook): | 
|  | 88 | +    r""" | 
|  | 89 | +    A hook that disables the casting of inputs to the module weight dtype for the forward pass. By default, PEFT casts | 
|  | 90 | +    the inputs to the weight dtype of the module, which can lead to precision loss. | 
|  | 91 | +
 | 
|  | 92 | +    The reasons for needing this are: | 
|  | 93 | +        - If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the | 
|  | 94 | +          inputs will be casted to the, possibly lower precision, storage dtype. Reference: | 
|  | 95 | +          https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706 | 
|  | 96 | +        - We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure | 
|  | 97 | +          that the inputs are casted to the computation dtype correctly always. However, there are two goals we are | 
|  | 98 | +          hoping to achieve: | 
|  | 99 | +            1. Making forward implementations independent of device/dtype casting operations as much as possible. | 
|  | 100 | +            2. Peforming inference without losing information from casting to different precisions. With the current | 
|  | 101 | +               PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference | 
|  | 102 | +               with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are casted to | 
|  | 103 | +               torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the | 
|  | 104 | +               forward pass with PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from | 
|  | 105 | +               LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality. | 
|  | 106 | +    """ | 
|  | 107 | + | 
|  | 108 | +    def new_forward(self, module: torch.nn.Module, *args, **kwargs): | 
|  | 109 | +        with disable_lora_input_dtype_casting(module, disable=True): | 
|  | 110 | +            return self.fn_ref.original_forward(*args, **kwargs) | 
|  | 111 | + | 
|  | 112 | + | 
| 73 | 113 | def apply_layerwise_casting( | 
| 74 | 114 |     module: torch.nn.Module, | 
| 75 | 115 |     storage_dtype: torch.dtype, | 
| @@ -134,6 +174,7 @@ def apply_layerwise_casting( | 
| 134 | 174 |         skip_modules_classes, | 
| 135 | 175 |         non_blocking, | 
| 136 | 176 |     ) | 
|  | 177 | +    _disable_peft_input_autocast(module) | 
| 137 | 178 | 
 | 
| 138 | 179 | 
 | 
| 139 | 180 | def _apply_layerwise_casting( | 
| @@ -188,4 +229,13 @@ def apply_layerwise_casting_hook( | 
| 188 | 229 |     """ | 
| 189 | 230 |     registry = HookRegistry.check_if_exists_or_initialize(module) | 
| 190 | 231 |     hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking) | 
| 191 |  | -    registry.register_hook(hook, "layerwise_casting") | 
|  | 232 | +    registry.register_hook(hook, _LAYERWISE_CASTING_HOOK) | 
|  | 233 | + | 
|  | 234 | + | 
|  | 235 | +def _disable_peft_input_autocast(module: torch.nn.Module) -> None: | 
|  | 236 | +    for submodule in module.modules(): | 
|  | 237 | +        if not isinstance(submodule, BaseTunerLayer): | 
|  | 238 | +            continue | 
|  | 239 | +        registry = HookRegistry.check_if_exists_or_initialize(submodule) | 
|  | 240 | +        hook = PeftInputAutocastDisableHook() | 
|  | 241 | +        registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK) | 
0 commit comments