Skip to content

Commit 268933b

Browse files
authored
Refactor imports inside tensorrt_llm._torch. (NVIDIA#3015)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent e68749c commit 268933b

File tree

21 files changed

+171
-180
lines changed

21 files changed

+171
-180
lines changed

tensorrt_llm/_torch/attention_backend/star_flashinfer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1+
from dataclasses import dataclass, field
2+
from typing import Dict, Optional
3+
4+
import flashinfer
15
import numpy as np
26
import torch
37

8+
from tensorrt_llm.functional import AttentionMaskType
9+
from tensorrt_llm.models.modeling_utils import QuantConfig
10+
411
from ..distributed import allgather
512
from ..modules.linear import ParallelConfig
6-
from .flashinfer import *
13+
from .flashinfer import FlashInferAttentionMetadata, PlanParams
14+
from .interface import AttentionBackend, AttentionMask, PredefinedAttentionMask
715
from .vanilla import VanillaAttention
816

917

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55

66
import torch
77

8-
from tensorrt_llm._torch.attention_backend.interface import (
9-
AttentionBackend, AttentionMask, AttentionMetadata, MLAParams,
10-
PositionalEmbeddingParams, PredefinedAttentionMask, RopeParams)
11-
from tensorrt_llm._torch.attention_backend.vanilla import VanillaAttention
128
from tensorrt_llm.functional import (AttentionMaskType, RopeEmbeddingUtils,
139
RotaryScalingType)
1410
from tensorrt_llm.logger import logger
1511
from tensorrt_llm.models.modeling_utils import QuantConfig
1612

13+
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
14+
MLAParams, PositionalEmbeddingParams,
15+
PredefinedAttentionMask, RopeParams)
16+
from .vanilla import VanillaAttention
17+
1718

1819
# The type of requests in qkv passed to attention
1920
# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp

tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import torch
44

5-
import tensorrt_llm._torch.auto_deploy.distributed.common as dist
6-
import tensorrt_llm._torch.auto_deploy.distributed.trtllm as trtllm_dist
5+
from ..distributed import common as dist
6+
from ..distributed import trtllm as trtllm_dist
77

88

99
@torch.library.custom_op("dist::all_gather", mutates_args=(), device_types="cuda")

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.nn.functional as F
55

6-
from tensorrt_llm._torch.modules.fused_moe import FusedMoE # noqa: F401
6+
from ...modules.fused_moe import FusedMoE # noqa: F401
77

88

99
@torch.library.custom_op("moe::torch_moe", mutates_args=())

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,8 @@
44

55
# use trtllm distributed ops to improve TP performance if possible
66
try:
7-
from tensorrt_llm._torch.distributed import AllReduce, allgather
8-
from tensorrt_llm._torch.modules.linear import (
9-
AllReduceFusionOp,
10-
AllReduceParams,
11-
ParallelConfig,
12-
)
7+
from ...distributed import AllReduce, allgather
8+
from ...modules.linear import AllReduceFusionOp, AllReduceParams, ParallelConfig
139

1410
def trtllm_allgather(tensor, dim):
1511
rank, world_size = get_rank_world_size()

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
try:
18-
from tensorrt_llm._torch.quantization.utils import float4_sf_dtype
18+
from ...quantization.utils import float4_sf_dtype
1919
except ImportError:
2020
float4_sf_dtype = None
2121

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
register_replacement)
99

1010
import tensorrt_llm
11-
from tensorrt_llm._torch.distributed import AllReduceFusionOp
11+
12+
from ...distributed import AllReduceFusionOp
1213

1314
aten = torch.ops.aten
1415
from tensorrt_llm.mapping import Mapping

tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
register_replacement)
99

1010
import tensorrt_llm
11-
from tensorrt_llm._torch.distributed import AllReduceFusionOp, AllReduceStrategy
11+
12+
from ...distributed import AllReduceFusionOp, AllReduceStrategy
1213

1314
aten = torch.ops.aten
1415
from tensorrt_llm.mapping import Mapping

tensorrt_llm/_torch/models/modeling_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch import nn
55
from transformers import BertConfig
66

7-
from tensorrt_llm._torch.modules.linear import Linear
87
from tensorrt_llm.llmapi.utils import print_colored_debug
98
from tensorrt_llm.logger import logger
109

@@ -14,6 +13,7 @@
1413
from ..modules.attention import Attention
1514
from ..modules.decoder_layer import DecoderLayer
1615
from ..modules.embedding import Embedding
16+
from ..modules.linear import Linear
1717
from .modeling_utils import register_auto_model
1818

1919

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
from tqdm import tqdm
1010
from transformers import PretrainedConfig
1111

12-
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
13-
AllReduceParams, DeepseekAllReduce,
14-
ParallelConfig, allgather,
15-
reducescatter)
1612
from tensorrt_llm.functional import PositionEmbeddingType
13+
from tensorrt_llm.llmapi.utils import enable_llm_debug
1714

18-
from ...llmapi.utils import enable_llm_debug
1915
from ..attention_backend import AttentionMetadata
2016
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
17+
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
18+
DeepseekAllReduce, ParallelConfig, allgather,
19+
reducescatter)
2120
from ..model_config import ModelConfig
2221
from ..models.modeling_utils import MissingLayer, ModelConfig, support_pp
2322
from ..modules.attention import MLA
@@ -263,7 +262,7 @@ def __init__(self,
263262
dtype: Optional[torch.dtype] = None,
264263
tune_max_num_tokens: int = 8192,
265264
model_config: ModelConfig = ModelConfig()):
266-
from tensorrt_llm._torch.distributed import AllReduce
265+
from ..distributed import AllReduce
267266

268267
super().__init__()
269268
config = model_config.pretrained_config

0 commit comments

Comments
 (0)