Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch
import torch.nn as nn
import types
Expand Down Expand Up @@ -1296,6 +1297,10 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
self.register_scale("descale_amax", mod_extra_config.scale.inputs[3].type(torch.float32), self.scale_format)
self.register_scale("scale_output", 1 / mod_extra_config.scale.outputs[0].type(torch.float32), self.scale_format)
self.register_scale("scale_amax", 1 / self.descale_amax, self.scale_format)
self.qkv_slice_thld = int(os.getenv("VLLM_FUSEDSPA_QKV_SLICE_SEQ_LEN_THLD", 8192))
if self.qkv_slice_thld > 0:
self.q_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_Q_SLICE_CHUNK_SIZE", self.qkv_slice_thld))
self.kv_chunk_size = int(os.getenv("VLLM_FUSEDSDPA_KV_SLICE_CHUNK_SIZE", self.qkv_slice_thld))

def forward_qdq(
self,
Expand Down Expand Up @@ -1330,6 +1335,41 @@ def forward_qdq(
seq_padding_type,
)
return results

def fp8_fsdpa_fwd(self,
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
softmax_mode,
):
results = torch.ops.hpu.fp8_sdpa_recomp_fwd(
q,
k,
v,
attn_mask,
dropout_p,
scale,
is_causal,
True, # requires_backward
softmax_mode, # softmax_mode
self.scale_q, # d_scale_q
self.scale_k, # d_scale_k
self.scale_v, # d_scale_v
self.scale_amax, # q_scale_s
self.scale_output, # q_scale_o
self.descale_amax, # d_scale_s
False, # is_amax_s
False, # is_amax_o
None, # valid_seq_len
"right", # seq_padding_type
(-1, -1), # window_size
None, # sink
)
return results

def forward_quant(
self,
Expand All @@ -1345,32 +1385,98 @@ def forward_quant(
valid_seq_len=None,
seq_padding_type="None",
):
sm_mode = softmax_mode if softmax_mode == "fp32" else "None"
sm_mode = softmax_mode if softmax_mode == "fp32" else "none"
qinput = self.quant_q(q).detach()
kinput = self.quant_k(k).detach()
vinput = self.quant_v(v).detach()
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out
q_len = q.shape[-2]
kv_len = kinput.size(-2)

# for prefill with prefix caching
if q_len != 1 and q_len != kv_len \
and kv_len > self.qkv_slice_thld:
assert attn_mask is not None, "Attention mask is required for FSDPA with prefix caching."
ctx_len = kv_len - q_len
from habana_frameworks.torch.hpex.kernels.Fp8FusedSDPA import is_gqa, gqa_input_reshape_fwd, gqa_output_reshape
gqa = is_gqa(qinput, kinput)
if gqa:
qinput, kinput, vinput, attn_mask = gqa_input_reshape_fwd(qinput, kinput, vinput, attn_mask)

num_q_chunks = (q_len + self.q_chunk_size - 1) // self.q_chunk_size
num_kv_chunks = (kv_len + self.kv_chunk_size - 1) // self.kv_chunk_size
chunk_outputs = []
for q_chunk_idx in range(num_q_chunks):
q_start = q_chunk_idx * self.q_chunk_size
q_end = min((q_chunk_idx + 1) * self.q_chunk_size, q_len)
q_chunk = qinput[..., q_start:q_end, :]

last_out = None
last_m = None
last_linv = None
for kv_chunk_idx in range(num_kv_chunks):
kv_start = kv_chunk_idx * self.kv_chunk_size
kv_end = min((kv_chunk_idx + 1) * self.kv_chunk_size, kv_len)
k_chunk = kinput[..., kv_start:kv_end, :]
v_chunk = vinput[..., kv_start:kv_end, :]

# skip the upper triangular part for causal attention
if kv_start > ctx_len + q_end:
continue

is_causal= True if kv_start-ctx_len==0 else False

# current chunk_size should be multiple of 1024 to get right m/linv
if kv_end-ctx_len==0 and ((q_end-q_start)%1024!=0 or (kv_end-kv_start)%1024!=0):
is_causal = False
attn_mask_chunk = attn_mask[..., kv_start:kv_end]
else:
attn_mask_chunk = None

chunk_res = self.fp8_fsdpa_fwd(q_chunk, k_chunk, v_chunk, attn_mask_chunk, dropout_p, scale, is_causal, sm_mode)
chunk_out, chunk_m, chunk_linv = (gqa_output_reshape(x) for x in (chunk_res[:3])) if gqa else chunk_res[:3]

chunk_m = chunk_m.to(torch.float32)
chunk_linv = chunk_linv.to(torch.float32) * 128.0 if softmax_mode != "fp32" else chunk_linv.to(torch.float32)
chunk_out = self.dequant_output(chunk_out).to(torch.float32)

if kv_chunk_idx == 0:
last_out = chunk_out
last_m = chunk_m
last_linv = chunk_linv
else:
new_m = torch.maximum(last_m, chunk_m)
last_linv_rescaled = (1.0 / last_linv) * torch.exp(last_m - new_m)
chunk_linv_rescaled = (1.0 / chunk_linv) * torch.exp(chunk_m - new_m)
last_linv = 1.0 / (last_linv_rescaled + chunk_linv_rescaled)
last_out = (last_linv_rescaled * last_linv) * last_out + (
chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m
chunk_outputs.append(last_out)
output = torch.cat(chunk_outputs, dim=-2)
return output.to(q.dtype)
else:
results = self.fp8_fused_sdpa(
qinput,
kinput,
vinput,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
softmax_mode=sm_mode,
d_scale_q=self.scale_q,
d_scale_k=self.scale_k,
d_scale_v=self.scale_v,
q_scale_s=self.scale_amax,
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type,
)
output = results[0]
d_out = self.dequant_output(output)
return d_out

def forward_measure(
self,
Expand Down