Skip to content

Commit 1c79893

Browse files
committed
some modifications after review
1 parent 58652e4 commit 1c79893

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
471471
*/
472472
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
473473
int device) {
474-
if (device == 0) {
475-
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
476-
}
477-
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
474+
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
478475
}
479476

480477
// cann buffer
@@ -486,22 +483,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
486483
*/
487484
struct ggml_backend_cann_buffer_context {
488485
int32_t device; ///< The device ID associated with this buffer context.
489-
ggml_cann_pool_alloc* alloc; ///< Pointer to the device memory allocated for the buffer.
486+
void* dev_ptr = nullptr;
490487

491488
/**
492489
* @brief Constructor to initialize the CANN buffer context.
493490
*
494491
* @param device The device ID associated with this buffer context.
495-
* @param alloc Pointer to the device memory allocated for the buffer.
496492
*/
497-
ggml_backend_cann_buffer_context(int32_t device, ggml_cann_pool_alloc* alloc)
493+
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
498494
: device(device),
499-
alloc(alloc) {}
495+
dev_ptr(dev_ptr) {}
500496

501497
/**
502498
* @brief Destructor to free the device memory allocated for the buffer.
503499
*/
504-
~ggml_backend_cann_buffer_context() { delete alloc; }
500+
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr));}
505501
};
506502

507503
/**
@@ -547,7 +543,7 @@ static void* ggml_backend_cann_buffer_get_base(
547543
ggml_backend_buffer_t buffer) {
548544
ggml_backend_cann_buffer_context* ctx =
549545
(ggml_backend_cann_buffer_context*)buffer->context;
550-
return ctx->alloc->get();
546+
return ctx->dev_ptr;
551547
}
552548

553549
/**
@@ -954,7 +950,7 @@ static void ggml_backend_cann_buffer_clear(
954950
(ggml_backend_cann_buffer_context*)buffer->context;
955951

956952
ggml_cann_set_device(ctx->device);
957-
ACL_CHECK(aclrtMemset(ctx->alloc->get(), buffer->size, value, buffer->size));
953+
ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
958954
}
959955

960956
/**
@@ -1016,13 +1012,25 @@ static const char* ggml_backend_cann_buffer_type_name(
10161012
static ggml_backend_buffer_t
10171013
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
10181014
size_t size) {
1019-
ggml_backend_cann_context* cann_ctx =
1020-
(ggml_backend_cann_context*)buft->device->context;
1015+
ggml_backend_cann_buffer_type_context* buft_ctx =
1016+
(ggml_backend_cann_buffer_type_context*)buft->context;
1017+
1018+
ggml_cann_set_device(buft_ctx->device);
10211019

1022-
ggml_cann_pool_alloc* alloc = new ggml_cann_pool_alloc(cann_ctx->pool(), size);
1020+
size = std::max(size, (size_t)1);
1021+
1022+
void* dev_ptr;
1023+
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1024+
if (err != ACL_SUCCESS) {
1025+
GGML_LOG_ERROR(
1026+
"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1027+
__func__, size / 1024.0 / 1024.0, buft_ctx->device,
1028+
aclGetRecentErrMsg());
1029+
return nullptr;
1030+
}
10231031

10241032
ggml_backend_cann_buffer_context* ctx =
1025-
new ggml_backend_cann_buffer_context(cann_ctx->device, alloc);
1033+
new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
10261034

10271035
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
10281036
ctx, size);

0 commit comments

Comments
 (0)