Skip to content

Commit cf2082e

Browse files
committed
Make changes to work with fms and aftu
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent f05beb5 commit cf2082e

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

fms_mo/aiu_addons/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def _infer_quantization_config(quant_config: dict) -> dict | None:
2020
quant_config["config_groups"]["group_0"]["weights"]["type"] == "float"
2121
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
2222
):
23+
# First, import required FP8 linear classes from fms-mo
24+
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import
25+
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import
2326
# This is used by get_linear to decide whether a linear layer
2427
# will be quantized or not inside the model
2528
def fp8_linear_type(name: str) -> str:

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,9 @@ def _spyre_scaled_paged_compute_op(
251251
query: torch.Tensor,
252252
key_cache: torch.Tensor,
253253
value_cache: torch.Tensor,
254-
nheads: int,
255-
kvheads: int,
256-
p_dropout: float,
254+
nheads: int, # pylint: disable=unused-argument
255+
kvheads: int, # pylint: disable=unused-argument
256+
p_dropout: float, # pylint: disable=unused-argument
257257
scale_factor: Optional[float],
258258
**attn_kwargs,
259259
) -> torch.Tensor:

fms_mo/aiu_addons/fp8/fp8_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
193193
)
194194
qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs)
195195

196-
# Copied from torchao _linear_fp8_act_fp8_weight_impl
196+
# Copied from torchao _linear_fp8_act_fp8_weight_impl
197197
# (with changes to support fp8 out)
198198
scaled_mm_config = Float8MMConfig(use_fast_accum=True)
199199
out_shape = get_out_shape(qx.shape, qweight.shape)

0 commit comments

Comments
 (0)