@@ -471,10 +471,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
471471 */
472472std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device (
473473 int device) {
474- if (device == 0 ) {
475- return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm (device));
476- }
477- return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg (device));
474+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm (device));
478475}
479476
480477// cann buffer
@@ -486,22 +483,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
486483 */
487484struct ggml_backend_cann_buffer_context {
488485 int32_t device; // /< The device ID associated with this buffer context.
489- ggml_cann_pool_alloc* alloc; // /< Pointer to the device memory allocated for the buffer.
486+ void * dev_ptr = nullptr ;
490487
491488 /* *
492489 * @brief Constructor to initialize the CANN buffer context.
493490 *
494491 * @param device The device ID associated with this buffer context.
495- * @param alloc Pointer to the device memory allocated for the buffer.
496492 */
497- ggml_backend_cann_buffer_context (int32_t device, ggml_cann_pool_alloc* alloc )
493+ ggml_backend_cann_buffer_context (int32_t device, void * dev_ptr )
498494 : device(device),
499- alloc (alloc ) {}
495+ dev_ptr (dev_ptr ) {}
500496
501497 /* *
502498 * @brief Destructor to free the device memory allocated for the buffer.
503499 */
504- ~ggml_backend_cann_buffer_context () { delete alloc; }
500+ ~ggml_backend_cann_buffer_context () { ACL_CHECK ( aclrtFree (dev_ptr)); }
505501};
506502
507503/* *
@@ -547,7 +543,7 @@ static void* ggml_backend_cann_buffer_get_base(
547543 ggml_backend_buffer_t buffer) {
548544 ggml_backend_cann_buffer_context* ctx =
549545 (ggml_backend_cann_buffer_context*)buffer->context ;
550- return ctx->alloc -> get () ;
546+ return ctx->dev_ptr ;
551547}
552548
553549/* *
@@ -954,7 +950,7 @@ static void ggml_backend_cann_buffer_clear(
954950 (ggml_backend_cann_buffer_context*)buffer->context ;
955951
956952 ggml_cann_set_device (ctx->device );
957- ACL_CHECK (aclrtMemset (ctx->alloc -> get () , buffer->size , value, buffer->size ));
953+ ACL_CHECK (aclrtMemset (ctx->dev_ptr , buffer->size , value, buffer->size ));
958954}
959955
960956/* *
@@ -1016,13 +1012,25 @@ static const char* ggml_backend_cann_buffer_type_name(
10161012static ggml_backend_buffer_t
10171013ggml_backend_cann_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft,
10181014 size_t size) {
1019- ggml_backend_cann_context* cann_ctx =
1020- (ggml_backend_cann_context*)buft->device ->context ;
1015+ ggml_backend_cann_buffer_type_context* buft_ctx =
1016+ (ggml_backend_cann_buffer_type_context*)buft->context ;
1017+
1018+ ggml_cann_set_device (buft_ctx->device );
10211019
1022- ggml_cann_pool_alloc* alloc = new ggml_cann_pool_alloc (cann_ctx->pool (), size);
1020+ size = std::max (size, (size_t )1 );
1021+
1022+ void * dev_ptr;
1023+ aclError err = aclrtMalloc (&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1024+ if (err != ACL_SUCCESS) {
1025+ GGML_LOG_ERROR (
1026+ " %s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n " ,
1027+ __func__, size / 1024.0 / 1024.0 , buft_ctx->device ,
1028+ aclGetRecentErrMsg ());
1029+ return nullptr ;
1030+ }
10231031
10241032 ggml_backend_cann_buffer_context* ctx =
1025- new ggml_backend_cann_buffer_context (cann_ctx ->device , alloc );
1033+ new ggml_backend_cann_buffer_context (buft_ctx ->device , dev_ptr );
10261034
10271035 return ggml_backend_buffer_init (buft, ggml_backend_cann_buffer_interface,
10281036 ctx, size);
0 commit comments