File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed
examples/models/llama2/source_transformation/torchtune/modules Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -70,9 +70,8 @@ class MultiHeadAttention(nn.Module):
7070 max_seq_len (int): maximum sequence length supported by the model.
7171 This is needed to compute the RoPE Cache. Default: 4096.
7272 is_causal (bool): sets the default mask to causal when no mask is provided
73- attn_dropout (float): dropout value passed onto the
74- scaled_dot_product_attention function. This argument is ignored if the
75- self.training is False. Default value is 0.0.
73+ attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function.
74+ This argument is ignored if self.training is False. Default value is 0.0.
7675
7776 Raises:
7877 ValueError: If ``num_heads % num_kv_heads != 0``
@@ -147,7 +146,7 @@ def __init__(
147146 num_heads = self .num_heads ,
148147 head_dim = self .head_dim ,
149148 q_per_kv = self .q_per_kv ,
150- attn_dropout = self .attn_dropout ,
149+ attn_dropout = self .attn_dropout if self . training else 0.0 ,
151150 is_causal = self .is_causal ,
152151 attention_fn = self ._attention_call ,
153152 kv_cache = self .kv_cache ,
You can’t perform that action at this time.
0 commit comments