Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 0b952af

Browse files
authored
[Hardware][Intel] Support compressed-tensor W8A8 for CPU backend (vllm-project#7257)
1 parent 3b7fea7 commit 0b952af

File tree

18 files changed

+686
-43
lines changed

18 files changed

+686
-43
lines changed

.buildkite/run-cpu-test.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ docker exec cpu-test bash -c "
3030
--ignore=tests/models/test_jamba.py \
3131
--ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
3232

33+
# Run compressed-tensor test
34+
docker exec cpu-test bash -c "
35+
pytest -s -v \
36+
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
37+
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"
38+
3339
# online inference
3440
docker exec cpu-test bash -c "
3541
export VLLM_CPU_KVCACHE_SPACE=10

Dockerfile.cpu

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
FROM ubuntu:22.04 AS cpu-test-1
44

5+
ENV CCACHE_DIR=/root/.cache/ccache
6+
7+
ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
8+
59
RUN --mount=type=cache,target=/var/cache/apt \
610
apt-get update -y \
711
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
@@ -26,6 +30,19 @@ RUN --mount=type=cache,target=/root/.cache/pip \
2630
pip install --upgrade pip && \
2731
pip install -r requirements-build.txt
2832

33+
# install oneDNN
34+
RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git
35+
36+
RUN --mount=type=cache,target=/root/.cache/ccache \
37+
cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \
38+
-DONEDNN_BUILD_DOC=OFF \
39+
-DONEDNN_BUILD_EXAMPLES=OFF \
40+
-DONEDNN_BUILD_TESTS=OFF \
41+
-DONEDNN_BUILD_GRAPH=OFF \
42+
-DONEDNN_ENABLE_WORKLOAD=INFERENCE \
43+
-DONEDNN_ENABLE_PRIMITIVE=MATMUL && \
44+
cmake --build ./oneDNN/build --target install --config Release
45+
2946
FROM cpu-test-1 AS build
3047

3148
WORKDIR /workspace/vllm
@@ -41,7 +58,6 @@ COPY ./ ./
4158
ARG VLLM_CPU_DISABLE_AVX512
4259
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
4360

44-
ENV CCACHE_DIR=/root/.cache/ccache
4561
RUN --mount=type=cache,target=/root/.cache/pip \
4662
--mount=type=cache,target=/root/.cache/ccache \
4763
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \

cmake/cpu_extension.cmake

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
2+
set(CMAKE_CXX_STANDARD 17)
23

34
#
45
# Define environment variables for special configurations
@@ -83,12 +84,7 @@ endif()
8384

8485
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
8586

86-
list(APPEND LIBS "numa")
87-
88-
89-
#
90-
# Define extension targets
91-
#
87+
list(APPEND LIBS dnnl numa)
9288

9389
#
9490
# _C extension
@@ -102,6 +98,16 @@ set(VLLM_EXT_SRC
10298
"csrc/cpu/pos_encoding.cpp"
10399
"csrc/cpu/torch_bindings.cpp")
104100

101+
if (AVX512_FOUND AND NOT AVX512_DISABLED)
102+
set(VLLM_EXT_SRC
103+
"csrc/cpu/quant.cpp"
104+
${VLLM_EXT_SRC})
105+
endif()
106+
107+
#
108+
# Define extension targets
109+
#
110+
105111
define_gpu_extension_target(
106112
_C
107113
DESTINATION vllm

csrc/cpu/cpu_types_x86.hpp

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ namespace vec_op {
2424
#define CPU_KERNEL_GUARD_OUT(NAME)
2525
#else
2626
#define CPU_KERNEL_GUARD_IN(NAME) \
27-
std::cout << #NAME << " invoked." << std::endl;
28-
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
27+
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
28+
#define CPU_KERNEL_GUARD_OUT(NAME)
2929
#endif
3030

3131
#define FORCE_INLINE __attribute__((always_inline)) inline
@@ -106,6 +106,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
106106
explicit BF16Vec16(const FP32Vec16 &);
107107

108108
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
109+
110+
void save(void* ptr, const int elem_num) const {
111+
constexpr uint32_t M = 0xFFFFFFFF;
112+
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
113+
_mm256_mask_storeu_epi16(ptr, mask, reg);
114+
}
109115
};
110116

111117
#ifdef __AVX512F__
@@ -313,8 +319,28 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
313319
return FP32Vec16(_mm512_div_ps(reg, b.reg));
314320
}
315321

322+
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
323+
return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg)));
324+
}
325+
326+
FP32Vec16 max(const FP32Vec16& b) const {
327+
return FP32Vec16(_mm512_max_ps(reg, b.reg));
328+
}
329+
330+
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
331+
constexpr uint32_t M = 0xFFFFFFFF;
332+
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
333+
return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
334+
}
335+
336+
FP32Vec16 abs() const {
337+
return FP32Vec16(_mm512_abs_ps(reg));
338+
}
339+
316340
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
317341

342+
float reduce_max() const { return _mm512_reduce_max_ps(reg); }
343+
318344
template <int group_size> float reduce_sub_sum(int idx) {
319345
static_assert(VEC_ELEM_NUM % group_size == 0);
320346
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
@@ -323,6 +349,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
323349
}
324350

325351
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
352+
353+
void save(float* ptr, const int elem_num) const {
354+
constexpr uint32_t M = 0xFFFFFFFF;
355+
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
356+
_mm512_mask_storeu_ps(ptr, mask, reg);
357+
}
326358
};
327359
#else
328360
struct FP32Vec16 : public Vec<FP32Vec16> {
@@ -433,6 +465,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
433465
};
434466
#endif
435467

468+
#ifdef __AVX512F__
469+
struct INT8Vec16: public Vec<INT8Vec16> {
470+
constexpr static int VEC_ELEM_NUM = 16;
471+
union AliasReg {
472+
__m128i reg;
473+
int8_t values[VEC_ELEM_NUM];
474+
};
475+
476+
__m128i reg;
477+
478+
explicit INT8Vec16(const FP32Vec16& vec) : reg(
479+
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
480+
) {}
481+
482+
void save(int8_t* ptr) const {
483+
_mm_storeu_epi8(ptr, reg);
484+
}
485+
486+
void save(int8_t* ptr, const int elem_num) const {
487+
constexpr uint32_t M = 0xFFFFFFFF;
488+
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
489+
_mm_mask_storeu_epi8(ptr, mask, reg);
490+
}
491+
};
492+
#endif
493+
436494
template <typename T> struct VecType { using vec_type = void; };
437495

438496
template <typename T> using vec_t = typename VecType<T>::vec_type;

csrc/cpu/dnnl_helper.hpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#ifndef DNNL_HELPER_HPP
2+
#define DNNL_HELPER_HPP
3+
4+
#include <c10/util/BFloat16.h>
5+
6+
#include "oneapi/dnnl/dnnl.hpp"
7+
8+
namespace {
9+
template <typename T>
10+
struct DNNLType {
11+
static constexpr dnnl::memory::data_type type =
12+
dnnl::memory::data_type::undef;
13+
};
14+
15+
template <>
16+
struct DNNLType<int8_t> {
17+
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
18+
};
19+
20+
template <>
21+
struct DNNLType<int32_t> {
22+
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
23+
};
24+
25+
template <>
26+
struct DNNLType<float> {
27+
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
28+
};
29+
30+
template <>
31+
struct DNNLType<c10::BFloat16> {
32+
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
33+
};
34+
35+
template <typename T>
36+
constexpr inline dnnl::memory::data_type get_dnnl_type() {
37+
return DNNLType<std::decay_t<T>>::type;
38+
}
39+
}; // namespace
40+
41+
template <bool InputNoScale>
42+
class DNNLPrimitiveHelper {
43+
public:
44+
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
45+
// A: [M, K], row-major
46+
// B: [K, N], column-major
47+
// C: [M, N], row-major
48+
// bias: [N], row-major, optional
49+
// a_scales: [MS]
50+
// b_scales: [NS]
51+
// Note: Due to the limitation of oneDNN
52+
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
53+
// not supported.
54+
template <typename OutputT, typename BiasT>
55+
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
56+
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
57+
dnnl_dim_t K, const float* a_scales,
58+
const float* b_scales, dnnl_dim_t MS,
59+
dnnl_dim_t NS) {
60+
auto&& OutputType = get_dnnl_type<OutputT>();
61+
auto&& BiasType = get_dnnl_type<BiasT>();
62+
63+
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
64+
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
65+
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
66+
67+
dnnl::primitive_attr attr;
68+
if constexpr (!InputNoScale) {
69+
if (MS == 1) {
70+
// per-tensor
71+
attr.set_scales_mask(DNNL_ARG_SRC, 0);
72+
} else {
73+
// per-token
74+
TORCH_CHECK(false, "per-token quantization is unsupported.");
75+
}
76+
}
77+
78+
if (NS == 1) {
79+
// per-tensor
80+
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
81+
} else {
82+
// per-channel
83+
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
84+
}
85+
86+
dnnl::matmul::primitive_desc matmul_pd;
87+
if (bias) {
88+
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
89+
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
90+
bias_md, c_md, attr);
91+
} else {
92+
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
93+
c_md, attr);
94+
}
95+
dnnl::matmul matmul(matmul_pd);
96+
97+
auto& engine = default_engine();
98+
99+
dnnl::memory a_m(a_md, engine, (void*)a);
100+
dnnl::memory b_m(b_md, engine, (void*)b);
101+
dnnl::memory c_m(c_md, engine, (void*)c);
102+
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
103+
(void*)a_scales);
104+
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
105+
(void*)b_scales);
106+
107+
auto& stream = default_stream();
108+
if constexpr (InputNoScale) {
109+
if (bias) {
110+
dnnl::memory::desc bias_md({N}, BiasType, {1});
111+
dnnl::memory bias_m(bias_md, engine, (void*)bias);
112+
matmul.execute(
113+
stream, {
114+
{DNNL_ARG_SRC, a_m},
115+
{DNNL_ARG_WEIGHTS, b_m},
116+
{DNNL_ARG_BIAS, bias_m},
117+
{DNNL_ARG_DST, c_m},
118+
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
119+
});
120+
} else {
121+
matmul.execute(
122+
stream, {
123+
{DNNL_ARG_SRC, a_m},
124+
{DNNL_ARG_WEIGHTS, b_m},
125+
{DNNL_ARG_DST, c_m},
126+
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
127+
});
128+
}
129+
} else {
130+
if (bias) {
131+
dnnl::memory::desc bias_md({N}, BiasType, {1});
132+
dnnl::memory bias_m(bias_md, engine, (void*)bias);
133+
matmul.execute(
134+
stream, {
135+
{DNNL_ARG_SRC, a_m},
136+
{DNNL_ARG_WEIGHTS, b_m},
137+
{DNNL_ARG_BIAS, bias_m},
138+
{DNNL_ARG_DST, c_m},
139+
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
140+
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
141+
});
142+
} else {
143+
matmul.execute(
144+
stream, {
145+
{DNNL_ARG_SRC, a_m},
146+
{DNNL_ARG_WEIGHTS, b_m},
147+
{DNNL_ARG_DST, c_m},
148+
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
149+
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
150+
});
151+
}
152+
}
153+
stream.wait();
154+
}
155+
156+
private:
157+
static dnnl::engine& default_engine() {
158+
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
159+
return engine;
160+
}
161+
162+
static dnnl::stream& default_stream() {
163+
static dnnl::stream stream(default_engine());
164+
return stream;
165+
}
166+
};
167+
168+
#endif

0 commit comments

Comments
 (0)