From a33e30a7089b111159c65ad1d9e9ab3159b5632a Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Wed, 26 Feb 2025 09:31:45 -0800 Subject: [PATCH] Add mask_val option to StaticAttentionmask Summary: This is needed to support quantization where a negative constant should be used instead of negative infinity. Differential Revision: D70255619 --- examples/models/llama/static_attention.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 43db873fb65..9bb5cee5b2a 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -125,18 +125,19 @@ def update( class StaticAttentionMask: - def __init__(self, input_len, cache_len, style): + def __init__(self, input_len, cache_len, style, mask_val=float("-inf")): self.input_len = input_len self.cache_len = cache_len assert style in ("shift_pointer", "smart_mask") self.style = style + self.mask_val = mask_val self.unmasked_len = 0 self.tensor = torch.zeros(1, input_len, input_len + cache_len) self.reset() def reset(self): self.unmasked_len = 0 - self.tensor[:, :, : self.cache_len] = float("-inf") + self.tensor[:, :, : self.cache_len] = self.mask_val def unmask(self, new_unmasked_len): if new_unmasked_len <= 0: