|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | | -from contextlib import contextmanager |
5 | 4 | from dataclasses import dataclass, field |
6 | 5 | from io import StringIO |
7 | 6 |
|
@@ -70,7 +69,7 @@ class DuoAttentionPress(BasePress): |
70 | 69 | sink_size: int = field(init=False, default=None) |
71 | 70 | streaming_mask: torch.Tensor = field(init=False, default=None) |
72 | 71 |
|
73 | | - def __post_init_from_model__(self, model): |
| 72 | + def post_init_from_model(self, model): |
74 | 73 | """ |
75 | 74 | Initialize sink_size, recent_size, and streaming_mask from a model |
76 | 75 | """ |
@@ -101,7 +100,7 @@ def compress(self, module, hidden_states, keys, values, attentions, kwargs): |
101 | 100 | assert module.config._attn_implementation != "eager", "eager mode not supported" |
102 | 101 | if self.streaming_mask is None: |
103 | 102 | raise ValueError( |
104 | | - "Streaming mask not initialized. Make sure to call __post_init_from_model__ to initialize this press." |
| 103 | + "Streaming mask not initialized. Make sure to call post_init_from_model to initialize this press." |
105 | 104 | ) |
106 | 105 | k_len = keys.shape[2] |
107 | 106 |
|
@@ -141,12 +140,6 @@ def load_attention_pattern(model): |
141 | 140 |
|
142 | 141 | return config["sink_size"], config["recent_size"], head_scores |
143 | 142 |
|
144 | | - @contextmanager |
145 | | - def __call__(self, model): |
146 | | - self.__post_init_from_model__(model) |
147 | | - with super().__call__(model): |
148 | | - yield |
149 | | - |
150 | 143 |
|
151 | 144 | @cached(cache, key=lambda model, num_samples=50, q_len=500: (model.config.name_or_path, num_samples, q_len)) |
152 | 145 | def duo_attention_on_the_fly(model, num_samples=50, q_len=500): |
|
0 commit comments