File tree Expand file tree Collapse file tree 3 files changed +21
-12
lines changed
Expand file tree Collapse file tree 3 files changed +21
-12
lines changed Original file line number Diff line number Diff line change 66from torch import nn
77
88from tensorrt_llm .logger import logger
9+ from tensorrt_llm .mapping import Mapping
910
1011from ..distributed import AllReduceParams
1112from ..model_config import ModelConfig
@@ -41,14 +42,19 @@ def __init__(
4142 self .activation = activation
4243
4344 config = config or ModelConfig ()
45+ self .mapping = config .mapping
4446 if overridden_tp_size is not None :
4547 assert config .mapping .tp_size % overridden_tp_size == 0
48+ tp_size = overridden_tp_size
4649 # "Misuse" pp_size here to perform all-reduce within smaller groups
4750 pp_size = config .mapping .pp_size * config .mapping .tp_size // overridden_tp_size
48- mapping = config .mapping .clone ()
49- mapping .world_size = overridden_tp_size * pp_size
50- mapping .tp_size = overridden_tp_size
51- mapping .pp_size = pp_size
51+ mapping = Mapping (
52+ world_size = tp_size * pp_size ,
53+ rank = self .mapping .rank ,
54+ gpus_per_node = self .mapping .gpus_per_node ,
55+ tp_size = tp_size ,
56+ pp_size = pp_size ,
57+ )
5258 else :
5359 mapping = config .mapping
5460
Original file line number Diff line number Diff line change 44import torch
55from torch import nn
66
7+ from tensorrt_llm .mapping import Mapping
8+
79from ..model_config import ModelConfig
810from ..peft .lora .layer import LoraLayer , LoraModuleType
911from .linear import Linear , TensorParallelMode , WeightMode , WeightsLoadingConfig
@@ -30,14 +32,19 @@ def __init__(self,
3032 self .activation = activation
3133
3234 config = config or ModelConfig ()
35+ self .mapping = config .mapping
3336 if overridden_tp_size is not None :
3437 assert config .mapping .tp_size % overridden_tp_size == 0
38+ tp_size = overridden_tp_size
3539 # "Misuse" pp_size here to perform all-reduce within smaller groups
3640 pp_size = config .mapping .pp_size * config .mapping .tp_size // overridden_tp_size
37- mapping = config .mapping .clone ()
38- mapping .world_size = overridden_tp_size * pp_size
39- mapping .tp_size = overridden_tp_size
40- mapping .pp_size = pp_size
41+ mapping = Mapping (
42+ world_size = tp_size * pp_size ,
43+ rank = self .mapping .rank ,
44+ gpus_per_node = self .mapping .gpus_per_node ,
45+ tp_size = tp_size ,
46+ pp_size = pp_size ,
47+ )
4148 else :
4249 mapping = config .mapping
4350
Original file line number Diff line number Diff line change 1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- import copy
1615from enum import IntEnum
1716from typing import List
1817
@@ -239,9 +238,6 @@ def local_rank(self):
239238 def dp_size (self ):
240239 return self .tp_size if self .enable_attention_dp else 1
241240
242- def clone (self ):
243- return copy .deepcopy (self )
244-
245241 def has_cp_ulysses (self ):
246242 return self .cp_size > 1 and self .cp_config .get (
247243 "cp_type" ) == CpType .ULYSSES
You can’t perform that action at this time.
0 commit comments