|
33 | 33 | # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 |
34 | 34 |
|
35 | 35 |
|
36 | | -### FP8 linear layers |
| 36 | +# Gated torchao imports for FP8 implementation |
37 | 37 | if find_spec("torchao"): |
38 | 38 | TORCHAO_INSTALLED = True |
39 | 39 |
|
40 | 40 | # Third Party |
41 | | - from torchao.dtypes.affine_quantized_tensor import ( # type: ignore |
| 41 | + from torchao.dtypes.affine_quantized_tensor import ( |
42 | 42 | AffineQuantizedTensor, |
43 | 43 | to_affine_quantized_floatx, |
44 | 44 | to_affine_quantized_floatx_static, |
45 | 45 | ) |
46 | | - from torchao.dtypes.floatx.float8_layout import ( # type: ignore |
| 46 | + from torchao.dtypes.floatx.float8_layout import ( |
47 | 47 | Float8AQTTensorImpl, |
48 | 48 | Float8Layout, |
49 | 49 | Float8MMConfig, |
50 | 50 | preprocess_data, |
51 | 51 | preprocess_scale, |
52 | 52 | ) |
53 | | - from torchao.dtypes.utils import get_out_shape # type: ignore |
54 | | - from torchao.float8.inference import ( # type: ignore |
| 53 | + from torchao.dtypes.utils import get_out_shape |
| 54 | + from torchao.float8.inference import ( |
55 | 55 | _is_rowwise_scaled, |
56 | 56 | addmm_float8_unwrapped_inference, |
57 | 57 | ) |
58 | | - from torchao.quantization.granularity import PerRow, PerTensor # type: ignore |
59 | | - from torchao.quantization.observer import get_block_size # type: ignore |
60 | | - from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore |
| 58 | + from torchao.quantization.granularity import PerRow, PerTensor |
| 59 | + from torchao.quantization.observer import get_block_size |
| 60 | + from torchao.quantization.quant_primitives import ZeroPointDomain |
61 | 61 | else: |
62 | 62 | TORCHAO_INSTALLED = False |
63 | 63 |
|
@@ -177,7 +177,8 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor": |
177 | 177 | ) |
178 | 178 |
|
179 | 179 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
180 | | - """If input quantization is active, compute FP8xFP8 addmm.""" |
| 180 | + """If input quantization is active, compute FP8xFP8 addmm leveraging torchao |
| 181 | + functionalities. Otherwise compute non-quantized addmm.""" |
181 | 182 |
|
182 | 183 | # fp8 weight tensor for torchao |
183 | 184 | qweight: AffineQuantizedTensor = self._construct_qweight_structure() |
@@ -282,6 +283,7 @@ def shard_fp8_linear( |
282 | 283 | | input_scale | Y/N | 0/- | |
283 | 284 | | bias | 0 | - | |
284 | 285 | """ |
| 286 | + |
285 | 287 | param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} |
286 | 288 | for module_name, module_info in module_sharding_info.items(): |
287 | 289 | linear_mod: torch.nn.Module = module_info.linear_module |
|
0 commit comments