Skip to content

Commit 2522bf8

Browse files
authored
[INFER] Fix tune_cublaslt_int8_gemm.py and remove dist_config (#9520)
1 parent eae8d9f commit 2522bf8

File tree

2 files changed

+2
-26
lines changed

2 files changed

+2
-26
lines changed

csrc/utils/tune_cublaslt_int8_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
n1 = [6144, 4096, 28672, 4096]
2424

2525
# llama3.1-405b mp=8
26-
k2 = [16384, 16384, 16384, 6656]
26+
k2 = [16384, 2048, 16384, 6656]
2727
n2 = [2560, 16384, 13312, 16384]
2828

2929
# qwen2-1.5b
@@ -43,5 +43,5 @@
4343

4444
# shape 计算公式
4545
# [qkv, out_linear, ffn1, ffn2]
46-
# k = [hidden_size, hidden_size, hidden_size, intermediate_size//mp_size]
46+
# k = [hidden_size, hidden_size//mp_size, hidden_size, intermediate_size//mp_size]
4747
# n = [((num_attention_heads//mp_size)+2*(num_key_value_heads//mp_size))*(hidden_size//num_attention_heads), hidden_size, 2*(intermediate_size//mp_size), hidden_size]

llm/predict/predictor.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -673,18 +673,6 @@ def _create_predictor(self, predictor_args: PredictorArgument):
673673
config.enable_use_gpu(100, device_id)
674674
config.enable_new_executor()
675675

676-
if self.tensor_parallel_degree > 1:
677-
trainer_endpoints = fleet.worker_endpoints()
678-
current_endpoint = trainer_endpoints[self.tensor_parallel_rank]
679-
680-
dist_config = config.dist_config()
681-
dist_config.set_ranks(self.tensor_parallel_degree, self.tensor_parallel_rank)
682-
dist_config.set_endpoints(trainer_endpoints, current_endpoint)
683-
dist_config.enable_dist_model(True)
684-
685-
dist_config.set_comm_init_config(os.path.join(predictor_args.model_name_or_path, "rank_mapping.csv"))
686-
config.set_dist_config(dist_config)
687-
688676
predictor = paddle.inference.create_predictor(config)
689677
return predictor
690678

@@ -1178,18 +1166,6 @@ def _create_predictor(self, predictor_args: PredictorArgument):
11781166
pass_builder = config.pass_builder()
11791167
passes.addPasses(pass_builder, self.model_config.model_type, self.model_config.quant_type)
11801168

1181-
if self.tensor_parallel_degree > 1:
1182-
trainer_endpoints = fleet.worker_endpoints()
1183-
current_endpoint = trainer_endpoints[self.tensor_parallel_rank]
1184-
1185-
dist_config = config.dist_config()
1186-
dist_config.set_ranks(self.tensor_parallel_degree, self.tensor_parallel_rank)
1187-
dist_config.set_endpoints(trainer_endpoints, current_endpoint)
1188-
dist_config.enable_dist_model(True)
1189-
1190-
dist_config.set_comm_init_config(os.path.join(predictor_args.model_name_or_path, "rank_mapping.csv"))
1191-
config.set_dist_config(dist_config)
1192-
11931169
self.predictor = paddle.inference.create_predictor(config)
11941170

11951171
def predict(self, input_texts: list[str], return_tokens=False):

0 commit comments

Comments
 (0)