Skip to content

Commit ee5610f

Browse files
malfetpytorchmergebot
authored andcommitted
[BE] Check that swizzle arguments are passed to the call (pytorch#167869)
Otherwise is causes null pointer deref Pull Request resolved: pytorch#167869 Approved by: https://github.com/slayton58, https://github.com/Skylion007 ghstack dependencies: pytorch#167868
1 parent d0e7d2e commit ee5610f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

aten/src/ATen/native/cuda/GroupedBlas.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ _scaled_grouped_mm_cuda_v2(
607607
// scale shape checks
608608
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
609609
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
610+
// swizze checks
611+
TORCH_CHECK_VALUE(swizzle_a_enum.size() == 1 && swizzle_b_enum.size() == 1, "Expected single swizzle argument");
610612
return _mx8_mx8_bf16_grouped_mm_fbgemm(
611613
mat_a,
612614
mat_b,

0 commit comments

Comments
 (0)