Skip to content

Commit 1a4c015

Browse files
authored
[Misc] Lazy import ep in mooncake_ep_buffer.py (#1014)
* Lazy import `ep` in `mooncake_ep_buffer.py` * Fix
1 parent 7c22adb commit 1a4c015

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

mooncake-wheel/mooncake/mooncake_ep_buffer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
import torch.distributed as dist
33
from typing import Any, Callable, List, Tuple, Optional, Union
44

5-
# noinspection PyUnresolvedReferences
6-
from mooncake import ep
7-
85

96
class EventOverlap:
107
"""
@@ -15,7 +12,7 @@ class EventOverlap:
1512
extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph.
1613
"""
1714

18-
def __init__(self, event: Optional[ep.EventHandle] = None,
15+
def __init__(self, event: Optional["ep.EventHandle"] = None,
1916
extra_tensors: Optional[Tuple[torch.Tensor, ...]] = None) -> None:
2017
"""
2118
Initialize the class.
@@ -63,6 +60,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
6360

6461
class Buffer:
6562
def __init__(self, group: dist.ProcessGroup, num_ep_buffer_bytes: int = 0):
63+
from mooncake import ep
6664
# Initialize the CPP runtime
6765
self.rank = group.rank()
6866
self.group_size = group.size()
@@ -120,6 +118,7 @@ def __init__(self, group: dist.ProcessGroup, num_ep_buffer_bytes: int = 0):
120118

121119
@staticmethod
122120
def get_ep_buffer_size_hint(num_max_dispatch_tokens_per_rank: int, hidden: int, num_ranks: int, num_experts: int) -> int:
121+
from mooncake import ep
123122
return ep.get_ep_buffer_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
124123

125124
# noinspection PyTypeChecker

0 commit comments

Comments
 (0)