@@ -1116,30 +1116,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
11161116 return GGML_STATUS_SUCCESS;
11171117}
11181118
1119- // ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
1120- namespace {
1121- void * g_nz_workspace = nullptr ;
1122- size_t g_nz_workspace_allocated = 0 ;
1123-
1124- void release_nz_workspace () {
1125- if (g_nz_workspace) {
1126- aclrtFree (g_nz_workspace);
1127- g_nz_workspace = nullptr ;
1128- g_nz_workspace_allocated = 0 ;
1119+ /* *
1120+ * @brief Workspace for caching NZ buffers per device.
1121+ *
1122+ * This struct manages a device buffer used in NZ computations. It supports
1123+ * allocation, reallocation, and clearing of cached memory. The struct is
1124+ * designed to be used with a global array, one per device.
1125+ */
1126+ struct ggml_cann_nz_workspace {
1127+ void * ptr; // Pointer to allocated device buffer
1128+ size_t allocated; // Size of currently allocated buffer in bytes
1129+
1130+ /* *
1131+ * @brief Constructor. Initializes the workspace with no allocated memory.
1132+ */
1133+ ggml_cann_nz_workspace () : ptr(nullptr ), allocated(0 ) {}
1134+
1135+ /* *
1136+ * @brief Free cached memory and reset the workspace.
1137+ *
1138+ * If a buffer has been allocated, this function releases it using
1139+ * aclrtFree and resets internal state.
1140+ */
1141+ void clear () {
1142+ if (ptr) {
1143+ ACL_CHECK (aclrtFree (ptr));
1144+ ptr = nullptr ;
1145+ allocated = 0 ;
11291146 }
11301147 }
11311148
1132- void relloc_nz_workspace (size_t new_size) {
1133- if (new_size > g_nz_workspace_allocated) {
1134- if (g_nz_workspace) {
1135- aclrtFree (g_nz_workspace);
1136- g_nz_workspace = nullptr ;
1149+ /* *
1150+ * @brief Allocate or reallocate the workspace buffer.
1151+ *
1152+ * If the requested size is larger than the currently allocated size,
1153+ * the old buffer will be freed and a new buffer of the requested size
1154+ * will be allocated on the device.
1155+ *
1156+ * @param new_size Size in bytes to allocate for the workspace.
1157+ */
1158+ void realloc (size_t new_size) {
1159+ if (new_size > allocated) {
1160+ clear ();
1161+ ACL_CHECK (aclrtMalloc (&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1162+ allocated = new_size;
11371163 }
1138- ACL_CHECK (aclrtMalloc (&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1139- g_nz_workspace_allocated = new_size;
1140- }
11411164 }
1142- }
1165+
1166+ /* *
1167+ * @brief Get the device buffer pointer.
1168+ *
1169+ * @return Pointer to the allocated buffer, or nullptr if not allocated.
1170+ */
1171+ void * get () const { return ptr; }
1172+ };
1173+
1174+ /* *
1175+ * @brief Global array of NZ workspaces, one per device.
1176+ */
1177+ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
11431178
11441179/* *
11451180 * @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1149,13 +1184,13 @@ namespace {
11491184 * improve performance on certain hardware.
11501185 *
11511186 * @param tensor Pointer to the input ggml_tensor containing the weights.
1152- * @param data Pointer to the raw data buffer for the tensor weights.
11531187 * @param offset Byte offset within the tensor data buffer where weights start.
1188+ * @param device device id.
11541189 *
11551190 * @note The workspace buffer used in this function is managed globally and reused
11561191 * across calls. This reduces overhead from repeated memory allocation and deallocation.
11571192 */
1158- static void weight_format_to_nz (ggml_tensor *tensor, size_t offset) {
1193+ static void weight_format_to_nz (ggml_tensor *tensor, size_t offset, int device ) {
11591194 aclTensor* weightTransposed = ggml_cann_create_tensor (tensor, tensor->ne ,
11601195 tensor->nb , 2 , ACL_FORMAT_ND, offset);
11611196 uint64_t workspaceSize = 0 ;
@@ -1165,7 +1200,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
11651200 ACL_CHECK (aclnnTransMatmulWeightGetWorkspaceSize (weightTransposed,
11661201 &workspaceSize, &executor));
11671202 // Avoid frequent malloc/free of the workspace.
1168- relloc_nz_workspace (workspaceSize);
1203+ g_nz_workspaces[device].realloc (workspaceSize);
1204+
1205+ void * g_nz_workspace = g_nz_workspaces[device].get ();
11691206
11701207 ACL_CHECK (aclnnTransMatmulWeight (g_nz_workspace, workspaceSize, executor, nullptr ));
11711208 ACL_CHECK (aclDestroyTensor (weightTransposed));
@@ -1203,7 +1240,7 @@ static void ggml_backend_cann_buffer_set_tensor(
12031240 if (weight_to_nz && is_matmul_weight ((const ggml_tensor*)tensor)) {
12041241 GGML_ASSERT (tensor->ne [2 ] == 1 );
12051242 GGML_ASSERT (tensor->ne [3 ] == 1 );
1206- weight_format_to_nz (tensor, offset);
1243+ weight_format_to_nz (tensor, offset, ctx-> device );
12071244 }
12081245 } else {
12091246 void *transform_buffer = malloc (size);
@@ -2262,7 +2299,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
22622299 ggml_backend_cann_context* cann_ctx =
22632300 (ggml_backend_cann_context*)backend->context ;
22642301 ggml_cann_set_device (cann_ctx->device );
2265- release_nz_workspace ();
2302+ g_nz_workspaces[cann_ctx-> device ]. clear ();
22662303
22672304#ifdef USE_ACL_GRAPH
22682305 bool use_cann_graph = true ;
0 commit comments