Skip to content

Commit 40130a5

Browse files
authored
[INTEL_HPU] enable tensor_wise_fp8 kernels (PaddlePaddle#2148)
* fuse MoE gate matmul to fused_gate_moe kernel
1 parent cd89e54 commit 40130a5

29 files changed

+4821
-706
lines changed

backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc

Lines changed: 147 additions & 217 deletions
Large diffs are not rendered by default.

backends/intel_hpu/custom_ops/llama_infer/fused_fp8_sdpa.cc

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
#include "paddle/extension.h"
2121
#include "utils/utils.h"
2222

23+
#define SDPA_SET_FLAGS(condition, flag_name) \
24+
if (condition) { \
25+
flags |= SdpaFlags_t::SDPA_FLAGS_##flag_name; \
26+
}
2327
#define SDPA_SET_INPUT_AND_FLAGS(ptr, flag_name) \
2428
if (ptr) { \
2529
flags |= SdpaFlags_t::SDPA_FLAGS_##flag_name; \
@@ -35,7 +39,7 @@ struct SDPAParams {
3539

3640
class FusedFp8Sdpa : public HpuOperator {
3741
public:
38-
FusedFp8Sdpa() : HpuOperator("sdpa_recomp_fwd_hf8") {}
42+
explicit FusedFp8Sdpa(std::string guid) : HpuOperator(guid) {}
3943
void AddNode(ConvertTensors& ct, SDPAParams& params) {
4044
auto inputs = ct.GetTensors();
4145
auto outputs = ct.GetTensors(false);
@@ -67,12 +71,24 @@ class FusedFp8Sdpa : public HpuOperator {
6771
}
6872

6973
std::vector<synTensor> sync_outputs;
70-
for (size_t i = 0; i < outputs.size(); i++) {
71-
sync_outputs.push_back(createTensor(outputs[i].dims.size(),
72-
outputs[i].type,
73-
outputs[i].dims,
74-
true,
75-
outputs[i].name));
74+
// [0] out, bf16
75+
sync_outputs.push_back(createTensor(outputs[0].dims.size(),
76+
outputs[0].type,
77+
outputs[0].dims,
78+
true,
79+
outputs[0].name));
80+
if (params.params.flags & SdpaFlags_t::SDPA_FLAGS_AMAX_S) {
81+
// [1] m, bf16 [1]
82+
sync_outputs.push_back(createTensor(1, syn_type_bf16, {1}, false, "m"));
83+
// [2] linv, float32 [1]
84+
sync_outputs.push_back(
85+
createTensor(1, syn_type_float, {1}, false, "linv"));
86+
// [3] seed, int32 [1]
87+
sync_outputs.push_back(
88+
createTensor(1, syn_type_int32, {1}, false, "seed"));
89+
// [4] amax_s, float32 [1]
90+
sync_outputs.push_back(
91+
createTensor(1, syn_type_float, {1}, true, outputs[1].name));
7692
}
7793

7894
status = synNodeCreate(graphHandle_,
@@ -105,9 +121,13 @@ void fused_fp8_sdpa(const Context& dev_ctx,
105121
const paddle::optional<phi::DenseTensor>& d_scale_s,
106122
float scale,
107123
bool causal,
108-
phi::DenseTensor* out) {
124+
bool is_amax_s,
125+
phi::DenseTensor* out,
126+
phi::DenseTensor* amax) {
109127
// allocate memory on device.
110128
dev_ctx.template Alloc<T>(out);
129+
dev_ctx.template Alloc<float>(amax);
130+
111131
if (out->numel() == 0) {
112132
return;
113133
}
@@ -117,6 +137,7 @@ void fused_fp8_sdpa(const Context& dev_ctx,
117137
ct.Add(k);
118138
ct.Add(v);
119139

140+
std::string guid = "sdpa_recomp_fwd_hf8";
120141
unsigned int flags = 0;
121142

122143
SDPA_SET_INPUT_AND_FLAGS(d_scale_q.get_ptr(), D_SCALE_Q)
@@ -125,6 +146,10 @@ void fused_fp8_sdpa(const Context& dev_ctx,
125146
SDPA_SET_INPUT_AND_FLAGS(q_scale_s.get_ptr(), Q_SCALE_S)
126147
SDPA_SET_INPUT_AND_FLAGS(q_scale_o.get_ptr(), Q_SCALE_O)
127148
SDPA_SET_INPUT_AND_FLAGS(d_scale_s.get_ptr(), D_SCALE_S)
149+
if (flags == 0) {
150+
guid = "sdpa_recomp_fwd_bf16";
151+
}
152+
SDPA_SET_FLAGS(is_amax_s, AMAX_S)
128153

129154
SDPAParams params{};
130155

@@ -141,6 +166,8 @@ void fused_fp8_sdpa(const Context& dev_ctx,
141166
params.params.flags = flags;
142167

143168
ct.Add(*out, false);
169+
ct.Add(*amax, false);
170+
144171
std::vector<DIMS> inputs_dims = ct.GetDims();
145172

146173
OpCacheOperator op_info;
@@ -149,7 +176,7 @@ void fused_fp8_sdpa(const Context& dev_ctx,
149176
auto recipe = op_info.GetRecipe();
150177

151178
if (recipe == nullptr) {
152-
FusedFp8Sdpa op;
179+
FusedFp8Sdpa op(guid);
153180
op.AddNode(ct, params);
154181
op.Compile();
155182
op_info.setOp(op);
@@ -175,7 +202,8 @@ std::vector<paddle::Tensor> FusedFp8SdpaForward(
175202
const paddle::optional<paddle::Tensor>& q_scale_o,
176203
const paddle::optional<paddle::Tensor>& d_scale_s,
177204
bool causal,
178-
float scale) {
205+
float scale,
206+
bool is_amax_s) {
179207
auto dev_ctx = static_cast<const phi::CustomContext*>(
180208
paddle::experimental::DeviceContextPool::Instance().Get(q.place()));
181209

@@ -242,6 +270,9 @@ std::vector<paddle::Tensor> FusedFp8SdpaForward(
242270
auto out_tensor = std::make_shared<phi::DenseTensor>();
243271
out_tensor->Resize(q_tensor->dims());
244272

273+
auto amax_tensor = std::make_shared<phi::DenseTensor>();
274+
amax_tensor->Resize({1});
275+
245276
custom_kernel::fused_fp8_sdpa<phi::dtype::bfloat16>(
246277
*dev_ctx,
247278
*q_tensor,
@@ -256,11 +287,11 @@ std::vector<paddle::Tensor> FusedFp8SdpaForward(
256287
d_scale_s ? *d_scale_s_tensor : paddle::optional<phi::DenseTensor>(),
257288
scale,
258289
causal,
259-
out_tensor.get());
260-
261-
paddle::Tensor out(out_tensor);
290+
is_amax_s,
291+
out_tensor.get(),
292+
amax_tensor.get());
262293

263-
return {out};
294+
return {paddle::Tensor(out_tensor), paddle::Tensor(amax_tensor)};
264295
}
265296

266297
std::vector<std::vector<int64_t>> FusedFp8SdpaForwardShape(
@@ -271,7 +302,7 @@ std::vector<std::vector<int64_t>> FusedFp8SdpaForwardShape(
271302
int64_t num_heads = query_states_shape[1];
272303
int64_t seq_len = query_states_shape[2];
273304
int head_dim = query_states_shape[3];
274-
return {{bsz, num_heads, seq_len, head_dim}};
305+
return {{bsz, num_heads, seq_len, head_dim}, {1}};
275306
}
276307

277308
std::vector<paddle::DataType> FusedFp8SdpaForwardDtype(
@@ -294,8 +325,8 @@ PD_BUILD_OP(fused_fp8_sdpa)
294325
paddle::Optional("q_scale_o"),
295326
paddle::Optional("d_scale_s"),
296327
})
297-
.Attrs({"causal: bool", "scaling_factor: float"})
298-
.Outputs({"out"})
328+
.Attrs({"causal: bool", "scaling_factor: float", "is_amax_s: bool"})
329+
.Outputs({"out", "amax"})
299330
.SetKernelFn(PD_KERNEL(FusedFp8SdpaForward))
300331
.SetInferShapeFn(PD_INFER_SHAPE(FusedFp8SdpaForwardShape))
301332
.SetInferDtypeFn(PD_INFER_DTYPE(FusedFp8SdpaForwardDtype));

0 commit comments

Comments
 (0)