@@ -200,6 +200,7 @@ class ET_EXPERIMENTAL CudaBackend final
200200 DelegateHandle* handle_,
201201 Span<EValue*> args) const override {
202202 AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
203+ cudaStream_t stream = static_cast <cudaStream_t>(handle->cuda_stream );
203204
204205 size_t n_inputs;
205206 handle->get_num_inputs (handle->container_handle , &n_inputs);
@@ -215,106 +216,147 @@ class ET_EXPERIMENTAL CudaBackend final
215216 n_outputs,
216217 args.size ())
217218
218- // NOTE: ExecuTorch tensors are always on CPU/host memory
219- // We need to create GPU copies for CUDA kernel execution
220- std::vector<AOTITensorHandle> gpu_inputs (
221- n_inputs); // GPU copies for kernel execution
222- std::vector<AOTITensorHandle> gpu_outputs (
223- n_outputs); // GPU tensors for kernel output
219+ // NOTE: ExecuTorch tensors are always on CPU/host memory (pageable)
220+ // We use pinned staging buffers for efficient async transfers:
221+ // CPU (pageable) -> Pinned -> GPU -> Pinned -> CPU (pageable)
222+ std::vector<AOTITensorHandle> pinned_inputs (n_inputs);
223+ std::vector<AOTITensorHandle> gpu_inputs (n_inputs);
224+ std::vector<AOTITensorHandle> gpu_outputs (n_outputs);
225+ std::vector<AOTITensorHandle> pinned_outputs (n_outputs);
224226
225- // Process input tensors: ExecuTorch provides CPU tensors, create GPU
226- // copies
227+ // Process input tensors: create pinned staging buffers and GPU tensors
227228 for (int i = 0 ; i < n_inputs; i++) {
228- // Get tensor dimensions and properties from ExecuTorch CPU tensor
229229 auto cpu_tensor = &(args[i]->toTensor ());
230230 auto sizes = cpu_tensor->sizes ();
231231 auto scalar_type = cpu_tensor->scalar_type ();
232-
233- // Create GPU tensor with same shape
234232 std::vector<int64_t > sizes_vec (sizes.begin (), sizes.end ());
235233
236- AOTITensorHandle gpu_input_handle;
237- Error create_err = aoti_torch_empty_strided (
238- sizes_vec.size (),
239- sizes_vec.data (),
240- nullptr , // use default strides
241- static_cast <int32_t >(scalar_type),
242- 1 , // device_type = cuda
243- 0 , // device_index = 0
244- &gpu_input_handle);
234+ // Create pinned staging buffer
235+ AOTITensorHandle pinned_input_handle;
236+ ET_CHECK_OR_RETURN_ERROR (
237+ aoti_torch_empty_strided_pinned (
238+ sizes_vec.size (),
239+ sizes_vec.data (),
240+ nullptr , // use default strides
241+ static_cast <int32_t >(scalar_type),
242+ &pinned_input_handle) == Error::Ok,
243+ Internal,
244+ " Failed to create pinned staging buffer for input %d" ,
245+ i);
246+ pinned_inputs[i] = pinned_input_handle;
245247
248+ // Create GPU tensor
249+ AOTITensorHandle gpu_input_handle;
246250 ET_CHECK_OR_RETURN_ERROR (
247- create_err == Error::Ok,
251+ aoti_torch_empty_strided (
252+ sizes_vec.size (),
253+ sizes_vec.data (),
254+ nullptr , // use default strides
255+ static_cast <int32_t >(scalar_type),
256+ 1 , // device_type = cuda
257+ 0 , // device_index = 0
258+ &gpu_input_handle) == Error::Ok,
248259 Internal,
249260 " Failed to create GPU tensor for input %d" ,
250261 i);
251-
252262 gpu_inputs[i] = gpu_input_handle;
253263
254- // Copy data from CPU to GPU
264+ // Copy from ExecuTorch CPU to pinned buffer (fast memcpy)
265+ std::memcpy (
266+ pinned_inputs[i]->mutable_data_ptr (),
267+ cpu_tensor->data_ptr (),
268+ cpu_tensor->nbytes ());
269+
270+ // Async copy from pinned to GPU (truly async with DMA)
255271 ET_CHECK_OR_RETURN_ERROR (
256- aoti_torch_copy_ (gpu_inputs[i], cpu_tensor, 0 ) == Error::Ok,
272+ aoti_torch_copy_async (gpu_inputs[i], pinned_inputs[i], stream) ==
273+ Error::Ok,
257274 Internal,
258- " Failed to copy input %d from CPU to GPU" ,
275+ " Failed to async copy input %d from pinned to GPU" ,
259276 i);
260277 }
261- // Process output tensors: create GPU counterparts for ExecuTorch CPU
262- // tensors
278+
279+ // Process output tensors: create GPU tensors and pinned staging buffers
263280 for (int i = 0 ; i < n_outputs; i++) {
264- // Get output tensor dimensions from ExecuTorch CPU tensor
265281 auto cpu_output_tensor = &(args[i + n_inputs]->toTensor ());
266282 auto sizes = cpu_output_tensor->sizes ();
267283 auto scalar_type = cpu_output_tensor->scalar_type ();
268-
269- // Create GPU tensor with same shape for kernel output
270284 std::vector<int64_t > sizes_vec (sizes.begin (), sizes.end ());
271285
286+ // Create GPU tensor for kernel output
272287 AOTITensorHandle gpu_output_handle;
273- Error create_err = aoti_torch_empty_strided (
274- sizes_vec.size (),
275- sizes_vec.data (),
276- nullptr , // use default strides
277- static_cast <int32_t >(scalar_type),
278- 1 , // device_type = cuda
279- 0 , // device_index = 0
280- &gpu_output_handle);
281-
282288 ET_CHECK_OR_RETURN_ERROR (
283- create_err == Error::Ok,
289+ aoti_torch_empty_strided (
290+ sizes_vec.size (),
291+ sizes_vec.data (),
292+ nullptr , // use default strides
293+ static_cast <int32_t >(scalar_type),
294+ 1 , // device_type = cuda
295+ 0 , // device_index = 0
296+ &gpu_output_handle) == Error::Ok,
284297 Internal,
285298 " Failed to create GPU tensor for output %d" ,
286299 i);
287-
288300 gpu_outputs[i] = gpu_output_handle;
301+
302+ // Create pinned staging buffer for output
303+ AOTITensorHandle pinned_output_handle;
304+ ET_CHECK_OR_RETURN_ERROR (
305+ aoti_torch_empty_strided_pinned (
306+ sizes_vec.size (),
307+ sizes_vec.data (),
308+ nullptr , // use default strides
309+ static_cast <int32_t >(scalar_type),
310+ &pinned_output_handle) == Error::Ok,
311+ Internal,
312+ " Failed to create pinned staging buffer for output %d" ,
313+ i);
314+ pinned_outputs[i] = pinned_output_handle;
289315 }
316+
290317 // Run AOTI container with GPU tensors
318+ // Note: kernel is queued on the same stream as H2D copies,
319+ // so it will automatically wait for copies to complete
291320 AOTIRuntimeError error = handle->run (
292321 handle->container_handle ,
293- gpu_inputs.data (), // Use GPU input tensors
322+ gpu_inputs.data (),
294323 n_inputs,
295- gpu_outputs.data (), // Use GPU output tensors
324+ gpu_outputs.data (),
296325 n_outputs,
297- handle->cuda_stream , // Pass the actual CUDA stream
298- nullptr ); // proxy_executor_handle can remain nullptr
326+ handle->cuda_stream ,
327+ nullptr );
299328
300329 ET_CHECK_OR_RETURN_ERROR (
301330 error == Error::Ok,
302331 Internal,
303332 " AOTInductorModelContainerRun failed with error code %d" ,
304333 error);
305334
306- // Copy GPU output results back to CPU output tensors
335+ // Async copy GPU outputs to pinned staging buffers (truly async with DMA)
336+ for (int i = 0 ; i < n_outputs; i++) {
337+ ET_CHECK_OR_RETURN_ERROR (
338+ aoti_torch_copy_async (pinned_outputs[i], gpu_outputs[i], stream) ==
339+ Error::Ok,
340+ Internal,
341+ " Failed to async copy GPU output %d to pinned buffer" ,
342+ i);
343+ }
344+
345+ // Synchronize stream to ensure all async operations complete
346+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaStreamSynchronize (stream));
347+
348+ // Copy from pinned buffers to ExecuTorch CPU output tensors (fast memcpy)
307349 for (int i = 0 ; i < n_outputs; i++) {
308350 auto cpu_output_tensor = &(args[i + n_inputs]->toTensor ());
309351 // For DYNAMIC_BOUND tensors we try to resize
310352 ET_CHECK_OK_OR_RETURN_ERROR (
311353 resize_tensor (*cpu_output_tensor, gpu_outputs[i]->sizes ()),
312354 " Error resizing tensor at output index %d" ,
313355 i);
314- ET_CHECK_OK_OR_RETURN_ERROR (
315- aoti_torch_copy_ ( cpu_output_tensor, gpu_outputs[i], 0 ),
316- " Failed to copy GPU output %d back to CPU " ,
317- i );
356+ std::memcpy (
357+ cpu_output_tensor-> mutable_data_ptr ( ),
358+ pinned_outputs[i]-> data_ptr () ,
359+ pinned_outputs[i]-> nbytes () );
318360 }
319361
320362 return Error::Ok;
0 commit comments