Skip to content

Commit bbe0fd2

Browse files
committed
[CANN]codestyle adjustment
Signed-off-by: noemotiovon <[email protected]>
1 parent def7d45 commit bbe0fd2

File tree

1 file changed

+16
-51
lines changed

1 file changed

+16
-51
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 16 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -181,55 +181,6 @@ static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
181181
ACL_CHECK(aclDestroyScalar(alpha));
182182
}
183183

184-
/**
185-
* @brief Computes the argmax of a tensor along the specified dimension using the CANN backend.
186-
*
187-
* This function performs the argmax operation on the input tensor (`acl_src`)
188-
* and stores the result in the destination tensor (`acl_dst`). The argmax is
189-
* computed along a specified axis, and the result is the index of the maximum value
190-
* along that axis. The operation is performed using the CANN backend, and
191-
* necessary memory allocation is handled automatically.
192-
*
193-
* @param ctx The context for CANN backend operations.
194-
* @param acl_src The source tensor on which the argmax operation will be performed.
195-
* @param acl_dst The destination tensor that will hold the resulting indices.
196-
* @param dst The destination tensor object that stores the result after the argmax operation.
197-
*/
198-
static void aclnn_argmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
199-
aclTensor* acl_dst, ggml_tensor* dst) {
200-
ggml_cann_pool_alloc dst_buffer_allocator(
201-
ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type));
202-
void* buffer = dst_buffer_allocator.get();
203-
int64_t dst_buffer_ne[4] = {1, dst->ne[0], dst->ne[1], dst->ne[2]};
204-
size_t dst_buffer_nb[GGML_MAX_DIMS];
205-
dst_buffer_nb[0] = ggml_type_size(dst->type);
206-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
207-
dst_buffer_nb[i] = dst_buffer_nb[i - 1] * dst_buffer_ne[i - 1];
208-
}
209-
210-
aclTensor* dst_buffer_tensor =
211-
ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
212-
dst_buffer_ne, dst_buffer_nb, 4);
213-
214-
uint64_t workspaceSize = 0;
215-
aclOpExecutor* executor;
216-
void* workspaceAddr = nullptr;
217-
218-
219-
ACL_CHECK(aclnnArgMaxGetWorkspaceSize(acl_src, 3, true, dst_buffer_tensor,
220-
&workspaceSize, &executor));
221-
if (workspaceSize > 0) {
222-
ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
223-
workspaceAddr = workspace_allocator.get();
224-
}
225-
226-
ACL_CHECK(aclnnArgMax(workspaceAddr, workspaceSize, executor, ctx.stream()));
227-
228-
size_t cpy_size = ggml_nbytes(dst);
229-
ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, buffer, cpy_size,
230-
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
231-
}
232-
233184
void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
234185
ggml_tensor* src0 = dst->src[0];
235186
ggml_tensor* src1 = dst->src[1];
@@ -3500,8 +3451,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
35003451
ggml_tensor * src0 = dst->src[0];
35013452

35023453
aclTensor* acl_src = ggml_cann_create_tensor(src0);
3503-
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3504-
aclnn_argmax(ctx, acl_src, acl_dst, dst);
3454+
int64_t dst_ne[3] = {dst->ne[0], dst->ne[1], dst->ne[2]};
3455+
size_t dst_nb[3] = {dst->nb[0], dst->ne[1], dst->nb[2]};
3456+
aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst_ne, dst_nb, 3);
3457+
3458+
uint64_t workspaceSize = 0;
3459+
aclOpExecutor* executor;
3460+
void* workspaceAddr = nullptr;
3461+
3462+
ACL_CHECK(aclnnArgMaxGetWorkspaceSize(acl_src, 3, false, acl_dst,
3463+
&workspaceSize, &executor));
3464+
if (workspaceSize > 0) {
3465+
ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
3466+
workspaceAddr = workspace_allocator.get();
3467+
}
3468+
ACL_CHECK(aclnnArgMax(workspaceAddr, workspaceSize, executor, ctx.stream()));
3469+
35053470
ACL_CHECK(aclDestroyTensor(acl_src));
35063471
ACL_CHECK(aclDestroyTensor(acl_dst));
35073472
}

0 commit comments

Comments
 (0)