|
38 | 38 | shard_base_linear, |
39 | 39 | ) |
40 | 40 | from fms.modules.tp import ShardType, TPModule |
| 41 | + |
| 42 | + # Register decomps for torchao >= 0.12 + AIU. |
| 43 | + # This import only succeeds if torchao is 0.12 or higher |
| 44 | + try: |
| 45 | + # Third Party |
| 46 | + from torchao.quantization.quant_primitives import _expand_scale_to_tensor_shape |
| 47 | + |
| 48 | + # This function is copied from _quantize_affine_float8 |
| 49 | + # in torchao.quantization.quant_primitives, but removing |
| 50 | + # the wrapping that turns it into a custom pytorch op |
| 51 | + def _quantize_affine_float8_custom( |
| 52 | + tensor: torch.Tensor, |
| 53 | + scale: torch.Tensor, |
| 54 | + float8_dtype: torch.dtype = torch.float8_e4m3fn, |
| 55 | + ) -> torch.Tensor: |
| 56 | + """ |
| 57 | + Quantizes the high precision floating point tensor |
| 58 | + to a float8 tensor, using the given scaling factor. |
| 59 | + """ |
| 60 | + tensor_fp32 = tensor.to(torch.float32) |
| 61 | + |
| 62 | + # Expand scale to match tensor dimensions for block-wise quantization |
| 63 | + scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) |
| 64 | + |
| 65 | + tensor_scaled = tensor_fp32 / scale_expanded |
| 66 | + max_value = torch.finfo(float8_dtype).max |
| 67 | + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) |
| 68 | + fp8_tensor = tensor_clamped.to(float8_dtype) |
| 69 | + return fp8_tensor |
| 70 | + |
| 71 | + quant_lib = torch.library.Library("torchao", "FRAGMENT") |
| 72 | + quant_lib.impl( |
| 73 | + "quantize_affine_float8", |
| 74 | + _quantize_affine_float8_custom, |
| 75 | + "CompositeImplicitAutograd", |
| 76 | + ) |
| 77 | + except ImportError: |
| 78 | + pass |
| 79 | + |
| 80 | + # Third Party |
41 | 81 | from torchao.dtypes.affine_quantized_tensor import ( |
42 | 82 | AffineQuantizedTensor, |
43 | 83 | to_affine_quantized_floatx, |
|
0 commit comments