|
| 1 | +# Copyright The FMS Model Optimizer Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" |
| 15 | + |
| 16 | +# Standard |
| 17 | +from importlib.util import find_spec |
| 18 | +from typing import NotRequired, Unpack |
| 19 | +import math |
| 20 | + |
| 21 | +# Third Party |
| 22 | +from fms.modules.attention import ( |
| 23 | + AttentionKwargs, |
| 24 | + _sdpa_update_attn_kwargs, |
| 25 | + register_attention_op, |
| 26 | +) |
| 27 | +from torch import Tensor |
| 28 | +import torch |
| 29 | + |
| 30 | +# Local |
| 31 | +import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import |
| 32 | + |
| 33 | +if find_spec("torchao"): |
| 34 | + TORCHAO_INSTALLED = True |
| 35 | + # Third Party |
| 36 | + from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor |
| 37 | + from torchao.dtypes.floatx.float8_layout import ( |
| 38 | + Float8AQTTensorImpl, |
| 39 | + Float8Layout, |
| 40 | + Float8MMConfig, |
| 41 | + ) |
| 42 | + from torchao.quantization.granularity import PerTensor |
| 43 | + from torchao.quantization.observer import get_block_size |
| 44 | + from torchao.quantization.quant_primitives import ZeroPointDomain |
| 45 | +else: |
| 46 | + TORCHAO_INSTALLED = False |
| 47 | + |
| 48 | + |
| 49 | +class MathFP8AttentionKwargs(AttentionKwargs): |
| 50 | + """TypedDict for FP8 attention.""" |
| 51 | + |
| 52 | + mask: NotRequired[Tensor] |
| 53 | + do_scale_q: bool |
| 54 | + is_causal_mask: bool |
| 55 | + |
| 56 | + |
| 57 | +# TODO: Doesn't quite work yet, more discussion needed |
| 58 | +Q_RANGE = 200.0 |
| 59 | +K_RANGE = 200.0 |
| 60 | +V_RANGE = 100.0 |
| 61 | + |
| 62 | + |
| 63 | +def _construct_fp8_cache( |
| 64 | + tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype |
| 65 | +) -> AffineQuantizedTensor: |
| 66 | + """Construct the torchao tensor to save kv cache with its scales.""" |
| 67 | + |
| 68 | + weight_granularity = PerTensor() |
| 69 | + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) |
| 70 | + return AffineQuantizedTensor( |
| 71 | + Float8AQTTensorImpl.from_plain( |
| 72 | + tensor, |
| 73 | + scale, |
| 74 | + None, |
| 75 | + fp8_layout, |
| 76 | + ), |
| 77 | + get_block_size(tensor.shape, weight_granularity), |
| 78 | + tensor.shape, |
| 79 | + zero_point_domain=ZeroPointDomain.NONE, |
| 80 | + dtype=orig_dtype, |
| 81 | + ) |
| 82 | + |
| 83 | + |
| 84 | +def _math_fp8_store_op( |
| 85 | + keys: Tensor, # pylint: disable=unused-argument |
| 86 | + values: Tensor, |
| 87 | + key_cache: Tensor | None, |
| 88 | + value_cache: Tensor | None, |
| 89 | + **attn_kwargs: Unpack[MathFP8AttentionKwargs], |
| 90 | +) -> tuple[Tensor, Tensor, Tensor, Tensor]: |
| 91 | + """Implement math of KV cache storing.""" |
| 92 | + |
| 93 | + orig_dtype = keys.dtype |
| 94 | + |
| 95 | + if isinstance(key_cache, AffineQuantizedTensor) and isinstance( |
| 96 | + value_cache, AffineQuantizedTensor |
| 97 | + ): |
| 98 | + k_scale = key_cache.tensor_impl.scale |
| 99 | + v_scale = value_cache.tensor_impl.scale |
| 100 | + else: |
| 101 | + k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) |
| 102 | + v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) |
| 103 | + |
| 104 | + keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) |
| 105 | + values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) |
| 106 | + |
| 107 | + if ( |
| 108 | + isinstance(key_cache, AffineQuantizedTensor) |
| 109 | + and isinstance(value_cache, AffineQuantizedTensor) |
| 110 | + and value_cache.numel() > 0 |
| 111 | + ): |
| 112 | + key_cache = torch.cat((key_cache.tensor_impl.float8_data, keys), dim=2) |
| 113 | + value_cache = torch.cat((value_cache.tensor_impl.float8_data, values), dim=2) |
| 114 | + key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype) |
| 115 | + value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype) |
| 116 | + return ( |
| 117 | + key_cache, |
| 118 | + value_cache, |
| 119 | + key_cache, |
| 120 | + value_cache, |
| 121 | + ) |
| 122 | + |
| 123 | + keys = _construct_fp8_cache(keys, k_scale, orig_dtype) |
| 124 | + values = _construct_fp8_cache(values, v_scale, orig_dtype) |
| 125 | + return (keys, values, keys, values) |
| 126 | + |
| 127 | + |
| 128 | +def _math_fp8_compute_op( |
| 129 | + query: Tensor, |
| 130 | + key_cache: Tensor, |
| 131 | + value_cache: Tensor, |
| 132 | + nheads: int, |
| 133 | + kvheads: int, |
| 134 | + p_dropout: float, |
| 135 | + scale_factor: float | None, |
| 136 | + **attn_kwargs: Unpack[MathFP8AttentionKwargs], |
| 137 | +) -> Tensor: |
| 138 | + """Implement computation of attention BMM, leveraging the custom scaled attention |
| 139 | + BMM op that was pre-registered for torch.compile.""" |
| 140 | + |
| 141 | + orig_dtype = query.dtype |
| 142 | + |
| 143 | + q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) |
| 144 | + if attn_kwargs.get("do_scale_q", False): |
| 145 | + q_scale.copy_(torch.abs(query).max() / Q_RANGE) |
| 146 | + query = query / q_scale |
| 147 | + |
| 148 | + query = query.to(torch.float8_e4m3fn).transpose(2, 1) |
| 149 | + |
| 150 | + if isinstance(key_cache, AffineQuantizedTensor) and isinstance( |
| 151 | + value_cache, AffineQuantizedTensor |
| 152 | + ): |
| 153 | + k_scale = key_cache.tensor_impl.scale |
| 154 | + v_scale = value_cache.tensor_impl.scale |
| 155 | + key_cache = key_cache.tensor_impl.float8_data |
| 156 | + value_cache = value_cache.tensor_impl.float8_data |
| 157 | + else: |
| 158 | + k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) |
| 159 | + v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) |
| 160 | + key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) |
| 161 | + value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) |
| 162 | + |
| 163 | + # no longer transposing prior to store, so need to check this in case of no cache |
| 164 | + # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here |
| 165 | + if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: |
| 166 | + key_cache = key_cache.transpose(2, 1) |
| 167 | + value_cache = value_cache.transpose(2, 1) |
| 168 | + |
| 169 | + mask = attn_kwargs.get("mask", None) |
| 170 | + if mask is not None: |
| 171 | + # Our expected mask format is bs x q_len x k_len, so to make it broadcastable |
| 172 | + # we need to create the nheads dimension |
| 173 | + while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len |
| 174 | + mask = mask.unsqueeze(1) |
| 175 | + |
| 176 | + L, S = query.size(-2), key_cache.size(-2) |
| 177 | + scale_factor = ( |
| 178 | + 1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor |
| 179 | + ) |
| 180 | + attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device) |
| 181 | + if attn_kwargs.get("is_causal_mask", False): |
| 182 | + assert mask is None |
| 183 | + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
| 184 | + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
| 185 | + attn_bias.to(torch.float32) |
| 186 | + |
| 187 | + if mask is not None: |
| 188 | + if mask.dtype == torch.bool: |
| 189 | + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) |
| 190 | + else: |
| 191 | + attn_bias = mask + attn_bias |
| 192 | + |
| 193 | + expansion = nheads // kvheads |
| 194 | + if expansion > 1: |
| 195 | + key_cache = key_cache.repeat_interleave( |
| 196 | + query.size(-3) // key_cache.size(-3), -3 |
| 197 | + ) |
| 198 | + value_cache = value_cache.repeat_interleave( |
| 199 | + query.size(-3) // value_cache.size(-3), -3 |
| 200 | + ) |
| 201 | + |
| 202 | + attn_weight = ( |
| 203 | + torch.ops.sendnn.scaled_bmm( |
| 204 | + query, |
| 205 | + key_cache.transpose(-2, -1), |
| 206 | + q_scale, |
| 207 | + k_scale, |
| 208 | + out_dtype=orig_dtype, |
| 209 | + use_fast_accum=True, |
| 210 | + ) |
| 211 | + * scale_factor |
| 212 | + ) |
| 213 | + attn_weight += attn_bias |
| 214 | + attn_weight = torch.softmax(attn_weight, dim=-1) |
| 215 | + attn_weight = torch.dropout(attn_weight, p_dropout, train=True) |
| 216 | + # Do matmul in orig_dtype |
| 217 | + attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) |
| 218 | + |
| 219 | + attn = attn.to(orig_dtype).transpose(2, 1).contiguous() |
| 220 | + return attn |
| 221 | + |
| 222 | + |
| 223 | +register_attention_op( |
| 224 | + "math_fp8", |
| 225 | + _math_fp8_store_op, |
| 226 | + _math_fp8_compute_op, |
| 227 | + update_attn_kwargs_op=_sdpa_update_attn_kwargs, |
| 228 | +) |
0 commit comments