|  | 
|  | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import inspect | 
|  | 16 | +from dataclasses import dataclass | 
|  | 17 | +from typing import Tuple, Union | 
|  | 18 | + | 
|  | 19 | +import torch | 
|  | 20 | + | 
|  | 21 | +from ..utils import get_logger | 
|  | 22 | +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS | 
|  | 23 | +from .hooks import HookRegistry, ModelHook | 
|  | 24 | +from .utils import _extract_return_information | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +logger = get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" | 
|  | 31 | +_FBC_BLOCK_HOOK = "fbc_block_hook" | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +@dataclass | 
|  | 35 | +class FirstBlockCacheConfig: | 
|  | 36 | +    r""" | 
|  | 37 | +    Configuration for [First Block | 
|  | 38 | +    Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). | 
|  | 39 | +
 | 
|  | 40 | +    Args: | 
|  | 41 | +        threshold (`float`, defaults to `0.05`): | 
|  | 42 | +            The threshold to determine whether or not a forward pass through all layers of the model is required. A | 
|  | 43 | +            higher threshold usually results in lower number of forward passes and faster inference, but might lead to | 
|  | 44 | +            poorer generation quality. A lower threshold may not result in significant generation speedup. The | 
|  | 45 | +            threshold is compared against the absmean difference of the residuals between the current and cached | 
|  | 46 | +            outputs from the first transformer block. If the difference is below the threshold, the forward pass is | 
|  | 47 | +            skipped. | 
|  | 48 | +    """ | 
|  | 49 | + | 
|  | 50 | +    threshold: float = 0.05 | 
|  | 51 | + | 
|  | 52 | + | 
|  | 53 | +class FBCSharedBlockState: | 
|  | 54 | +    def __init__(self) -> None: | 
|  | 55 | +        self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None | 
|  | 56 | +        self.head_block_residual: torch.Tensor = None | 
|  | 57 | +        self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None | 
|  | 58 | +        self.should_compute: bool = True | 
|  | 59 | + | 
|  | 60 | +    def reset(self): | 
|  | 61 | +        self.tail_block_residuals = None | 
|  | 62 | +        self.should_compute = True | 
|  | 63 | + | 
|  | 64 | +    def __repr__(self): | 
|  | 65 | +        return f"FirstBlockCacheSharedState(cache={self.cache})" | 
|  | 66 | + | 
|  | 67 | + | 
|  | 68 | +class FBCHeadBlockHook(ModelHook): | 
|  | 69 | +    _is_stateful = True | 
|  | 70 | + | 
|  | 71 | +    def __init__(self, shared_state: FBCSharedBlockState, threshold: float): | 
|  | 72 | +        self.shared_state = shared_state | 
|  | 73 | +        self.threshold = threshold | 
|  | 74 | + | 
|  | 75 | +    def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | 
|  | 76 | +        inputs = inspect.signature(module.__class__.forward) | 
|  | 77 | +        inputs_index_to_str = dict(enumerate(inputs.parameters.keys())) | 
|  | 78 | +        inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()} | 
|  | 79 | + | 
|  | 80 | +        try: | 
|  | 81 | +            outputs = _extract_return_information(module.__class__.forward) | 
|  | 82 | +            outputs_index_to_str = dict(enumerate(outputs)) | 
|  | 83 | +            outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()} | 
|  | 84 | +        except RuntimeError: | 
|  | 85 | +            logger.error(f"Failed to extract return information for {module.__class__}") | 
|  | 86 | +            raise NotImplementedError( | 
|  | 87 | +                f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at " | 
|  | 88 | +                f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example " | 
|  | 89 | +                f"in order for us to add support for this module." | 
|  | 90 | +            ) | 
|  | 91 | + | 
|  | 92 | +        self._inputs_index_to_str = inputs_index_to_str | 
|  | 93 | +        self._inputs_str_to_index = inputs_str_to_index | 
|  | 94 | +        self._outputs_index_to_str = outputs_index_to_str | 
|  | 95 | +        self._outputs_str_to_index = outputs_str_to_index | 
|  | 96 | +        return module | 
|  | 97 | + | 
|  | 98 | +    def new_forward(self, module: torch.nn.Module, *args, **kwargs): | 
|  | 99 | +        hs_input_idx = self._inputs_str_to_index.get("hidden_states") | 
|  | 100 | +        ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None) | 
|  | 101 | +        original_hs = kwargs.get("hidden_states", None) | 
|  | 102 | +        original_ehs = kwargs.get("encoder_hidden_states", None) | 
|  | 103 | +        original_hs = original_hs if original_hs is not None else args[hs_input_idx] | 
|  | 104 | +        if ehs_input_idx is not None: | 
|  | 105 | +            original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx] | 
|  | 106 | + | 
|  | 107 | +        hs_output_idx = self._outputs_str_to_index.get("hidden_states") | 
|  | 108 | +        ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None) | 
|  | 109 | +        assert (ehs_input_idx is None) == (ehs_output_idx is None) | 
|  | 110 | + | 
|  | 111 | +        output = self.fn_ref.original_forward(*args, **kwargs) | 
|  | 112 | + | 
|  | 113 | +        hs_residual = None | 
|  | 114 | +        if isinstance(output, tuple): | 
|  | 115 | +            hs_residual = output[hs_output_idx] - original_hs | 
|  | 116 | +        else: | 
|  | 117 | +            hs_residual = output - original_hs | 
|  | 118 | + | 
|  | 119 | +        should_compute = self._should_compute_remaining_blocks(hs_residual) | 
|  | 120 | +        self.shared_state.should_compute = should_compute | 
|  | 121 | + | 
|  | 122 | +        hs, ehs = None, None | 
|  | 123 | +        if not should_compute: | 
|  | 124 | +            # Apply caching | 
|  | 125 | +            logger.info("Skipping forward pass through remaining blocks") | 
|  | 126 | +            hs = self.shared_state.tail_block_residuals[0] + output[hs_output_idx] | 
|  | 127 | +            if ehs_output_idx is not None: | 
|  | 128 | +                ehs = self.shared_state.tail_block_residuals[1] + output[ehs_output_idx] | 
|  | 129 | + | 
|  | 130 | +            if isinstance(output, tuple): | 
|  | 131 | +                return_output = [None] * len(output) | 
|  | 132 | +                return_output[hs_output_idx] = hs | 
|  | 133 | +                return_output[ehs_output_idx] = ehs | 
|  | 134 | +                return_output = tuple(return_output) | 
|  | 135 | +            else: | 
|  | 136 | +                return_output = hs | 
|  | 137 | +            return return_output | 
|  | 138 | +        else: | 
|  | 139 | +            logger.info("Computing forward pass through remaining blocks") | 
|  | 140 | +            if isinstance(output, tuple): | 
|  | 141 | +                head_block_output = [None] * len(output) | 
|  | 142 | +                head_block_output[0] = output[hs_output_idx] | 
|  | 143 | +                head_block_output[1] = output[ehs_output_idx] | 
|  | 144 | +            else: | 
|  | 145 | +                head_block_output = output | 
|  | 146 | +            self.shared_state.head_block_output = head_block_output | 
|  | 147 | +            self.shared_state.head_block_residual = hs_residual | 
|  | 148 | +            return output | 
|  | 149 | + | 
|  | 150 | +    def reset_state(self, module): | 
|  | 151 | +        self.shared_state.reset() | 
|  | 152 | +        return module | 
|  | 153 | + | 
|  | 154 | +    def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: | 
|  | 155 | +        if self.shared_state.head_block_residual is None: | 
|  | 156 | +            return True | 
|  | 157 | +        prev_hs_residual = self.shared_state.head_block_residual | 
|  | 158 | +        hs_absmean = (hs_residual - prev_hs_residual).abs().mean() | 
|  | 159 | +        prev_hs_mean = prev_hs_residual.abs().mean() | 
|  | 160 | +        diff = (hs_absmean / prev_hs_mean).item() | 
|  | 161 | +        logger.info(f"Diff: {diff}, Threshold: {self.threshold}") | 
|  | 162 | +        return diff > self.threshold | 
|  | 163 | + | 
|  | 164 | + | 
|  | 165 | +class FBCBlockHook(ModelHook): | 
|  | 166 | +    def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): | 
|  | 167 | +        super().__init__() | 
|  | 168 | +        self.shared_state = shared_state | 
|  | 169 | +        self.is_tail = is_tail | 
|  | 170 | + | 
|  | 171 | +    def initialize_hook(self, module): | 
|  | 172 | +        inputs = inspect.signature(module.__class__.forward) | 
|  | 173 | +        inputs_index_to_str = dict(enumerate(inputs.parameters.keys())) | 
|  | 174 | +        inputs_str_to_index = {v: k for k, v in inputs_index_to_str.items()} | 
|  | 175 | + | 
|  | 176 | +        try: | 
|  | 177 | +            outputs = _extract_return_information(module.__class__.forward) | 
|  | 178 | +            outputs_index_to_str = dict(enumerate(outputs)) | 
|  | 179 | +            outputs_str_to_index = {v: k for k, v in outputs_index_to_str.items()} | 
|  | 180 | +        except RuntimeError: | 
|  | 181 | +            logger.error(f"Failed to extract return information for {module.__class__}") | 
|  | 182 | +            raise NotImplementedError( | 
|  | 183 | +                f"Module {module.__class__} is not supported with FirstBlockCache. Please open an issue at " | 
|  | 184 | +                f"https://github.com/huggingface/diffusers to notify us about the error with a minimal example " | 
|  | 185 | +                f"in order for us to add support for this module." | 
|  | 186 | +            ) | 
|  | 187 | + | 
|  | 188 | +        self._inputs_index_to_str = inputs_index_to_str | 
|  | 189 | +        self._inputs_str_to_index = inputs_str_to_index | 
|  | 190 | +        self._outputs_index_to_str = outputs_index_to_str | 
|  | 191 | +        self._outputs_str_to_index = outputs_str_to_index | 
|  | 192 | +        return module | 
|  | 193 | + | 
|  | 194 | +    def new_forward(self, module: torch.nn.Module, *args, **kwargs): | 
|  | 195 | +        hs_input_idx = self._inputs_str_to_index.get("hidden_states") | 
|  | 196 | +        ehs_input_idx = self._inputs_str_to_index.get("encoder_hidden_states", None) | 
|  | 197 | +        original_hs = kwargs.get("hidden_states", None) | 
|  | 198 | +        original_ehs = kwargs.get("encoder_hidden_states", None) | 
|  | 199 | +        original_hs = original_hs if original_hs is not None else args[hs_input_idx] | 
|  | 200 | +        if ehs_input_idx is not None: | 
|  | 201 | +            original_ehs = original_ehs if original_ehs is not None else args[ehs_input_idx] | 
|  | 202 | + | 
|  | 203 | +        hs_output_idx = self._outputs_str_to_index.get("hidden_states") | 
|  | 204 | +        ehs_output_idx = self._outputs_str_to_index.get("encoder_hidden_states", None) | 
|  | 205 | +        assert (ehs_input_idx is None) == (ehs_output_idx is None) | 
|  | 206 | + | 
|  | 207 | +        if self.shared_state.should_compute: | 
|  | 208 | +            output = self.fn_ref.original_forward(*args, **kwargs) | 
|  | 209 | +            if self.is_tail: | 
|  | 210 | +                hs_residual, ehs_residual = None, None | 
|  | 211 | +                if isinstance(output, tuple): | 
|  | 212 | +                    hs_residual = output[hs_output_idx] - self.shared_state.head_block_output[0] | 
|  | 213 | +                    ehs_residual = output[ehs_output_idx] - self.shared_state.head_block_output[1] | 
|  | 214 | +                else: | 
|  | 215 | +                    hs_residual = output - self.shared_state.head_block_output | 
|  | 216 | +                self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) | 
|  | 217 | +            return output | 
|  | 218 | + | 
|  | 219 | +        output_count = len(self._outputs_index_to_str.keys()) | 
|  | 220 | +        return_output = [None] * output_count if output_count > 1 else original_hs | 
|  | 221 | +        if output_count == 1: | 
|  | 222 | +            return_output = original_hs | 
|  | 223 | +        else: | 
|  | 224 | +            return_output[hs_output_idx] = original_hs | 
|  | 225 | +            return_output[ehs_output_idx] = original_ehs | 
|  | 226 | +        return return_output | 
|  | 227 | + | 
|  | 228 | + | 
|  | 229 | +def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: | 
|  | 230 | +    shared_state = FBCSharedBlockState() | 
|  | 231 | +    remaining_blocks = [] | 
|  | 232 | + | 
|  | 233 | +    for name, submodule in module.named_children(): | 
|  | 234 | +        if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): | 
|  | 235 | +            continue | 
|  | 236 | +        for block in submodule: | 
|  | 237 | +            remaining_blocks.append((name, block)) | 
|  | 238 | + | 
|  | 239 | +    head_block_name, head_block = remaining_blocks.pop(0) | 
|  | 240 | +    tail_block_name, tail_block = remaining_blocks.pop(-1) | 
|  | 241 | + | 
|  | 242 | +    logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'") | 
|  | 243 | +    apply_fbc_head_block_hook(head_block, shared_state, config.threshold) | 
|  | 244 | + | 
|  | 245 | +    for name, block in remaining_blocks: | 
|  | 246 | +        logger.debug(f"Apply FBCBlockHook to '{name}'") | 
|  | 247 | +        apply_fbc_block_hook(block, shared_state) | 
|  | 248 | + | 
|  | 249 | +    logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'") | 
|  | 250 | +    apply_fbc_block_hook(tail_block, shared_state, is_tail=True) | 
|  | 251 | + | 
|  | 252 | + | 
|  | 253 | +def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: | 
|  | 254 | +    registry = HookRegistry.check_if_exists_or_initialize(block) | 
|  | 255 | +    hook = FBCHeadBlockHook(state, threshold) | 
|  | 256 | +    registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) | 
|  | 257 | + | 
|  | 258 | + | 
|  | 259 | +def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: | 
|  | 260 | +    registry = HookRegistry.check_if_exists_or_initialize(block) | 
|  | 261 | +    hook = FBCBlockHook(state, is_tail) | 
|  | 262 | +    registry.register_hook(hook, _FBC_BLOCK_HOOK) | 
0 commit comments