Skip to content

Commit 1a798a8

Browse files
authored
gemma3: fix accuracy issue caused by not skipping image on top right (#1635)
This PR has 2 changes for gemma3 1. bring back the image mask off for top right with memory optimization. 2. create an env VLLM_FUSEDSDPA_SLIDE_RIGHT to used for fused sdpa sliding window right side. It's targeted to improve performance/accuracy for longer sequence. ## (Optional) Documentation Update <!--- pyml disable-next-line no-emphasis-as-heading -->
1 parent 0793bf1 commit 1a798a8

File tree

4 files changed

+74
-39
lines changed

4 files changed

+74
-39
lines changed

README_GAUDI.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ batch size is often at its maximum, making large-batch HPU graphs critical to ca
367367
- `VLLM_HANDLE_TOPK_DUPLICATES`: if ``true`` - handles duplicates outside top-k. The default is `false`.
368368
- `VLLM_CONFIG_HIDDEN_LAYERS`: configures how many hidden layers to run in a HPUGraph for model splitting among hidden layers when TP is 1. It helps to improve throughput by reducing inter-token latency limitations in some models. The default is `1`.
369369
- `VLLM_SKIP_WARMUP`: if `true`, warm-up is skipped. The default is `false`.
370+
- `VLLM_FUSEDSDPA_SLIDE_RIGHT`: right sliding window size when fusedsdpa used with sliding window. It helps with memory and performance when long context is used. The default is `0`. Example: for sliding window of size 1024, set VLLM_FUSEDSDPA_SLIDE_RIGHT=1024
370371

371372
> [!TIP]
372373
> When a deployed workload does not utilize the full context that a model can handle, it is good practice to limit the maximum values upfront based on the input and output token lengths that will be generated after serving the vLLM server.

vllm/attention/backends/hpu_attn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
143143
window_block_usage: Optional[torch.Tensor] = None
144144
window_attn_bias: Optional[torch.Tensor] = None
145145
use_window_sdpa: Optional[bool] = None
146+
sliding_window_right: Optional[int] = None
146147

147148

148149
@dataclass
@@ -425,6 +426,7 @@ def __init__(
425426

426427
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
427428
self.sliding_window = sliding_window
429+
428430
self.prompt_position_bias = None
429431
self.prev_attn = None
430432
self.alibi_slopes = None
@@ -569,7 +571,8 @@ def forward(
569571

570572
if attn_metadata.use_window_sdpa:
571573
attn_bias = attn_metadata.attn_bias
572-
window_size = (self.sliding_window, 0)
574+
window_size = (self.sliding_window,
575+
attn_metadata.sliding_window_right)
573576
common_args['window_size'] = window_size
574577
# TODO: Currently HPU doesn't support GQA for FusedSDPA
575578
# with causal + window, so repeat KV so QKV are all the

vllm/model_executor/models/gemma3_mm.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,43 @@ def forward(self,
664664

665665
return hidden_states
666666

667+
def hpu_build_mask(self, input_ids: torch.Tensor,
668+
mask_dtype: torch.dtype) -> torch.Tensor:
669+
bs, seq_len = input_ids.shape
670+
device = input_ids.device
671+
img_tokens = self.config.mm_tokens_per_image
672+
image_token_index = self.config.image_token_index
673+
# bool causal mask (True == masked)
674+
causal_bool = torch.triu(
675+
torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), 1)
676+
mask_bool = causal_bool.unsqueeze(0).unsqueeze(0).expand(
677+
bs, 1, -1, -1).clone()
678+
679+
# pre-compute a few broadcastable helpers
680+
img_pos = (input_ids == image_token_index) # [B,S]
681+
img_row = img_pos.unsqueeze(1).unsqueeze(3) # [B,1,S,1]
682+
img_col = img_pos.unsqueeze(1).unsqueeze(2) # [B,1,1,S]
683+
684+
img_pos_cum = torch.cumsum(img_pos, 1)
685+
img_causal = torch.arange(seq_len, device=device).unsqueeze(0) \
686+
- img_pos_cum + (img_pos_cum // img_tokens + 1) * img_tokens + 1
687+
img_causal = torch.cat((img_causal[:, :1] - 1, img_causal[:, :-1]), 1) \
688+
.clamp_(0, seq_len - 1) \
689+
.unsqueeze(1).unsqueeze(3) # [B,1,S,1]
690+
ind = torch.arange(seq_len, device=device).view(1, 1, 1,
691+
-1) # [1,1,1,S]
692+
693+
# positions we must *unmask* (row img ∧ col img
694+
# ∧ col < img_causal)
695+
allow = img_row & img_col & (ind < img_causal)
696+
mask_bool &= ~allow # flip to False
697+
698+
# 4) final bfp16/32 version
699+
out = torch.zeros_like(mask_bool, dtype=mask_dtype) \
700+
.masked_fill(mask_bool, float("-inf"))
701+
702+
return out
703+
667704
def prepare_attn_masks(
668705
self,
669706
input_ids: torch.Tensor,
@@ -697,40 +734,35 @@ def prepare_attn_masks(
697734
local_attn_masks = []
698735
start_idx = 0
699736
for seq_len in seq_lens:
700-
if not is_hpu:
737+
if is_hpu:
738+
global_attn_mask = self.hpu_build_mask(input_ids, mask_dtype)
739+
else:
701740
end_idx = start_idx + seq_len
702741
input_token_ids = input_ids[start_idx:end_idx]
703742
start_idx = end_idx
704743
bs = 1
705-
else:
706-
input_token_ids = input_ids
707-
# Create a global causal mask.
708-
global_attn_mask = torch.empty(
709-
bs,
710-
1,
711-
seq_len,
712-
seq_len,
713-
dtype=mask_dtype,
714-
device=input_ids.device,
715-
)
716-
global_attn_mask.fill_(float("-inf"))
717-
# Fill the lower triangle with 0.
718-
global_attn_mask = global_attn_mask.triu(diagonal=1)
744+
# Create a global causal mask.
745+
global_attn_mask = torch.empty(
746+
bs,
747+
1,
748+
seq_len,
749+
seq_len,
750+
dtype=mask_dtype,
751+
device=input_ids.device,
752+
)
753+
global_attn_mask.fill_(float("-inf"))
754+
# Fill the lower triangle with 0.
755+
global_attn_mask = global_attn_mask.triu(diagonal=1)
719756

720-
# Consider the bidirectional attention between image tokens.
721-
img_mask = torch.zeros_like(global_attn_mask)
722-
img_pos = (input_token_ids == self.config.image_token_index)
757+
# Consider the bidirectional attention between image tokens.
758+
img_mask = torch.zeros_like(global_attn_mask)
759+
img_pos = (input_token_ids == self.config.image_token_index)
723760

724-
if not is_hpu:
725761
img_mask[:, :, :, img_pos] += 1
726762
img_mask[:, :, img_pos, :] += 1
727-
else:
728-
img_mask[img_pos.unsqueeze(1)] += 1
729-
img_mask = img_mask.permute(0, 1, 3, 2)
730-
img_mask[img_pos.unsqueeze(1)] += 1
731-
img_mask = img_mask.permute(0, 1, 3, 2)
763+
global_attn_mask = torch.where(img_mask == 2, 0,
764+
global_attn_mask)
732765

733-
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
734766
global_attn_masks.append(global_attn_mask)
735767

736768
if self.sliding_window is not None:

vllm/worker/hpu_model_runner.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,20 @@ def __init__(self, model, vllm_config, is_causal, sampler):
354354
self.use_window_sdpa = os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD",
355355
"false").strip().lower() in ("1",
356356
"true")
357+
self.sliding_window_right = 0
357358
if self.use_window_sdpa:
358359
self.slice_size = int(
359360
os.getenv("PT_HPU_QKV_SLICE_SEQ_LEN_THLD", "1024"))
360361

361362
os.environ["PT_HPU_SDPA_BC_FACTOR"] = str(self.slice_size)
362363
os.environ["PT_HPU_SDPA_BR_FACTOR"] = str(self.slice_size)
363364
os.environ["PT_HPU_QKV_SLICE_SEQ_LEN_THLD"] = str(self.slice_size)
365+
self.sliding_window_right = int(
366+
os.environ.get('VLLM_FUSEDSDPA_SLIDE_RIGHT', '0'))
367+
assert self.sliding_window_right % self.slice_size == 0, \
368+
f'VLLM_FUSEDSDPA_SLIDE_RIGHT({self.sliding_window_right}) '\
369+
f'not supported due to not a multiplier of '\
370+
f'PT_HPU_QKV_SLICE_SEQ_LEN_THLD({self.slice_size})!'
364371

365372
# This applies exclusively to Qwen2/2.5-VL models
366373
# both use mrope. We wrap the visual and language
@@ -582,6 +589,8 @@ def _update_use_window_sdpa(self, attn_metadata, seq_len):
582589
f"VLLM_PROMPT_SEQ_BUCKET_STEP: 1024 ")
583590

584591
attn_metadata = attn_metadata._replace(use_window_sdpa=use_window_sdpa)
592+
attn_metadata = attn_metadata._replace(
593+
sliding_window_right=self.sliding_window_right)
585594
return attn_metadata
586595

587596
def _update_metadata(self,
@@ -1484,17 +1493,6 @@ def move_to_device(self, tensor):
14841493
return tensor if tensor is None else tensor.to(self.device,
14851494
non_blocking=True)
14861495

1487-
def _get_position_pad(self) -> int:
1488-
"""
1489-
For gemma3 models,
1490-
due to the Hack in Gemma3ForConditionalGeneration::prepare_attn_masks,
1491-
'0' can't be used as pad for input position tensor.
1492-
In case, it might have '0's for bucketing, those '0' will be counted as
1493-
new sequence in the prepare_attn_masks() which is wrong.
1494-
"""
1495-
model_type = getattr(self.model_config.hf_config, 'model_type', '')
1496-
return -1 if model_type == 'gemma3' else 0
1497-
14981496
def add_vision_buckets_to_mrope_mm_optimized(self):
14991497
if self.mm_registry is not None:
15001498
model = self.get_model()
@@ -1750,11 +1748,11 @@ def _prepare_prompt(
17501748
make_mrope_positions_tensor_with_pad(input_positions=input_positions,
17511749
input_mrope_positions=input_mrope_positions,
17521750
max_prompt_len=max_prompt_len,
1753-
pad=self._get_position_pad())
1751+
pad=0)
17541752
else:
17551753
input_positions = make_cpu_tensor(input_positions,
17561754
max_len=max_prompt_len,
1757-
pad=self._get_position_pad(),
1755+
pad=0,
17581756
dtype=torch.long,
17591757
flat=self.use_merged_prefill)
17601758

@@ -2663,6 +2661,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
26632661
'window_block_groups',
26642662
'window_attn_bias',
26652663
'use_window_sdpa',
2664+
'sliding_window_right',
26662665
])
26672666
return attention_metadata
26682667

0 commit comments

Comments
 (0)