|
9 | 9 | from tqdm import tqdm |
10 | 10 | from transformers import PretrainedConfig |
11 | 11 |
|
12 | | -from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, |
13 | | - AllReduceParams, DeepseekAllReduce, |
14 | | - ParallelConfig, allgather, |
15 | | - reducescatter) |
16 | 12 | from tensorrt_llm.functional import PositionEmbeddingType |
| 13 | +from tensorrt_llm.llmapi.utils import enable_llm_debug |
17 | 14 |
|
18 | | -from ...llmapi.utils import enable_llm_debug |
19 | 15 | from ..attention_backend import AttentionMetadata |
20 | 16 | from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams |
| 17 | +from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, |
| 18 | + DeepseekAllReduce, ParallelConfig, allgather, |
| 19 | + reducescatter) |
21 | 20 | from ..model_config import ModelConfig |
22 | 21 | from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp |
23 | 22 | from ..modules.attention import MLA |
@@ -263,7 +262,7 @@ def __init__(self, |
263 | 262 | dtype: Optional[torch.dtype] = None, |
264 | 263 | tune_max_num_tokens: int = 8192, |
265 | 264 | model_config: ModelConfig = ModelConfig()): |
266 | | - from tensorrt_llm._torch.distributed import AllReduce |
| 265 | + from ..distributed import AllReduce |
267 | 266 |
|
268 | 267 | super().__init__() |
269 | 268 | config = model_config.pretrained_config |
|
0 commit comments