113113 HAVE_FUSED_QKV_ROPE = False
114114
115115
116- class LinearQkv (Protocol ):
117- """Protocol for linear_qkv modules."""
116+ class LinearQkvInterface (Protocol ):
117+ """Interface required for linear_qkv modules."""
118118
119119 def forward (self , input : Tensor , / ) -> tuple [Tensor , object ]:
120120 """Applies linear_qkv."""
@@ -142,11 +142,11 @@ def __call__(
142142 is_expert : bool ,
143143 tp_comm_buffer_name : str ,
144144 tp_group : torch .distributed .ProcessGroup | None = None ,
145- ) -> LinearQkv : ...
145+ ) -> LinearQkvInterface : ...
146146
147147
148- class LinearLayer (Protocol ):
149- """Protocol for linear_q and linear_kv modules."""
148+ class LinearLayerInterface (Protocol ):
149+ """Interface required for linear_q and linear_kv modules."""
150150
151151 def forward (self , input : Tensor , / ) -> Tuple [Tensor , object ]:
152152 """Applies linear_q/linear_kv."""
@@ -168,23 +168,23 @@ def __call__(
168168 bias : bool ,
169169 skip_bias_add : bool ,
170170 is_expert : bool ,
171- ) -> LinearLayer : ...
171+ ) -> LinearLayerInterface : ...
172172
173173
174- class CoreAttention (Protocol ):
175- """Protocol for core_attention modules."""
174+ class CoreAttentionInterface (Protocol ):
175+ """Interface required for core_attention modules."""
176176
177177 def forward (
178178 self ,
179179 query : Tensor ,
180180 key : Tensor ,
181181 value : Tensor ,
182- attention_mask : Optional [ Tensor ] ,
182+ attention_mask : Tensor | None ,
183183 / ,
184184 * ,
185185 attn_mask_type : AttnMaskType ,
186- attention_bias : Optional [ Tensor ] ,
187- packed_seq_params : Optional [ PackedSeqParams ] ,
186+ attention_bias : Tensor | None ,
187+ packed_seq_params : PackedSeqParams | None ,
188188 ) -> Tensor :
189189 """Applies dot product attention."""
190190 ...
@@ -200,10 +200,10 @@ def __call__(
200200 layer_number : int ,
201201 attn_mask_type : AttnMaskType ,
202202 attention_type : str ,
203- cp_comm_type : Optional [ str ] ,
204- softmax_scale : Optional [ float ] ,
205- pg_collection : Optional [ ProcessGroupCollection ] ,
206- ) -> CoreAttention : ...
203+ cp_comm_type : str | None ,
204+ softmax_scale : float | None ,
205+ pg_collection : ProcessGroupCollection | None ,
206+ ) -> CoreAttentionInterface : ...
207207
208208
209209@dataclass
@@ -1578,10 +1578,10 @@ def __init__(
15781578 def get_query_key_value_tensors (
15791579 self ,
15801580 hidden_states : Tensor ,
1581- key_value_states : Optional [ Tensor ] ,
1581+ key_value_states : Tensor | None ,
15821582 output_gate : bool = False ,
15831583 split_qkv : bool = True ,
1584- ) -> Tuple [Tensor , Tensor , Tensor ]:
1584+ ) -> tuple [Tensor , Tensor , Tensor ]:
15851585 """
15861586 Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
15871587 from `key_value_states`.
0 commit comments