Skip to content

Commit ae46c63

Browse files
committed
[WIP][CUDA backend]: Async copy between host<->device
1 parent 33ec615 commit ae46c63

File tree

3 files changed

+143
-6
lines changed

3 files changed

+143
-6
lines changed

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ class ET_EXPERIMENTAL CudaBackend final
200200
DelegateHandle* handle_,
201201
Span<EValue*> args) const override {
202202
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
203+
cudaStream_t stream = static_cast<cudaStream_t>(handle->cuda_stream);
203204

204205
size_t n_inputs;
205206
handle->get_num_inputs(handle->container_handle, &n_inputs);
@@ -251,11 +252,11 @@ class ET_EXPERIMENTAL CudaBackend final
251252

252253
gpu_inputs[i] = gpu_input_handle;
253254

254-
// Copy data from CPU to GPU
255+
// Async copy data from CPU to GPU
255256
ET_CHECK_OR_RETURN_ERROR(
256-
aoti_torch_copy_(gpu_inputs[i], cpu_tensor, 0) == Error::Ok,
257+
aoti_torch_copy_async(gpu_inputs[i], cpu_tensor, stream) == Error::Ok,
257258
Internal,
258-
"Failed to copy input %d from CPU to GPU",
259+
"Failed to async copy input %d from CPU to GPU",
259260
i);
260261
}
261262
// Process output tensors: create GPU counterparts for ExecuTorch CPU
@@ -288,6 +289,8 @@ class ET_EXPERIMENTAL CudaBackend final
288289
gpu_outputs[i] = gpu_output_handle;
289290
}
290291
// Run AOTI container with GPU tensors
292+
// Note: kernel is queued on the same stream as H2D copies,
293+
// so it will automatically wait for copies to complete
291294
AOTIRuntimeError error = handle->run(
292295
handle->container_handle,
293296
gpu_inputs.data(), // Use GPU input tensors
@@ -303,7 +306,7 @@ class ET_EXPERIMENTAL CudaBackend final
303306
"AOTInductorModelContainerRun failed with error code %d",
304307
error);
305308

306-
// Copy GPU output results back to CPU output tensors
309+
// Async copy GPU output results back to CPU output tensors
307310
for (int i = 0; i < n_outputs; i++) {
308311
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
309312
// For DYNAMIC_BOUND tensors we try to resize
@@ -312,11 +315,15 @@ class ET_EXPERIMENTAL CudaBackend final
312315
"Error resizing tensor at output index %d",
313316
i);
314317
ET_CHECK_OK_OR_RETURN_ERROR(
315-
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
316-
"Failed to copy GPU output %d back to CPU",
318+
aoti_torch_copy_async(cpu_output_tensor, gpu_outputs[i], stream),
319+
"Failed to async copy GPU output %d back to CPU",
317320
i);
318321
}
319322

323+
// Synchronize stream to ensure all async operations complete
324+
// before returning to the caller
325+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamSynchronize(stream));
326+
320327
return Error::Ok;
321328
}
322329

backends/cuda/runtime/shims/memory.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,111 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
582582
return Error::Ok;
583583
}
584584

585+
AOTITorchError
586+
aoti_torch_copy_async(Tensor* self, Tensor* src, cudaStream_t stream) {
587+
// Check for null pointers first
588+
ET_CHECK_OR_RETURN_ERROR(
589+
self != nullptr,
590+
InvalidArgument,
591+
"aoti_torch_copy_async failed: self tensor is null");
592+
593+
ET_CHECK_OR_RETURN_ERROR(
594+
src != nullptr,
595+
InvalidArgument,
596+
"aoti_torch_copy_async failed: src tensor is null");
597+
598+
// Get dtype information and validate compatibility
599+
int32_t self_dtype, src_dtype;
600+
aoti_torch_get_dtype(self, &self_dtype);
601+
aoti_torch_get_dtype(src, &src_dtype);
602+
603+
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(self_dtype));
604+
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(src_dtype));
605+
606+
// Check dtype compatibility - both tensors must have the same dtype
607+
ET_CHECK_OR_RETURN_ERROR(
608+
self_dtype == src_dtype,
609+
InvalidArgument,
610+
"dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_async requires same dtypes",
611+
self_dtype,
612+
src_dtype);
613+
614+
// Check total number of elements compatibility
615+
int64_t self_numel = self->numel();
616+
int64_t src_numel = src->numel();
617+
618+
ET_CHECK_OR_RETURN_ERROR(
619+
self_numel == src_numel,
620+
InvalidArgument,
621+
"numel mismatch. self.numel()=%ld, src.numel()=%ld",
622+
self_numel,
623+
src_numel);
624+
625+
// Get tensor metadata
626+
int64_t* self_strides;
627+
int64_t* src_strides;
628+
aoti_torch_get_strides(self, &self_strides);
629+
aoti_torch_get_strides(src, &src_strides);
630+
631+
// Check if tensors have the same strides (required for async copy)
632+
bool same_strides = true;
633+
for (int i = 0; i < self->dim(); i++) {
634+
if (self_strides[i] != src_strides[i]) {
635+
same_strides = false;
636+
break;
637+
}
638+
}
639+
640+
ET_CHECK_OR_RETURN_ERROR(
641+
same_strides,
642+
InvalidArgument,
643+
"aoti_torch_copy_async requires tensors with same strides. Use aoti_torch_copy_ for non-contiguous tensors");
644+
645+
// Determine device locations
646+
cudaPointerAttributes srcAttributes{};
647+
cudaPointerAttributes dstAttributes{};
648+
649+
ET_CUDA_CHECK_OR_RETURN_ERROR(
650+
cudaPointerGetAttributes(&srcAttributes, src->data_ptr()));
651+
652+
ET_CUDA_CHECK_OR_RETURN_ERROR(
653+
cudaPointerGetAttributes(&dstAttributes, self->data_ptr()));
654+
655+
bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
656+
bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;
657+
658+
size_t total_bytes = src->nbytes();
659+
660+
// Determine copy direction and perform async copy
661+
if (srcIsDevice && dstIsDevice) {
662+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync(
663+
self->mutable_data_ptr(),
664+
src->data_ptr(),
665+
total_bytes,
666+
cudaMemcpyDeviceToDevice,
667+
stream));
668+
} else if (srcIsDevice && !dstIsDevice) {
669+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync(
670+
self->mutable_data_ptr(),
671+
src->data_ptr(),
672+
total_bytes,
673+
cudaMemcpyDeviceToHost,
674+
stream));
675+
} else if (!srcIsDevice && dstIsDevice) {
676+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaMemcpyAsync(
677+
self->mutable_data_ptr(),
678+
src->data_ptr(),
679+
total_bytes,
680+
cudaMemcpyHostToDevice,
681+
stream));
682+
} else {
683+
// Host to host - use regular memcpy (no async benefit)
684+
std::memcpy(self->mutable_data_ptr(), src->data_ptr(), total_bytes);
685+
}
686+
687+
return Error::Ok;
688+
}
689+
585690
AOTITorchError aoti_torch__reinterpret_tensor(
586691
Tensor* self,
587692
int64_t ndim,

backends/cuda/runtime/shims/memory.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,31 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch__reinterpret_tensor(
140140
AOTI_SHIM_EXPORT AOTITorchError
141141
aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking);
142142

143+
/**
144+
* Asynchronously copies data from source tensor to destination tensor.
145+
*
146+
* This function performs an asynchronous memory copy between tensors using
147+
* cudaMemcpyAsync. The copy is queued on the specified CUDA stream and returns
148+
* immediately without waiting for completion. The caller must synchronize the
149+
* stream before accessing the destination data.
150+
*
151+
* Requirements:
152+
* - Both tensors must have the same dtype and number of elements
153+
* - Both tensors must be contiguous (same strides)
154+
* - For non-contiguous tensors, use aoti_torch_copy_ instead
155+
*
156+
* @param self Destination tensor (data will be overwritten)
157+
* @param src Source tensor (data will be copied from this tensor)
158+
* @param stream CUDA stream on which to queue the async copy
159+
*
160+
* @return Error::Ok on success, appropriate error code on failure:
161+
* - Error::InvalidArgument: null pointers, dtype mismatch, numel
162+
* mismatch, or non-contiguous tensors
163+
* - Error::Internal: CUDA operation failures
164+
*/
165+
AOTI_SHIM_EXPORT AOTITorchError
166+
aoti_torch_copy_async(Tensor* self, Tensor* src, cudaStream_t stream);
167+
143168
/**
144169
* Creates a new tensor handle from an existing one.
145170
*

0 commit comments

Comments
 (0)