Skip to content

Commit f158ba4

Browse files
address mapping clone
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent 958a0dd commit f158ba4

File tree

3 files changed

+12
-21
lines changed

3 files changed

+12
-21
lines changed

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch import nn
77

88
from tensorrt_llm.logger import logger
9-
from tensorrt_llm.mapping import Mapping
109

1110
from ..distributed import AllReduceParams
1211
from ..model_config import ModelConfig
@@ -42,19 +41,14 @@ def __init__(
4241
self.activation = activation
4342

4443
config = config or ModelConfig()
45-
self.mapping = config.mapping
4644
if overridden_tp_size is not None:
4745
assert config.mapping.tp_size % overridden_tp_size == 0
48-
tp_size = overridden_tp_size
4946
# "Misuse" pp_size here to perform all-reduce within smaller groups
5047
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
51-
mapping = Mapping(
52-
world_size=tp_size * pp_size,
53-
rank=self.mapping.rank,
54-
gpus_per_node=self.mapping.gpus_per_node,
55-
tp_size=tp_size,
56-
pp_size=pp_size,
57-
)
48+
mapping = config.mapping.clone()
49+
mapping.world_size = overridden_tp_size * pp_size
50+
mapping.tp_size = overridden_tp_size
51+
mapping.pp_size = pp_size
5852
else:
5953
mapping = config.mapping
6054

tensorrt_llm/_torch/modules/mlp.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import torch
55
from torch import nn
66

7-
from tensorrt_llm.mapping import Mapping
8-
97
from ..model_config import ModelConfig
108
from ..peft.lora.layer import LoraLayer, LoraModuleType
119
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
@@ -32,19 +30,14 @@ def __init__(self,
3230
self.activation = activation
3331

3432
config = config or ModelConfig()
35-
self.mapping = config.mapping
3633
if overridden_tp_size is not None:
3734
assert config.mapping.tp_size % overridden_tp_size == 0
38-
tp_size = overridden_tp_size
3935
# "Misuse" pp_size here to perform all-reduce within smaller groups
4036
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
41-
mapping = Mapping(
42-
world_size=tp_size * pp_size,
43-
rank=self.mapping.rank,
44-
gpus_per_node=self.mapping.gpus_per_node,
45-
tp_size=tp_size,
46-
pp_size=pp_size,
47-
)
37+
mapping = config.mapping.clone()
38+
mapping.world_size = overridden_tp_size * pp_size
39+
mapping.tp_size = overridden_tp_size
40+
mapping.pp_size = pp_size
4841
else:
4942
mapping = config.mapping
5043

tensorrt_llm/mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import copy
1516
from enum import IntEnum
1617
from typing import List
1718

@@ -238,6 +239,9 @@ def local_rank(self):
238239
def dp_size(self):
239240
return self.tp_size if self.enable_attention_dp else 1
240241

242+
def clone(self):
243+
return copy.deepcopy(self)
244+
241245
def has_cp_ulysses(self):
242246
return self.cp_size > 1 and self.cp_config.get(
243247
"cp_type") == CpType.ULYSSES

0 commit comments

Comments
 (0)