|
41 | 41 | ExpertConfig, |
42 | 42 | LinearConfig, |
43 | 43 | ModelConfig, |
44 | | - MOEConfig, |
45 | 44 | RelativeAttentionTableConfig, |
46 | 45 | ) |
47 | 46 | from .model_config_utils import pad_weights |
@@ -99,17 +98,6 @@ def _split_model_config_for_tp(merged_config, split_factor): |
99 | 98 | for i, config in enumerate(configs): |
100 | 99 | config.weight = weights[i] |
101 | 100 |
|
102 | | - elif isinstance(merged_config, MOEConfig): |
103 | | - split_expert_configs = _split_model_config_for_tp( |
104 | | - merged_config.experts, |
105 | | - split_factor, |
106 | | - ) |
107 | | - # TP for rounter of MoE is skipped for better performance |
108 | | - # See https://github.com/NVIDIA/TensorRT-LLM/pull/1091 for details |
109 | | - for i in range(split_factor): |
110 | | - configs[i].experts = split_expert_configs[i] |
111 | | - configs[i].router = merged_config.router |
112 | | - |
113 | 101 | elif isinstance(merged_config, ExpertConfig): |
114 | 102 | assert merged_config.proj.linear_type != LINEAR_COLUMN # row |
115 | 103 | assert merged_config.fc.linear_type == LINEAR_COLUMN # column |
@@ -199,6 +187,10 @@ def _split_model_config_for_tp(merged_config, split_factor): |
199 | 187 | "Do not support group linear TP merge or split" |
200 | 188 | ) |
201 | 189 |
|
| 190 | + # Do not do anything if we don't need to process TP. |
| 191 | + if not merged_config.tp: |
| 192 | + return configs |
| 193 | + |
202 | 194 | split_axis = 0 if merged_config.linear_type == LINEAR_COLUMN else 1 |
203 | 195 | if merged_config.linear_type == LINEAR_COLUMN: |
204 | 196 | merged_config.weight = pad_weights(merged_config.weight, split_factor) |
@@ -342,6 +334,10 @@ def _merge_model_configs_to_first_tp(config, ranks: list[int], group=None): |
342 | 334 |
|
343 | 335 | assert config.linear_type != LINEAR_GROUP, "Do not support group linear TP merge or split" |
344 | 336 |
|
| 337 | + # No merge is needed if tp is disabled. |
| 338 | + if not config.tp: |
| 339 | + return |
| 340 | + |
345 | 341 | # Handling constants |
346 | 342 | for field_name in [ |
347 | 343 | "activation_scaling_factor", |
@@ -758,41 +754,48 @@ def check_weight_shape_valid(config, inference_tensor_parallel=1, training_tenso |
758 | 754 | This function is recurisve. |
759 | 755 | """ |
760 | 756 |
|
761 | | - def _check_merged_weight(merged_k): |
762 | | - assert merged_k % inference_tensor_parallel == 0, ( |
763 | | - f"Weights cannot be split into {inference_tensor_parallel} ranks." |
764 | | - ) |
| 757 | + def _check_merged_weight(merged_k, tp): |
| 758 | + assert merged_k % tp == 0, f"Weights with shape {merged_k} cannot be split into {tp} ranks." |
765 | 759 |
|
766 | | - def _check_merged_weight_scaling_factor(merged_k, awq_block_size): |
767 | | - if awq_block_size > 0 and (merged_k // inference_tensor_parallel) % awq_block_size != 0: |
| 760 | + def _check_merged_weight_scaling_factor(merged_k, tp, awq_block_size): |
| 761 | + if awq_block_size > 0 and (merged_k // tp) % awq_block_size != 0: |
768 | 762 | raise NotImplementedError( |
769 | | - "Weight shape is not divisible for block size for block quantization." |
| 763 | + f"Weight shape {merged_k} of each TP tp={tp} " |
| 764 | + f"is not divisible for block size {awq_block_size} for block quantization." |
770 | 765 | ) |
771 | 766 |
|
772 | | - def _check_merged_channel_is_valid(merged_k, awq_block_size): |
773 | | - _check_merged_weight(merged_k=merged_k) |
774 | | - _check_merged_weight_scaling_factor(merged_k=merged_k, awq_block_size=awq_block_size) |
| 767 | + def _check_merged_channel_is_valid(merged_k, tp, awq_block_size): |
| 768 | + _check_merged_weight(merged_k=merged_k, tp=tp) |
| 769 | + _check_merged_weight_scaling_factor(merged_k=merged_k, tp=tp, awq_block_size=awq_block_size) |
775 | 770 |
|
776 | 771 | if isinstance(config, LinearConfig): |
777 | 772 | # check weight shape |
| 773 | + if not config.tp: |
| 774 | + inference_tensor_parallel = 1 |
778 | 775 | if config.linear_type == LINEAR_COLUMN: |
779 | 776 | _, k = config.weight.shape |
780 | 777 | merged_k = k * training_tensor_parallel |
781 | | - _check_merged_channel_is_valid(merged_k, config.awq_block_size) |
| 778 | + _check_merged_channel_is_valid( |
| 779 | + merged_k, tp=inference_tensor_parallel, awq_block_size=config.awq_block_size |
| 780 | + ) |
782 | 781 | elif config.linear_type == LINEAR_ROW: |
783 | 782 | k, m = config.weight.shape |
784 | 783 | merged_k = k * training_tensor_parallel |
785 | 784 | merged_m = m * training_tensor_parallel |
786 | 785 | # For int4_awq, weight scaling factors will be split as (k, (merged_m // TP) // block_size) |
787 | | - _check_merged_weight(merged_k=merged_k) |
788 | | - _check_merged_weight_scaling_factor(merged_m, config.awq_block_size) |
| 786 | + _check_merged_weight(merged_k=merged_k, tp=inference_tensor_parallel) |
| 787 | + _check_merged_weight_scaling_factor( |
| 788 | + merged_m, tp=inference_tensor_parallel, awq_block_size=config.awq_block_size |
| 789 | + ) |
789 | 790 |
|
790 | 791 | return |
791 | 792 |
|
792 | 793 | if isinstance(config, ExpertConfig): |
793 | 794 | _, _, k = config.fc.weight.shape |
794 | 795 | merged_k = k * training_tensor_parallel |
795 | | - _check_merged_channel_is_valid(merged_k, config.fc.awq_block_size) |
| 796 | + _check_merged_channel_is_valid( |
| 797 | + merged_k, tp=inference_tensor_parallel, awq_block_size=config.fc.awq_block_size |
| 798 | + ) |
796 | 799 | return |
797 | 800 |
|
798 | 801 | if is_dataclass(config): |
|
0 commit comments