@@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module):
2222 def __init__ (
2323 self ,
2424 dim : int ,
25+ max_context_len ,
26+ enable_dynamic_shape ,
27+ use_attention_mask : bool = False ,
2528 ):
2629 super ().__init__ ()
2730 self .dim = dim
31+ self .max_context_len = max_context_len
32+ self .use_attention_mask = use_attention_mask
33+ self .enable_dynamic_shape = enable_dynamic_shape
2834
2935 def forward (
3036 self ,
@@ -36,6 +42,16 @@ def forward(
3642 seqlen ,
3743 mask ,
3844 ):
45+ if self .enable_dynamic_shape :
46+ start_pos = input_pos [- 1 ].item ()
47+ torch ._check_is_size (start_pos )
48+ torch ._check (start_pos < self .max_context_len )
49+ seq_length = q .size (2 )
50+ # pyre-ignore: Incompatible parameter type [6]
51+ mask = mask .narrow (0 , start_pos , seq_length )
52+ else :
53+ mask = mask [input_pos ]
54+
3955 q = q .transpose (1 , 2 ) # (bs, seqlen, n_local_heads, head_dim)
4056 k = k .transpose (1 , 2 )
4157 v = v .transpose (1 , 2 )
@@ -47,34 +63,54 @@ def forward(
4763 k = k .to (dtype = torch .float )
4864 v = v .to (dtype = torch .float )
4965
50- output = torch .ops .llama .custom_sdpa (
51- q ,
52- k ,
53- v ,
54- input_pos [0 ].item (),
55- None , # Attention mask
56- 0 , # dropout probability. Ignored by the code
57- True , # is_causal
58- )
66+ if self .use_attention_mask :
67+ output = torch .ops .llama .custom_sdpa (
68+ q ,
69+ k ,
70+ v ,
71+ input_pos [0 ].item (),
72+ mask , # Attention mask
73+ 0 , # dropout probability. Ignored by the code
74+ False , # is_causal
75+ )
76+ else :
77+ output = torch .ops .llama .custom_sdpa (
78+ q ,
79+ k ,
80+ v ,
81+ input_pos [0 ].item (),
82+ None , # Attention mask
83+ 0 , # dropout probability. Ignored by the code
84+ True , # is_causal
85+ )
5986 return output .view (bsz , seqlen , self .dim ).to (dtype = input_dtype )
6087
6188
62- def _replace_sdpa_with_custom_op (module : torch .nn .Module ):
89+ def _replace_sdpa_with_custom_op (
90+ module : torch .nn .Module , use_attention_mask : bool = False
91+ ):
6392 for name , child in module .named_children ():
6493 if isinstance (child , SDPA ):
6594 setattr (
6695 module ,
6796 name ,
68- SDPACustom (child .dim ),
97+ SDPACustom (
98+ child .dim ,
99+ child .max_context_len ,
100+ child .enable_dynamic_shape ,
101+ use_attention_mask = use_attention_mask ,
102+ ),
69103 )
70104 else :
71- _replace_sdpa_with_custom_op (child )
105+ _replace_sdpa_with_custom_op (child , use_attention_mask = use_attention_mask )
72106
73107
74- def replace_sdpa_with_custom_op (module : torch .nn .Module ) -> torch .nn .Module :
108+ def replace_sdpa_with_custom_op (
109+ module : torch .nn .Module , use_attention_mask : bool = False
110+ ) -> torch .nn .Module :
75111 from executorch .extension .llm .custom_ops import custom_ops # noqa
76112
77- _replace_sdpa_with_custom_op (module )
113+ _replace_sdpa_with_custom_op (module , use_attention_mask = use_attention_mask )
78114 return module
79115
80116
0 commit comments