@@ -57,6 +57,7 @@ def forward(
5757 input_pos : torch .Tensor ,
5858 kv_cache : kv_utils .KVCache ,
5959 input_embeds : torch .Tensor = None ,
60+ mask : Optional [torch .Tensor ] = None ,
6061 export_config : Optional [model_builder .ExportConfig ] = None ,
6162 called_by_generate : bool = True ,
6263 ) -> dict [torch .Tensor , kv_utils .KVCache ]:
@@ -73,17 +74,21 @@ def forward(
7374 repo_pos , n_elem , attn_config .head_dim , attn_config .rotary_base
7475 )
7576
76- if called_by_generate :
77- # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
78- mask = [self .get_attention_mask (
79- self .config .block_config (i ).attn_config .attn_type , input_pos
80- ) for i in range (self .config .num_layers )]
81- else :
82- # By default, don't mask image embeds with a diagonal causal mask.
83- embeds_len = input_embeds .shape [1 ]
84- mask = torch .zeros (embeds_len , self .config .kv_cache_max )
85- mask [:, embeds_len :] = float ("-inf" )
86- mask = [mask ] * self .config .num_layers
77+ if mask is None :
78+ if called_by_generate :
79+ # PaliGemma2 generate() use a diagonal causal mask even with image embeds.
80+ mask = [
81+ self .get_attention_mask (
82+ self .config .block_config (i ).attn_config .attn_type , input_pos
83+ )
84+ for i in range (self .config .num_layers )
85+ ]
86+ else :
87+ # By default, don't mask image embeds with a diagonal causal mask.
88+ embeds_len = input_embeds .shape [1 ]
89+ mask = torch .zeros (embeds_len , self .config .kv_cache_max )
90+ mask [:, embeds_len :] = float ("-inf" )
91+ mask = [mask ] * self .config .num_layers
8792
8893 return self ._forward_with_embeds (
8994 input_embeds , rope , mask , input_pos , kv_cache , export_config
0 commit comments