Skip to content

Commit eabb554

Browse files
A-nnonymousmaxiaolong001
authored andcommitted
Enhancing fused_transpose_split_quant with fp8 capability. (PaddlePaddle#74471)
* Enhanced fused_transpose_split_quant with fp8 capability. * optimize performance. * Clean comment * clean miscs * Fix example
1 parent 5450354 commit eabb554

File tree

6 files changed

+176
-84
lines changed

6 files changed

+176
-84
lines changed

paddle/phi/infermeta/fusion.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,16 +2421,17 @@ void FusedMultiTransformerInt8InferMeta(
24212421
}
24222422

24232423
void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
2424+
const MetaTensor& input_scales,
24242425
const IntArray& tokens_per_expert,
24252426
bool pow_2_scales,
24262427
std::vector<MetaTensor*> outs,
24272428
std::vector<MetaTensor*> scales) {
24282429
PADDLE_ENFORCE_EQ(
2429-
x.dtype(),
2430-
DataType::BFLOAT16,
2431-
common::errors::InvalidArgument(
2432-
"The dtype of Input(x) must be BFLOAT16, but received %s",
2433-
x.dtype()));
2430+
x.dtype() == DataType::BFLOAT16 || x.dtype() == DataType::FLOAT8_E4M3FN,
2431+
true,
2432+
common::errors::InvalidArgument("The dtype of Input(x) must be BFLOAT16 "
2433+
"or FLOAT8_E4M3FN, but received %s",
2434+
x.dtype()));
24342435

24352436
auto x_dims = x.dims();
24362437

paddle/phi/infermeta/fusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ void FusedMultiTransformerInt8InferMeta(
669669
MetaTensor* out);
670670

671671
void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
672+
const MetaTensor& input_scales,
672673
const IntArray& tokens_per_expert,
673674
bool pow_2_scales,
674675
std::vector<MetaTensor*> outs,

paddle/phi/kernels/fusion/gpu/fused_transpose_split_quant_kernel.cu

Lines changed: 94 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,62 @@ struct __align__(sizeof(T) * VecSize) VecType {
2929
}
3030
};
3131

32-
template <int VecSize>
33-
__device__ void BlockLoad(const phi::bfloat16* input,
32+
template <typename InT, int VecSize>
33+
__device__ void BlockLoad(const InT* input,
34+
const float* input_scales,
3435
__nv_bfloat16 x[8][4],
35-
size_t K) {
36+
size_t K,
37+
size_t k_scaled) {
38+
constexpr bool need_dequant = std::is_same_v<InT, phi::dtype::float8_e4m3fn>;
39+
40+
#pragma unroll
3641
for (uint32_t i = 0; i < 8; i++) {
37-
size_t off_m = blockIdx.x * size_t(128) + threadIdx.y + i * 16;
38-
size_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
39-
size_t offset = off_m * K + off_k;
42+
const uint32_t local_off_M = threadIdx.y + i * 16;
43+
const uint32_t off_m = blockIdx.x * 128 + local_off_M;
44+
const uint32_t off_k = blockIdx.y * 128 + threadIdx.x * VecSize;
45+
const size_t offset = off_m * K + off_k;
46+
47+
float scale;
48+
if constexpr (need_dequant) {
49+
const uint32_t m_base = blockIdx.x * 128;
50+
const uint32_t m_stride = k_scaled;
51+
scale = input_scales[off_m * m_stride + blockIdx.y];
52+
}
4053

54+
#pragma unroll
4155
for (uint32_t j = 0; j < 4; j += VecSize) {
42-
if (off_k + j * 32 < K) {
43-
size_t idx = offset + j * 32;
44-
using LoadT = VecType<__nv_bfloat16, VecSize>;
45-
LoadT data = *reinterpret_cast<const LoadT*>(input + idx);
46-
for (uint32_t k = 0; k < VecSize; k++) {
47-
x[i][j + k] = data[k];
56+
const size_t idx = offset + j * 32;
57+
using LoadT = VecType<InT, VecSize>;
58+
LoadT data = *reinterpret_cast<const LoadT*>(input + idx);
59+
#pragma unroll
60+
for (uint32_t k = 0; k < VecSize; k++) {
61+
if constexpr (need_dequant) {
62+
x[i][j + k] = __float2bfloat16(static_cast<float>(data[k]) * scale);
63+
} else {
64+
x[i][j + k] = (*reinterpret_cast<__nv_bfloat16*>(&data[k]));
4865
}
4966
}
5067
}
5168
}
5269
}
53-
5470
template <bool Pow2Scales>
5571
__device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
56-
float col_scale[128],
72+
float scales[128],
5773
__nv_bfloat16* shm) {
5874
// reduce [(8), 16, 32, 4] => [16, 32, 4]
5975
__nv_bfloat16 warp_max[4];
76+
#pragma unroll
6077
for (uint32_t i = 0; i < 8; i++) {
78+
#pragma unroll
6179
for (uint32_t j = 0; j < 4; j++) {
62-
__nv_bfloat16 t = BF16_ABS(x[i][j]);
80+
const __nv_bfloat16 t = BF16_ABS(x[i][j]);
6381
warp_max[j] = i == 0 ? t : BF16_MAX(warp_max[j], t);
6482
}
6583
}
6684

6785
// reduce [(16), 32, 4] => [8, 32, 4]
6886
if (threadIdx.y >= 8) {
87+
#pragma unroll
6988
for (uint32_t j = 0; j < 4; j++) {
7089
shm[(threadIdx.y - 8) * 128 + threadIdx.x + j * 32] = warp_max[j];
7190
}
@@ -75,8 +94,9 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
7594
// reduce [(8), 32, 4] => [32, 4]
7695
for (uint32_t offset = 8; offset > 0; offset /= 2) {
7796
if (threadIdx.y < offset) {
97+
#pragma unroll
7898
for (uint32_t j = 0; j < 4; j++) {
79-
__nv_bfloat16 other =
99+
const __nv_bfloat16 other =
80100
offset == 8
81101
? warp_max[j]
82102
: shm[(threadIdx.y + offset) * 128 + threadIdx.x + j * 32];
@@ -85,7 +105,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
85105
if (offset > 1) {
86106
shm[threadIdx.y * 128 + threadIdx.x + j * 32] = next_val;
87107
} else {
88-
col_scale[threadIdx.x + j * 32] =
108+
scales[threadIdx.x + j * 32] =
89109
ComputeScale<__nv_bfloat16, __nv_fp8_e4m3, Pow2Scales>(
90110
static_cast<float>(next_val), 0.0f);
91111
}
@@ -98,7 +118,7 @@ __device__ void BlockColumnScale(const __nv_bfloat16 x[8][4],
98118
template <typename OutT, int VecSize>
99119
__device__ void BlockStoreScale(float* scale,
100120
size_t off_m,
101-
float col_scale[128],
121+
float scales[128],
102122
size_t K) {
103123
if (threadIdx.y < 4) {
104124
uint32_t off = threadIdx.y * 32 + threadIdx.x;
@@ -107,10 +127,10 @@ __device__ void BlockStoreScale(float* scale,
107127
} else if constexpr (VecSize == 2) {
108128
off = (off / 64) * 64 + (off % 2) * 32 + (off % 64) / 2;
109129
}
110-
float scale_out = 1.0f / col_scale[off];
111-
size_t idx_y = blockIdx.x - off_m / 128;
112-
size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x;
113-
size_t idx = idx_y * K + idx_x;
130+
float scale_out = 1.0f / scales[off];
131+
const size_t idx_y = blockIdx.x - off_m / 128;
132+
const size_t idx_x = blockIdx.y * 128 + threadIdx.y * 32 + threadIdx.x;
133+
const size_t idx = idx_y * K + idx_x;
114134
if (idx_x < K) {
115135
scale[idx] = scale_out;
116136
}
@@ -123,14 +143,16 @@ __device__ void BlockStoreOut(OutT* out,
123143
size_t cur_tokens,
124144
const OutT shm[128][129],
125145
size_t K) {
146+
#pragma unroll
126147
for (uint32_t i = 0; i < 8; i++) {
127-
size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4;
128-
size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16;
129-
size_t idx = idx_k * cur_tokens + (idx_m - off_m);
148+
const size_t idx_m = blockIdx.x * size_t(128) + threadIdx.x * 4;
149+
const size_t idx_k = blockIdx.y * 128 + threadIdx.y + i * 16;
150+
const size_t idx = idx_k * cur_tokens + (idx_m - off_m);
130151

131152
if (idx_k < K) {
132153
using StoreT = VecType<OutT, VecSize>;
133154
StoreT data;
155+
#pragma unroll
134156
for (uint32_t j = 0; j < VecSize; j++) {
135157
data[j] = shm[i * 16 + threadIdx.y][threadIdx.x * 4 + j];
136158
}
@@ -139,23 +161,27 @@ __device__ void BlockStoreOut(OutT* out,
139161
}
140162
}
141163

142-
template <typename OutT, bool Pow2Scales, int VecSize>
164+
template <typename InT, typename OutT, bool Pow2Scales, int VecSize>
143165
__global__ void __launch_bounds__(512)
144-
FusedTransposeSplitQuantKernel(const phi::bfloat16* __restrict__ input,
166+
FusedTransposeSplitQuantKernel(const InT* __restrict__ input,
167+
const float* __restrict__ input_scales,
145168
int64_t* __restrict__ meta,
146169
size_t num_experts,
147-
size_t K) {
170+
size_t K,
171+
size_t k_scaled) {
148172
__shared__ OutT shm[128][129];
173+
__shared__ size_t expert_info[2];
174+
__shared__ float scales[128]; // May be reused? Is it worthy?
175+
149176
int64_t* tokens_per_expert = meta;
150177
OutT** out_ptrs = reinterpret_cast<OutT**>(meta + num_experts);
151178
float** scale_ptrs = reinterpret_cast<float**>(meta + num_experts * 2);
152179

153180
// 1. Load 128x128 elements from input
154181
__nv_bfloat16 x[8][4];
155-
BlockLoad<VecSize>(input, x, K);
182+
BlockLoad<InT, VecSize>(input, input_scales, x, K, k_scaled);
156183

157184
// 2. Get expert index and offset of the current block
158-
__shared__ size_t expert_info[2];
159185
if (threadIdx.x == 0 && threadIdx.y == 0) {
160186
size_t idx_m = blockIdx.x * size_t(128);
161187
size_t off_m = 0, next_off_m = 0;
@@ -172,21 +198,23 @@ __global__ void __launch_bounds__(512)
172198
}
173199

174200
// 3. Calculate scale along the column
175-
__shared__ float col_scale[128];
176201
BlockColumnScale<Pow2Scales>(
177-
x, col_scale, reinterpret_cast<__nv_bfloat16*>(shm));
202+
x, scales, reinterpret_cast<__nv_bfloat16*>(shm));
178203

179204
// 4. Store scale
180205
const size_t expert_idx = expert_info[0];
181206
const size_t off_m = expert_info[1];
182-
BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, col_scale, K);
207+
BlockStoreScale<OutT, VecSize>(scale_ptrs[expert_idx], off_m, scales, K);
183208

184-
// 5. Scale x and save into shared memory with transposed layout
209+
// 5. Scale x and save into shared memory with transposed layout
210+
#pragma unroll
185211
for (uint32_t i = 0; i < 8; i++) {
212+
#pragma unroll
186213
for (uint32_t j = 0; j < 4; j += VecSize) {
214+
#pragma unroll
187215
for (uint32_t k = 0; k < VecSize; k++) {
188216
float x_fp32 = static_cast<float>(x[i][j + k]);
189-
float x_scaled = x_fp32 * col_scale[threadIdx.x + (j + k) * 32];
217+
float x_scaled = x_fp32 * scales[threadIdx.x + (j + k) * 32];
190218
shm[threadIdx.x * VecSize + j * 32 + k][i * 16 + threadIdx.y] =
191219
static_cast<OutT>(x_scaled);
192220
}
@@ -204,10 +232,11 @@ template <typename T, typename Context>
204232
void FusedTransposeSplitQuantKernel(
205233
const Context& dev_ctx,
206234
const DenseTensor& x,
235+
const paddle::optional<DenseTensor>& input_scales,
207236
const std::vector<int64_t>& tokens_per_expert,
208237
bool pow_2_scales,
209238
std::vector<DenseTensor*> outs,
210-
std::vector<DenseTensor*> scales) {
239+
std::vector<DenseTensor*> output_scales) {
211240
auto x_dims = x.dims();
212241
const int64_t M = x_dims[0];
213242
const int64_t K = x_dims[1];
@@ -221,8 +250,8 @@ void FusedTransposeSplitQuantKernel(
221250
if (outs[i] != nullptr) {
222251
dev_ctx.template Alloc<phi::dtype::float8_e4m3fn>(outs[i]);
223252
}
224-
if (scales[i] != nullptr) {
225-
dev_ctx.template Alloc<float>(scales[i]);
253+
if (output_scales[i] != nullptr) {
254+
dev_ctx.template Alloc<float>(output_scales[i]);
226255
}
227256
}
228257

@@ -245,8 +274,8 @@ void FusedTransposeSplitQuantKernel(
245274

246275
for (size_t i = 0; i < num_experts; i++) {
247276
meta_ptr[num_experts * 2 + i] =
248-
scales[i] != nullptr
249-
? reinterpret_cast<int64_t>(scales[i]->data<float>())
277+
output_scales[i] != nullptr
278+
? reinterpret_cast<int64_t>(output_scales[i]->data<float>())
250279
: 0;
251280
}
252281

@@ -255,23 +284,35 @@ void FusedTransposeSplitQuantKernel(
255284

256285
auto stream = dev_ctx.stream();
257286

258-
dim3 grid(M / 128, (K + 127) / 128);
287+
// pre-compute on CPU to reduce size_t division cost in kernel
288+
const size_t k_scaled = (K + 127) / 128;
289+
dim3 grid(M / 128, k_scaled);
259290
dim3 block(32, 16);
260291

261-
#define LAUNCH_KERNEL(POW_2_SCALES, VEC_SIZE) \
262-
FusedTransposeSplitQuantKernel<phi::dtype::float8_e4m3fn, \
263-
POW_2_SCALES, \
264-
VEC_SIZE> \
265-
<<<grid, block, 0, stream>>>(x.data<phi::dtype::bfloat16>(), \
266-
meta_gpu.data<int64_t>(), \
267-
num_experts, \
268-
K);
292+
#define DTYPE_CASE(dtype, type) dtype == phi::DataType::type
293+
#define LAUNCH_KERNEL(T, POW_2_SCALES, VEC_SIZE) \
294+
FusedTransposeSplitQuantKernel<T, \
295+
phi::dtype::float8_e4m3fn, \
296+
POW_2_SCALES, \
297+
VEC_SIZE><<<grid, block, 0, stream>>>( \
298+
x.data<T>(), \
299+
input_scales ? input_scales.get_ptr()->data<float>() : nullptr, \
300+
meta_gpu.data<int64_t>(), \
301+
num_experts, \
302+
K, \
303+
k_scaled);
304+
#define DISPATCH_DATATYPE(POW_2_SCALES, VEC_SIZE) \
305+
if (DTYPE_CASE(x.dtype(), BFLOAT16)) { \
306+
LAUNCH_KERNEL(phi::bfloat16, POW_2_SCALES, VEC_SIZE); \
307+
} else if (DTYPE_CASE(x.dtype(), FLOAT8_E4M3FN)) { \
308+
LAUNCH_KERNEL(phi::float8_e4m3fn, POW_2_SCALES, VEC_SIZE); \
309+
}
269310

270311
#define LAUNCH_KERNEL_PARTIAL(VEC_SIZE) \
271312
if (pow_2_scales) { \
272-
LAUNCH_KERNEL(true, VEC_SIZE); \
313+
DISPATCH_DATATYPE(true, VEC_SIZE); \
273314
} else { \
274-
LAUNCH_KERNEL(false, VEC_SIZE); \
315+
DISPATCH_DATATYPE(false, VEC_SIZE); \
275316
}
276317

277318
if (K % 4 == 0) {
@@ -296,7 +337,8 @@ PD_REGISTER_KERNEL(fused_transpose_split_quant,
296337
double,
297338
int,
298339
int64_t,
299-
phi::dtype::bfloat16) {
340+
phi::dtype::bfloat16,
341+
phi::dtype::float8_e4m3fn) {
300342
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT8_E4M3FN);
301343
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
302344
}

paddle/phi/ops/yaml/fused_ops.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,12 +916,13 @@
916916
support_dygraph_mode : true
917917

918918
- op: fused_transpose_split_quant
919-
args: (Tensor x, IntArray tokens_per_expert, bool pow_2_scales=false)
919+
args: (Tensor x, Tensor input_scales, IntArray tokens_per_expert, bool pow_2_scales=false)
920920
output: Tensor[](out){tokens_per_expert.size()}, Tensor[](scales){tokens_per_expert.size()}
921921
infer_meta:
922922
func: FusedTransposeSplitQuantInferMeta
923923
kernel:
924924
func: fused_transpose_split_quant
925+
optional: input_scales
925926
support_dygraph_mode : true
926927

927928
- op: fused_weighted_swiglu_act_quant

python/paddle/incubate/nn/functional/fp8.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ def fused_swiglu_weighted_bwd(
173173
return _C_ops.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs)
174174

175175

176-
def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False):
176+
def fused_transpose_split_quant(
177+
x, input_scales, tokens_per_expert, pow_2_scales=False
178+
):
177179
"""
178180
Applies fused transpose, split, and quantization operation for Mixture of Experts (MoE) models.
179181
@@ -215,7 +217,7 @@ def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False):
215217
>>> x = paddle.randn([384, 512], dtype='bfloat16')
216218
>>> x = paddle.clip(x, min=-50, max=50)
217219
>>> tokens_per_expert = [128, 128, 128]
218-
>>> outs, scales = F.fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=True)
220+
>>> outs, scales = F.fused_transpose_split_quant(x,None, tokens_per_expert, pow_2_scales=True)
219221
>>> print(outs[0].shape)
220222
[512, 128]
221223
>>> print(scales[0].shape)
@@ -228,7 +230,7 @@ def fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales=False):
228230

229231
if in_dynamic_or_pir_mode():
230232
return _C_ops.fused_transpose_split_quant(
231-
x, tokens_per_expert, pow_2_scales
233+
x, input_scales, tokens_per_expert, pow_2_scales
232234
)
233235

234236

0 commit comments

Comments
 (0)