@@ -582,6 +582,111 @@ aoti_torch_copy_(Tensor* self, Tensor* src, int32_t non_blocking) {
582582 return Error::Ok;
583583}
584584
585+ AOTITorchError
586+ aoti_torch_copy_async (Tensor* self, Tensor* src, cudaStream_t stream) {
587+ // Check for null pointers first
588+ ET_CHECK_OR_RETURN_ERROR (
589+ self != nullptr ,
590+ InvalidArgument,
591+ " aoti_torch_copy_async failed: self tensor is null" );
592+
593+ ET_CHECK_OR_RETURN_ERROR (
594+ src != nullptr ,
595+ InvalidArgument,
596+ " aoti_torch_copy_async failed: src tensor is null" );
597+
598+ // Get dtype information and validate compatibility
599+ int32_t self_dtype, src_dtype;
600+ aoti_torch_get_dtype (self, &self_dtype);
601+ aoti_torch_get_dtype (src, &src_dtype);
602+
603+ ET_CHECK_OK_OR_RETURN_ERROR (validate_dtype (self_dtype));
604+ ET_CHECK_OK_OR_RETURN_ERROR (validate_dtype (src_dtype));
605+
606+ // Check dtype compatibility - both tensors must have the same dtype
607+ ET_CHECK_OR_RETURN_ERROR (
608+ self_dtype == src_dtype,
609+ InvalidArgument,
610+ " dtype mismatch. self.dtype=%d, src.dtype=%d. aoti_torch_copy_async requires same dtypes" ,
611+ self_dtype,
612+ src_dtype);
613+
614+ // Check total number of elements compatibility
615+ int64_t self_numel = self->numel ();
616+ int64_t src_numel = src->numel ();
617+
618+ ET_CHECK_OR_RETURN_ERROR (
619+ self_numel == src_numel,
620+ InvalidArgument,
621+ " numel mismatch. self.numel()=%ld, src.numel()=%ld" ,
622+ self_numel,
623+ src_numel);
624+
625+ // Get tensor metadata
626+ int64_t * self_strides;
627+ int64_t * src_strides;
628+ aoti_torch_get_strides (self, &self_strides);
629+ aoti_torch_get_strides (src, &src_strides);
630+
631+ // Check if tensors have the same strides (required for async copy)
632+ bool same_strides = true ;
633+ for (int i = 0 ; i < self->dim (); i++) {
634+ if (self_strides[i] != src_strides[i]) {
635+ same_strides = false ;
636+ break ;
637+ }
638+ }
639+
640+ ET_CHECK_OR_RETURN_ERROR (
641+ same_strides,
642+ InvalidArgument,
643+ " aoti_torch_copy_async requires tensors with same strides. Use aoti_torch_copy_ for non-contiguous tensors" );
644+
645+ // Determine device locations
646+ cudaPointerAttributes srcAttributes{};
647+ cudaPointerAttributes dstAttributes{};
648+
649+ ET_CUDA_CHECK_OR_RETURN_ERROR (
650+ cudaPointerGetAttributes (&srcAttributes, src->data_ptr ()));
651+
652+ ET_CUDA_CHECK_OR_RETURN_ERROR (
653+ cudaPointerGetAttributes (&dstAttributes, self->data_ptr ()));
654+
655+ bool srcIsDevice = srcAttributes.type == cudaMemoryTypeDevice;
656+ bool dstIsDevice = dstAttributes.type == cudaMemoryTypeDevice;
657+
658+ size_t total_bytes = src->nbytes ();
659+
660+ // Determine copy direction and perform async copy
661+ if (srcIsDevice && dstIsDevice) {
662+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpyAsync (
663+ self->mutable_data_ptr (),
664+ src->data_ptr (),
665+ total_bytes,
666+ cudaMemcpyDeviceToDevice,
667+ stream));
668+ } else if (srcIsDevice && !dstIsDevice) {
669+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpyAsync (
670+ self->mutable_data_ptr (),
671+ src->data_ptr (),
672+ total_bytes,
673+ cudaMemcpyDeviceToHost,
674+ stream));
675+ } else if (!srcIsDevice && dstIsDevice) {
676+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaMemcpyAsync (
677+ self->mutable_data_ptr (),
678+ src->data_ptr (),
679+ total_bytes,
680+ cudaMemcpyHostToDevice,
681+ stream));
682+ } else {
683+ // Host to host - use regular memcpy (no async benefit)
684+ std::memcpy (self->mutable_data_ptr (), src->data_ptr (), total_bytes);
685+ }
686+
687+ return Error::Ok;
688+ }
689+
585690AOTITorchError aoti_torch__reinterpret_tensor (
586691 Tensor* self,
587692 int64_t ndim,
0 commit comments