|  | 
|  | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +from typing import Optional | 
|  | 16 | + | 
|  | 17 | +import torch | 
|  | 18 | + | 
|  | 19 | +from ..utils import logging | 
|  | 20 | +from .activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU | 
|  | 21 | +from .attention import FeedForward | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +logger = logging.get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 25 | + | 
|  | 26 | + | 
|  | 27 | +class _MemoryOptimizedFeedForward(torch.nn.Module): | 
|  | 28 | +    r""" | 
|  | 29 | +    See [`~models.attention.FeedForward`] parameter documentation. This class is a copy of the FeedForward class. The | 
|  | 30 | +    only difference is that this module is optimized for memory. | 
|  | 31 | +
 | 
|  | 32 | +    This method achieves memory savings by applying the ideas of tensor-parallelism sequentially. Input projection | 
|  | 33 | +    layers are split column-wise and output projection layers are split row-wise. This allows for the computation of | 
|  | 34 | +    the feedforward pass to occur without ever materializing the full intermediate tensor. Typically, the intermediate | 
|  | 35 | +    tensor takes 4x-8x more memory than the input tensor. This method reduces that with a small performance tradeoff. | 
|  | 36 | +    """ | 
|  | 37 | + | 
|  | 38 | +    def __init__( | 
|  | 39 | +        self, | 
|  | 40 | +        dim: int, | 
|  | 41 | +        dim_out: Optional[int] = None, | 
|  | 42 | +        mult: int = 4, | 
|  | 43 | +        dropout: float = 0.0, | 
|  | 44 | +        activation_fn: str = "geglu", | 
|  | 45 | +        final_dropout: bool = False, | 
|  | 46 | +        inner_dim: Optional[int] = None, | 
|  | 47 | +        bias: bool = True, | 
|  | 48 | +        num_splits: int = 4, | 
|  | 49 | +    ) -> None: | 
|  | 50 | +        super().__init__() | 
|  | 51 | + | 
|  | 52 | +        if inner_dim is None: | 
|  | 53 | +            inner_dim = int(dim * mult) | 
|  | 54 | + | 
|  | 55 | +        dim_out = dim_out if dim_out is not None else dim | 
|  | 56 | + | 
|  | 57 | +        dim_split = inner_dim // num_splits | 
|  | 58 | +        if inner_dim % dim_split != 0: | 
|  | 59 | +            raise ValueError(f"inner_dim must be divisible by {mult=}, or {num_splits=} if provided.") | 
|  | 60 | + | 
|  | 61 | +        self._dim = dim | 
|  | 62 | +        self._dim_out = dim_out | 
|  | 63 | +        self._mult = mult | 
|  | 64 | +        self._dropout = dropout | 
|  | 65 | +        self._activation_fn = activation_fn | 
|  | 66 | +        self._final_dropout = final_dropout | 
|  | 67 | +        self._inner_dim = inner_dim | 
|  | 68 | +        self._bias = bias | 
|  | 69 | +        self._num_splits = num_splits | 
|  | 70 | + | 
|  | 71 | +        def get_activation_fn(dim_: int, inner_dim_: int): | 
|  | 72 | +            if activation_fn == "gelu": | 
|  | 73 | +                act_fn = GELU(dim_, inner_dim_, bias=bias) | 
|  | 74 | +            if activation_fn == "gelu-approximate": | 
|  | 75 | +                act_fn = GELU(dim_, inner_dim_, approximate="tanh", bias=bias) | 
|  | 76 | +            elif activation_fn == "geglu": | 
|  | 77 | +                act_fn = GEGLU(dim_, inner_dim_, bias=bias) | 
|  | 78 | +            elif activation_fn == "geglu-approximate": | 
|  | 79 | +                act_fn = ApproximateGELU(dim_, inner_dim_, bias=bias) | 
|  | 80 | +            elif activation_fn == "swiglu": | 
|  | 81 | +                act_fn = SwiGLU(dim_, inner_dim_, bias=bias) | 
|  | 82 | +            elif activation_fn == "linear-silu": | 
|  | 83 | +                act_fn = LinearActivation(dim_, inner_dim_, bias=bias, activation="silu") | 
|  | 84 | +            return act_fn | 
|  | 85 | + | 
|  | 86 | +        # Split column-wise | 
|  | 87 | +        self.proj_in = torch.nn.ModuleList([get_activation_fn(dim, dim_split) for _ in range(inner_dim // dim_split)]) | 
|  | 88 | + | 
|  | 89 | +        self.dropout = torch.nn.Dropout(dropout) | 
|  | 90 | + | 
|  | 91 | +        # Split row-wise | 
|  | 92 | +        self.proj_out = torch.nn.ModuleList( | 
|  | 93 | +            [torch.nn.Linear(dim_split, dim_out, bias=False) for _ in range(inner_dim // dim_split)] | 
|  | 94 | +        ) | 
|  | 95 | + | 
|  | 96 | +        self.bias = None | 
|  | 97 | +        if bias: | 
|  | 98 | +            self.bias = torch.nn.Parameter(torch.zeros(dim_out)) | 
|  | 99 | + | 
|  | 100 | +        self.final_dropout = None | 
|  | 101 | +        if final_dropout: | 
|  | 102 | +            self.final_dropout = torch.nn.Dropout(dropout) | 
|  | 103 | + | 
|  | 104 | +    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | 
|  | 105 | +        # Output tensor for "all_reduce" operation | 
|  | 106 | +        output = hidden_states.new_zeros(hidden_states.shape) | 
|  | 107 | + | 
|  | 108 | +        # Apply feedforward pass sequentially since this is intended for memory optimization on a single GPU | 
|  | 109 | +        for proj_in, proj_out in zip(self.proj_in, self.proj_out): | 
|  | 110 | +            out = proj_in(hidden_states) | 
|  | 111 | +            out = self.dropout(out) | 
|  | 112 | +            out = proj_out(out) | 
|  | 113 | +            # Perform "all_reduce" | 
|  | 114 | +            output += out | 
|  | 115 | + | 
|  | 116 | +        if self.bias is not None: | 
|  | 117 | +            output += self.bias | 
|  | 118 | +        if self.final_dropout is not None: | 
|  | 119 | +            output = self.final_dropout(output) | 
|  | 120 | + | 
|  | 121 | +        return output | 
|  | 122 | + | 
|  | 123 | + | 
|  | 124 | +def apply_memory_optimized_feedforward(module: torch.nn.Module, num_splits: Optional[int] = None) -> torch.nn.Module: | 
|  | 125 | +    module_dict = dict(module.named_modules()) | 
|  | 126 | + | 
|  | 127 | +    for name, submodule in module_dict.items(): | 
|  | 128 | +        if not isinstance(submodule, FeedForward): | 
|  | 129 | +            continue | 
|  | 130 | + | 
|  | 131 | +        logger.debug(f"Applying memory optimized feedforward to layer '{name}'") | 
|  | 132 | +        state_dict = submodule.state_dict() | 
|  | 133 | +        num_splits = submodule._mult if num_splits is None else num_splits | 
|  | 134 | + | 
|  | 135 | +        # remap net.0.proj.weight | 
|  | 136 | +        net_0_proj = state_dict.pop("net.0.proj.weight") | 
|  | 137 | +        net_0_proj = net_0_proj.chunk(num_splits, dim=0) | 
|  | 138 | +        for i in range(num_splits): | 
|  | 139 | +            state_dict[f"proj_in.{i}.proj.weight"] = net_0_proj[i] | 
|  | 140 | + | 
|  | 141 | +        # remap net.0.proj.bias | 
|  | 142 | +        if "net.0.proj.bias" in state_dict: | 
|  | 143 | +            net_0_proj_bias = state_dict.pop("net.0.proj.bias") | 
|  | 144 | +            net_0_proj_bias = net_0_proj_bias.chunk(num_splits, dim=0) | 
|  | 145 | +            for i in range(num_splits): | 
|  | 146 | +                state_dict[f"proj_in.{i}.proj.bias"] = net_0_proj_bias[i] | 
|  | 147 | + | 
|  | 148 | +        # remap net.2.weight | 
|  | 149 | +        net_2_weight = state_dict.pop("net.2.weight") | 
|  | 150 | +        net_2_weight = net_2_weight.chunk(num_splits, dim=1) | 
|  | 151 | +        for i in range(num_splits): | 
|  | 152 | +            state_dict[f"proj_out.{i}.weight"] = net_2_weight[i] | 
|  | 153 | + | 
|  | 154 | +        # remap net.2.bias | 
|  | 155 | +        if "net.2.bias" in state_dict: | 
|  | 156 | +            net_2_bias = state_dict.pop("net.2.bias") | 
|  | 157 | +            state_dict["bias"] = net_2_bias | 
|  | 158 | + | 
|  | 159 | +        with torch.device("meta"): | 
|  | 160 | +            new_ff = _MemoryOptimizedFeedForward( | 
|  | 161 | +                dim=submodule._dim, | 
|  | 162 | +                dim_out=submodule._dim_out, | 
|  | 163 | +                mult=submodule._mult, | 
|  | 164 | +                dropout=submodule._dropout, | 
|  | 165 | +                activation_fn=submodule._activation_fn, | 
|  | 166 | +                final_dropout=submodule._final_dropout, | 
|  | 167 | +                inner_dim=submodule._inner_dim, | 
|  | 168 | +                bias=submodule._bias, | 
|  | 169 | +                num_splits=num_splits, | 
|  | 170 | +            ) | 
|  | 171 | + | 
|  | 172 | +        new_ff.load_state_dict(state_dict, strict=True, assign=True) | 
|  | 173 | + | 
|  | 174 | +        parent_module_name, _, submodule_name = name.rpartition(".") | 
|  | 175 | +        parent_module = module_dict[parent_module_name] | 
|  | 176 | +        setattr(parent_module, submodule_name, new_ff) | 
|  | 177 | + | 
|  | 178 | +    return module | 
0 commit comments