@@ -181,55 +181,6 @@ static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
181181 ACL_CHECK (aclDestroyScalar (alpha));
182182}
183183
184- /* *
185- * @brief Computes the argmax of a tensor along the specified dimension using the CANN backend.
186- *
187- * This function performs the argmax operation on the input tensor (`acl_src`)
188- * and stores the result in the destination tensor (`acl_dst`). The argmax is
189- * computed along a specified axis, and the result is the index of the maximum value
190- * along that axis. The operation is performed using the CANN backend, and
191- * necessary memory allocation is handled automatically.
192- *
193- * @param ctx The context for CANN backend operations.
194- * @param acl_src The source tensor on which the argmax operation will be performed.
195- * @param acl_dst The destination tensor that will hold the resulting indices.
196- * @param dst The destination tensor object that stores the result after the argmax operation.
197- */
198- static void aclnn_argmax (ggml_backend_cann_context& ctx, aclTensor* acl_src,
199- aclTensor* acl_dst, ggml_tensor* dst) {
200- ggml_cann_pool_alloc dst_buffer_allocator (
201- ctx.pool (), ggml_nelements (dst) * ggml_type_size (dst->type ));
202- void * buffer = dst_buffer_allocator.get ();
203- int64_t dst_buffer_ne[4 ] = {1 , dst->ne [0 ], dst->ne [1 ], dst->ne [2 ]};
204- size_t dst_buffer_nb[GGML_MAX_DIMS];
205- dst_buffer_nb[0 ] = ggml_type_size (dst->type );
206- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
207- dst_buffer_nb[i] = dst_buffer_nb[i - 1 ] * dst_buffer_ne[i - 1 ];
208- }
209-
210- aclTensor* dst_buffer_tensor =
211- ggml_cann_create_tensor (buffer, ggml_cann_type_mapping (dst->type ), ggml_type_size (dst->type ),
212- dst_buffer_ne, dst_buffer_nb, 4 );
213-
214- uint64_t workspaceSize = 0 ;
215- aclOpExecutor* executor;
216- void * workspaceAddr = nullptr ;
217-
218-
219- ACL_CHECK (aclnnArgMaxGetWorkspaceSize (acl_src, 3 , true , dst_buffer_tensor,
220- &workspaceSize, &executor));
221- if (workspaceSize > 0 ) {
222- ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
223- workspaceAddr = workspace_allocator.get ();
224- }
225-
226- ACL_CHECK (aclnnArgMax (workspaceAddr, workspaceSize, executor, ctx.stream ()));
227-
228- size_t cpy_size = ggml_nbytes (dst);
229- ACL_CHECK (aclrtMemcpyAsync (dst->data , cpy_size, buffer, cpy_size,
230- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
231- }
232-
233184void ggml_cann_add (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
234185 ggml_tensor* src0 = dst->src [0 ];
235186 ggml_tensor* src1 = dst->src [1 ];
@@ -3500,8 +3451,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
35003451 ggml_tensor * src0 = dst->src [0 ];
35013452
35023453 aclTensor* acl_src = ggml_cann_create_tensor (src0);
3503- aclTensor* acl_dst = ggml_cann_create_tensor (dst);
3504- aclnn_argmax (ctx, acl_src, acl_dst, dst);
3454+ int64_t dst_ne[3 ] = {dst->ne [0 ], dst->ne [1 ], dst->ne [2 ]};
3455+ size_t dst_nb[3 ] = {dst->nb [0 ], dst->ne [1 ], dst->nb [2 ]};
3456+ aclTensor* acl_dst = ggml_cann_create_tensor (dst, dst_ne, dst_nb, 3 );
3457+
3458+ uint64_t workspaceSize = 0 ;
3459+ aclOpExecutor* executor;
3460+ void * workspaceAddr = nullptr ;
3461+
3462+ ACL_CHECK (aclnnArgMaxGetWorkspaceSize (acl_src, 3 , false , acl_dst,
3463+ &workspaceSize, &executor));
3464+ if (workspaceSize > 0 ) {
3465+ ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
3466+ workspaceAddr = workspace_allocator.get ();
3467+ }
3468+ ACL_CHECK (aclnnArgMax (workspaceAddr, workspaceSize, executor, ctx.stream ()));
3469+
35053470 ACL_CHECK (aclDestroyTensor (acl_src));
35063471 ACL_CHECK (aclDestroyTensor (acl_dst));
35073472}
0 commit comments