Skip to content

Commit cee8578

Browse files
authored
[xpu] support llama2 7b in xpu (#65036)
* [xpu]support block_multi_head_attention_xpu op (#64637) * [xpu] fix block_multi_head_attention_kernel (#64926)
1 parent 09fc618 commit cee8578

File tree

14 files changed

+1627
-6
lines changed

14 files changed

+1627
-6
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ if(WITH_XPU)
334334
DEPS ${XPU_PASS_DEPS})
335335
pass_library(weight_only_linear_xpu_pass inference DIR xpu DEPS
336336
${XPU_PASS_DEPS})
337+
pass_library(block_multihead_attention_xpu_pass inference DIR xpu DEPS
338+
${XPU_PASS_DEPS})
337339
endif()
338340

339341
cc_library(
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
17+
#include "glog/logging.h"
18+
19+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
#include "paddle/fluid/framework/ir/pass.h"
22+
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
23+
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
24+
#include "paddle/fluid/framework/op_version_registry.h"
25+
#include "paddle/fluid/platform/enforce.h"
26+
27+
namespace phi {
28+
class DenseTensor;
29+
} // namespace phi
30+
31+
namespace paddle {
32+
namespace framework {
33+
class Scope;
34+
} // namespace framework
35+
} // namespace paddle
36+
37+
namespace paddle {
38+
namespace framework {
39+
namespace ir {
40+
41+
class BlockMultiHeadAttentionXPUPass : public FusePassBase {
42+
protected:
43+
void ApplyImpl(ir::Graph* graph) const override;
44+
45+
private:
46+
void InplaceBlockMultiHeadAttentionXPU(ir::Graph* graph) const;
47+
48+
const std::string name_scope_{"block_multihead_attention_xpu_pass"};
49+
};
50+
51+
void BlockMultiHeadAttentionXPUPass::ApplyImpl(ir::Graph* graph) const {
52+
PADDLE_ENFORCE_NOT_NULL(
53+
graph, platform::errors::PreconditionNotMet("graph should not be null."));
54+
Init(name_scope_, graph);
55+
56+
InplaceBlockMultiHeadAttentionXPU(graph);
57+
}
58+
59+
void BlockMultiHeadAttentionXPUPass::InplaceBlockMultiHeadAttentionXPU(
60+
ir::Graph* graph) const {
61+
const int64_t max_batch_size = 10;
62+
auto* scope = param_scope();
63+
for (auto* node : graph->Nodes()) {
64+
if (node->IsOp() && node->Op()->Type() == "block_multihead_attention") {
65+
auto* op_desc = node->Op();
66+
op_desc->SetType("block_multihead_attention_xpu");
67+
phi::DenseTensor cache_k_per_batch_maxs;
68+
auto base_name = op_desc->Input("qkv")[0];
69+
int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1);
70+
std::string cache_k_per_batch_maxs_name = base_name + "_max_cache_k";
71+
VarDesc cache_k_per_batch_maxs_desc(cache_k_per_batch_maxs_name);
72+
cache_k_per_batch_maxs_desc.SetPersistable(true);
73+
cache_k_per_batch_maxs_desc.SetShape(
74+
{max_batch_size, static_cast<int64_t>(max_ptr_size)});
75+
cache_k_per_batch_maxs_desc.SetDataType(
76+
proto::VarType::Type::VarType_Type_FP32);
77+
Node* cache_k_per_batch_maxs_in =
78+
graph->CreateVarNode(&cache_k_per_batch_maxs_desc);
79+
phi::DenseTensor cpu_tensor;
80+
auto* cpu_ctx = static_cast<phi::CPUContext*>(
81+
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
82+
cpu_tensor.set_type(phi::DataType::FLOAT32);
83+
cpu_tensor.Resize({max_batch_size, max_ptr_size});
84+
std::vector<float> tmp(max_batch_size * max_ptr_size, 0);
85+
memcpy(cpu_ctx->Alloc<float>(&cpu_tensor),
86+
tmp.data(),
87+
max_batch_size * max_ptr_size * sizeof(float));
88+
Assign(cpu_tensor,
89+
scope->Var(cache_k_per_batch_maxs_name)
90+
->GetMutable<phi::DenseTensor>());
91+
op_desc->SetInput("cache_k_per_batch_maxs",
92+
{cache_k_per_batch_maxs_name});
93+
94+
std::string cache_v_per_batch_maxs_name = base_name + "_max_cache_v";
95+
VarDesc cache_v_per_batch_maxs_desc(cache_v_per_batch_maxs_name);
96+
cache_v_per_batch_maxs_desc.SetPersistable(true);
97+
cache_v_per_batch_maxs_desc.SetShape(
98+
{max_batch_size, static_cast<int64_t>(max_ptr_size)});
99+
cache_v_per_batch_maxs_desc.SetDataType(
100+
proto::VarType::Type::VarType_Type_FP32);
101+
Node* cache_v_per_batch_maxs_in =
102+
graph->CreateVarNode(&cache_v_per_batch_maxs_desc);
103+
Assign(cpu_tensor,
104+
scope->Var(cache_v_per_batch_maxs_name)
105+
->GetMutable<phi::DenseTensor>());
106+
op_desc->SetInput("cache_v_per_batch_maxs",
107+
{cache_v_per_batch_maxs_name});
108+
109+
IR_NODE_LINK_TO(cache_k_per_batch_maxs_in, node);
110+
IR_NODE_LINK_TO(cache_v_per_batch_maxs_in, node);
111+
}
112+
}
113+
}
114+
115+
} // namespace ir
116+
} // namespace framework
117+
} // namespace paddle
118+
119+
REGISTER_PASS(block_multihead_attention_xpu_pass,
120+
paddle::framework::ir::BlockMultiHeadAttentionXPUPass);
121+
122+
REGISTER_PASS_CAPABILITY(block_multihead_attention_xpu_pass)
123+
.AddCombination(
124+
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
125+
"block_multihead_attention_xpu", 0));

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
538538
"group_norm_silu_xpu_fuse_pass",
539539
"embedding_with_eltwise_add_xpu_fuse_pass",
540540
"qk_qkv_attention_xpu_fuse_pass",
541+
"block_multihead_attention_xpu_pass",
541542
"multi_encoder_xpu_fuse_pass",
542543
"multi_encoder_xpu_adaptive_seqlen_fuse_pass",
543544
"multi_encoder_xpu_slice_fuse_pass",

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,8 @@ XPUOpMap& get_kl2_ops() {
10551055
phi::DataType::INT64,
10561056
phi::DataType::BOOL,
10571057
phi::DataType::FLOAT64,
1058-
phi::DataType::FLOAT32})},
1058+
phi::DataType::FLOAT32,
1059+
phi::DataType::FLOAT16})},
10591060
{"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})},
10601061
{"transpose2_grad",
10611062
XPUKernelSet({phi::DataType::FLOAT32,
@@ -1248,6 +1249,7 @@ XPUOpMap& get_kl2_ops() {
12481249
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
12491250
{"sequence_unpad_xpu",
12501251
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
1252+
{"block_multihead_attention_xpu", XPUKernelSet({phi::DataType::FLOAT16})},
12511253
};
12521254

12531255
return s_xpu2_kernels;

paddle/phi/infermeta/fusion.cc

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,89 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
377377
}
378378
}
379379

380+
void BlockMultiheadAttentionInferXPUMeta(
381+
const MetaTensor& qkv,
382+
const MetaTensor& key_cache,
383+
const MetaTensor& value_cache,
384+
const MetaTensor& seq_lens_encoder,
385+
const MetaTensor& seq_lens_decoder,
386+
const MetaTensor& seq_lens_this_time,
387+
const MetaTensor& padding_offsets,
388+
const MetaTensor& cum_offsets,
389+
const MetaTensor& cu_seqlens_q,
390+
const MetaTensor& cu_seqlens_k,
391+
const MetaTensor& cache_k_per_batch_maxs,
392+
const MetaTensor& cache_v_per_batch_maxs,
393+
const MetaTensor& block_tables,
394+
const MetaTensor& pre_key_cache,
395+
const MetaTensor& pre_value_cache,
396+
const MetaTensor& rope_emb,
397+
const MetaTensor& mask,
398+
const MetaTensor& tgt_mask,
399+
const MetaTensor& cache_k_quant_scales,
400+
const MetaTensor& cache_v_quant_scales,
401+
const MetaTensor& cache_k_dequant_scales,
402+
const MetaTensor& cache_v_dequant_scales,
403+
const MetaTensor& qkv_out_scale,
404+
const MetaTensor& qkv_bias,
405+
const MetaTensor& out_shift,
406+
const MetaTensor& out_smooth,
407+
const MetaTensor& max_enc_len_this_time,
408+
const MetaTensor& max_dec_len_this_time,
409+
int max_seq_len,
410+
int block_size,
411+
bool use_neox_style,
412+
bool dynamic_cachekv_quant,
413+
const int quant_round_type,
414+
const float quant_max_bound,
415+
const float quant_min_bound,
416+
const float out_scale,
417+
const std::string& compute_dtype,
418+
MetaTensor* fmha_out,
419+
MetaTensor* qkv_out,
420+
MetaTensor* key_cache_out,
421+
MetaTensor* value_cache_out) {
422+
BlockMultiheadAttentionInferMeta(qkv,
423+
key_cache,
424+
value_cache,
425+
seq_lens_encoder,
426+
seq_lens_decoder,
427+
seq_lens_this_time,
428+
padding_offsets,
429+
cum_offsets,
430+
cu_seqlens_q,
431+
cu_seqlens_k,
432+
block_tables,
433+
pre_key_cache,
434+
pre_value_cache,
435+
rope_emb,
436+
mask,
437+
tgt_mask,
438+
cache_k_quant_scales,
439+
cache_v_quant_scales,
440+
cache_k_dequant_scales,
441+
cache_v_dequant_scales,
442+
qkv_out_scale,
443+
qkv_bias,
444+
out_shift,
445+
out_smooth,
446+
max_enc_len_this_time,
447+
max_dec_len_this_time,
448+
max_seq_len,
449+
block_size,
450+
use_neox_style,
451+
dynamic_cachekv_quant,
452+
quant_round_type,
453+
quant_max_bound,
454+
quant_min_bound,
455+
out_scale,
456+
compute_dtype,
457+
fmha_out,
458+
qkv_out,
459+
key_cache_out,
460+
value_cache_out);
461+
}
462+
380463
void Conv1dXPUInferMeta(const MetaTensor& x,
381464
const MetaTensor& x_max,
382465
const MetaTensor& filter,

paddle/phi/infermeta/fusion.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,49 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
128128
MetaTensor* key_cache_out,
129129
MetaTensor* value_cache_out);
130130

131+
void BlockMultiheadAttentionInferXPUMeta(
132+
const MetaTensor& qkv,
133+
const MetaTensor& key_cache,
134+
const MetaTensor& value_cache,
135+
const MetaTensor& seq_lens_encoder,
136+
const MetaTensor& seq_lens_decoder,
137+
const MetaTensor& seq_lens_this_time,
138+
const MetaTensor& padding_offsets,
139+
const MetaTensor& cum_offsets,
140+
const MetaTensor& cu_seqlens_q,
141+
const MetaTensor& cu_seqlens_k,
142+
const MetaTensor& cache_k_per_batch_maxs,
143+
const MetaTensor& cache_v_per_batch_maxs,
144+
const MetaTensor& block_tables,
145+
const MetaTensor& pre_key_cache,
146+
const MetaTensor& pre_value_cache,
147+
const MetaTensor& rope_emb,
148+
const MetaTensor& mask,
149+
const MetaTensor& tgt_mask,
150+
const MetaTensor& cache_k_quant_scales,
151+
const MetaTensor& cache_v_quant_scales,
152+
const MetaTensor& cache_k_dequant_scales,
153+
const MetaTensor& cache_v_dequant_scales,
154+
const MetaTensor& qkv_out_scale,
155+
const MetaTensor& qkv_bias,
156+
const MetaTensor& out_shift,
157+
const MetaTensor& out_smooth,
158+
const MetaTensor& max_enc_len_this_time,
159+
const MetaTensor& max_dec_len_this_time,
160+
int max_seq_len,
161+
int block_size,
162+
bool use_neox_style,
163+
bool dynamic_cachekv_quant,
164+
const int quant_round_type,
165+
const float quant_max_bound,
166+
const float quant_min_bound,
167+
const float out_scale,
168+
const std::string& compute_dtype,
169+
MetaTensor* fmha_out,
170+
MetaTensor* qkv_out,
171+
MetaTensor* key_cache_out,
172+
MetaTensor* value_cache_out);
173+
131174
void Conv1dXPUInferMeta(const MetaTensor& x,
132175
const MetaTensor& x_max,
133176
const MetaTensor& filter,

paddle/phi/kernels/cpu/tile_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@ PD_REGISTER_KERNEL(tile,
2727
double,
2828
int,
2929
int64_t,
30+
phi::dtype::float16,
3031
phi::dtype::complex<float>,
3132
phi::dtype::complex<double>) {}

0 commit comments

Comments
 (0)