Skip to content

Commit 49804aa

Browse files
authored
Cherry pick vLLM related bug fixes for 2.8 release branch (#5732)
1). FP8 GEMM support input shape [M, B, K] keep output shape as [M, B, N]. keep output stride similar as input. In some scenario, the input shape is [M, B, K] and stride is [K, M*K, 1]. We have to make the output stride is [N, M*N, 1] to keep consistency. 2). fix QWEN 32B int4 TP=8 bug When we run QWEN-32B int4 model with TP=8. The weight will be [80, 5120]. And the group_size of int4 gemm is 128 which is larger then gemm in_feature (80). So, the group_size is changed to 80 which is erroneous.
1 parent 4d90735 commit 49804aa

File tree

4 files changed

+32
-12
lines changed

4 files changed

+32
-12
lines changed

csrc/gpu/aten/operators/fp8/FP8Linear.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,15 @@ Tensor fp8_gemm_w8a16(
139139
TORCH_CHECK(false, "linear only support for 2D and 3D tensors!\n");
140140
}
141141

142-
at::Tensor result = at::empty(result_shape, A.options());
142+
// deal with input shape [m, b, k] stride [k, m * k, 1]
143+
auto k = A.size(A.dim() - 1);
144+
auto n = result_shape.back();
145+
auto res_stride = A.strides().vec();
146+
for (int i = 0; i < res_stride.size() - 1; i++) {
147+
res_stride[i] = res_stride[i] / k * n;
148+
}
149+
150+
at::Tensor result = at::empty_strided(result_shape, res_stride, A.options());
143151

144152
// check if nt format
145153
bool is_nt = true;

csrc/gpu/oneDNN/DnnlMatmulQuant.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ static inline void dnnl_matmul_w8a16_fp8(
738738

739739
const int m = std::reduce(
740740
src_sz.begin(), src_sz.end() - 1, 1, std::multiplies<int64_t>());
741-
const int n = o_sz[1]; // presume channel last format
741+
const int n = o_sz.back(); // presume channel last format
742742
const int k = *(src_sz.end() - 1);
743743

744744
// get device, engine, stream
@@ -791,11 +791,22 @@ static inline void dnnl_matmul_w8a16_fp8(
791791
tt = dnnl::trans_type_t::nt;
792792
}
793793

794-
int64_t lda = mat1.strides()[mat1.dim() - 2];
794+
// get lda ldb and ldc
795+
auto mat1_strides = mat1.strides();
796+
int64_t leading_dim = -1;
797+
if (mat1.dim() == 2) {
798+
leading_dim = 0;
799+
} else if (mat1.dim() == 3) {
800+
leading_dim = mat1_strides[0] < mat1_strides[1] ? 0 : 1;
801+
} else {
802+
TORCH_CHECK(
803+
false, "Unsupported input dimension for fp8 matmul: ", mat1.dim());
804+
}
805+
int64_t lda = mat1_strides[leading_dim];
795806
int64_t ldb = mat2.strides()[mat2.dim() - 1] == 1
796807
? mat2.strides()[mat2.dim() - 2]
797808
: mat2.strides()[mat2.dim() - 1];
798-
int64_t ldc = result.strides()[result.dim() - 2];
809+
int64_t ldc = result.strides()[leading_dim];
799810

800811
auto f_attr = [&](primitive_attr& pattr) {
801812
#ifdef USE_SCRATCHPAD_MODE

intel_extension_for_pytorch/nn/utils/_quantize_convert.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,7 @@ def __init__(
316316
self.double_quant_scale_dtype = double_quant_scale_dtype
317317
self.compute_dtype = compute_dtype
318318
self.compress_statistics = compress_statistics
319-
self.blocksize = (
320-
blocksize
321-
if blocksize != -1 and blocksize < self.in_features
322-
else self.in_features
323-
)
319+
self.blocksize = blocksize if blocksize != -1 else self.in_features
324320
self.scheme = scheme
325321
self.weight_dtype = weight_dtype
326322
self.device = device

tests/gpu/examples/test_fp8_linear_v2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,12 @@ def test_fp8_linear_v2(fp8_dtype, dtype, is_input_fp8, is_bias):
6666
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
6767
@pytest.mark.parametrize("is_bias", [True, False])
6868
@pytest.mark.parametrize("is_nt", [True, False])
69-
def test_fp8_linear_w8a16(fp8_dtype, dtype, is_bias, is_nt):
69+
@pytest.mark.parametrize("is_mbk", [True, False])
70+
def test_fp8_linear_w8a16(fp8_dtype, dtype, is_bias, is_nt, is_mbk):
7071
seed = 1234
7172
torch.manual_seed(seed)
7273

73-
input = torch.randn([8, 2], dtype=dtype, device=torch.device("xpu")) / 10.0
74+
input = torch.randn([1, 8, 2], dtype=dtype, device=torch.device("xpu")) / 10.0
7475
weight = torch.rand([3, 2], dtype=dtype).xpu() / 10.0
7576

7677
gemm_ref = torch.nn.Linear(2, 3, bias=is_bias).xpu().to(dtype)
@@ -105,10 +106,14 @@ def test_fp8_linear_w8a16(fp8_dtype, dtype, is_bias, is_nt):
105106
),
106107
)
107108

109+
if is_mbk:
110+
input = input.transpose(0, 1)
111+
108112
output_fp8 = fp8_linear(input, gemm_ref.bias.data.clone() if is_bias else None)
113+
output_fp8 = output_fp8.transpose(0, 1) if is_mbk else output_fp8
109114

110115
torch.testing.assert_close(output_fp8, output_ref, atol=1e-2, rtol=1e-2)
111116

112117

113118
if __name__ == "__main__":
114-
test_fp8_linear_w8a16(torch.float8_e5m2, torch.float16, True)
119+
test_fp8_linear_w8a16(torch.float8_e5m2, torch.float16, True, True, True)

0 commit comments

Comments
 (0)