Skip to content

Commit 9bd202e

Browse files
authored
Add stride check for attn_mask on non-cpu device (pytorch#158618)
Add stride check for attn_mask on non-cpu device (pytorch#158424) Fixes pytorch#158374 Pull Request resolved: pytorch#158424 Approved by: https://github.com/Valentine233, https://github.com/drisspg, https://github.com/atalman
1 parent f2b69a0 commit 9bd202e

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

aten/src/ATen/native/transformers/sdp_utils_cpp.h

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <ATen/core/Tensor.h>
77
#include <ATen/core/grad_mode.h>
88
#include <ATen/native/DispatchStub.h>
9+
#include <c10/core/DeviceType.h>
910
#include <c10/core/ScalarType.h>
1011

1112
#include <c10/util/Exception.h>
@@ -503,17 +504,27 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool
503504
if (ignore_singleton_dim){
504505
qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
505506
}
506-
if (!qkv_strides_equal_1) {
507+
bool is_cpu = params.query.device().type() == c10::DeviceType::CPU;
508+
bool mask_stride_equal_1 = params.attn_mask.has_value()
509+
? params.attn_mask.value().sym_stride(-1) == 1
510+
: true;
511+
bool mask_stride_valid = is_cpu ? true : mask_stride_equal_1;
512+
if (!(qkv_strides_equal_1 && mask_stride_valid)) {
507513
if (debug) {
508-
TORCH_WARN(
509-
"All fused kernels require the last dimension of the input to have stride 1. ",
510-
"Got Query.stride(-1): ",
511-
params.query.sym_stride(-1),
512-
", Key.stride(-1): ",
513-
params.key.sym_stride(-1),
514-
", Value.stride(-1): ",
515-
params.value.sym_stride(-1),
516-
" instead.");
514+
std::ostringstream message;
515+
message
516+
<< "All fused kernels require the last dimension of the input to have stride 1. ";
517+
message << "Got Query.stride(-1): " << params.query.sym_stride(-1)
518+
<< ", Key.stride(-1): " << params.key.sym_stride(-1)
519+
<< ", Value.stride(-1): " << params.value.sym_stride(-1);
520+
521+
if (params.attn_mask.has_value()) {
522+
message
523+
<< ", Attn_mask.stride(-1): "
524+
<< params.attn_mask.value().sym_stride(-1)
525+
<< " (GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not).";
526+
}
527+
TORCH_WARN(message.str());
517528
}
518529

519530
return false;

test/inductor/test_fused_attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ def dot_prod_attention(
10231023
return attn_weights.matmul(value), key, value
10241024

10251025
tensor_shape = (4, 2, 16, 32)
1026-
attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device)
1026+
attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device)
10271027
args = [
10281028
torch.randn(tensor_shape, device=self.device),
10291029
torch.randn(tensor_shape, device=self.device),
@@ -1036,6 +1036,16 @@ def dot_prod_attention(
10361036
has_dropout=False,
10371037
check_train=False,
10381038
)
1039+
# test attn_mask with stride of last dim != 1
1040+
attn_mask_ = attn_mask.transpose(2, 3)
1041+
args[3] = attn_mask_
1042+
self._check_common(
1043+
dot_prod_attention,
1044+
args1=args,
1045+
has_dropout=False,
1046+
check_train=False,
1047+
contains=self.device == "cpu",
1048+
)
10391049

10401050
def _test_sdpa_rewriter_23(self):
10411051
def dot_prod_attention(

test/test_transformers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,34 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
16181618
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
16191619
q, k, v, None, 0.0, False))
16201620

1621+
@onlyCUDA
1622+
@unittest.skipIf(
1623+
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
1624+
or not PLATFORM_SUPPORTS_CUDNN_ATTENTION,
1625+
"Efficient or cuDNN Attention was not built for this system",
1626+
)
1627+
@parametrize("kernel", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION])
1628+
def test_mask_invalid_last_dim_stride(self, device, kernel):
1629+
with sdpa_kernel(backends=[kernel]):
1630+
dtype = torch.float16
1631+
make_tensor = partial(torch.rand, device=device, dtype=dtype)
1632+
size = SdpaShape(2, 2, 8, 8)
1633+
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
1634+
attn_mask = make_tensor((2, 2, 8, 8))
1635+
# Passing in a attn_mask with last dim stride not equal to 1 will error
1636+
attn_mask.as_strided_(size, [2, 2, 2, 2])
1637+
1638+
with self.assertWarnsRegex(
1639+
UserWarning,
1640+
"GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not",
1641+
):
1642+
self.assertRaises(
1643+
RuntimeError,
1644+
lambda: torch.nn.functional.scaled_dot_product_attention(
1645+
q, k, v, attn_mask, 0.0, False
1646+
),
1647+
)
1648+
16211649
@onlyCUDA
16221650
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
16231651
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])

0 commit comments

Comments
 (0)