Skip to content

Commit a56eba6

Browse files
rollback to mapping class
Signed-off-by: yechank <[email protected]>
1 parent f158ba4 commit a56eba6

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import nn
77

88
from tensorrt_llm.logger import logger
9+
from tensorrt_llm.mapping import Mapping
910

1011
from ..distributed import AllReduceParams
1112
from ..model_config import ModelConfig
@@ -41,14 +42,19 @@ def __init__(
4142
self.activation = activation
4243

4344
config = config or ModelConfig()
45+
self.mapping = config.mapping
4446
if overridden_tp_size is not None:
4547
assert config.mapping.tp_size % overridden_tp_size == 0
48+
tp_size = overridden_tp_size
4649
# "Misuse" pp_size here to perform all-reduce within smaller groups
4750
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
48-
mapping = config.mapping.clone()
49-
mapping.world_size = overridden_tp_size * pp_size
50-
mapping.tp_size = overridden_tp_size
51-
mapping.pp_size = pp_size
51+
mapping = Mapping(
52+
world_size=tp_size * pp_size,
53+
rank=self.mapping.rank,
54+
gpus_per_node=self.mapping.gpus_per_node,
55+
tp_size=tp_size,
56+
pp_size=pp_size,
57+
)
5258
else:
5359
mapping = config.mapping
5460

tensorrt_llm/_torch/modules/mlp.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch import nn
66

7+
from tensorrt_llm.mapping import Mapping
8+
79
from ..model_config import ModelConfig
810
from ..peft.lora.layer import LoraLayer, LoraModuleType
911
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
@@ -30,14 +32,19 @@ def __init__(self,
3032
self.activation = activation
3133

3234
config = config or ModelConfig()
35+
self.mapping = config.mapping
3336
if overridden_tp_size is not None:
3437
assert config.mapping.tp_size % overridden_tp_size == 0
38+
tp_size = overridden_tp_size
3539
# "Misuse" pp_size here to perform all-reduce within smaller groups
3640
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
37-
mapping = config.mapping.clone()
38-
mapping.world_size = overridden_tp_size * pp_size
39-
mapping.tp_size = overridden_tp_size
40-
mapping.pp_size = pp_size
41+
mapping = Mapping(
42+
world_size=tp_size * pp_size,
43+
rank=self.mapping.rank,
44+
gpus_per_node=self.mapping.gpus_per_node,
45+
tp_size=tp_size,
46+
pp_size=pp_size,
47+
)
4148
else:
4249
mapping = config.mapping
4350

tensorrt_llm/mapping.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import copy
1615
from enum import IntEnum
1716
from typing import List
1817

@@ -239,9 +238,6 @@ def local_rank(self):
239238
def dp_size(self):
240239
return self.tp_size if self.enable_attention_dp else 1
241240

242-
def clone(self):
243-
return copy.deepcopy(self)
244-
245241
def has_cp_ulysses(self):
246242
return self.cp_size > 1 and self.cp_config.get(
247243
"cp_type") == CpType.ULYSSES

0 commit comments

Comments
 (0)