Skip to content

Commit 9233922

Browse files
authored
Refactor Code of cpp backend and flatten head (#220)
* refactor test code and flatten head logic * refactor is_cpp_backend_enable logic to fix partial import problem * change test_pipeline time to 900
1 parent 51a603d commit 9233922

19 files changed

+117
-165
lines changed

magi_attention/__init__.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,24 @@
1616
import os
1717
import warnings
1818

19+
from . import comm, config, functional
20+
from .dist_attn_runtime_mgr import (
21+
init_dist_attn_runtime_key,
22+
init_dist_attn_runtime_mgr,
23+
)
24+
25+
if importlib.util.find_spec("magi_attention._version") is None:
26+
warnings.warn(
27+
"You are using magi_attention without installing it. This may cause some unexpected errors."
28+
)
29+
version = None
30+
else:
31+
from ._version import __version__ as git_version
32+
33+
version = git_version
34+
35+
__version__: str | None = version
36+
1937

2038
def is_sanity_check_enable() -> bool:
2139
"""
@@ -85,17 +103,6 @@ def is_profile_mode_enable() -> bool:
85103
return os.environ.get("MAGI_ATTENTION_PROFILE_MODE", "0") == "1"
86104

87105

88-
def is_cpp_backend_enable() -> bool:
89-
"""
90-
Toggle this env variable to ``1`` to enable C++ backend
91-
for core data structures (AttnRange, AttnMaskType, etc.)
92-
and fall back to Python implementation.
93-
94-
Default value is ``0``
95-
"""
96-
return os.environ.get("MAGI_ATTENTION_CPP_BACKEND", "0") == "1"
97-
98-
99106
def dist_attn_runtime_dict_size() -> int:
100107
"""
101108
Set the value of this env variable to control
@@ -106,31 +113,13 @@ def dist_attn_runtime_dict_size() -> int:
106113
return int(os.environ.get("MAGI_ATTENTION_DIST_ATTN_RUNTIME_DICT_SIZE", "1000"))
107114

108115

109-
from . import comm, config, functional # noqa: E402
110-
from .dist_attn_runtime_mgr import ( # noqa: E402
111-
init_dist_attn_runtime_key,
112-
init_dist_attn_runtime_mgr,
113-
)
114-
115-
if importlib.util.find_spec("magi_attention._version") is None:
116-
warnings.warn(
117-
"You are using magi_attention without installing it. This may cause some unexpected errors."
118-
)
119-
version = None
120-
else:
121-
from ._version import __version__ as git_version
122-
123-
version = git_version
124-
125-
__version__: str | None = version
126-
127116
__all__ = [
128117
"init_dist_attn_runtime_key",
129118
"init_dist_attn_runtime_mgr",
130119
"is_sanity_check_enable",
131120
"is_flatten_head_groups_enable",
132121
"is_cuda_device_max_connections_one",
133-
"is_cpp_backend_enable",
122+
"dist_attn_runtime_dict_size",
134123
"config",
135124
"comm",
136125
"functional",

magi_attention/common/__init__.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from magi_attention import is_cpp_backend_enable
15+
import os
1616

17-
from . import enum, jit, range_op
18-
from .mask import AttnMask
19-
from .range import AttnRange, RangeError
20-
from .ranges import AttnRanges
21-
from .rectangle import AttnRectangle
22-
from .rectangles import AttnRectangles
17+
18+
def is_cpp_backend_enable() -> bool:
19+
"""
20+
Toggle this env variable to ``1`` to enable C++ backend
21+
for core data structures (AttnRange, AttnMaskType, etc.)
22+
and fall back to Python implementation.
23+
24+
Default value is ``0``
25+
"""
26+
return os.environ.get("MAGI_ATTENTION_CPP_BACKEND", "0") == "1"
27+
28+
29+
from . import enum, jit, range_op # noqa: E402
30+
from .mask import AttnMask # noqa: E402
31+
from .range import AttnRange, RangeError # noqa: E402
32+
from .ranges import AttnRanges # noqa: E402
33+
from .rectangle import AttnRectangle # noqa: E402
34+
from .rectangles import AttnRectangles # noqa: E402
2335

2436
# Try to use C++ extensions for core data structures to avoid Python overhead
2537
# The submodules (range, ranges, rectangle, rectangles, enum) already handle
@@ -47,4 +59,5 @@
4759
"AttnRectangles",
4860
"range_op",
4961
"USE_CPP_BACKEND",
62+
"is_cpp_backend_enable",
5063
]

magi_attention/common/enum.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from . import is_cpp_backend_enable
21+
2022
GroupReduceOp: TypeAlias = Literal["sum", "avg", "lse"]
2123

2224
OutMaybeWithLSE: TypeAlias = torch.Tensor | tuple[torch.Tensor, torch.Tensor]
@@ -124,8 +126,6 @@ class DynamicAttnAlgType(Enum):
124126
BINARY_GREEDY_PARALLEL = "binary_greedy_parallel"
125127

126128

127-
from magi_attention import is_cpp_backend_enable # noqa: E402
128-
129129
if is_cpp_backend_enable():
130130
try:
131131
from magi_attention.magi_attn_ext import AttnMaskType as _AttnMaskType

magi_attention/common/range.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from typing import Any, TypeAlias, Union
1616

17+
from . import is_cpp_backend_enable
18+
1719
NaiveRange: TypeAlias = tuple[int, int] | list[int]
1820

1921

@@ -183,8 +185,6 @@ def __repr__(self) -> str:
183185
return f"[{self._start}, {self._end})"
184186

185187

186-
from magi_attention import is_cpp_backend_enable # noqa: E402
187-
188188
if is_cpp_backend_enable():
189189
try:
190190
from magi_attention.magi_attn_ext import AttnRange as _AttnRange

magi_attention/common/ranges.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from magi_attention.utils import nvtx
2121

22+
from . import is_cpp_backend_enable
2223
from .range import AttnRange, NaiveRange, RangeError
2324

2425
NaiveRanges: TypeAlias = Sequence[NaiveRange]
@@ -782,8 +783,6 @@ def __repr__(self) -> str:
782783
return f"{self._ranges}"
783784

784785

785-
from magi_attention import is_cpp_backend_enable # noqa: E402
786-
787786
if is_cpp_backend_enable():
788787
try:
789788
from magi_attention.magi_attn_ext import AttnRanges as _AttnRanges

magi_attention/common/rectangle.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Any, Union
1616

17+
from . import is_cpp_backend_enable
1718
from .enum import AttnMaskType
1819
from .range import AttnRange
1920

@@ -510,8 +511,6 @@ def __repr__(self) -> str:
510511
return f"{self._q_range} x {self._k_range} x {self._d_range}"
511512

512513

513-
from magi_attention import is_cpp_backend_enable # noqa: E402
514-
515514
if is_cpp_backend_enable():
516515
try:
517516
from magi_attention.magi_attn_ext import AttnRectangle as _AttnRectangle

magi_attention/common/rectangles.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from typing import Any, Iterator, Sequence, TypeAlias, Union
1616

17+
from . import is_cpp_backend_enable
1718
from .enum import AttnMaskType
1819
from .range import AttnRange, NaiveRange
1920
from .ranges import AttnRanges
@@ -252,8 +253,6 @@ def __repr__(self) -> str:
252253
return f"{self._rects}"
253254

254255

255-
from magi_attention import is_cpp_backend_enable # noqa: E402
256-
257256
if is_cpp_backend_enable():
258257
try:
259258
from magi_attention.magi_attn_ext import AttnRectangles as _AttnRectangles

magi_attention/meta/_make_attn_meta.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def make_attn_meta_from_dispatch_meta(
8686
q_ranges=q_ranges,
8787
k_ranges=k_ranges,
8888
attn_mask_type=attn_mask_type,
89-
flatten_head_groups=magi_attention.is_flatten_head_groups_enable(),
9089
)
9190
# only for debug: visualize the buckets
9291
# if cp_group.rank() == 0:
@@ -107,7 +106,6 @@ def make_attn_meta_from_dispatch_meta(
107106
attn_mask_type=attn_mask_type,
108107
dispatch_meta_q=dispatch_meta_q,
109108
dispatch_meta_k=dispatch_meta_k,
110-
flatten_head_groups=magi_attention.is_flatten_head_groups_enable(),
111109
)
112110

113111
assert attn_solver.is_solved

magi_attention/meta/algorithms/binary_greedy_parallel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
import torch.distributed as dist
1919

20-
from magi_attention import is_cpp_backend_enable
21-
from magi_attention.common import AttnRange, AttnRanges, AttnRectangles
20+
from magi_attention.common import (
21+
AttnRange,
22+
AttnRanges,
23+
AttnRectangles,
24+
is_cpp_backend_enable,
25+
)
2226
from magi_attention.common.enum import DynamicAttnAlgType
2327

2428
from .base import DynamicAttnAlgorithm

magi_attention/meta/container/transfer_table.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import dataclass
1616
from typing import Any, Iterator
1717

18+
from magi_attention.common import is_cpp_backend_enable
1819
from magi_attention.common.range import AttnRange
1920
from magi_attention.common.ranges import AttnRanges
2021
from magi_attention.utils import nvtx
@@ -81,8 +82,6 @@ def __iter__(self) -> Iterator[AttnRangeWithRank]:
8182
return iter(self._ranges)
8283

8384

84-
from magi_attention import is_cpp_backend_enable # noqa: E402
85-
8685
if is_cpp_backend_enable():
8786
try:
8887
from magi_attention.magi_attn_ext import AttnRangeWithRank as _AttnRangeWithRank

0 commit comments

Comments
 (0)