Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,30 +1117,62 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
}

// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
namespace {
void* g_nz_workspace = nullptr;
size_t g_nz_workspace_allocated = 0;

void release_nz_workspace() {
if (g_nz_workspace) {
aclrtFree(g_nz_workspace);
g_nz_workspace = nullptr;
g_nz_workspace_allocated = 0;
class NzWorkspace {
public:
// Constructor: initialize with no allocated buffer
NzWorkspace() : ptr_(nullptr), allocated_(0) {}

// Reset workspace to uninitialized state:
// - Free allocated device memory (if any)
// - Clear internal pointer and size
// Equivalent to release_nz_workspace(device) in old version
void init() {
if (ptr_) {
aclrtFree(ptr_);
ptr_ = nullptr;
allocated_ = 0;
}
}

void relloc_nz_workspace(size_t new_size) {
if (new_size > g_nz_workspace_allocated) {
if (g_nz_workspace) {
aclrtFree(g_nz_workspace);
g_nz_workspace = nullptr;
// Allocate or reallocate the workspace buffer:
// - If requested size > currently allocated size:
// * Free the old buffer (if any)
// * Allocate a new buffer with requested size on device
// - If requested size <= currently allocated size:
// * Do nothing (reuse existing buffer)
// Equivalent to relloc_nz_workspace(device, new_size) in old version
void realloc(size_t new_size) {
if (new_size > allocated_) {
init();
ACL_CHECK(aclrtMalloc(&ptr_, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
allocated_ = new_size;
}
ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
g_nz_workspace_allocated = new_size;
}

// Return raw device pointer (may be nullptr if not allocated)
// Equivalent to get_nz_workspace(device) in old version
void* get() const { return ptr_; }

private:
void* ptr_; // Pointer to allocated device buffer
size_t allocated_; // Size of currently allocated buffer (bytes)
};

// Global array of NzWorkspace, one per device
// g_nz_workspaces[device] corresponds to workspace of given device
static std::array<NzWorkspace, GGML_CANN_MAX_DEVICES> g_nz_workspaces;

// Accessor for workspace of a given device
// - Throws std::out_of_range if device index is invalid
// - Caller can then use .init(), .realloc(), .get()
inline NzWorkspace& get_workspace(int device) {
if (device < 0 || device >= static_cast<int>(g_nz_workspaces.size())) {
throw std::out_of_range("device id out of range");
}
return g_nz_workspaces[device];
}


/**
* @brief Convert tensor weights to NZ format using Ascend CANN API.
*
Expand All @@ -1149,13 +1181,13 @@ namespace {
* improve performance on certain hardware.
*
* @param tensor Pointer to the input ggml_tensor containing the weights.
* @param data Pointer to the raw data buffer for the tensor weights.
* @param offset Byte offset within the tensor data buffer where weights start.
* @param device device id.
*
* @note The workspace buffer used in this function is managed globally and reused
* across calls. This reduces overhead from repeated memory allocation and deallocation.
*/
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
Expand All @@ -1165,7 +1197,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
&workspaceSize, &executor));
// Avoid frequent malloc/free of the workspace.
relloc_nz_workspace(workspaceSize);
get_workspace(device).realloc(workspaceSize);

void* g_nz_workspace = get_workspace(device).get();

ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
ACL_CHECK(aclDestroyTensor(weightTransposed));
Expand Down Expand Up @@ -1203,7 +1237,7 @@ static void ggml_backend_cann_buffer_set_tensor(
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset);
weight_format_to_nz(tensor, offset, ctx->device);
}
} else {
void *transform_buffer = malloc(size);
Expand Down Expand Up @@ -2246,7 +2280,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
release_nz_workspace();
get_workspace(cann_ctx->device).init();

#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
Expand Down
Loading