Skip to content

Commit f5ca96e

Browse files
authored
[LLM INFER] not use gemm_dequant default and fix bug (#9498)
* not use gemm_dequant default and fix bug
1 parent 4b02477 commit f5ca96e

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class AvxConfig:
139139

140140
@dataclass
141141
class SpeculateConfig:
142-
speculate_max_draft_token_num: int = (1,)
142+
speculate_max_draft_token_num: int = 5
143143
speculate_method: str = None
144144

145145

@@ -1690,7 +1690,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
16901690
self.quant_round_type = config.quant_round_type
16911691
self.quant_max_bound = config.quant_max_bound
16921692
self.quant_min_bound = config.quant_min_bound
1693-
# self.use_gemm_dequant = False
1693+
self.use_gemm_dequant = False
16941694

16951695
self.qkv_out_scales = []
16961696
self.linear_out_scales = []
@@ -1928,7 +1928,6 @@ def compute_qkv_linear(self, ln_out, i):
19281928
if paddle.is_compiled_with_rocm():
19291929
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i])
19301930
else:
1931-
# TODO: add gemm_dequant after qkv_out
19321931
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i], False, True)
19331932
return qkv_out
19341933

@@ -2033,13 +2032,13 @@ def compute_out_linear(self, fmha_out, i):
20332032
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i])
20342033
out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype)
20352034
else:
2036-
try:
2035+
if self.use_gemm_dequant:
20372036
from paddlenlp_ops import gemm_dequant
20382037

20392038
out_linear_out = gemm_dequant(
20402039
fmha_out, self.linear_weights[i], self.linear_out_scales[i], self._dtype
20412040
)
2042-
except:
2041+
else:
20432042
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i], False, True)
20442043
out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype)
20452044
return out_linear_out
@@ -2094,11 +2093,11 @@ def compute_ffn2(self, ffn1_out, i):
20942093
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i])
20952094
ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype)
20962095
else:
2097-
try:
2096+
if self.use_gemm_dequant:
20982097
from paddlenlp_ops import gemm_dequant
20992098

21002099
ffn2_out = gemm_dequant(ffn1_out, self.ffn2_weights[i], self.ffn2_out_scales[i], self._dtype)
2101-
except:
2100+
else:
21022101
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i], False, True)
21032102
ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype)
21042103
return ffn2_out

paddlenlp/experimental/transformers/qwen2_moe/modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,10 @@ def set_state_dict(self, state_dict):
292292
self.embed_tokens.weight.set_value(embed_tokens_weight)
293293
self.norm.weight.set_value(norm_weight)
294294

295+
if self.use_weight_only:
296+
logger.info("weight only is enabled")
295297
for idx in range(self.num_layers):
298+
logger.info(f"set state for layer {idx}")
296299
unfused_state_dict = {}
297300
ln_scale = paddle.to_tensor(state_dict["qwen2_moe.layers.{}.input_layernorm.weight".format(idx)]).cast(
298301
self.transformer_block.ln_scales[idx].dtype

0 commit comments

Comments
 (0)