Skip to content

Commit 69d3b2c

Browse files
soumithclaude
andcommitted
Fix CUDA stream synchronization in custom kernels
This commit fixes GitHub issue pytorch/pytorch#157363 where custom CUDA kernels were not properly synchronized with PyTorch's CUDA stream when used with torch.compile in reduce-overhead mode. Changes: - Add #include <ATen/cuda/CUDAContext.h> for getCurrentCUDAStream() - Use at::cuda::getCurrentCUDAStream() to get PyTorch's current CUDA stream - Launch all kernels with the correct stream parameter The issue occurred because custom kernels launched on the default CUDA stream while PyTorch operations (like nn.Linear) run on PyTorch's managed stream. This created race conditions where custom kernels would execute before PyTorch operations completed, resulting in incorrect output values. With this fix, all custom kernels are properly synchronized with PyTorch's CUDA stream, ensuring correct execution order and preventing race conditions when used with torch.compile. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 725e7a4 commit 69d3b2c

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

extension_cpp/csrc/cuda/muladd.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <cuda.h>
66
#include <cuda_runtime.h>
7+
#include <ATen/cuda/CUDAContext.h>
78

89
namespace extension_cpp {
910

@@ -26,7 +27,8 @@ at::Tensor mymuladd_cuda(const at::Tensor& a, const at::Tensor& b, double c) {
2627
float* result_ptr = result.data_ptr<float>();
2728

2829
int numel = a_contig.numel();
29-
muladd_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, c, result_ptr);
30+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
31+
muladd_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, c, result_ptr);
3032
return result;
3133
}
3234

@@ -48,7 +50,8 @@ at::Tensor mymul_cuda(const at::Tensor& a, const at::Tensor& b) {
4850
const float* b_ptr = b_contig.data_ptr<float>();
4951
float* result_ptr = result.data_ptr<float>();
5052
int numel = a_contig.numel();
51-
mul_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
53+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
54+
mul_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
5255
return result;
5356
}
5457

@@ -73,7 +76,8 @@ void myadd_out_cuda(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
7376
const float* b_ptr = b_contig.data_ptr<float>();
7477
float* result_ptr = out.data_ptr<float>();
7578
int numel = a_contig.numel();
76-
add_kernel<<<(numel+255)/256, 256>>>(numel, a_ptr, b_ptr, result_ptr);
79+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
80+
add_kernel<<<(numel+255)/256, 256, 0, stream>>>(numel, a_ptr, b_ptr, result_ptr);
7781
}
7882

7983

0 commit comments

Comments
 (0)