Skip to content

Commit 79bfb18

Browse files
authored
multihead_matmul op support codegen and kernel remove to phi (#56846)
1 parent 7fd6ffb commit 79bfb18

File tree

13 files changed

+1035
-1026
lines changed

13 files changed

+1035
-1026
lines changed

paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
#include "paddle/fluid/framework/tensor_util.h"
2323
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
2424
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
25-
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
2625
#include "paddle/fluid/platform/device_context.h"
2726
#include "paddle/phi/kernels/funcs/blas/blas.h"
27+
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
2828

2929
namespace paddle {
3030
namespace inference {
@@ -254,7 +254,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
254254
platform::CUDAPlace(device_id)));
255255

256256
const phi::GPUContext &dev_ctx = *device_ctx;
257-
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
257+
phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
258258
multihead_compute_func(dev_ctx,
259259
batch,
260260
seq_len,
@@ -341,7 +341,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
341341
tptr, static_cast<half>(scale_), n_q);
342342

343343
const phi::GPUContext &dev_ctx = *device_ctx;
344-
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
344+
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
345345
multihead_compute_func(dev_ctx,
346346
batch,
347347
seq_len,

paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
2525
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
2626
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
27-
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
2827
#include "paddle/fluid/platform/device_context.h"
2928
#include "paddle/phi/kernels/funcs/blas/blas.h"
29+
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
3030

3131
namespace paddle {
3232
namespace inference {
@@ -396,7 +396,7 @@ int QkvToContextPluginDynamic::enqueue(
396396
platform::CUDAPlace(device_id)));
397397

398398
const phi::GPUContext &dev_ctx = *device_ctx;
399-
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
399+
phi::funcs::MultiheadGPUComputeFunctor<float> multihead_compute_func;
400400
multihead_compute_func(dev_ctx,
401401
batch,
402402
seq_len,
@@ -506,7 +506,7 @@ int QkvToContextPluginDynamic::enqueue(
506506
tptr, static_cast<half>(scale_), n_q);
507507

508508
const phi::GPUContext &dev_ctx = *device_ctx;
509-
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
509+
phi::funcs::MultiheadGPUComputeFunctor<half> multihead_compute_func;
510510
multihead_compute_func(dev_ctx,
511511
batch,
512512
seq_len,

paddle/fluid/operators/fused/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ register_operators(
1010
fusion_transpose_flatten_concat_op
1111
fusion_conv_inception_op
1212
fused_fc_elementwise_layernorm_op
13-
multihead_matmul_op
1413
self_dp_attention_op
1514
skip_layernorm_op
1615
yolo_box_head_op
@@ -74,8 +73,6 @@ if(WITH_GPU OR WITH_ROCM)
7473
endif()
7574
# fused_fc_elementwise_layernorm_op
7675
op_library(fused_fc_elementwise_layernorm_op)
77-
# multihead_matmul_op
78-
op_library(multihead_matmul_op)
7976
op_library(skip_layernorm_op)
8077
op_library(yolo_box_head_op)
8178
op_library(yolo_box_post_op)

paddle/fluid/operators/fused/multihead_matmul_op.cc

Lines changed: 0 additions & 116 deletions
This file was deleted.

0 commit comments

Comments
 (0)