@@ -132,9 +132,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
132132 const MetaTensor& rope_emb,
133133 const MetaTensor& mask,
134134 const MetaTensor& tgt_mask,
135+ const MetaTensor& cache_k_quant_scales,
136+ const MetaTensor& cache_v_quant_scales,
137+ const MetaTensor& cache_k_dequant_scales,
138+ const MetaTensor& cache_v_dequant_scales,
139+ const MetaTensor& qkv_out_scale,
140+ const MetaTensor& qkv_bias,
141+ const MetaTensor& out_shift,
142+ const MetaTensor& out_smooth,
135143 int max_seq_len,
136144 int block_size,
137145 bool use_neox_style,
146+ bool dynamic_cachekv_quant,
147+ const int quant_round_type,
148+ const float quant_max_bound,
149+ const float quant_min_bound,
150+ const float out_scale,
151+ const std::string& compute_dtype,
138152 MetaTensor* fmha_out,
139153 MetaTensor* qkv_out,
140154 MetaTensor* key_cache_out,
@@ -159,13 +173,74 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
159173 " The input_dims[1] must be equal to 3 * num_head * dim_head" ));
160174
161175 fmha_out->set_dims ({input_dims[0 ], num_head * dim_head});
162- fmha_out->set_dtype (qkv.dtype ());
163176 qkv_out->set_dims (qkv.dims ());
164- qkv_out->set_dtype (qkv.dtype ());
165177 key_cache_out->set_dims (key_cache_dims);
166178 key_cache_out->set_dtype (key_cache.dtype ());
167179 value_cache_out->set_dims (key_cache_dims);
168180 value_cache_out->set_dtype (value_cache.dtype ());
181+
182+ auto FBADtypeCheck = [](const MetaTensor& check_tensor,
183+ const std::string& tensor_name,
184+ const std::string& compute_dtype) {
185+ if (compute_dtype == " bf16" ) {
186+ PADDLE_ENFORCE_EQ (
187+ check_tensor.dtype (),
188+ phi::DataType::BFLOAT16,
189+ phi::errors::InvalidArgument (
190+ " Input(%s) dtype must be the same with Attr(compute_dtype)" ,
191+ tensor_name));
192+ } else if (compute_dtype == " fp16" ) {
193+ PADDLE_ENFORCE_EQ (
194+ check_tensor.dtype (),
195+ phi::DataType::FLOAT16,
196+ phi::errors::InvalidArgument (
197+ " Input(%s) dtype must be the same with Attr(compute_dtype)" ,
198+ tensor_name));
199+ } else if (compute_dtype == " fp32" ) {
200+ PADDLE_ENFORCE_EQ (
201+ check_tensor.dtype (),
202+ phi::DataType::FLOAT32,
203+ phi::errors::InvalidArgument (
204+ " Input(%s) dtype must be the same with Attr(compute_dtype)" ,
205+ tensor_name));
206+ }
207+ };
208+
209+ // In the case of quantization enabled, the dtype for computation is
210+ // determined based on compute_dtype.
211+ if (qkv.dtype () == phi::DataType::INT32) {
212+ PADDLE_ENFORCE_NE (
213+ compute_dtype,
214+ " default" ,
215+ phi::errors::InvalidArgument (
216+ " If Input(x) dtype is INT32, Attr(compute_dtype) must be set." ));
217+ if (out_scale > 0 ) {
218+ fmha_out->set_dtype (phi::DataType::INT8);
219+ } else {
220+ if (compute_dtype == " bf16" ) {
221+ fmha_out->set_dtype (phi::DataType::BFLOAT16);
222+ } else if (compute_dtype == " fp16" ) {
223+ fmha_out->set_dtype (phi::DataType::FLOAT16);
224+ } else if (compute_dtype == " fp32" ) {
225+ fmha_out->set_dtype (phi::DataType::FLOAT32);
226+ } else {
227+ PADDLE_THROW (phi::errors::InvalidArgument (
228+ " In the case of quantization enabled with Input(x) INT32, "
229+ " Attr(compute_dtype) must be set in (bf16, fp16, fp32), "
230+ " but get compute_dtype (%s)" ,
231+ compute_dtype));
232+ }
233+ }
234+ } else {
235+ if (compute_dtype != " default" ) {
236+ FBADtypeCheck (qkv, " qkv" , compute_dtype);
237+ }
238+ if (out_scale > 0 ) {
239+ fmha_out->set_dtype (phi::DataType::INT8);
240+ } else {
241+ fmha_out->set_dtype (qkv.dtype ());
242+ }
243+ }
169244}
170245
171246void Conv1dXPUInferMeta (const MetaTensor& x,
0 commit comments