Skip to content

Commit cfa344d

Browse files
authored
【Metax】 Add profile and weight (#1860)
1 parent 56b15ef commit cfa344d

File tree

8 files changed

+1028
-17
lines changed

8 files changed

+1028
-17
lines changed

backends/metax_gpu/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ file(
165165
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
166166
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/decode_jpeg_kernel.cu
167167
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/nvjpeg.cc
168+
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cupti.cc
168169
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/embedding_kernel.cu
169170
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/embedding_grad_kernel.cu
170171
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/embedding_with_scaled_gradient_grad_kernel_register.cu
@@ -284,8 +285,6 @@ file(
284285
${CMAKE_SOURCE_DIR}/kernels/gpudnn/cudnn.cc
285286
${CMAKE_SOURCE_DIR}/kernels/cross_entropy_kernel_register.cu
286287
${CMAKE_SOURCE_DIR}/kernels/cross_entropy_grad_kernel_register.cu
287-
# ${CMAKE_SOURCE_DIR}/kernels/c_softmax_with_cross_entropy_kernel.cu
288-
# ${CMAKE_SOURCE_DIR}/kernels/c_softmax_with_cross_entropy_grad_kernel.cu
289288
${CMAKE_SOURCE_DIR}/kernels/layer_norm_kernel_register.cu
290289
${CMAKE_SOURCE_DIR}/kernels/layer_norm_grad_kernel_register.cu
291290
${CMAKE_SOURCE_DIR}/kernels/flash_attn_grad_kernel.cu
@@ -362,6 +361,7 @@ target_link_libraries(
362361
${PADDLE_CORE_LIB})
363362
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmccl.so)
364363
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcFlashAttn.so)
364+
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcpti.so)
365365
include_directories(BEFORE ${PADDLE_SOURCE_DIR})
366366

367367
target_compile_definitions(
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "mctlass/epilogue/thread/scale_type.h"
16+
#include "mctlass/half.h"
17+
#include "mctlass/layout/matrix.h"
18+
#include "mctlass/mctlass_ex.h"
19+
#include "paddle/phi/backends/gpu/gpu_context.h"
20+
#include "paddle/phi/common/datatype_traits.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/funcs/weight_only_gemv.h"
23+
#include "paddle/phi/kernels/weight_only_linear_kernel.h"
24+
25+
namespace phi {
26+
27+
template <typename T, typename Context>
28+
void WeightOnlyLinearKernel(const Context& dev_ctx,
29+
const DenseTensor& x,
30+
const DenseTensor& weight,
31+
const paddle::optional<DenseTensor>& bias,
32+
const DenseTensor& weight_scale,
33+
const std::string& weight_dtype,
34+
const int32_t arch,
35+
const int32_t group_size,
36+
DenseTensor* out) {
37+
dev_ctx.template Alloc<T>(out);
38+
const T* x_data = x.data<T>();
39+
const int8_t* weight_data = weight.data<int8_t>();
40+
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
41+
const T* weight_scale_data = weight_scale.data<T>();
42+
T* out_data = out->data<T>();
43+
const auto x_dims = x.dims();
44+
const auto w_dims = weight.dims();
45+
int n = group_size > 0 ? weight_scale.dims()[1] : weight_scale.dims()[0];
46+
int k = w_dims[1];
47+
int m = x.numel() / k;
48+
49+
using ElementA = maca_bfloat16;
50+
using ElementB_w8a16 = int8_t;
51+
using ElementB_w4a16 = uint8_t;
52+
using ElementC = maca_bfloat16;
53+
using ElementCompute = float;
54+
using ElementOutput = ElementC;
55+
using LayoutA = mctlass::layout::RowMajor;
56+
using LayoutB = mctlass::layout::ColumnMajor;
57+
using LayoutC = mctlass::layout::RowMajor;
58+
using ArchTag = mctlass::arch::Sm80;
59+
60+
using mctlassGemmScaleOp_w8a16_nobias =
61+
mctlassGemmScale<ElementA,
62+
LayoutA,
63+
ElementB_w8a16,
64+
LayoutB,
65+
ElementC,
66+
LayoutC,
67+
ElementCompute,
68+
ArchTag,
69+
mctlass::epilogue::thread::ScaleType::NoScaleAsBs>;
70+
71+
using mctlassGemmScaleOp_w8a16_bias =
72+
mctlassGemmScale<ElementA,
73+
LayoutA,
74+
ElementB_w8a16,
75+
LayoutB,
76+
ElementC,
77+
LayoutC,
78+
ElementCompute,
79+
ArchTag,
80+
mctlass::epilogue::thread::ScaleType::ScaleOnlyBias>;
81+
82+
using mctlassGemmScaleOp_w4a16_nobias =
83+
mctlassGemmScale<ElementA,
84+
LayoutA,
85+
ElementB_w4a16,
86+
LayoutB,
87+
ElementC,
88+
LayoutC,
89+
ElementCompute,
90+
ArchTag,
91+
mctlass::epilogue::thread::ScaleType::NoScaleAsBs>;
92+
93+
using mctlassGemmScaleOp_w4a16_bias =
94+
mctlassGemmScale<ElementA,
95+
LayoutA,
96+
ElementB_w4a16,
97+
LayoutB,
98+
ElementC,
99+
LayoutC,
100+
ElementCompute,
101+
ArchTag,
102+
mctlass::epilogue::thread::ScaleType::ScaleOnlyBias>;
103+
104+
mctlass::gemm::GemmCoord problem_size(m, n, k);
105+
106+
if (weight_dtype == "int8") {
107+
if (bias_data == nullptr) {
108+
mctlassGemmScaleOp_w8a16_nobias mctlass_op;
109+
typename mctlassGemmScaleOp_w8a16_nobias::Arguments arguments{
110+
mctlass::gemm::GemmUniversalMode::kGemmQuantB,
111+
problem_size,
112+
1,
113+
mctlassGemmScaleOp_w8a16_nobias::epilogueParams(
114+
reinterpret_cast<const maca_bfloat16*>(bias_data)),
115+
mctlassGemmScaleOp_w8a16_nobias::quantscaleParams(
116+
1,
117+
group_size,
118+
reinterpret_cast<const maca_bfloat16*>(weight_scale_data)),
119+
reinterpret_cast<const maca_bfloat16*>(x_data),
120+
weight_data,
121+
reinterpret_cast<const maca_bfloat16*>(out_data),
122+
out_data,
123+
m * k,
124+
n * k,
125+
m * n,
126+
m * n,
127+
k,
128+
k,
129+
n,
130+
n};
131+
mctlass_op(arguments);
132+
} else {
133+
mctlassGemmScaleOp_w8a16_bias mctlass_op;
134+
typename mctlassGemmScaleOp_w8a16_bias::Arguments arguments{
135+
mctlass::gemm::GemmUniversalMode::kGemmQuantB,
136+
problem_size,
137+
1,
138+
mctlassGemmScaleOp_w8a16_bias::epilogueParams(
139+
reinterpret_cast<const maca_bfloat16*>(bias_data)),
140+
mctlassGemmScaleOp_w8a16_bias::quantscaleParams(
141+
1,
142+
group_size,
143+
reinterpret_cast<const maca_bfloat16*>(weight_scale_data)),
144+
reinterpret_cast<const maca_bfloat16*>(x_data),
145+
weight_data,
146+
reinterpret_cast<const maca_bfloat16*>(out_data),
147+
out_data,
148+
m * k,
149+
n * k,
150+
m * n,
151+
m * n,
152+
k,
153+
k,
154+
n,
155+
n};
156+
mctlass_op(arguments);
157+
}
158+
} else if (weight_dtype == "int4") {
159+
if (bias_data == nullptr) {
160+
mctlassGemmScaleOp_w4a16_nobias mctlass_op;
161+
typename mctlassGemmScaleOp_w4a16_nobias::Arguments arguments{
162+
mctlass::gemm::GemmUniversalMode::kGemmQuantB,
163+
problem_size,
164+
1,
165+
mctlassGemmScaleOp_w4a16_nobias::epilogueParams(
166+
reinterpret_cast<const maca_bfloat16*>(bias_data)),
167+
mctlassGemmScaleOp_w4a16_nobias::quantscaleParams(
168+
1,
169+
group_size,
170+
reinterpret_cast<const maca_bfloat16*>(weight_scale_data)),
171+
reinterpret_cast<const maca_bfloat16*>(x_data),
172+
weight_data,
173+
reinterpret_cast<const maca_bfloat16*>(out_data),
174+
out_data,
175+
m * k,
176+
n * k,
177+
m * n,
178+
m * n,
179+
k,
180+
k,
181+
n,
182+
n};
183+
mctlass_op(arguments);
184+
} else {
185+
mctlassGemmScaleOp_w4a16_bias mctlass_op;
186+
typename mctlassGemmScaleOp_w4a16_bias::Arguments arguments{
187+
mctlass::gemm::GemmUniversalMode::kGemmQuantB,
188+
problem_size,
189+
1,
190+
mctlassGemmScaleOp_w4a16_bias::epilogueParams(
191+
reinterpret_cast<const maca_bfloat16*>(bias_data)),
192+
mctlassGemmScaleOp_w4a16_bias::quantscaleParams(
193+
1,
194+
group_size,
195+
reinterpret_cast<const maca_bfloat16*>(weight_scale_data)),
196+
reinterpret_cast<const maca_bfloat16*>(x_data),
197+
weight_data,
198+
reinterpret_cast<const maca_bfloat16*>(out_data),
199+
out_data,
200+
m * k,
201+
n * k,
202+
m * n,
203+
m * n,
204+
k,
205+
k,
206+
n,
207+
n};
208+
mctlass_op(arguments);
209+
}
210+
}
211+
}
212+
} // namespace phi
213+
214+
PD_REGISTER_PLUGIN_KERNEL(weight_only_linear,
215+
metax_gpu,
216+
ALL_LAYOUT,
217+
phi::WeightOnlyLinearKernel,
218+
phi::dtype::float16,
219+
phi::dtype::bfloat16) {}

backends/metax_gpu/kernels/cuda_kernels/weight_quantize_kernel_register.cu

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,152 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
#include "paddle/common/enforce.h"
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/common/datatype_traits.h"
17+
#include "paddle/phi/core/dense_tensor.h"
1418
#include "paddle/phi/core/kernel_registry.h"
15-
#include "paddle/phi/kernels/gpu/weight_quantize_kernel.cu" // NOLINT
19+
#include "paddle/phi/kernels/funcs/common_shape.h"
20+
#include "paddle/phi/kernels/funcs/math_function.h"
1621
#include "paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h"
1722

18-
PD_CUSTOM_KERNEL_REGISTER(weight_quantize,
23+
namespace phi {
24+
25+
template <typename T, typename Context>
26+
void WeightQuantizeKernel(const Context& dev_ctx,
27+
const DenseTensor& x,
28+
const std::string& algo,
29+
const int32_t arch,
30+
const int32_t group_size,
31+
DenseTensor* out,
32+
DenseTensor* scale) {
33+
PADDLE_ENFORCE_EQ(
34+
((group_size == -1) || (group_size == 64) || (group_size == 128)),
35+
true,
36+
common::errors::InvalidArgument(
37+
"Currently, group_size only support -1(per-channel), 64 or 128."));
38+
39+
const int64_t m = x.dims()[0];
40+
const int64_t n = x.dims()[1];
41+
PADDLE_ENFORCE_LE(
42+
m,
43+
std::numeric_limits<int>::max(),
44+
common::errors::InvalidArgument(
45+
"Currently only supports x.shape[0] <= INT_MAX, but got %d", m));
46+
47+
DenseTensor quanted_x;
48+
dev_ctx.template Alloc<int8_t>(out);
49+
if (out->numel() == 0) {
50+
if (algo == "llm.int8") {
51+
dev_ctx.template Alloc<float>(scale);
52+
} else {
53+
dev_ctx.template Alloc<T>(scale);
54+
}
55+
return;
56+
}
57+
quanted_x.Resize({m, n});
58+
dev_ctx.template Alloc<int8_t>(&quanted_x);
59+
std::vector<int64_t> weight_shape{m, n};
60+
#ifndef PADDLE_WITH_HIP
61+
PADDLE_ENFORCE_EQ(
62+
((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) ||
63+
(arch == 89) || (arch == 90)),
64+
true,
65+
common::errors::InvalidArgument(
66+
"Currently, arch only support 70, 75, 80, 86, 89, 90."));
67+
#endif
68+
if (algo == "llm.int8") {
69+
dev_ctx.template Alloc<float>(scale);
70+
std::vector<int> axis = {1, 0};
71+
funcs::Transpose<Context, int8_t, 2> trans;
72+
weight_quant_gpu<T, Context>(dev_ctx,
73+
x.data<T>(),
74+
quanted_x.data<int8_t>(),
75+
scale->data<float>(),
76+
weight_shape,
77+
arch,
78+
algo);
79+
trans(dev_ctx, quanted_x, out, axis);
80+
} else if (algo == "weight_only_int8") {
81+
dev_ctx.template Alloc<T>(scale);
82+
83+
if (std::is_same<T, int8_t>::value) {
84+
// Zkk: you are loading already quantized weight, so we skip doing
85+
// quantize. and just copy!
86+
#ifdef PADDLE_WITH_CUDA
87+
cudaMemcpy(quanted_x.data<int8_t>(),
88+
x.data<T>(),
89+
x.numel(),
90+
cudaMemcpyDeviceToDevice);
91+
#endif
92+
} else {
93+
weight_quant_gpu<T, Context>(dev_ctx,
94+
x.data<T>(),
95+
out->data<int8_t>(),
96+
scale->data<T>(),
97+
weight_shape,
98+
arch,
99+
algo);
100+
}
101+
out->Resize({m, n});
102+
#ifdef PADDLE_WITH_HIP
103+
std::vector<int> axis = {1, 0};
104+
funcs::Transpose<Context, int8_t, 2> trans;
105+
trans(dev_ctx, quanted_x, out, axis);
106+
// #else
107+
// weight_permute_gpu<Context>(dev_ctx,
108+
// quanted_x.data<int8_t>(),
109+
// out->data<int8_t>(),
110+
// weight_shape,
111+
// arch,
112+
// algo);
113+
#endif
114+
} else if (algo == "weight_only_int4") {
115+
dev_ctx.template Alloc<T>(scale);
116+
weight_quant_gpu<T, Context>(dev_ctx,
117+
x.data<T>(),
118+
quanted_x.data<int8_t>(),
119+
scale->data<T>(),
120+
weight_shape,
121+
arch,
122+
algo);
123+
#ifdef PADDLE_WITH_HIP
124+
DenseTensor x_int_tmp(out->type());
125+
x_int_tmp.Resize({m, n / 2});
126+
dev_ctx.template Alloc<int8_t>(&x_int_tmp);
127+
int8_t* x_int_tmp_data = x_int_tmp.data<int8_t>();
128+
int8_t* quanted_x_data = quanted_x.data<int8_t>();
129+
for (int i = 0; i < out->numel(); ++i) {
130+
x_int_tmp_data[i] = quanted_x_data[i];
131+
}
132+
std::vector<int> axis = {1, 0};
133+
funcs::Transpose<Context, int8_t, 2> trans;
134+
trans(dev_ctx, x_int_tmp, out, axis);
135+
#else
136+
weight_permute_gpu<Context>(dev_ctx,
137+
quanted_x.data<int8_t>(),
138+
out->data<int8_t>(),
139+
weight_shape,
140+
arch,
141+
algo);
142+
#endif
143+
} else if (algo == "w4a8") {
144+
weight_permute_gpu_w4a8<Context>(dev_ctx,
145+
x.data<int8_t>(),
146+
out->data<int8_t>(),
147+
weight_shape,
148+
arch,
149+
algo);
150+
} else {
151+
PADDLE_FATAL(
152+
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
153+
"'llm.int8', 'w4a8'], but got[%s]",
154+
algo);
155+
}
156+
}
157+
} // namespace phi
158+
159+
PD_REGISTER_PLUGIN_KERNEL(weight_quantize,
19160
metax_gpu,
20161
ALL_LAYOUT,
21162
phi::WeightQuantizeKernel,

0 commit comments

Comments
 (0)