@@ -139,7 +139,7 @@ class AvxConfig:
139
139
140
140
@dataclass
141
141
class SpeculateConfig :
142
- speculate_max_draft_token_num : int = ( 1 ,)
142
+ speculate_max_draft_token_num : int = 5
143
143
speculate_method : str = None
144
144
145
145
@@ -1690,7 +1690,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
1690
1690
self .quant_round_type = config .quant_round_type
1691
1691
self .quant_max_bound = config .quant_max_bound
1692
1692
self .quant_min_bound = config .quant_min_bound
1693
- # self.use_gemm_dequant = False
1693
+ self .use_gemm_dequant = False
1694
1694
1695
1695
self .qkv_out_scales = []
1696
1696
self .linear_out_scales = []
@@ -1928,7 +1928,6 @@ def compute_qkv_linear(self, ln_out, i):
1928
1928
if paddle .is_compiled_with_rocm ():
1929
1929
qkv_out = paddle .matmul (ln_out , self .qkv_weights [i ])
1930
1930
else :
1931
- # TODO: add gemm_dequant after qkv_out
1932
1931
qkv_out = paddle .matmul (ln_out , self .qkv_weights [i ], False , True )
1933
1932
return qkv_out
1934
1933
@@ -2033,13 +2032,13 @@ def compute_out_linear(self, fmha_out, i):
2033
2032
out_linear_out = paddle .matmul (fmha_out , self .linear_weights [i ])
2034
2033
out_linear_out = dequant_int8 (out_linear_out , self .linear_out_scales [i ], self ._dtype )
2035
2034
else :
2036
- try :
2035
+ if self . use_gemm_dequant :
2037
2036
from paddlenlp_ops import gemm_dequant
2038
2037
2039
2038
out_linear_out = gemm_dequant (
2040
2039
fmha_out , self .linear_weights [i ], self .linear_out_scales [i ], self ._dtype
2041
2040
)
2042
- except :
2041
+ else :
2043
2042
out_linear_out = paddle .matmul (fmha_out , self .linear_weights [i ], False , True )
2044
2043
out_linear_out = dequant_int8 (out_linear_out , self .linear_out_scales [i ], self ._dtype )
2045
2044
return out_linear_out
@@ -2094,11 +2093,11 @@ def compute_ffn2(self, ffn1_out, i):
2094
2093
ffn2_out = paddle .matmul (ffn1_out , self .ffn2_weights [i ])
2095
2094
ffn2_out = dequant_int8 (ffn2_out , self .ffn2_out_scales [i ], self ._dtype )
2096
2095
else :
2097
- try :
2096
+ if self . use_gemm_dequant :
2098
2097
from paddlenlp_ops import gemm_dequant
2099
2098
2100
2099
ffn2_out = gemm_dequant (ffn1_out , self .ffn2_weights [i ], self .ffn2_out_scales [i ], self ._dtype )
2101
- except :
2100
+ else :
2102
2101
ffn2_out = paddle .matmul (ffn1_out , self .ffn2_weights [i ], False , True )
2103
2102
ffn2_out = dequant_int8 (ffn2_out , self .ffn2_out_scales [i ], self ._dtype )
2104
2103
return ffn2_out
0 commit comments