1212
1313#ifdef CUDA_AVAILABLE
1414#include < executorch/backends/aoti/slim/c10/cuda/Exception.h>
15- #include < executorch/backends/aoti/slim/ cuda/Guard .h>
15+ #include < executorch/backends/cuda/runtime/guard .h>
1616#endif
1717
1818#include < executorch/backends/aoti/slim/c10/core/Device.h>
@@ -87,24 +87,53 @@ struct DeviceTraits<c10::DeviceType::CPU> {
8787#ifdef CUDA_AVAILABLE
8888// / CUDA specialization of DeviceTraits.
8989// / Provides CUDA memory allocation and copy operations using
90- // / cudaMalloc/cudaFree.
90+ // / cudaMallocAsync/cudaFreeAsync with proper stream handling.
91+ // /
92+ // / IMPORTANT: Callers are expected to set the correct CUDA device and stream
93+ // / using CUDAStreamGuard before calling these methods. This is consistent
94+ // / with PyTorch's CUDACachingAllocator design pattern where the allocator
95+ // / assumes the caller has already set the correct device context.
9196template <>
9297struct DeviceTraits <c10::DeviceType::CUDA> {
93- // / Allocates CUDA device memory.
98+ // / Allocates CUDA device memory on the current stream.
99+ // / Uses cudaMallocAsync for asynchronous allocation on the stream
100+ // / that is currently set via CUDAStreamGuard, similar to how
101+ // / PyTorch's CUDACachingAllocator works.
102+ // /
103+ // / NOTE: Caller must ensure the correct device is already set via
104+ // / CUDAStreamGuard. This function does NOT create a device guard internally.
105+ // /
94106 // / @param nbytes Number of bytes to allocate.
95- // / @param device The target CUDA device.
107+ // / @param device The target CUDA device (used to get the stream) .
96108 // / @return Pointer to allocated device memory.
97109 static void * allocate (size_t nbytes, const c10::Device& device) {
98- cuda::CUDAGuard guard (device);
110+ // Get the current stream for this device (set by CUDAStreamGuard if any)
111+ // This follows PyTorch's pattern where the allocator assumes the caller
112+ // has already set the correct device via CUDAStreamGuard.
113+ auto stream_result =
114+ executorch::backends::cuda::getCurrentCUDAStream (device.index ());
115+ ET_CHECK_MSG (
116+ stream_result.ok (),
117+ " Failed to get current CUDA stream for device %d" ,
118+ static_cast <int >(device.index ()));
119+
120+ cudaStream_t stream = stream_result.get ();
99121 void * data = nullptr ;
100- ET_CUDA_CHECK (cudaMalloc (&data, nbytes));
122+ ET_CUDA_CHECK (cudaMallocAsync (&data, nbytes, stream ));
101123 return data;
102124 }
103125
104- // / Frees CUDA device memory.
126+ // / Frees CUDA device memory on the current stream .
105127 // / @param ptr Pointer to device memory to free.
106128 static void free (void * ptr) {
107- ET_CUDA_LOG_WARN (cudaFree (ptr));
129+ // Get the current stream for the current device
130+ auto stream_result = executorch::backends::cuda::getCurrentCUDAStream (-1 );
131+ if (stream_result.ok ()) {
132+ ET_CUDA_LOG_WARN (cudaFreeAsync (ptr, stream_result.get ()));
133+ } else {
134+ // Fallback to synchronous free if we can't get the stream
135+ ET_CUDA_LOG_WARN (cudaFree (ptr));
136+ }
108137 }
109138
110139 // / Copies memory between CPU and CUDA or CUDA and CUDA.
@@ -120,13 +149,11 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
120149 const c10::Device& dst_device,
121150 const c10::Device& src_device) {
122151 cudaMemcpyKind direction = cudaMemcpyDeviceToDevice;
123- c10::Device cuda_device = dst_device;
124152
125153 if (src_device.is_cpu ()) {
126154 direction = cudaMemcpyHostToDevice;
127155 } else if (dst_device.is_cpu ()) {
128156 direction = cudaMemcpyDeviceToHost;
129- cuda_device = src_device;
130157 } else {
131158 ET_CHECK_MSG (
132159 src_device.index () == dst_device.index (),
@@ -135,7 +162,6 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
135162 static_cast <int >(dst_device.index ()));
136163 }
137164
138- cuda::CUDAGuard guard (cuda_device);
139165 ET_CUDA_CHECK (cudaMemcpy (dst, src, nbytes, direction));
140166 }
141167};
0 commit comments