Skip to content

Commit 913f40e

Browse files
author
zhangkaihuo
authored
[cherry-pick] remove if constexpr(), which is not supported on gcc54 (#50421)
att, cherry-pick #48563
1 parent eb61074 commit 913f40e

File tree

2 files changed

+73
-54
lines changed

2 files changed

+73
-54
lines changed

paddle/phi/kernels/sparse/gpu/conv_kernel.cu

Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
150150
const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i];
151151
const IntT* scatter_indices =
152152
rulebook_ptr + rulebook_len + h_offsets_ptr[i];
153-
154-
if constexpr (std::is_same<T, phi::dtype::float16>::value &&
155-
std::is_same<IntT, int32_t>::value) {
156-
fp16_gather_gemm_scatter gather_gemm_scatter =
157-
getBestFp16Kernel(M, N, K);
158-
gather_gemm_scatter(
159-
dev_ctx,
160-
reinterpret_cast<const cutlass::half_t*>(
161-
x.non_zero_elements().data<T>()),
162-
reinterpret_cast<const cutlass::half_t*>(tmp_kernel_ptr),
163-
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
164-
reinterpret_cast<cutlass::half_t*>(out_values_ptr),
165-
M,
166-
N,
167-
K,
168-
static_cast<const int32_t*>(gather_indices),
169-
static_cast<const int32_t*>(scatter_indices),
170-
static_cast<cutlass::half_t>(1),
171-
static_cast<cutlass::half_t>(1));
172-
}
173-
if constexpr (std::is_same<T, float>::value &&
174-
std::is_same<IntT, int32_t>::value) {
175-
fp32_gather_gemm_scatter gather_gemm_scatter =
176-
getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
177-
gather_gemm_scatter(dev_ctx,
178-
x.non_zero_elements().data<T>(),
179-
tmp_kernel_ptr,
180-
out_values_ptr,
181-
out_values_ptr,
182-
M,
183-
N,
184-
K,
185-
gather_indices,
186-
scatter_indices,
187-
static_cast<T>(1),
188-
static_cast<T>(1));
189-
}
190-
if constexpr (std::is_same<T, double>::value &&
191-
std::is_same<IntT, int32_t>::value) {
192-
fp64_gather_gemm_scatter gather_gemm_scatter =
193-
getBestFp64Kernel(M, N, K);
194-
gather_gemm_scatter(dev_ctx,
195-
x.non_zero_elements().data<T>(),
196-
tmp_kernel_ptr,
197-
out_values_ptr,
198-
out_values_ptr,
199-
M,
200-
N,
201-
K,
202-
gather_indices,
203-
scatter_indices,
204-
static_cast<T>(1),
205-
static_cast<T>(1));
206-
}
153+
dispatchKernel(dev_ctx,
154+
x.non_zero_elements().data<T>(),
155+
tmp_kernel_ptr,
156+
out_values_ptr,
157+
out_values_ptr,
158+
M,
159+
N,
160+
K,
161+
gather_indices,
162+
scatter_indices,
163+
cutlass,
164+
x.dtype());
207165
}
208166
} else {
209167
#endif

paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "cutlass/util/device_memory.h"
2424
#include "examples/common/helper.h"
2525
#include "paddle/phi/backends/gpu/gpu_context.h"
26+
#include "paddle/phi/common/data_type.h"
2627
namespace phi {
2728
namespace sparse {
2829
typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
@@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
115116
CUTLASS_CHECK(status);
116117
gemm_op(dev_ctx.stream());
117118
}
119+
static void dispatchKernel(const GPUContext& dev_ctx,
120+
const void* const a,
121+
const void* const b,
122+
const void* const c,
123+
void* const d,
124+
const int m,
125+
const int n,
126+
const int k,
127+
const void* a_indices,
128+
const void* c_d_indices,
129+
const bool cutlass,
130+
const phi::DataType type) {
131+
if (!cutlass) return;
132+
133+
if (type == phi::DataType::FLOAT16) {
134+
fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k);
135+
gather_gemm_scatter(dev_ctx,
136+
static_cast<const cutlass::half_t*>(a),
137+
static_cast<const cutlass::half_t*>(b),
138+
static_cast<const cutlass::half_t*>(c),
139+
static_cast<cutlass::half_t*>(d),
140+
m,
141+
n,
142+
k,
143+
static_cast<const int32_t*>(a_indices),
144+
static_cast<const int32_t*>(c_d_indices),
145+
static_cast<cutlass::half_t>(1),
146+
static_cast<cutlass::half_t>(1));
147+
} else if (type == phi::DataType::FLOAT32) {
148+
fp32_gather_gemm_scatter gather_gemm_scatter =
149+
getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability());
150+
gather_gemm_scatter(dev_ctx,
151+
static_cast<const float*>(a),
152+
static_cast<const float*>(b),
153+
static_cast<const float*>(c),
154+
static_cast<float*>(d),
155+
m,
156+
n,
157+
k,
158+
static_cast<const int32_t*>(a_indices),
159+
static_cast<const int32_t*>(c_d_indices),
160+
static_cast<float>(1),
161+
static_cast<float>(1));
162+
} else if (type == phi::DataType::FLOAT64) {
163+
fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k);
164+
gather_gemm_scatter(dev_ctx,
165+
static_cast<const double*>(a),
166+
static_cast<const double*>(b),
167+
static_cast<const double*>(c),
168+
static_cast<double*>(d),
169+
m,
170+
n,
171+
k,
172+
static_cast<const int32_t*>(a_indices),
173+
static_cast<const int32_t*>(c_d_indices),
174+
static_cast<double>(1),
175+
static_cast<double>(1));
176+
}
177+
}
178+
118179
struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
119180
using Gemm = cutlass::gemm::device::GemmUniversal<
120181
cutlass::half_t,

0 commit comments

Comments
 (0)