Skip to content

Commit 5506f27

Browse files
committed
Updates for consistency
1 parent 77e87b7 commit 5506f27

File tree

3 files changed

+23
-22
lines changed

3 files changed

+23
-22
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,19 @@
5858
is_torch_min_version,
5959
)
6060

61-
if TYPE_CHECKING:
62-
# For type checking, treat transformer_engine as always available.
61+
try:
6362
import transformer_engine as te
6463
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast
6564

6665
HAVE_TE = True
67-
else:
68-
try:
66+
except ImportError:
67+
if TYPE_CHECKING:
68+
# For type checking, treat transformer_engine as always available.
6969
import transformer_engine as te
7070
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, fp8_autocast
7171

7272
HAVE_TE = True
73-
except ImportError:
73+
else:
7474
from unittest.mock import MagicMock
7575

7676
te = MagicMock()

megatron/core/transformer/attention.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@
113113
HAVE_FUSED_QKV_ROPE = False
114114

115115

116-
class LinearQkv(Protocol):
117-
"""Protocol for linear_qkv modules."""
116+
class LinearQkvInterface(Protocol):
117+
"""Interface required for linear_qkv modules."""
118118

119119
def forward(self, input: Tensor, /) -> tuple[Tensor, object]:
120120
"""Applies linear_qkv."""
@@ -142,11 +142,11 @@ def __call__(
142142
is_expert: bool,
143143
tp_comm_buffer_name: str,
144144
tp_group: torch.distributed.ProcessGroup | None = None,
145-
) -> LinearQkv: ...
145+
) -> LinearQkvInterface: ...
146146

147147

148-
class LinearLayer(Protocol):
149-
"""Protocol for linear_q and linear_kv modules."""
148+
class LinearLayerInterface(Protocol):
149+
"""Interface required for linear_q and linear_kv modules."""
150150

151151
def forward(self, input: Tensor, /) -> Tuple[Tensor, object]:
152152
"""Applies linear_q/linear_kv."""
@@ -168,23 +168,23 @@ def __call__(
168168
bias: bool,
169169
skip_bias_add: bool,
170170
is_expert: bool,
171-
) -> LinearLayer: ...
171+
) -> LinearLayerInterface: ...
172172

173173

174-
class CoreAttention(Protocol):
175-
"""Protocol for core_attention modules."""
174+
class CoreAttentionInterface(Protocol):
175+
"""Interface required for core_attention modules."""
176176

177177
def forward(
178178
self,
179179
query: Tensor,
180180
key: Tensor,
181181
value: Tensor,
182-
attention_mask: Optional[Tensor],
182+
attention_mask: Tensor | None,
183183
/,
184184
*,
185185
attn_mask_type: AttnMaskType,
186-
attention_bias: Optional[Tensor],
187-
packed_seq_params: Optional[PackedSeqParams],
186+
attention_bias: Tensor | None,
187+
packed_seq_params: PackedSeqParams | None,
188188
) -> Tensor:
189189
"""Applies dot product attention."""
190190
...
@@ -200,10 +200,10 @@ def __call__(
200200
layer_number: int,
201201
attn_mask_type: AttnMaskType,
202202
attention_type: str,
203-
cp_comm_type: Optional[str],
204-
softmax_scale: Optional[float],
205-
pg_collection: Optional[ProcessGroupCollection],
206-
) -> CoreAttention: ...
203+
cp_comm_type: str | None,
204+
softmax_scale: float | None,
205+
pg_collection: ProcessGroupCollection | None,
206+
) -> CoreAttentionInterface: ...
207207

208208

209209
@dataclass
@@ -1578,10 +1578,10 @@ def __init__(
15781578
def get_query_key_value_tensors(
15791579
self,
15801580
hidden_states: Tensor,
1581-
key_value_states: Optional[Tensor],
1581+
key_value_states: Tensor | None,
15821582
output_gate: bool = False,
15831583
split_qkv: bool = True,
1584-
) -> Tuple[Tensor, Tensor, Tensor]:
1584+
) -> tuple[Tensor, Tensor, Tensor]:
15851585
"""
15861586
Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
15871587
from `key_value_states`.

megatron/core/typed_torch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
"""Utilities for improved type hinting with torch interfaces."""
3+
from __future__ import annotations
34

45
from collections.abc import Callable
56
from typing import Generic, ParamSpec, Protocol, TypeVar

0 commit comments

Comments
 (0)