@@ -1116,56 +1116,77 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
11161116    return  GGML_STATUS_SUCCESS;
11171117}
11181118
1119- //  ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
1120- class  NzWorkspace  {
1121- public: 
1122-     //  Constructor: initialize with no allocated buffer
1123-     NzWorkspace () : ptr_(nullptr ), allocated_(0 ) {}
1124- 
1125-     //  Reset workspace to uninitialized state:
1126-     //  - Free allocated device memory (if any)
1127-     //  - Clear internal pointer and size
1128-     //  Equivalent to release_nz_workspace(device) in old version
1129-     void  init () {
1130-         if  (ptr_) {
1131-             aclrtFree (ptr_);
1132-             ptr_ = nullptr ;
1133-             allocated_ = 0 ;
1119+ /* *
1120+  * @brief Workspace for caching NZ buffers per device. 
1121+  * 
1122+  * This struct manages a device buffer used in NZ computations. It supports 
1123+  * allocation, reallocation, and clearing of cached memory. The struct is 
1124+  * designed to be used with a global array, one per device. 
1125+  */  
1126+ struct  ggml_cann_nz_workspace  {
1127+     void *  ptr;       //  Pointer to allocated device buffer
1128+     size_t  allocated; //  Size of currently allocated buffer in bytes
1129+ 
1130+     /* *
1131+      * @brief Constructor. Initializes the workspace with no allocated memory. 
1132+      */  
1133+     ggml_cann_nz_workspace () : ptr(nullptr ), allocated(0 ) {}
1134+ 
1135+     /* *
1136+      * @brief Free cached memory and reset the workspace. 
1137+      * 
1138+      * If a buffer has been allocated, this function releases it using 
1139+      * aclrtFree and resets internal state. 
1140+      */  
1141+     void  clear () {
1142+         if  (ptr) {
1143+             aclrtFree (ptr);
1144+             ptr = nullptr ;
1145+             allocated = 0 ;
11341146        }
11351147    }
11361148
1137-     //  Allocate or reallocate the workspace buffer:
1138-     //  - If requested size > currently allocated size:
1139-     //    * Free the old buffer (if any)
1140-     //    * Allocate a new buffer with requested size on device
1141-     //  - If requested size <= currently allocated size:
1142-     //    * Do nothing (reuse existing buffer)
1143-     //  Equivalent to relloc_nz_workspace(device, new_size) in old version
1149+     /* *
1150+      * @brief Allocate or reallocate the workspace buffer. 
1151+      * 
1152+      * If the requested size is larger than the currently allocated size, 
1153+      * the old buffer will be freed and a new buffer of the requested size 
1154+      * will be allocated on the device. 
1155+      * 
1156+      * @param new_size Size in bytes to allocate for the workspace. 
1157+      */  
11441158    void  realloc (size_t  new_size) {
1145-         if  (new_size > allocated_ ) {
1146-             init ();
1147-             ACL_CHECK (aclrtMalloc (&ptr_ , new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1148-             allocated_  = new_size;
1159+         if  (new_size > allocated ) {
1160+             clear ();
1161+             ACL_CHECK (aclrtMalloc (&ptr , new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1162+             allocated  = new_size;
11491163        }
11501164    }
11511165
1152-     //  Return raw device pointer (may be nullptr if not allocated)
1153-     //  Equivalent to get_nz_workspace(device) in old version
1154-     void * get () const  { return  ptr_; }
1155- 
1156- private: 
1157-     void * ptr_;  //  Pointer to allocated device buffer
1158-     size_t  allocated_;  //  Size of currently allocated buffer (bytes)
1166+     /* *
1167+      * @brief Get the device buffer pointer. 
1168+      * 
1169+      * @return Pointer to the allocated buffer, or nullptr if not allocated. 
1170+      */  
1171+     void * get () const  { return  ptr; }
11591172};
11601173
1161- //  Global array of NzWorkspace, one per device
1162- //  g_nz_workspaces[device] corresponds to workspace of given device
1163- static  std::array<NzWorkspace, GGML_CANN_MAX_DEVICES> g_nz_workspaces;
1174+ /* *
1175+  * @brief Global array of NZ workspaces, one per device. 
1176+  */  
1177+ static  std::array<ggml_cann_nz_workspace, GGML_CANN_MAX_DEVICES> g_nz_workspaces;
11641178
1165- //  Accessor for workspace of a given device
1166- //  - Throws std::out_of_range if device index is invalid
1167- //  - Caller can then use .init(), .realloc(), .get()
1168- inline  NzWorkspace& get_workspace (int  device) {
1179+ /* *
1180+  * @brief Get the NZ workspace for a specific device. 
1181+  * 
1182+  * This function returns a reference to the workspace corresponding to the 
1183+  * given device index. 
1184+  * 
1185+  * @param device Device index (0-based). Must be less than GGML_CANN_MAX_DEVICES. 
1186+  * @return Reference to the device's NZ workspace. 
1187+  * @throws std::out_of_range if device index is invalid. 
1188+  */  
1189+ inline  ggml_cann_nz_workspace& get_nz_workspace (int  device) {
11691190    if  (device < 0  || device >= static_cast <int >(g_nz_workspaces.size ())) {
11701191        throw  std::out_of_range (" device id out of range" 
11711192    }
@@ -1197,9 +1218,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device)
11971218    ACL_CHECK (aclnnTransMatmulWeightGetWorkspaceSize (weightTransposed,
11981219                                                    &workspaceSize, &executor));
11991220    //  Avoid frequent malloc/free of the workspace.
1200-     get_workspace (device).realloc (workspaceSize);
1221+     get_nz_workspace (device).realloc (workspaceSize);
12011222
1202-     void * g_nz_workspace = get_workspace (device).get ();
1223+     void * g_nz_workspace = get_nz_workspace (device).get ();
12031224
12041225    ACL_CHECK (aclnnTransMatmulWeight (g_nz_workspace, workspaceSize, executor, nullptr ));
12051226    ACL_CHECK (aclDestroyTensor (weightTransposed));
@@ -2280,7 +2301,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
22802301    ggml_backend_cann_context* cann_ctx =
22812302        (ggml_backend_cann_context*)backend->context ;
22822303    ggml_cann_set_device (cann_ctx->device );
2283-     get_workspace (cann_ctx->device ).init ();
2304+     get_nz_workspace (cann_ctx->device ).clear ();
22842305
22852306#ifdef  USE_ACL_GRAPH
22862307    bool  use_cann_graph = true ;
0 commit comments