@@ -144,12 +144,13 @@ def forward(
144144 attn_config = self .config .block_config (0 ).attn_config
145145 n_elem = int (attn_config .rotary_percentage * attn_config .head_dim )
146146 rope = rotary_pos_emb .build_rope (input_pos , n_elem , attn_config .rotary_base )
147- mask = [
148- self .get_attention_mask (
149- self .config .block_config (i ).attn_config .attn_type , input_pos
150- )
151- for i in range (self .config .num_layers )
152- ]
147+ if mask is None :
148+ mask = [
149+ self .get_attention_mask (
150+ self .config .block_config (i ).attn_config .attn_type , input_pos
151+ )
152+ for i in range (self .config .num_layers )
153+ ]
153154
154155 return self ._forward_with_embeds (
155156 input_embeds , rope , mask , input_pos , kv_cache , export_config
@@ -159,7 +160,7 @@ def _forward_with_embeds(
159160 self ,
160161 input_embeds : torch .Tensor ,
161162 rope : Tuple [torch .Tensor , torch .Tensor ],
162- mask : List [torch .Tensor ],
163+ mask : torch . Tensor | List [torch .Tensor ],
163164 input_pos : torch .Tensor ,
164165 kv_cache : kv_utils .KVCache ,
165166 export_config : Optional [model_builder .ExportConfig ] = None ,
@@ -174,17 +175,10 @@ def _forward_with_embeds(
174175 input_embeds = input_embeds * self .config .embedding_scale
175176 x = input_embeds
176177 updated_kv_entries = []
177- mask_input = mask is not None
178178 for i , block in enumerate (self .transformer_blocks ):
179- mask = (
180- mask
181- if mask_input
182- else self .get_attention_mask (
183- block .config .attn_config .attn_type , input_pos
184- )
185- )
179+ mask_entry = mask [i ] if isinstance (mask , list ) else mask
186180 kv_entry = kv_cache .caches [i ] if kv_cache else None
187- x , kv_entry = block (x , rope , mask [ i ] , input_pos , kv_entry )
181+ x , kv_entry = block (x , rope , mask_entry , input_pos , kv_entry )
188182 if kv_entry :
189183 updated_kv_entries .append (kv_entry )
190184 updated_kv_cache = kv_utils .KVCache (tuple (updated_kv_entries ))
0 commit comments