Skip to content

Commit b8b2ecb

Browse files
authored
enable_qnn_masked_softmax
Differential Revision: D81248699 Pull Request resolved: pytorch#13788
1 parent dc57e56 commit b8b2ecb

File tree

4 files changed

+21
-5
lines changed

4 files changed

+21
-5
lines changed

examples/models/llama/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def __init__(
331331
args: ModelArgs,
332332
layer_id: int,
333333
rope: Rope,
334+
**_kwargs: Any,
334335
):
335336
"""
336337
Multi-head attention layer.

examples/models/llama/llama_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
121121
f"Available: {list(ATTENTION_REGISTRY.keys())}"
122122
)
123123
cls = ATTENTION_REGISTRY[args.attention_type]
124-
attention = cls(args, layer_id, rope)
124+
attention = cls(args, layer_id, rope, **args.attention_kwargs)
125125
return TransformerBlock(args, attention)
126126

127127
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
@@ -255,7 +255,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
255255
layers = torch.nn.ModuleList()
256256
cls = ATTENTION_REGISTRY[model_args.attention_type]
257257
for layer_id in range(model_args.n_layers):
258-
attention = cls(model_args, layer_id, rope)
258+
attention = cls(model_args, layer_id, rope, **model_args.attention_kwargs)
259259
transformer_block = TransformerBlock(model_args, attention)
260260
layers.append(transformer_block)
261261

examples/models/llama/model_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import dataclasses
12
from dataclasses import dataclass
2-
from typing import Dict, Optional
3+
from typing import Any, Dict, Optional
34

45

56
@dataclass
@@ -69,6 +70,7 @@ class ModelArgs:
6970
kv_io_bit_width: Optional[int] = (
7071
None # KV cache bit width. This is for QNN backend only for now.
7172
)
73+
attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
7274

7375
def __post_init__(self):
7476
if self.n_kv_heads is None:

examples/models/llama/static_attention.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,12 @@ class StaticAttention(Attention):
658658
"""
659659

660660
def __init__(
661-
self, config: ModelArgs, layer_id: int, rope: Rope, split_mha: bool = True
661+
self,
662+
config: ModelArgs,
663+
layer_id: int,
664+
rope: Rope,
665+
split_mha: bool = True,
666+
**kwargs: Any,
662667
):
663668
super().__init__()
664669
self.n_heads = config.n_heads
@@ -676,6 +681,7 @@ def __init__(
676681
self.qk_norm_before_rope = config.qk_norm_before_rope
677682
self.split_mha = split_mha
678683
self.use_conv2d = False
684+
self.enable_qnn_masked_softmax = kwargs.get("enable_qnn_masked_softmax", False)
679685

680686
if self.split_mha:
681687
self.wqs = nn.ModuleList(
@@ -857,7 +863,14 @@ def _forward_sha(
857863
kv_idx = i // self.n_heads_per_kv_group
858864
attn = new_qs[i] @ all_ks[kv_idx].transpose(-2, -1)
859865
attn = attn * self.inv_scale
860-
attn = attn + mask
866+
if self.enable_qnn_masked_softmax:
867+
attn_min = torch.amin(attn, dim=-1, keepdim=True)
868+
minus_value = -20
869+
attn = torch.where(
870+
mask == 0, attn, attn_min + minus_value
871+
) # prye-ignore
872+
else:
873+
attn = attn + mask
861874
attn = F.softmax(attn, dim=-1)
862875
heads.append(attn @ all_vs[kv_idx])
863876

0 commit comments

Comments
 (0)