@@ -1559,23 +1559,18 @@ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async(
1559
1559
return false ;
1560
1560
}
1561
1561
1562
+ // need open both directions for memcpyasync between devices.
1563
+ ggml_cann_set_device (cann_ctx_dst->device );
1564
+ ACL_CHECK (aclrtDeviceEnablePeerAccess (cann_ctx_src->device , 0 ));
1562
1565
ggml_cann_set_device (cann_ctx_src->device );
1563
1566
ACL_CHECK (aclrtDeviceEnablePeerAccess (cann_ctx_dst->device , 0 ));
1567
+
1564
1568
ACL_CHECK (aclrtMemcpyAsync (dst->data , copy_size, src->data , copy_size,
1565
1569
ACL_MEMCPY_DEVICE_TO_DEVICE,
1566
- cann_ctx_dst->stream ()));
1567
-
1568
- // record event on src stream
1569
- if (!cann_ctx_src->copy_event ) {
1570
- ACL_CHECK (aclrtCreateEvent (&cann_ctx_src->copy_event ));
1571
- }
1572
-
1573
- ACL_CHECK (
1574
- aclrtRecordEvent (cann_ctx_src->copy_event , cann_ctx_src->stream ()));
1570
+ cann_ctx_src->stream ()));
1575
1571
1576
- // wait on dst stream for the copy to complete
1577
- ACL_CHECK (aclrtStreamWaitEvent (cann_ctx_dst->stream (),
1578
- cann_ctx_src->copy_event ));
1572
+ // TODO: workaround for Event didn`t work here.
1573
+ aclrtSynchronizeStream (cann_ctx_src->stream ());
1579
1574
} else {
1580
1575
// src and dst are on the same backend
1581
1576
ACL_CHECK (aclrtMemcpyAsync (dst->data , copy_size, src->data , copy_size,
@@ -1763,8 +1758,8 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1763
1758
*
1764
1759
* This function determines whether the CANN backend supports the given backend
1765
1760
* buffer type by comparing the device context of the backend and buffer type.
1766
- * It returns true if the device associated with the buffer type matches the
1767
- * device associated with the backend .
1761
+ * It returns true if the devices are same between the backend context and
1762
+ * buffer type context .
1768
1763
*
1769
1764
* @param backend Pointer to the CANN backend.
1770
1765
* @param buft Pointer to the backend buffer type to check.
@@ -1773,9 +1768,14 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1773
1768
*/
1774
1769
GGML_CALL static bool ggml_backend_cann_supports_buft (
1775
1770
ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1776
- return buft->iface .get_name == ggml_backend_cann_buffer_type_name;
1777
-
1778
- GGML_UNUSED (backend);
1771
+ if (ggml_backend_buft_is_cann (buft)) {
1772
+ ggml_backend_cann_context * cann_ctx =
1773
+ (ggml_backend_cann_context *)backend->context ;
1774
+ ggml_backend_cann_buffer_type_context * buft_ctx =
1775
+ (ggml_backend_cann_buffer_type_context *)buft->context ;
1776
+ return buft_ctx->device == cann_ctx->device ;
1777
+ }
1778
+ return false ;
1779
1779
}
1780
1780
1781
1781
/* *
0 commit comments