|
| 1 | +# Copyright The FMS Model Optimizer Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""Implement FP8 linear module to be loaded via FMS.""" |
| 15 | + |
| 16 | +# Standard |
| 17 | +from importlib.util import find_spec |
| 18 | +from typing import Any, Mapping |
| 19 | + |
| 20 | +# Third Party |
| 21 | +from fms.modules.linear import ( |
| 22 | + LinearModuleShardingInfo, |
| 23 | + LinearParameterShardingInfo, |
| 24 | + register_linear_type_to_module_map, |
| 25 | + register_linear_type_to_sharding_map, |
| 26 | + shard_base_linear, |
| 27 | +) |
| 28 | +from fms.modules.tp import ShardType, TPModule |
| 29 | +import torch |
| 30 | + |
| 31 | +# pylint: disable=not-callable |
| 32 | +# torch.nn.functional.linear not recognized as callable |
| 33 | +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 |
| 34 | + |
| 35 | + |
| 36 | +### FP8 linear layers |
| 37 | +if find_spec("torchao"): |
| 38 | + TORCHAO_INSTALLED = True |
| 39 | + |
| 40 | + # Third Party |
| 41 | + from torchao.dtypes.affine_quantized_tensor import ( # type: ignore |
| 42 | + AffineQuantizedTensor, |
| 43 | + to_affine_quantized_floatx, |
| 44 | + to_affine_quantized_floatx_static, |
| 45 | + ) |
| 46 | + from torchao.dtypes.floatx.float8_layout import ( # type: ignore |
| 47 | + Float8AQTTensorImpl, |
| 48 | + Float8Layout, |
| 49 | + Float8MMConfig, |
| 50 | + preprocess_data, |
| 51 | + preprocess_scale, |
| 52 | + ) |
| 53 | + from torchao.dtypes.utils import get_out_shape # type: ignore |
| 54 | + from torchao.float8.inference import ( # type: ignore |
| 55 | + _is_rowwise_scaled, |
| 56 | + addmm_float8_unwrapped_inference, |
| 57 | + ) |
| 58 | + from torchao.quantization.granularity import PerRow, PerTensor # type: ignore |
| 59 | + from torchao.quantization.observer import get_block_size # type: ignore |
| 60 | + from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore |
| 61 | +else: |
| 62 | + TORCHAO_INSTALLED = False |
| 63 | + |
| 64 | + |
| 65 | +class FP8Linear(torch.nn.Module): |
| 66 | + """Class handles FP8 weights loading and uses torchao for the matmuls.""" |
| 67 | + |
| 68 | + def __init__( |
| 69 | + self, |
| 70 | + in_features: int, |
| 71 | + out_features: int, |
| 72 | + bias: bool, |
| 73 | + linear_config: Mapping[str, Any], |
| 74 | + ): |
| 75 | + super().__init__() |
| 76 | + |
| 77 | + self.in_features = in_features |
| 78 | + self.out_features = out_features |
| 79 | + self.has_bias = bias |
| 80 | + self.linear_config = linear_config |
| 81 | + |
| 82 | + assert ( |
| 83 | + self.linear_config["weights"] is not None |
| 84 | + ), "Weights must always be quantized for FP8Linear" |
| 85 | + assert self.linear_config["weights"][ |
| 86 | + "symmetric" |
| 87 | + ], "We only support symmetric weights for now" |
| 88 | + assert not self.linear_config["weights"][ |
| 89 | + "dynamic" |
| 90 | + ], "We only support pre-quantized weights for now" |
| 91 | + |
| 92 | + self.weight = torch.nn.Parameter( |
| 93 | + torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn), |
| 94 | + requires_grad=False, |
| 95 | + ) |
| 96 | + |
| 97 | + weight_scale_shape = ( |
| 98 | + (1,) |
| 99 | + if self.linear_config["weights"]["strategy"] == "tensor" |
| 100 | + else (out_features, 1) |
| 101 | + ) |
| 102 | + self.weight_scale = torch.nn.Parameter( |
| 103 | + torch.ones(weight_scale_shape), requires_grad=False |
| 104 | + ) |
| 105 | + |
| 106 | + self.has_bias = bias |
| 107 | + if self.has_bias: |
| 108 | + self.bias = torch.nn.Parameter(torch.zeros((out_features,))) |
| 109 | + |
| 110 | + if ( |
| 111 | + self.linear_config["input_activations"] is not None |
| 112 | + and not self.linear_config["input_activations"]["dynamic"] |
| 113 | + ): |
| 114 | + input_scale_shape = ( |
| 115 | + (1,) |
| 116 | + if self.linear_config["input_activations"]["strategy"] == "tensor" |
| 117 | + else (out_features, 1) |
| 118 | + ) |
| 119 | + self.input_scale = torch.nn.Parameter( |
| 120 | + torch.ones(input_scale_shape), requires_grad=False |
| 121 | + ) |
| 122 | + |
| 123 | + def _input_activation_quant_func_fp8( |
| 124 | + self, |
| 125 | + x: torch.Tensor, |
| 126 | + activation_granularity, |
| 127 | + activation_dtype: torch.dtype, |
| 128 | + scale: torch.Tensor | None = None, |
| 129 | + ): |
| 130 | + """Quantize the input activation tensor for an aqt_float variant. |
| 131 | + If scale is not provided, it will be dynamically calculated, otherwise the |
| 132 | + provided scale will be used. |
| 133 | + """ |
| 134 | + |
| 135 | + block_size = get_block_size(x.shape, activation_granularity) |
| 136 | + if scale is None: |
| 137 | + activation = to_affine_quantized_floatx( |
| 138 | + input_float=x, |
| 139 | + block_size=block_size, |
| 140 | + target_dtype=activation_dtype, |
| 141 | + scale_dtype=torch.float32, |
| 142 | + _layout=Float8Layout(mm_config=None), # Config is stored on weight |
| 143 | + ) |
| 144 | + else: |
| 145 | + assert isinstance( |
| 146 | + activation_granularity, PerTensor |
| 147 | + ), "Static quantization only supports PerTensor granularity" |
| 148 | + activation = to_affine_quantized_floatx_static( |
| 149 | + input_float=x, |
| 150 | + block_size=block_size, |
| 151 | + scale=scale, |
| 152 | + target_dtype=activation_dtype, |
| 153 | + _layout=Float8Layout(mm_config=None), # Config is stored on weight |
| 154 | + ) |
| 155 | + return activation |
| 156 | + |
| 157 | + def _construct_qweight_structure(self) -> "AffineQuantizedTensor": |
| 158 | + """Construct the torchao machinery for the fp8 matmul""" |
| 159 | + |
| 160 | + weight_granularity = ( |
| 161 | + PerTensor() |
| 162 | + if self.linear_config["weights"]["strategy"] == "tensor" |
| 163 | + else PerRow() |
| 164 | + ) |
| 165 | + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) |
| 166 | + return AffineQuantizedTensor( |
| 167 | + Float8AQTTensorImpl.from_plain( |
| 168 | + self.weight, |
| 169 | + self.weight_scale.squeeze().to(torch.float32), |
| 170 | + None, |
| 171 | + fp8_layout, |
| 172 | + ), |
| 173 | + get_block_size(self.weight.shape, weight_granularity), |
| 174 | + self.weight.shape, |
| 175 | + zero_point_domain=ZeroPointDomain.NONE, |
| 176 | + dtype=self.weight_scale.dtype, |
| 177 | + ) |
| 178 | + |
| 179 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 180 | + """If input quantization is active, compute FP8xFP8 addmm.""" |
| 181 | + |
| 182 | + # fp8 weight tensor for torchao |
| 183 | + qweight: AffineQuantizedTensor = self._construct_qweight_structure() |
| 184 | + |
| 185 | + if self.linear_config["input_activations"] is not None: |
| 186 | + # activations are also fp8, quantize as required by model |
| 187 | + act_granularity = ( |
| 188 | + PerTensor() |
| 189 | + if self.linear_config["input_activations"]["strategy"] == "tensor" |
| 190 | + else PerRow() |
| 191 | + ) |
| 192 | + input_quant_kwargs = { |
| 193 | + "activation_granularity": act_granularity, |
| 194 | + "activation_dtype": torch.float8_e4m3fn, |
| 195 | + } |
| 196 | + if not self.linear_config["input_activations"]["dynamic"]: |
| 197 | + input_quant_kwargs["scale"] = self.input_scale.squeeze().to( |
| 198 | + torch.float32 |
| 199 | + ) |
| 200 | + qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) |
| 201 | + |
| 202 | + # Copied from torchao _linear_fp8_act_fp8_weight_impl (with changes to support fp8 out) |
| 203 | + scaled_mm_config = Float8MMConfig(use_fast_accum=True) |
| 204 | + out_shape = get_out_shape(qx.shape, qweight.shape) |
| 205 | + |
| 206 | + # Weight tensor preprocessing |
| 207 | + w_tensor_impl = qweight.tensor_impl |
| 208 | + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" |
| 209 | + w_data = w_tensor_impl.float8_data |
| 210 | + w_scale = w_tensor_impl.scale |
| 211 | + |
| 212 | + # Input tensor preprocessing |
| 213 | + inpt_data = qx.tensor_impl.float8_data |
| 214 | + input_scale = qx.tensor_impl.scale |
| 215 | + # Handle case where input tensor is more than 2D |
| 216 | + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) |
| 217 | + |
| 218 | + # Handle rowwise case |
| 219 | + if _is_rowwise_scaled(qweight): |
| 220 | + assert _is_rowwise_scaled(qx), "Input tensor must be rowwise block size" |
| 221 | + w_scale = w_scale.unsqueeze(-1).T |
| 222 | + input_scale = preprocess_scale(input_scale, qx.shape) |
| 223 | + |
| 224 | + # Preprocess data |
| 225 | + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) |
| 226 | + |
| 227 | + # Perform the computation |
| 228 | + return addmm_float8_unwrapped_inference( |
| 229 | + inpt_data, |
| 230 | + input_scale, |
| 231 | + w_data, |
| 232 | + w_scale, |
| 233 | + output_dtype=qx.dtype, |
| 234 | + bias=getattr(self, "bias", None), |
| 235 | + use_fast_accum=scaled_mm_config.use_fast_accum, |
| 236 | + ).reshape(out_shape) |
| 237 | + |
| 238 | + # activations not quantized, dequant fp8 weight and do regular matmul |
| 239 | + out = torch.nn.functional.linear( |
| 240 | + x, qweight.dequantize(), self.bias if self.has_bias else None |
| 241 | + ) |
| 242 | + return out |
| 243 | + |
| 244 | + def __repr__(self) -> str: |
| 245 | + return ( |
| 246 | + f"{self.__class__.__name__}" |
| 247 | + f"(in={self.in_features}, out={self.out_features}, " |
| 248 | + f"bias={self.has_bias}, fp8_config={self.linear_config})" |
| 249 | + ) |
| 250 | + |
| 251 | + |
| 252 | +def get_fp8_linear( |
| 253 | + in_features: int, |
| 254 | + out_features: int, |
| 255 | + bias: bool, |
| 256 | + linear_config: Mapping[str, Any], |
| 257 | +) -> FP8Linear: |
| 258 | + """Retrieve an FP8 Linear module""" |
| 259 | + |
| 260 | + if not TORCHAO_INSTALLED: |
| 261 | + raise ModuleNotFoundError("You need to install torchao for FP8 support in FMS!") |
| 262 | + |
| 263 | + return FP8Linear(in_features, out_features, bias, linear_config) |
| 264 | + |
| 265 | + |
| 266 | +def shard_fp8_linear( |
| 267 | + tensor_values: dict[str, torch.Tensor], |
| 268 | + tp_module: TPModule, |
| 269 | + module_sharding_info: dict[str, LinearModuleShardingInfo], |
| 270 | +) -> set | None: |
| 271 | + """ |
| 272 | + | GPU | |
| 273 | + sharding | param | shard | dim | |
| 274 | + ----------+----------------+-------+-----| |
| 275 | + colwise | weight | Y | 0 | |
| 276 | + | weight_scale | N | - | |
| 277 | + | input_scale | N | - | |
| 278 | + | bias | Y | 0 | |
| 279 | + ----------+----------------+-------+-----| |
| 280 | + rowwise | weight | Y | 1 | |
| 281 | + | weight_scale | Y/N | 0/- | |
| 282 | + | input_scale | Y/N | 0/- | |
| 283 | + | bias | 0 | - | |
| 284 | + """ |
| 285 | + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} |
| 286 | + for module_name, module_info in module_sharding_info.items(): |
| 287 | + linear_mod: torch.nn.Module = module_info.linear_module |
| 288 | + weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][ |
| 289 | + "strategy" |
| 290 | + ] |
| 291 | + # Scales are per-row or per-tensor |
| 292 | + # Only sharding needed when row parallel and per-row |
| 293 | + shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 |
| 294 | + params: dict[str, LinearParameterShardingInfo] = { |
| 295 | + "weight": LinearParameterShardingInfo( |
| 296 | + module_info.sharding_dim, ShardType.SHARD |
| 297 | + ), |
| 298 | + "weight_scale": LinearParameterShardingInfo( |
| 299 | + module_info.sharding_dim, |
| 300 | + ShardType.SHARD if shard_scales else ShardType.CLONE, |
| 301 | + ), |
| 302 | + } |
| 303 | + if hasattr(linear_mod, "input_scale"): |
| 304 | + params["input_scale"] = LinearParameterShardingInfo( |
| 305 | + module_info.sharding_dim, |
| 306 | + ShardType.SHARD if shard_scales else ShardType.CLONE, |
| 307 | + ) |
| 308 | + if hasattr(linear_mod, "bias") and linear_mod.bias is not None: |
| 309 | + params["bias"] = LinearParameterShardingInfo( |
| 310 | + module_info.sharding_dim, |
| 311 | + ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0, |
| 312 | + ) |
| 313 | + param_sharding_info[module_name] = params |
| 314 | + |
| 315 | + unused_keys = shard_base_linear( |
| 316 | + tensor_values, |
| 317 | + tp_module, |
| 318 | + module_sharding_info, |
| 319 | + param_sharding_info, |
| 320 | + ) |
| 321 | + return unused_keys |
| 322 | + |
| 323 | + |
| 324 | +register_linear_type_to_module_map("fp8", get_fp8_linear) |
| 325 | +register_linear_type_to_sharding_map("fp8", shard_fp8_linear) |
0 commit comments