Skip to content

Commit f05beb5

Browse files
committed
Fix linting, add paged attention kernels
Signed-off-by: Antoni Viros i Martin <[email protected]>
1 parent c931ad7 commit f05beb5

File tree

8 files changed

+577
-327
lines changed

8 files changed

+577
-327
lines changed

fms_mo/aiu_addons/__init__.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
def _infer_quantization_config(quant_config: dict) -> dict | None:
2+
"""Construct linear_config dictionary carrying FP8 configuration for FMS.
3+
4+
There's many quantization packages compatible with HF
5+
We initially focus on llm-compressor as it is the one used in FMS-MO
6+
7+
llm-compressor saves its checkpoints with quant_method = compressed-tensors
8+
quantization_status tells us whether the model has already been quantized
9+
We only support loading already quantized models (compressed status)
10+
"""
11+
12+
if (
13+
quant_config["quant_method"] == "compressed-tensors"
14+
and quant_config["quantization_status"] == "compressed"
15+
):
16+
# FP8 quantization will have FP8 weights
17+
# We assume a single quantization group (group_0), to follow fms-mo checkpoints
18+
# num_bits and type tells us "float" with "8" bits, aka FP8
19+
if (
20+
quant_config["config_groups"]["group_0"]["weights"]["type"] == "float"
21+
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
22+
):
23+
# This is used by get_linear to decide whether a linear layer
24+
# will be quantized or not inside the model
25+
def fp8_linear_type(name: str) -> str:
26+
# We need to translate HF names to FMS names
27+
translations = {
28+
"lm_head": "head",
29+
}
30+
for ignored_layer in quant_config["ignore"]:
31+
assert isinstance(ignored_layer, str)
32+
fms_ign_layer = translations.get(ignored_layer, ignored_layer)
33+
if name in fms_ign_layer:
34+
return "torch_linear"
35+
for pattern in quant_config["config_groups"]["group_0"]["targets"]:
36+
# Special case from llm-compressor that covers all linear layers
37+
# not in the ignore pattern
38+
assert isinstance(pattern, str)
39+
if pattern == "Linear":
40+
return "fp8"
41+
if name in translations.get(pattern, pattern):
42+
return "fp8"
43+
return "torch_linear"
44+
45+
return {
46+
"linear_type": fp8_linear_type,
47+
"input_activations": quant_config["config_groups"]["group_0"][
48+
"input_activations"
49+
],
50+
"output_activations": quant_config["config_groups"]["group_0"][
51+
"output_activations"
52+
],
53+
"weights": quant_config["config_groups"]["group_0"]["weights"],
54+
}
55+
return None

fms_mo/aiu_addons/fp8/fp8_adapter.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# Standard
1717
from typing import Any, Mapping
18+
import functools
1819

1920
# Third Party
2021
from fms.modules.linear import get_linear_type
@@ -25,15 +26,15 @@
2526
# Retaining kwargs input arguments for consistency with other adapter steps.
2627

2728

28-
# NOTE: this adapter step must be registered before the adapter that uses it (such as
29-
# the llama adapter in fms.models.llama)
3029
# TODO: may be shared with gptq llama
31-
# TODO: generalize across architectures if possible
32-
def _hf_fp8_llama_check(
33-
input_sd: Mapping[str, Any], model_config: ModelConfig | None = None, **kwargs
30+
def _hf_fp8_check(
31+
input_sd: Mapping[str, Any],
32+
model_config: ModelConfig | None = None,
33+
checkpoint_is_fused: bool = False,
34+
**kwargs,
3435
) -> Mapping[str, Any]:
35-
"""Implementation of adapter step for FMS Llama: ensure that when FP8 quantization
36-
is in use, weights are unfused.
36+
"""Implementation of adapter step for FMS: ensure that when FP8 quantization
37+
is in use, weights are fused like the model checkpoint.
3738
"""
3839

3940
has_fused_weights = True
@@ -44,16 +45,26 @@ def _hf_fp8_llama_check(
4445
if model_config.linear_config:
4546
linear_type = model_config.linear_config["linear_type"]
4647
if callable(linear_type):
47-
# Calling this with "any" guarantees "fp8" to be returned
48+
# Calling this function with "any" guarantees "fp8" to be returned
4849
# when loading an HF fp8 checkpoint, and never in any other condition
4950
linear_type = get_linear_type(model_config.linear_config, "any")
5051

51-
if "fp8" in linear_type and has_fused_weights:
52+
if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused:
5253
raise ValueError(
5354
"FP8 HF llama checkpoints cannot be loaded into a model with fused weights"
5455
)
5556

5657
return input_sd
5758

5859

59-
serialization.register_adapter_step("llama", "hf_fp8_llama_check", _hf_fp8_llama_check)
60+
serialization.register_adapter_step(
61+
"llama", "hf_fp8_check", functools.partial(_hf_fp8_check, checkpoint_is_fused=False)
62+
)
63+
serialization.extend_adapter("llama", "hf", ["hf_fp8_check"])
64+
65+
serialization.register_adapter_step(
66+
"granite",
67+
"hf_fp8_check",
68+
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
69+
)
70+
serialization.extend_adapter("granite", "hf", ["hf_fp8_check"])

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""FMS registration of attention BMM operation using torch-registered scaled BMM."""
1515

1616
# Standard
17-
from typing import NotRequired, Unpack
17+
from typing import NotRequired, Optional, Unpack
1818
import math
1919

2020
# Third Party
@@ -23,6 +23,10 @@
2323
_sdpa_update_attn_kwargs,
2424
register_attention_op,
2525
)
26+
from fms.utils.spyre.paged import (
27+
SpyrePagedAttentionKwargs,
28+
__spyre_paged_validate_attn_kwargs_op,
29+
)
2630
import torch
2731

2832
# Local
@@ -46,7 +50,7 @@ class MathFP8AttentionKwargs(AttentionKwargs):
4650

4751
def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor:
4852
"""Construct the custom object to save KV cache with its scales."""
49-
return ScaledTensor(tensor, scale)
53+
return ScaledTensor(tensor, scale, True)
5054

5155

5256
def _math_fp8_store_op(
@@ -58,13 +62,15 @@ def _math_fp8_store_op(
5862
) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]:
5963
"""Implement math of KV cache storing."""
6064

65+
# Grab scale from kv-cache if already there, compute dynamically otherwise
6166
if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor):
6267
k_scale = key_cache._scale
6368
v_scale = value_cache._scale
6469
else:
6570
k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32)
6671
v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32)
6772

73+
# Scale kv tensors for storage
6874
keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1)
6975
values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1)
7076

@@ -83,6 +89,7 @@ def _math_fp8_store_op(
8389
key_cache,
8490
value_cache,
8591
)
92+
# If it's a new kv cache, ensure it's contiguous for spyre use cases
8693
keys = _construct_fp8_cache(keys.contiguous(), k_scale)
8794
values = _construct_fp8_cache(values.contiguous(), v_scale)
8895
return (keys, values, keys, values)
@@ -98,35 +105,40 @@ def _math_fp8_compute_op(
98105
scale_factor: float | None,
99106
**attn_kwargs: Unpack[MathFP8AttentionKwargs],
100107
) -> torch.Tensor:
101-
"""Implement computation of attention BMM, leveraging the custom scaled attention
102-
BMM op that was pre-registered for torch.compile."""
108+
"""Implement computation of scaled dot product attention, leveraging
109+
the custom scaled BMM op that was pre-registered for torch.compile."""
103110

104111
orig_dtype = query.dtype
105112

113+
# Scaling the Q tensor is optional
106114
q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device)
107115
if attn_kwargs.get("do_scale_q", False):
108116
q_scale.copy_(torch.abs(query).max() / Q_RANGE)
109117
query = query / q_scale
110118

111119
query = query.to(torch.float8_e4m3fn).transpose(2, 1)
112120

121+
# Grab kv cache and deal with cases where no store op was run
113122
if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor):
123+
# Store op was run
114124
k_scale = key_cache._scale
115125
v_scale = value_cache._scale
116126
key_cache = key_cache._data
117127
value_cache = value_cache._data
118128
else:
129+
# Store op wasn't run (e.g. encoders, use_cache=False)
119130
k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32)
120131
v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32)
121132
key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn)
122133
value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn)
123134

124-
# no longer transposing prior to store, so need to check this in case of no cache
135+
# If store wasn't run, we need to transpose the tensors here
125136
# TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here
126137
if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads:
127138
key_cache = key_cache.transpose(2, 1)
128139
value_cache = value_cache.transpose(2, 1)
129140

141+
# Most of the code that follows is a copy of Pytorch SDPA, with fp8 additions
130142
mask = attn_kwargs.get("mask", None)
131143
if mask is not None:
132144
# Our expected mask format is bs x q_len x k_len, so to make it broadcastable
@@ -187,3 +199,86 @@ def _math_fp8_compute_op(
187199
_math_fp8_compute_op,
188200
update_attn_kwargs_op=_sdpa_update_attn_kwargs,
189201
)
202+
203+
204+
def _spyre_scaled_paged_store_op(
205+
keys: torch.Tensor,
206+
values: torch.Tensor,
207+
key_cache: Optional[torch.Tensor],
208+
value_cache: Optional[torch.Tensor],
209+
**attn_kwargs: Unpack[SpyrePagedAttentionKwargs],
210+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
211+
# For paged store, we must have pre-allocated the kv-cache
212+
assert key_cache is not None and isinstance(
213+
key_cache, ScaledTensor
214+
), "kv cache must be preallocated"
215+
assert value_cache is not None and isinstance(
216+
value_cache, ScaledTensor
217+
), "kv cache must be preallocated"
218+
if not key_cache._scaled:
219+
key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32)
220+
value_cache._scale = (torch.abs(values).max() / 100.0).to(dtype=torch.float32)
221+
222+
result_key_cache_data, result_value_cache_data = (
223+
torch.ops.spyre.scaled_paged_attn_store(
224+
keys,
225+
values,
226+
key_cache._data,
227+
value_cache._data,
228+
key_cache._scale,
229+
value_cache._scale,
230+
attn_kwargs["slot_mapping"],
231+
)
232+
)
233+
234+
result_key_cache = _construct_fp8_cache(result_key_cache_data, key_cache._scale)
235+
result_value_cache = _construct_fp8_cache(
236+
result_value_cache_data, value_cache._scale
237+
)
238+
239+
# for prefill, we want to return the original keys/values
240+
if attn_kwargs.get("block_table", None) is None:
241+
return keys, values, result_key_cache, result_value_cache
242+
return (
243+
result_key_cache,
244+
result_value_cache,
245+
result_key_cache,
246+
result_value_cache,
247+
)
248+
249+
250+
def _spyre_scaled_paged_compute_op(
251+
query: torch.Tensor,
252+
key_cache: torch.Tensor,
253+
value_cache: torch.Tensor,
254+
nheads: int,
255+
kvheads: int,
256+
p_dropout: float,
257+
scale_factor: Optional[float],
258+
**attn_kwargs,
259+
) -> torch.Tensor:
260+
assert isinstance(key_cache, ScaledTensor), "kv cache must be scaled"
261+
assert isinstance(value_cache, ScaledTensor), "kv cache must be scaled"
262+
if scale_factor is None:
263+
scale_factor = 1 / math.sqrt(query.shape[-1])
264+
return torch.ops.spyre.scaled_paged_attn_compute(
265+
query,
266+
key_cache._data,
267+
value_cache._data,
268+
key_cache._scale,
269+
value_cache._scale,
270+
scale_factor,
271+
attn_kwargs["current_tkv_mask"],
272+
attn_kwargs["left_padded_prompt_mask"],
273+
attn_kwargs["block_table"],
274+
)
275+
276+
277+
register_attention_op(
278+
"spyre_paged_attn_fp8",
279+
_spyre_scaled_paged_store_op,
280+
compute_op=_math_fp8_compute_op,
281+
is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) is None,
282+
compute_decode_op=_spyre_scaled_paged_compute_op,
283+
validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op,
284+
)

0 commit comments

Comments
 (0)