diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 43db873fb65..9bb5cee5b2a 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -125,18 +125,19 @@ def update( class StaticAttentionMask: - def __init__(self, input_len, cache_len, style): + def __init__(self, input_len, cache_len, style, mask_val=float("-inf")): self.input_len = input_len self.cache_len = cache_len assert style in ("shift_pointer", "smart_mask") self.style = style + self.mask_val = mask_val self.unmasked_len = 0 self.tensor = torch.zeros(1, input_len, input_len + cache_len) self.reset() def reset(self): self.unmasked_len = 0 - self.tensor[:, :, : self.cache_len] = float("-inf") + self.tensor[:, :, : self.cache_len] = self.mask_val def unmask(self, new_unmasked_len): if new_unmasked_len <= 0: