|
41 | 41 | ExpertConfig,
|
42 | 42 | LinearConfig,
|
43 | 43 | ModelConfig,
|
44 |
| - MOEConfig, |
45 | 44 | RelativeAttentionTableConfig,
|
46 | 45 | )
|
47 | 46 | from .model_config_utils import pad_weights
|
@@ -99,17 +98,6 @@ def _split_model_config_for_tp(merged_config, split_factor):
|
99 | 98 | for i, config in enumerate(configs):
|
100 | 99 | config.weight = weights[i]
|
101 | 100 |
|
102 |
| - elif isinstance(merged_config, MOEConfig): |
103 |
| - split_expert_configs = _split_model_config_for_tp( |
104 |
| - merged_config.experts, |
105 |
| - split_factor, |
106 |
| - ) |
107 |
| - # TP for rounter of MoE is skipped for better performance |
108 |
| - # See https://github.com/NVIDIA/TensorRT-LLM/pull/1091 for details |
109 |
| - for i in range(split_factor): |
110 |
| - configs[i].experts = split_expert_configs[i] |
111 |
| - configs[i].router = merged_config.router |
112 |
| - |
113 | 101 | elif isinstance(merged_config, ExpertConfig):
|
114 | 102 | assert merged_config.proj.linear_type != LINEAR_COLUMN # row
|
115 | 103 | assert merged_config.fc.linear_type == LINEAR_COLUMN # column
|
@@ -199,6 +187,10 @@ def _split_model_config_for_tp(merged_config, split_factor):
|
199 | 187 | "Do not support group linear TP merge or split"
|
200 | 188 | )
|
201 | 189 |
|
| 190 | + # Do not do anything if we don't need to process TP. |
| 191 | + if not merged_config.tp: |
| 192 | + return configs |
| 193 | + |
202 | 194 | split_axis = 0 if merged_config.linear_type == LINEAR_COLUMN else 1
|
203 | 195 | if merged_config.linear_type == LINEAR_COLUMN:
|
204 | 196 | merged_config.weight = pad_weights(merged_config.weight, split_factor)
|
@@ -342,6 +334,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None):
|
342 | 334 |
|
343 | 335 | assert config.linear_type != LINEAR_GROUP, "Do not support group linear TP merge or split"
|
344 | 336 |
|
| 337 | + # No merge is needed if tp is disabled. |
| 338 | + if not config.tp: |
| 339 | + return |
| 340 | + |
345 | 341 | # Handling constants
|
346 | 342 | for field_name in [
|
347 | 343 | "activation_scaling_factor",
|
@@ -758,41 +754,48 @@ def check_weight_shape_valid(config, inference_tensor_parallel=1, training_tenso
|
758 | 754 | This function is recurisve.
|
759 | 755 | """
|
760 | 756 |
|
761 |
| - def _check_merged_weight(merged_k): |
762 |
| - assert merged_k % inference_tensor_parallel == 0, ( |
763 |
| - f"Weights cannot be split into {inference_tensor_parallel} ranks." |
764 |
| - ) |
| 757 | + def _check_merged_weight(merged_k, tp): |
| 758 | + assert merged_k % tp == 0, f"Weights with shape {merged_k} cannot be split into {tp} ranks." |
765 | 759 |
|
766 |
| - def _check_merged_weight_scaling_factor(merged_k, awq_block_size): |
767 |
| - if awq_block_size > 0 and (merged_k // inference_tensor_parallel) % awq_block_size != 0: |
| 760 | + def _check_merged_weight_scaling_factor(merged_k, tp, awq_block_size): |
| 761 | + if awq_block_size > 0 and (merged_k // tp) % awq_block_size != 0: |
768 | 762 | raise NotImplementedError(
|
769 |
| - "Weight shape is not divisible for block size for block quantization." |
| 763 | + f"Weight shape {merged_k} of each TP tp={tp} " |
| 764 | + f"is not divisible for block size {awq_block_size} for block quantization." |
770 | 765 | )
|
771 | 766 |
|
772 |
| - def _check_merged_channel_is_valid(merged_k, awq_block_size): |
773 |
| - _check_merged_weight(merged_k=merged_k) |
774 |
| - _check_merged_weight_scaling_factor(merged_k=merged_k, awq_block_size=awq_block_size) |
| 767 | + def _check_merged_channel_is_valid(merged_k, tp, awq_block_size): |
| 768 | + _check_merged_weight(merged_k=merged_k, tp=tp) |
| 769 | + _check_merged_weight_scaling_factor(merged_k=merged_k, tp=tp, awq_block_size=awq_block_size) |
775 | 770 |
|
776 | 771 | if isinstance(config, LinearConfig):
|
777 | 772 | # check weight shape
|
| 773 | + if not config.tp: |
| 774 | + inference_tensor_parallel = 1 |
778 | 775 | if config.linear_type == LINEAR_COLUMN:
|
779 | 776 | _, k = config.weight.shape
|
780 | 777 | merged_k = k * training_tensor_parallel
|
781 |
| - _check_merged_channel_is_valid(merged_k, config.awq_block_size) |
| 778 | + _check_merged_channel_is_valid( |
| 779 | + merged_k, tp=inference_tensor_parallel, awq_block_size=config.awq_block_size |
| 780 | + ) |
782 | 781 | elif config.linear_type == LINEAR_ROW:
|
783 | 782 | k, m = config.weight.shape
|
784 | 783 | merged_k = k * training_tensor_parallel
|
785 | 784 | merged_m = m * training_tensor_parallel
|
786 | 785 | # For int4_awq, weight scaling factors will be split as (k, (merged_m // TP) // block_size)
|
787 |
| - _check_merged_weight(merged_k=merged_k) |
788 |
| - _check_merged_weight_scaling_factor(merged_m, config.awq_block_size) |
| 786 | + _check_merged_weight(merged_k=merged_k, tp=inference_tensor_parallel) |
| 787 | + _check_merged_weight_scaling_factor( |
| 788 | + merged_m, tp=inference_tensor_parallel, awq_block_size=config.awq_block_size |
| 789 | + ) |
789 | 790 |
|
790 | 791 | return
|
791 | 792 |
|
792 | 793 | if isinstance(config, ExpertConfig):
|
793 | 794 | _, _, k = config.fc.weight.shape
|
794 | 795 | merged_k = k * training_tensor_parallel
|
795 |
| - _check_merged_channel_is_valid(merged_k, config.fc.awq_block_size) |
| 796 | + _check_merged_channel_is_valid( |
| 797 | + merged_k, tp=inference_tensor_parallel, awq_block_size=config.fc.awq_block_size |
| 798 | + ) |
796 | 799 | return
|
797 | 800 |
|
798 | 801 | if is_dataclass(config):
|
|
0 commit comments