File tree Expand file tree Collapse file tree 3 files changed +12
-21
lines changed
Expand file tree Collapse file tree 3 files changed +12
-21
lines changed Original file line number Diff line number Diff line change 66from torch import nn
77
88from tensorrt_llm .logger import logger
9- from tensorrt_llm .mapping import Mapping
109
1110from ..distributed import AllReduceParams
1211from ..model_config import ModelConfig
@@ -42,19 +41,14 @@ def __init__(
4241 self .activation = activation
4342
4443 config = config or ModelConfig ()
45- self .mapping = config .mapping
4644 if overridden_tp_size is not None :
4745 assert config .mapping .tp_size % overridden_tp_size == 0
48- tp_size = overridden_tp_size
4946 # "Misuse" pp_size here to perform all-reduce within smaller groups
5047 pp_size = config .mapping .pp_size * config .mapping .tp_size // overridden_tp_size
51- mapping = Mapping (
52- world_size = tp_size * pp_size ,
53- rank = self .mapping .rank ,
54- gpus_per_node = self .mapping .gpus_per_node ,
55- tp_size = tp_size ,
56- pp_size = pp_size ,
57- )
48+ mapping = config .mapping .clone ()
49+ mapping .world_size = overridden_tp_size * pp_size
50+ mapping .tp_size = overridden_tp_size
51+ mapping .pp_size = pp_size
5852 else :
5953 mapping = config .mapping
6054
Original file line number Diff line number Diff line change 44import torch
55from torch import nn
66
7- from tensorrt_llm .mapping import Mapping
8-
97from ..model_config import ModelConfig
108from ..peft .lora .layer import LoraLayer , LoraModuleType
119from .linear import Linear , TensorParallelMode , WeightMode , WeightsLoadingConfig
@@ -32,19 +30,14 @@ def __init__(self,
3230 self .activation = activation
3331
3432 config = config or ModelConfig ()
35- self .mapping = config .mapping
3633 if overridden_tp_size is not None :
3734 assert config .mapping .tp_size % overridden_tp_size == 0
38- tp_size = overridden_tp_size
3935 # "Misuse" pp_size here to perform all-reduce within smaller groups
4036 pp_size = config .mapping .pp_size * config .mapping .tp_size // overridden_tp_size
41- mapping = Mapping (
42- world_size = tp_size * pp_size ,
43- rank = self .mapping .rank ,
44- gpus_per_node = self .mapping .gpus_per_node ,
45- tp_size = tp_size ,
46- pp_size = pp_size ,
47- )
37+ mapping = config .mapping .clone ()
38+ mapping .world_size = overridden_tp_size * pp_size
39+ mapping .tp_size = overridden_tp_size
40+ mapping .pp_size = pp_size
4841 else :
4942 mapping = config .mapping
5043
Original file line number Diff line number Diff line change 1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ import copy
1516from enum import IntEnum
1617from typing import List
1718
@@ -238,6 +239,9 @@ def local_rank(self):
238239 def dp_size (self ):
239240 return self .tp_size if self .enable_attention_dp else 1
240241
242+ def clone (self ):
243+ return copy .deepcopy (self )
244+
241245 def has_cp_ulysses (self ):
242246 return self .cp_size > 1 and self .cp_config .get (
243247 "cp_type" ) == CpType .ULYSSES
You can’t perform that action at this time.
0 commit comments