Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def __init__(self, mod, parent, mod_extra_config, *args, **kwargs):
self.fp8_apc_fsdpa_impl = impl_mapping[qkv_slice_impl]

self.slice_causal = os.getenv("VLLM_HPU_FSDPA_SLICE_CAUSAL", "0") in ("1", "true")
self.with_mark_step = os.getenv("VLLM_HPU_FSDPA_SLICE_WITH_MARK_STEP", "0") in ("1", "true")
if self.with_mark_step:
import habana_frameworks.torch as ht
self.mark_step = ht.core.mark_step


def fp8_fsdpa_fwd(
self,
Expand Down Expand Up @@ -199,6 +204,9 @@ def fp8_apc_fsdpa_split_kv(
prefix_linv = prefix_linv.to(torch.float32) * (128.0 if softmax_mode == "fast" else 1.0)
prefix_out = self.dequant_output(prefix_out).to(torch.float32)

if self.with_mark_step:
self.mark_step()

# calculate the causal part
causal_k = k[..., prefix_len:, :]
causal_v = v[..., prefix_len:, :]
Expand Down Expand Up @@ -255,6 +263,9 @@ def fp8_apc_fsdpa_slice_causal(
prefix_linv = prefix_linv.to(torch.float32) * (128.0 if softmax_mode == "fast" else 1.0)
prefix_out = self.dequant_output(prefix_out).to(torch.float32)

if self.with_mark_step:
self.mark_step()

# calculate the causal part
chunk_outputs = []
num_chunks = (q_len + self.qkv_chunk_size - 1) // self.qkv_chunk_size
Expand Down Expand Up @@ -284,6 +295,19 @@ def fp8_apc_fsdpa_slice_causal(
if kv_chunk_idx == 0 and not is_causal_chunk
else None
)

if self.with_mark_step:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to add some comments why need clone and mark_step here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

# mark_step() cannot break the tensor slicing, use clone to isolate the graph
q_chunk = q_chunk.clone()
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
if mask_chunk is not None:
mask_chunk = mask_chunk.clone()
last_out = last_out.clone()
last_m = last_m.clone()
last_linv = last_linv.clone()
self.mark_step()

chunk_res = self.fp8_fsdpa_fwd(
q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, softmax_mode
)
Expand All @@ -301,6 +325,9 @@ def fp8_apc_fsdpa_slice_causal(
chunk_linv_rescaled * last_linv
) * chunk_out
last_m = new_m

if self.with_mark_step:
self.mark_step()
chunk_outputs.append(last_out)
chunk_outputs = list(reversed(chunk_outputs))
return torch.cat(chunk_outputs, dim=-2)
Expand Down Expand Up @@ -352,6 +379,12 @@ def fp8_apc_fsdpa_slice_qkv(
k_chunk = k[..., kv_start:kv_end, :]
v_chunk = v[..., kv_start:kv_end, :]

if self.with_mark_step:
q_chunk = q_chunk.clone()
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
self.mark_step()

chunk_res = self.fp8_fsdpa_fwd(
q_chunk, k_chunk, v_chunk, None, dropout_p, scale, False, softmax_mode
)
Expand All @@ -373,6 +406,9 @@ def fp8_apc_fsdpa_slice_qkv(
chunk_linv_rescaled * last_linv
) * chunk_out
last_m = new_m

if self.with_mark_step:
self.mark_step()

for kv_chunk_idx in range(0, num_q_chunks - q_chunk_idx):
kv_start = prefix_len + q_end - (kv_chunk_idx + 1) * self.qkv_chunk_size
Expand All @@ -389,6 +425,15 @@ def fp8_apc_fsdpa_slice_qkv(
if kv_chunk_idx == 0 and not is_causal_chunk
else None
)

if self.with_mark_step:
q_chunk = q_chunk.clone()
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
if mask_chunk is not None:
mask_chunk = mask_chunk.clone()
self.mark_step()

chunk_res = self.fp8_fsdpa_fwd(
q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, softmax_mode
)
Expand All @@ -410,6 +455,10 @@ def fp8_apc_fsdpa_slice_qkv(
last_out = (last_linv_rescaled * last_linv) * last_out + \
(chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m

if self.with_mark_step:
self.mark_step()

chunk_outputs.append(last_out)
chunk_outputs = list(reversed(chunk_outputs))
return torch.cat(chunk_outputs, dim=-2)
Expand Down Expand Up @@ -469,6 +518,13 @@ def fp8_causal_fsdpa_slice_qkv(
mask_chunk = (1.0 - torch.tril(torch.ones(mask_shape, dtype=self.hp_dtype, device=q_chunk.device))) * -3e38
else:
mask_chunk = None

if self.with_mark_step:
q_chunk = q_chunk.clone()
k_chunk = k_chunk.clone()
v_chunk = v_chunk.clone()
self.mark_step()

chunk_res = self.fp8_fsdpa_fwd(
q_chunk, k_chunk, v_chunk, mask_chunk, dropout_p, scale, is_causal_chunk, softmax_mode
)
Expand All @@ -490,6 +546,10 @@ def fp8_causal_fsdpa_slice_qkv(
last_out = (last_linv_rescaled * last_linv) * last_out + \
(chunk_linv_rescaled * last_linv) * chunk_out
last_m = new_m

if self.with_mark_step:
self.mark_step()

chunk_outputs.append(last_out)
chunk_outputs = list(reversed(chunk_outputs))
return torch.cat(chunk_outputs, dim=-2)
Expand Down
Loading