Skip to content

Commit de04d9e

Browse files
authored
[Custom Device ]Solved several problems for CUDA custom device backend (#74411)
* solve several problems for CUDA custom device backend * fix cpu kernel compilation bug
1 parent 039bf8f commit de04d9e

File tree

8 files changed

+82
-47
lines changed

8 files changed

+82
-47
lines changed

paddle/phi/kernels/funcs/elementwise/elementwise_op_function.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
#include "paddle/phi/backends/gpu/gpu_info.h"
2525
#include "paddle/phi/common/transform.h"
2626
#include "paddle/phi/core/dense_tensor.h"
27+
#if !defined(PADDLE_WITH_CUDA) || !defined(PADDLE_WITH_CUSTOM_DEVICE)
2728
#include "paddle/phi/kernels/cpu/elementwise.h"
2829
#include "paddle/phi/kernels/cpu/elementwise_grad.h"
30+
#endif
2931
#include "paddle/phi/kernels/funcs/eigen/common.h"
3032
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
3133

paddle/phi/kernels/funcs/quant_dequant.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ limitations under the License. */
2020
#include "paddle/phi/common/float16.h"
2121
#include "paddle/phi/common/transform.h"
2222
#include "paddle/phi/kernels/funcs/aligned_vector.h"
23+
#ifndef PADDLE_WITH_CUSTOM_DEVICE
2324
#include "paddle/phi/kernels/funcs/blas/blas.h"
24-
25+
#endif
2526
namespace phi {
2627

2728
using backends::gpu::GpuLaunchConfig;

paddle/phi/kernels/funcs/sparse/convolution.h

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ limitations under the License. */
1818
#include "paddle/phi/core/kmap_cache.h"
1919
#include "paddle/phi/core/tensor_utils.h"
2020
#include "paddle/phi/kernels/empty_kernel.h"
21-
#include "paddle/phi/kernels/funcs/blas/blas.h"
21+
22+
#if !defined(PADDLE_WITH_CUDA) || !defined(PADDLE_WITH_CUSTOM_DEVICE)
23+
#include "paddle/phi/kernels/funcs/sparse/convolution_blas.h"
24+
#endif
2225

2326
namespace phi {
2427
namespace funcs {
@@ -154,47 +157,6 @@ inline void ResetSubmKernelSizeAndStrides(const DDim& kernel_dims,
154157
}
155158
}
156159

157-
template <typename T, typename Context>
158-
inline void SubmPreProcess(const Context& dev_ctx,
159-
const SparseCooTensor& x,
160-
const DenseTensor& kernel,
161-
const DenseTensor& out_grad,
162-
const int in_channels,
163-
const int out_channels,
164-
const int half_kernel_size,
165-
DenseTensor* kernel_grad,
166-
DenseTensor* x_grad) {
167-
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
168-
const bool is_params_freezing = kernel_grad == nullptr;
169-
if (!is_params_freezing) {
170-
T* d_kernel_ptr = kernel_grad->data<T>();
171-
blas.GEMM(CblasTrans,
172-
CblasNoTrans,
173-
x.non_zero_elements().dims()[1],
174-
out_grad.dims()[1],
175-
x.non_zero_elements().dims()[0],
176-
static_cast<T>(1),
177-
x.non_zero_elements().data<T>(),
178-
out_grad.data<T>(),
179-
static_cast<T>(0),
180-
d_kernel_ptr + half_kernel_size * in_channels * out_channels);
181-
}
182-
183-
// call gemm: d_x = out_grad * transpose(kernel)
184-
// (n, out_channels) * (out_channels, in_channels)
185-
T* x_grad_ptr = x_grad->data<T>();
186-
blas.GEMM(CblasNoTrans,
187-
CblasTrans,
188-
out_grad.dims()[0],
189-
in_channels,
190-
out_grad.dims()[1],
191-
static_cast<T>(1),
192-
out_grad.data<T>(),
193-
kernel.data<T>() + half_kernel_size * in_channels * out_channels,
194-
static_cast<T>(0),
195-
x_grad_ptr);
196-
}
197-
198160
inline const std::vector<int> PoolResetKernel(
199161
const std::vector<int>& kernel_sizes,
200162
const int in_channels,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/common/ddim.h"
18+
#include "paddle/phi/core/kmap_cache.h"
19+
#include "paddle/phi/core/tensor_utils.h"
20+
#include "paddle/phi/kernels/empty_kernel.h"
21+
#include "paddle/phi/kernels/funcs/blas/blas.h"
22+
23+
namespace phi {
24+
namespace funcs {
25+
namespace sparse {
26+
27+
template <typename T, typename Context>
28+
inline void SubmPreProcess(const Context& dev_ctx,
29+
const SparseCooTensor& x,
30+
const DenseTensor& kernel,
31+
const DenseTensor& out_grad,
32+
const int in_channels,
33+
const int out_channels,
34+
const int half_kernel_size,
35+
DenseTensor* kernel_grad,
36+
DenseTensor* x_grad) {
37+
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
38+
const bool is_params_freezing = kernel_grad == nullptr;
39+
if (!is_params_freezing) {
40+
T* d_kernel_ptr = kernel_grad->data<T>();
41+
blas.GEMM(CblasTrans,
42+
CblasNoTrans,
43+
x.non_zero_elements().dims()[1],
44+
out_grad.dims()[1],
45+
x.non_zero_elements().dims()[0],
46+
static_cast<T>(1),
47+
x.non_zero_elements().data<T>(),
48+
out_grad.data<T>(),
49+
static_cast<T>(0),
50+
d_kernel_ptr + half_kernel_size * in_channels * out_channels);
51+
}
52+
53+
// call gemm: d_x = out_grad * transpose(kernel)
54+
// (n, out_channels) * (out_channels, in_channels)
55+
T* x_grad_ptr = x_grad->data<T>();
56+
blas.GEMM(CblasNoTrans,
57+
CblasTrans,
58+
out_grad.dims()[0],
59+
in_channels,
60+
out_grad.dims()[1],
61+
static_cast<T>(1),
62+
out_grad.data<T>(),
63+
kernel.data<T>() + half_kernel_size * in_channels * out_channels,
64+
static_cast<T>(0),
65+
x_grad_ptr);
66+
}
67+
68+
} // namespace sparse
69+
} // namespace funcs
70+
} // namespace phi

paddle/phi/kernels/stride/as_complex_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ PD_REGISTER_KERNEL(
9191
}
9292
#endif
9393

94-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
94+
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUDA)
9595
PD_REGISTER_KERNEL(
9696
as_complex, Custom, STRIDED, phi::AsComplexStridedKernel, float, double) {
9797
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));

paddle/phi/kernels/stride/as_real_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ PD_REGISTER_KERNEL(as_real,
7171
}
7272
#endif
7373

74-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
74+
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUDA)
7575
PD_REGISTER_KERNEL(as_real,
7676
Custom,
7777
STRIDED,

paddle/phi/kernels/stride/complex_grad_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ PD_REGISTER_KERNEL(imag_grad,
126126
}
127127
#endif
128128

129-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
129+
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUDA)
130130
PD_REGISTER_KERNEL(real_grad,
131131
Custom,
132132
STRIDED,

paddle/phi/kernels/stride/complex_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ PD_REGISTER_KERNEL(imag,
119119
}
120120
#endif
121121

122-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
122+
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUDA)
123123
PD_REGISTER_KERNEL(real,
124124
Custom,
125125
STRIDED,

0 commit comments

Comments
 (0)