Skip to content

Commit adbbf16

Browse files
committed
[MLU] fuse adam & upgrade interface & scale
1 parent f5dec55 commit adbbf16

File tree

9 files changed

+191
-201
lines changed

9 files changed

+191
-201
lines changed

backends/mlu/kernels/adam_kernel.cc

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ void AdamKernel(const Context& dev_ctx,
115115
"value is:%d.",
116116
beta2_pow_out->numel()));
117117

118-
const phi::DenseTensor* beta1_tensor = nullptr;
119-
const phi::DenseTensor* beta2_tensor = nullptr;
120-
const phi::DenseTensor* epsilon_tensor = nullptr;
118+
Tensor beta1_tensor;
121119

122120
phi::DenseTensor beta1_tmp;
123121
phi::DenseTensor beta2_tmp;
@@ -128,19 +126,30 @@ void AdamKernel(const Context& dev_ctx,
128126
epsilon_tmp.Resize({1});
129127

130128
MPDType beta1 = beta1_in.to<MPDType>();
131-
dev_ctx.template Alloc<MPDType>(&beta1_tmp);
132-
FillMLUTensorWithHostValue<MPDType>(dev_ctx, beta1, &beta1_tmp);
133-
beta1_tensor = &beta1_tmp;
134129

135130
MPDType beta2 = beta2_in.to<MPDType>();
136-
dev_ctx.template Alloc<MPDType>(&beta2_tmp);
137-
FillMLUTensorWithHostValue<MPDType>(dev_ctx, beta2, &beta2_tmp);
138-
beta2_tensor = &beta2_tmp;
139131

140132
MPDType epsilon = epsilon_in.to<MPDType>();
141-
dev_ctx.template Alloc<MPDType>(&epsilon_tmp);
142-
FillMLUTensorWithHostValue<MPDType>(dev_ctx, epsilon, &epsilon_tmp);
143-
epsilon_tensor = &epsilon_tmp;
133+
134+
std::vector<MPDType> parameter_list;
135+
parameter_list.push_back(beta1);
136+
parameter_list.push_back(beta2);
137+
parameter_list.push_back(epsilon);
138+
139+
Tensor dst;
140+
dst.Resize({3});
141+
auto dst_place = phi::CustomPlace();
142+
C_Device_st device{dst_place.GetDeviceId()};
143+
void* dst_ptr = dev_ctx.template Alloc<MPDType>(&dst);
144+
auto src_ptr = static_cast<void*>(parameter_list.data());
145+
MemCpyH2D(&device, dst_ptr, src_ptr, parameter_list.size() * sizeof(MPDType));
146+
147+
const void* beta1_tensor_ptr = nullptr;
148+
const void* beta2_tensor_ptr = nullptr;
149+
const void* epsilon_tensor_ptr = nullptr;
150+
beta1_tensor_ptr = dst_ptr,
151+
beta2_tensor_ptr = static_cast<char*>(dst_ptr) + sizeof(MPDType);
152+
epsilon_tensor_ptr = static_cast<char*>(dst_ptr) + 2 * sizeof(MPDType);
144153

145154
Tensor t_param_in_out, t_grad;
146155
t_param_in_out.Resize(param.dims());
@@ -198,11 +207,11 @@ void AdamKernel(const Context& dev_ctx,
198207
grad_desc.get(),
199208
GetBasePtr(&t_grad),
200209
GetBasePtr(&learning_rate),
201-
GetBasePtr(beta1_tensor),
202-
GetBasePtr(beta2_tensor),
210+
beta1_tensor_ptr,
211+
beta2_tensor_ptr,
203212
GetBasePtr(beta1_pow),
204213
GetBasePtr(beta2_pow),
205-
GetBasePtr(epsilon_tensor),
214+
epsilon_tensor_ptr,
206215
/*use_nesterov*/ false);
207216

208217
if (param.dtype() != phi::DataType::FLOAT32) {
@@ -221,7 +230,6 @@ void AdamKernel(const Context& dev_ctx,
221230
param_out_desc.get(),
222231
GetBasePtr(param_out));
223232
}
224-
225233
if (!use_global_beta_pow) {
226234
if (beta1_pow->place().GetType() == phi::AllocationType::CPU &&
227235
beta2_pow->place().GetType() == phi::AllocationType::CPU) {
@@ -235,7 +243,10 @@ void AdamKernel(const Context& dev_ctx,
235243
dev_ctx.template Alloc<MPDType>(beta1_pow_out);
236244
dev_ctx.template Alloc<MPDType>(beta2_pow_out);
237245

238-
MLUCnnlTensorDesc beta1_desc(*beta1_tensor);
246+
beta1_tensor.Resize({1});
247+
MLUCnnlTensorDesc beta1_desc(
248+
beta1_tensor, CNNL_LAYOUT_ARRAY, ToCnnlDataType<MPDType>());
249+
239250
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL,
240251
ToCnnlDataType<MPDType>(),
241252
CNNL_NOT_PROPAGATE_NAN);
@@ -245,7 +256,7 @@ void AdamKernel(const Context& dev_ctx,
245256
beta1_desc.get(),
246257
GetBasePtr(beta1_pow),
247258
beta1_desc.get(),
248-
GetBasePtr(beta1_tensor),
259+
beta1_tensor_ptr,
249260
beta1_desc.get(),
250261
GetBasePtr(beta1_pow_out),
251262
ToCnnlDataType<MPDType>());
@@ -255,7 +266,7 @@ void AdamKernel(const Context& dev_ctx,
255266
beta1_desc.get(),
256267
GetBasePtr(beta2_pow),
257268
beta1_desc.get(),
258-
GetBasePtr(beta2_tensor),
269+
beta2_tensor_ptr,
259270
beta1_desc.get(),
260271
GetBasePtr(beta2_pow_out),
261272
ToCnnlDataType<MPDType>());

backends/mlu/kernels/funcs/mlu_baseop.cc

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,8 +1788,18 @@ NormalizeDesc::~NormalizeDesc() {
17881788
beta_ptr = static_cast<const void*>(&beta_int64);
17891789
}
17901790

1791-
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize(
1792-
handle, a_desc, b_desc, output_desc, &workspace_size));
1791+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetOpTensorWorkspaceSize_v2(handle,
1792+
op_tensor_desc,
1793+
alpha1_ptr,
1794+
a_desc,
1795+
a,
1796+
alpha2_ptr,
1797+
b_desc,
1798+
b,
1799+
beta_ptr,
1800+
output_desc,
1801+
output,
1802+
&workspace_size));
17931803

17941804
Tensor workspace;
17951805
workspace.Resize({static_cast<int64_t>(workspace_size)});
@@ -1931,16 +1941,16 @@ NormalizeDesc::~NormalizeDesc() {
19311941

19321942
/* static */ void MLUCnnl::StridedSlice(
19331943
const Context& ctx,
1934-
const int begin[],
1935-
const int end[],
1936-
const int strides[],
1944+
const int64_t begin[],
1945+
const int64_t end[],
1946+
const int64_t strides[],
19371947
const cnnlTensorDescriptor_t input_desc,
19381948
const void* input,
19391949
const cnnlTensorDescriptor_t output_desc,
19401950
void* output) {
19411951
cnnlHandle_t handle = GetHandleFromCTX(ctx);
19421952

1943-
PADDLE_ENFORCE_MLU_SUCCESS(cnnlStridedSlice(
1953+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlStridedSlice_v2(
19441954
handle, input_desc, input, begin, end, strides, output_desc, output));
19451955
}
19461956

@@ -2312,14 +2322,23 @@ NormalizeDesc::~NormalizeDesc() {
23122322
void* index) {
23132323
cnnlHandle_t handle = GetHandleFromCTX(ctx);
23142324

2315-
PADDLE_ENFORCE_MLU_SUCCESS(cnnlAdaptivePoolingForward(handle,
2316-
input_desc,
2317-
input,
2318-
pool_mode,
2319-
output_desc,
2320-
output,
2321-
index_desc,
2322-
index));
2325+
size_t workspace_size = 0;
2326+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetAdaptivePoolingForwardWorkspaceSize(
2327+
handle, input_desc, pool_mode, output_desc, &workspace_size));
2328+
Tensor workspace;
2329+
workspace.Resize({static_cast<int64_t>(workspace_size)});
2330+
void* workspace_ptr = ctx.Alloc(&workspace, DataType::INT8, workspace_size);
2331+
2332+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlAdaptivePoolingForward_v2(handle,
2333+
input_desc,
2334+
input,
2335+
pool_mode,
2336+
workspace_ptr,
2337+
workspace_size,
2338+
output_desc,
2339+
output,
2340+
index_desc,
2341+
index));
23232342
}
23242343

23252344
/* static */ void MLUCnnl::Pool3D(const Context& ctx,
@@ -3280,7 +3299,6 @@ NormalizeDesc::~NormalizeDesc() {
32803299
const cnnlTensorDescriptor_t output_desc,
32813300
void* output) {
32823301
cnnlHandle_t handle = GetHandleFromCTX(ctx);
3283-
32843302
PADDLE_ENFORCE_MLU_SUCCESS(cnnlInterp_v2(handle,
32853303
align_corners,
32863304
half_pixel_centers,
@@ -3295,7 +3313,7 @@ NormalizeDesc::~NormalizeDesc() {
32953313

32963314
/* static */ void MLUCnnl::InterpBackward(
32973315
const Context& ctx,
3298-
const cnnlInterpBackwardMode_t mode,
3316+
const cnnlInterpMode_t mode,
32993317
const bool align_corners,
33003318
const bool half_pixel_centers,
33013319
const cnnlTensorDescriptor_t input_desc,
@@ -3304,16 +3322,26 @@ NormalizeDesc::~NormalizeDesc() {
33043322
void* output) {
33053323
cnnlHandle_t handle = GetHandleFromCTX(ctx);
33063324

3307-
PADDLE_ENFORCE_MLU_SUCCESS(cnnlInterpBackward_v2(handle,
3308-
align_corners,
3309-
half_pixel_centers,
3310-
mode,
3311-
NULL,
3312-
true,
3313-
input_desc,
3314-
input,
3315-
output_desc,
3316-
output));
3325+
cnnlInterpDescriptor_t interp_desc;
3326+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateInterpDescriptor(&interp_desc));
3327+
3328+
cnnlInterpAlgo_t algo;
3329+
if (align_corners == false && half_pixel_centers == false) {
3330+
algo = CNNL_INTERP_ALGO_0;
3331+
} else if (align_corners == false && half_pixel_centers == true) {
3332+
algo = CNNL_INTERP_ALGO_1;
3333+
} else if (align_corners == true && half_pixel_centers == false) {
3334+
algo = CNNL_INTERP_ALGO_3;
3335+
} else if (align_corners == true && half_pixel_centers == true) {
3336+
algo = CNNL_INTERP_ALGO_4;
3337+
}
3338+
PADDLE_ENFORCE_MLU_SUCCESS(
3339+
cnnlSetInterpDescriptor_v2(interp_desc, input_desc, mode, algo, NULL));
3340+
3341+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlInterpBackward_v3(
3342+
handle, interp_desc, input_desc, input, output_desc, output));
3343+
3344+
PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroyInterpDescriptor(interp_desc));
33173345
}
33183346

33193347
/* static */ void MLUCnnl::Cast(const Context& ctx,

backends/mlu/kernels/funcs/mlu_baseop.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,9 +1120,9 @@ class MLUCnnl {
11201120
void* indices_out);
11211121

11221122
static void StridedSlice(const Context& ctx,
1123-
const int begin[],
1124-
const int end[],
1125-
const int strides[],
1123+
const int64_t begin[],
1124+
const int64_t end[],
1125+
const int64_t strides[],
11261126
const cnnlTensorDescriptor_t input_desc,
11271127
const void* input,
11281128
const cnnlTensorDescriptor_t output_desc,
@@ -1807,7 +1807,7 @@ class MLUCnnl {
18071807
void* output);
18081808

18091809
static void InterpBackward(const Context& ctx,
1810-
const cnnlInterpBackwardMode_t mode,
1810+
const cnnlInterpMode_t mode,
18111811
const bool align_corners,
18121812
const bool half_pixel_centers,
18131813
const cnnlTensorDescriptor_t input_desc,

backends/mlu/kernels/gather_nd_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ PD_REGISTER_PLUGIN_KERNEL(gather_nd,
141141
mlu,
142142
ALL_LAYOUT,
143143
custom_kernel::GatherNdKernel,
144+
int,
144145
int64_t,
145146
float,
146147
phi::dtype::float16) {}

backends/mlu/kernels/interpolate_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ void InterpolateGradKernel(
508508
CNNL_LAYOUT_NHWC,
509509
ToCnnlDataType(transformed_input_grad.dtype()));
510510
MLUCnnl::InterpBackward(dev_ctx,
511-
GetMLUCnnlInterpBackwardMode(interp_method),
511+
GetMLUCnnlInterpMode(interp_method),
512512
align_corners,
513513
align_center,
514514
input_desc.get(),

0 commit comments

Comments
 (0)