@@ -2544,7 +2544,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
25442544
25452545 int64_t shifts[] = { 1 };
25462546 int64_t dims[] = { 3 };
2547- aclnn_roll (ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
2547+ aclnn_roll (ctx, acl_input_tensor. get () , acl_input_roll_tensor. get () , shifts, dims);
25482548
25492549 // init [-1, 1, -1, 1, ...]
25502550 minus_one_scale_buffer = minus_one_scale_allocator.get ();
@@ -2564,7 +2564,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
25642564 }
25652565 int64_t index_num = src0->ne [0 ];
25662566 float value = -1 ;
2567- aclnn_index_fill_tensor (ctx, acl_minus_one_tensor, dim, index, index_num, value);
2567+ aclnn_index_fill_tensor (ctx, acl_minus_one_tensor. get () , dim, index, index_num, value);
25682568 } else {
25692569 // roll input: [q0,q1,q2,...] ->
25702570 // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
@@ -2576,7 +2576,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
25762576
25772577 int64_t shifts[] = { src0->ne [0 ] / 2 };
25782578 int64_t dims[] = { 3 };
2579- aclnn_roll (ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
2579+ aclnn_roll (ctx, acl_input_tensor. get () , acl_input_roll_tensor. get () , shifts, dims);
25802580
25812581 // init [-1, -1, -1, 1, 1,1,...]
25822582 minus_one_scale_buffer = minus_one_scale_allocator.get ();
@@ -2599,7 +2599,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
25992599 first_half_ne, first_half_nb, GGML_MAX_DIMS);
26002600 bool inplace = true ;
26012601 float scale = -1 ;
2602- aclnn_muls (ctx, acl_first_half_tensor, scale, nullptr , inplace);
2602+ aclnn_muls (ctx, acl_first_half_tensor. get () , scale, nullptr , inplace);
26032603 }
26042604
26052605 // TODO: n_dims < ne0
@@ -2620,14 +2620,15 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
26202620 ggml_cann_create_tensor (input_roll_buffer, ggml_cann_type_mapping (src0->type ), ggml_type_size (src0->type ),
26212621 src0->ne , input_nb, GGML_MAX_DIMS);
26222622
2623- aclnn_mul (ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, acl_input_roll_mul_scale_tensor);
2623+ aclnn_mul (ctx, acl_input_roll_reshape_tensor.get (), acl_minus_one_tensor.get (),
2624+ acl_input_roll_mul_scale_tensor.get ());
26242625
26252626 // output
26262627 void * output_fp32_buffer;
26272628 if (src0->type == GGML_TYPE_F32) {
2628- aclnn_mul (ctx, acl_src, acl_cos_reshape_tensor);
2629- aclnn_mul (ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor);
2630- aclnn_add (ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst);
2629+ aclnn_mul (ctx, acl_src. get () , acl_cos_reshape_tensor. get () );
2630+ aclnn_mul (ctx, acl_input_roll_mul_scale_tensor. get () , acl_sin_reshape_tensor. get () );
2631+ aclnn_add (ctx, acl_src. get () , acl_input_roll_mul_scale_tensor. get () , acl_dst. get () );
26312632 // TODO: ne0 != n_dims in mode2
26322633 } else if (src0->type == GGML_TYPE_F16) {
26332634 size_t input_fp32_nb[GGML_MAX_DIMS];
@@ -2648,10 +2649,10 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
26482649 output_fp32_buffer = fp32_allocator.get ();
26492650 acl_tensor_ptr output_fp32_tensor = ggml_cann_create_tensor (output_fp32_buffer, ACL_FLOAT, sizeof (float ),
26502651 dst->ne , input_fp32_nb, GGML_MAX_DIMS);
2651- aclnn_mul (ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
2652- aclnn_mul (ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, input_fp32_tensor2);
2653- aclnn_add (ctx, input_fp32_tensor1, input_fp32_tensor2, output_fp32_tensor);
2654- aclnn_cast (ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
2652+ aclnn_mul (ctx, acl_src. get () , acl_cos_reshape_tensor. get () , input_fp32_tensor1. get () );
2653+ aclnn_mul (ctx, acl_input_roll_mul_scale_tensor. get () , acl_sin_reshape_tensor. get () , input_fp32_tensor2. get () );
2654+ aclnn_add (ctx, input_fp32_tensor1. get () , input_fp32_tensor2. get () , output_fp32_tensor. get () );
2655+ aclnn_cast (ctx, output_fp32_tensor. get () , acl_dst. get () , ACL_FLOAT16);
26552656 }
26562657 return ;
26572658#endif
0 commit comments