Skip to content

Commit 3049c78

Browse files
authored
WOQ: fix reference kernel for binary fusion with odd M (#3490)
1 parent 7ceed47 commit 3049c78

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

csrc/cpu/aten/utils/woq.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3454,11 +3454,21 @@ static at::Tensor woq_gemm_ref_impl(
34543454
at::silu_(y);
34553455
} else if (fusion_type == WOQ_FUSE_ADD || fusion_type == WOQ_FUSE_ADD_ADD) {
34563456
for (auto& tin : others_list) {
3457-
y = at::add(y, tin.view(y.sizes()));
3457+
auto tin_view = tin.view({-1, y.size(-1)});
3458+
if (tin_view.size(0) < y.size(0)) {
3459+
tin_view = at::pad(
3460+
tin_view, {0, 0, 0, y.size(0) - tin_view.size(0)}, "constant", 0);
3461+
}
3462+
y = at::add(y, tin_view);
34583463
}
34593464
} else if (fusion_type == WOQ_FUSE_MUL) {
34603465
for (auto& tin : others_list) {
3461-
y = at::mul(y, tin.view(y.sizes()));
3466+
auto tin_view = tin.view({-1, y.size(-1)});
3467+
if (tin_view.size(0) < y.size(0)) {
3468+
tin_view = at::pad(
3469+
tin_view, {0, 0, 0, y.size(0) - tin_view.size(0)}, "constant", 0);
3470+
}
3471+
y = at::mul(y, tin_view);
34623472
}
34633473
} else {
34643474
TORCH_CHECK(

tests/cpu/test_quantization_default_recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,7 @@ def forward(self, x, others):
12071207
]
12081208
bias_list = [False, True]
12091209
bf16_list = [False, True]
1210-
batch_size_list = [4, 1024]
1210+
batch_size_list = [4, 1024, 63]
12111211
cases = itertools.product(
12121212
weight_dtype_list, bias_list, bf16_list, batch_size_list
12131213
)

0 commit comments

Comments
 (0)