@@ -110,6 +110,7 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
110
110
}
111
111
112
112
ffi::Error mha_bwd_impl (cudaStream_t stream, ffi::ScratchAllocator scratch,
113
+ int32_t device,
113
114
ffi::AnyBuffer dout, // batch_size x seqlen_q x num_heads x head_size_og
114
115
ffi::AnyBuffer q, // batch_size x seqlen_q x num_heads x head_size
115
116
ffi::AnyBuffer k, // batch_size x seqlen_k x num_heads_k x head_size
@@ -121,9 +122,8 @@ ffi::Error mha_bwd_impl(cudaStream_t stream, ffi::ScratchAllocator scratch,
121
122
ffi::Result<ffi::AnyBuffer> dv, // batch_size x seqlen_k x num_heads_k x head_size
122
123
double softmax_scale, bool is_causal,
123
124
int64_t window_size_left, int64_t window_size_right) {
124
- int device, major, minor, sm_count;
125
- FFI_CUDA_CHECK (cudaStreamGetDevice (stream, &device));
126
- FFI_CUDA_CHECK (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
125
+ int major, minor, sm_count;
126
+ FFI_CUDA_CHECK (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
127
127
FFI_CUDA_CHECK (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
128
128
FFI_CUDA_CHECK (cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device));
129
129
@@ -249,6 +249,7 @@ ffi::Error
249
249
mha_varlen_bwd_impl (
250
250
cudaStream_t stream,
251
251
ffi::ScratchAllocator scratch,
252
+ int32_t device,
252
253
ffi::AnyBuffer dout, // total_q x num_heads, x head_size
253
254
ffi::AnyBuffer q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
254
255
ffi::AnyBuffer k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -270,8 +271,7 @@ mha_varlen_bwd_impl(
270
271
bool deterministic) {
271
272
272
273
if (is_causal) { window_size_right = 0 ; }
273
- int device, major, minor, sm_count;
274
- FFI_CUDA_CHECK (cudaStreamGetDevice (stream, &device));
274
+ int major, minor, sm_count;
275
275
FFI_CUDA_CHECK (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
276
276
FFI_CUDA_CHECK (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
277
277
FFI_CUDA_CHECK (cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device));
0 commit comments