|
23 | 23 | #include "cutlass/util/device_memory.h"
|
24 | 24 | #include "examples/common/helper.h"
|
25 | 25 | #include "paddle/phi/backends/gpu/gpu_context.h"
|
| 26 | +#include "paddle/phi/common/data_type.h" |
26 | 27 | namespace phi {
|
27 | 28 | namespace sparse {
|
28 | 29 | typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx,
|
@@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx,
|
115 | 116 | CUTLASS_CHECK(status);
|
116 | 117 | gemm_op(dev_ctx.stream());
|
117 | 118 | }
|
| 119 | +static void dispatchKernel(const GPUContext& dev_ctx, |
| 120 | + const void* const a, |
| 121 | + const void* const b, |
| 122 | + const void* const c, |
| 123 | + void* const d, |
| 124 | + const int m, |
| 125 | + const int n, |
| 126 | + const int k, |
| 127 | + const void* a_indices, |
| 128 | + const void* c_d_indices, |
| 129 | + const bool cutlass, |
| 130 | + const phi::DataType type) { |
| 131 | + if (!cutlass) return; |
| 132 | + |
| 133 | + if (type == phi::DataType::FLOAT16) { |
| 134 | + fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k); |
| 135 | + gather_gemm_scatter(dev_ctx, |
| 136 | + static_cast<const cutlass::half_t*>(a), |
| 137 | + static_cast<const cutlass::half_t*>(b), |
| 138 | + static_cast<const cutlass::half_t*>(c), |
| 139 | + static_cast<cutlass::half_t*>(d), |
| 140 | + m, |
| 141 | + n, |
| 142 | + k, |
| 143 | + static_cast<const int32_t*>(a_indices), |
| 144 | + static_cast<const int32_t*>(c_d_indices), |
| 145 | + static_cast<cutlass::half_t>(1), |
| 146 | + static_cast<cutlass::half_t>(1)); |
| 147 | + } else if (type == phi::DataType::FLOAT32) { |
| 148 | + fp32_gather_gemm_scatter gather_gemm_scatter = |
| 149 | + getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability()); |
| 150 | + gather_gemm_scatter(dev_ctx, |
| 151 | + static_cast<const float*>(a), |
| 152 | + static_cast<const float*>(b), |
| 153 | + static_cast<const float*>(c), |
| 154 | + static_cast<float*>(d), |
| 155 | + m, |
| 156 | + n, |
| 157 | + k, |
| 158 | + static_cast<const int32_t*>(a_indices), |
| 159 | + static_cast<const int32_t*>(c_d_indices), |
| 160 | + static_cast<float>(1), |
| 161 | + static_cast<float>(1)); |
| 162 | + } else if (type == phi::DataType::FLOAT64) { |
| 163 | + fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k); |
| 164 | + gather_gemm_scatter(dev_ctx, |
| 165 | + static_cast<const double*>(a), |
| 166 | + static_cast<const double*>(b), |
| 167 | + static_cast<const double*>(c), |
| 168 | + static_cast<double*>(d), |
| 169 | + m, |
| 170 | + n, |
| 171 | + k, |
| 172 | + static_cast<const int32_t*>(a_indices), |
| 173 | + static_cast<const int32_t*>(c_d_indices), |
| 174 | + static_cast<double>(1), |
| 175 | + static_cast<double>(1)); |
| 176 | + } |
| 177 | +} |
| 178 | + |
118 | 179 | struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 {
|
119 | 180 | using Gemm = cutlass::gemm::device::GemmUniversal<
|
120 | 181 | cutlass::half_t,
|
|
0 commit comments