diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp old mode 100755 new mode 100644 index bc33b99d96e..94a8aa14f41 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2229,6 +2229,10 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, ggml_cann_release_resources(ctx, acl_index, acl_value); } + + + + /** * @brief Initializes and caches sine/cosine positional encoding values * (used in RoPE, Rotary Position Embedding) for attention layers. @@ -2445,275 +2449,539 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace, } #endif +// Helper: YaRN theta correction (copied from ggml-common.c logic) +static void rope_yarn( + float theta, float freq_scale, const float corr_dims[2], + int64_t i0, float ext_factor, float mscale, + float* cos_val, float* sin_val +) { + // Original freq = theta + float inv_freq = theta; + float freq = inv_freq; + if (ext_factor != 0.0f) { + float t = i0 / 2; + float a = corr_dims[0]; + float b = corr_dims[1]; + float smooth = 1.0f + (a / (b + t)); + freq *= powf(smooth, ext_factor); + } + freq *= freq_scale; + float c = cosf(freq) * mscale; + float s = sinf(freq) * mscale; + *cos_val = c; + *sin_val = s; +} + +static void aclnn_compute_mrope_tables_host( + float* cos_out, float* sin_out, + const int32_t* pos_ids, + const float* freq_factors, + int64_t seq_len, int64_t n_dims, + const std::array& sections, + float freq_base, float freq_scale, + float ext_factor, float attn_factor, + bool is_neox, + const float corr_dims[2] +) { + const float theta_scale = powf(freq_base, -2.0f / n_dims); + const int total_sect = sections[0] + sections[1] + sections[2] + sections[3]; + GGML_ASSERT(total_sect > 0 && total_sect <= n_dims); + GGML_ASSERT(n_dims % 2 == 0); + + // Precompute section boundaries + int sec_w = sections[0] + sections[1]; + // int sec_e = sec_w + sections[2]; + + for (int64_t t = 0; t < seq_len; ++t) { + float theta_base_t = (float)pos_ids[0 * seq_len + t]; // [t0, t1, ... , h0, h1, ] + float theta_base_h = (float)pos_ids[1 * seq_len + t]; + float theta_base_w = (float)pos_ids[2 * seq_len + t]; + float theta_base_e = (float)pos_ids[3 * seq_len + t]; + + float theta_t = theta_base_t; + float theta_h = theta_base_h; + float theta_w = theta_base_w; + float theta_e = theta_base_e; // extra position id for vision encoder + + for (int64_t i = 0; i < n_dims; i += 2) { + int64_t idx = i / 2; // theta_idx + int sector = idx % total_sect; + + const float ff = freq_factors ? freq_factors[idx] : 1.0f; + + float theta = theta_t; + if (sector >= sections[0] && sector < sec_w) { + theta = theta_h; + } + else if (sector >= sec_w && sector < sec_w + sections[2]) { + theta = theta_w; + } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } + + float c, s; + rope_yarn( + theta/ff, freq_scale, corr_dims, i, ext_factor, attn_factor, &c, &s + ); + + cos_out[t * n_dims + idx ] = c; + cos_out[t * n_dims + idx + n_dims / 2] = c; + sin_out[t * n_dims + idx ] = s; + sin_out[t * n_dims + idx + n_dims / 2] = s; + + theta_t *= theta_scale; + theta_w *= theta_scale; + theta_h *= theta_scale; + theta_e *= theta_scale; + } + } +} + + void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input + ggml_tensor* src1 = dst->src[1]; // pos_ids: (size = 4 * seq_len), eg. src1->ne[0]: 8, src1->ne[1]: 1, seq_len: 2 + ggml_tensor* src2 = dst->src[2]; // freq_factors (optional) + const int32_t* params = (const int32_t*)dst->op_params; + const int n_dims = params[1]; + const int mode = params[2]; + const int n_ctx_orig = params[4]; // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; // const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t*)dst->op_params)[1]; - const int mode = ((int32_t*)dst->op_params)[2]; + // const int n_dims = ((int32_t*)dst->op_params)[1]; + // const int mode = ((int32_t*)dst->op_params)[2]; // const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t*)dst->op_params)[4]; + // const int n_ctx_orig = ((int32_t*)dst->op_params)[4]; - GGML_TENSOR_UNARY_OP_LOCALS + if(mode == GGML_ROPE_TYPE_MROPE){ + memcpy(&freq_base, params + 5, sizeof(float)); + memcpy(&freq_scale, params + 6, sizeof(float)); + memcpy(&ext_factor, params + 7, sizeof(float)); + memcpy(&attn_factor, params + 8, sizeof(float)); + memcpy(&beta_fast, params + 9, sizeof(float)); + memcpy(&beta_slow, params + 10, sizeof(float)); - memcpy(&freq_base, (int32_t*)dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t*)dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t*)dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t*)dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float)); + std::array sections; + memcpy(sections.data(), params + 11, 4 * sizeof(int32_t)); - // TODO: n_dims <= ne0 - GGML_ASSERT(n_dims == ne0); - GGML_ASSERT(n_dims % 2 == 0); - // TODO: ext_factor != 0 - GGML_ASSERT(ext_factor == 0); + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; - const float theta_scale = powf(freq_base, -2.0f / n_dims); + if (mrope_used) { + // MROPE always uses Neox-style rotation + (void)is_neox; // suppress unused warning + } - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, - beta_slow, corr_dims); + // YaRN + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + // seq_len = src0->ne[2] for layout [head_dim, n_heads, seq_len, batch] + int64_t seq_len = src0->ne[2]; - // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox); + // Read pos_ids (size = 4 * seq_len) + GGML_ASSERT(src1->ne[0] == 4 * seq_len); - int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; - size_t sin_reshape_nb[GGML_MAX_DIMS]; - sin_reshape_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; - } - aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), - sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), - sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + std::vector pos_ids_host(4 * seq_len); - aclTensor* acl_src = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); + ACL_CHECK(aclrtMemcpy(pos_ids_host.data(), 4 * seq_len * sizeof(int32_t), + src1->data, 4 * seq_len * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST)); -#ifdef ASCEND_310P - // Special ROPE operation for 310P - - // roll input - void* input_roll_buffer; - aclTensor* acl_minus_one_tensor; - void* minus_one_scale_buffer = nullptr; - ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0)); - ggml_cann_pool_alloc minus_one_scale_allocator( - ctx.pool(), sizeof(float_t) * src0->ne[0]); - if (!is_neox) { - // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...] - input_roll_buffer = roll_allocator.get(); - int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2), - src0->ne[2], src0->ne[3]}; - size_t input_roll_nb[GGML_MAX_DIMS]; - input_roll_nb[0] = ggml_type_size(src0->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1]; + // Read freq_factors + std::vector freq_factors_host; + float* freq_factors_ptr = nullptr; + if (src2) { + freq_factors_host.resize(ggml_nelements(src2)); + ACL_CHECK(aclrtMemcpy(freq_factors_host.data(), ggml_nbytes(src2), + src2->data, ggml_nbytes(src2), + ACL_MEMCPY_DEVICE_TO_HOST)); + freq_factors_ptr = freq_factors_host.data(); } - aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), input_roll_ne, input_roll_nb, - GGML_MAX_DIMS); - aclTensor* acl_input_tensor = ggml_cann_create_tensor( - src0->data, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), input_roll_ne, input_roll_nb, - GGML_MAX_DIMS); - int64_t shifts[] = {1}; - int64_t dims[] = {3}; - aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); + // Compute cos/sin table on host: [seq_len, n_dims] + size_t table_nelements = seq_len * n_dims; + size_t table_bytes = table_nelements * sizeof(float); + std::vector cos_host(table_nelements); + std::vector sin_host(table_nelements); + + aclnn_compute_mrope_tables_host( + cos_host.data(), sin_host.data(), + pos_ids_host.data(), freq_factors_ptr, + seq_len, n_dims, sections, + freq_base, freq_scale, + ext_factor, attn_factor, + is_neox, + corr_dims + ); - // init [-1, 1, -1, 1, ...] - minus_one_scale_buffer = minus_one_scale_allocator.get(); + // Allocate device memory + size_t f32_nelements = ggml_nelements(dst); + size_t f32_nbytes = f32_nelements * sizeof(float); + + ggml_cann_pool_alloc workspace(ctx.pool()); + char* dev_ptr = (char*)workspace.alloc(table_bytes * 2 + f32_nbytes * 3); + + float* cos_dev = (float*)dev_ptr; + float* sin_dev = (float*)(dev_ptr + table_bytes); + float* x_dev = (float*)(dev_ptr + table_bytes * 2); + float* y_dev = x_dev + f32_nelements; + float* tmp_dev = y_dev + f32_nelements; + + ACL_CHECK(aclrtMemcpyAsync(cos_dev, table_bytes, cos_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + ACL_CHECK(aclrtMemcpyAsync(sin_dev, table_bytes, sin_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + + int64_t x_ne[] = {src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]}; + size_t x_nb[] = { + sizeof(float), + sizeof(float) * x_ne[0], + sizeof(float) * x_ne[0] * x_ne[1], + sizeof(float) * x_ne[0] * x_ne[1] * x_ne[2] + }; + aclTensor* acl_x = ggml_cann_create_tensor(x_dev, ACL_FLOAT, sizeof(float), x_ne, x_nb, 4); + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclnn_cast(ctx, acl_src, acl_x, ACL_FLOAT); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())) + + aclTensor* acl_y_full_init = ggml_cann_create_tensor(y_dev, ACL_FLOAT, sizeof(float), x_ne, x_nb, 4); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_y_full_init, acl_x); + + ggml_cann_release_resources(ctx, acl_y_full_init); + int64_t half_dim = n_dims / 2; + int64_t ne_half[] = {half_dim, src0->ne[1], src0->ne[2], src0->ne[3]}; + size_t* nb_half = x_nb; + float* x0_dev = x_dev; + float* x1_dev = x_dev + half_dim; + aclTensor* acl_x0 = ggml_cann_create_tensor(x0_dev, ACL_FLOAT, sizeof(float), ne_half, nb_half, 4); + aclTensor* acl_x1 = ggml_cann_create_tensor(x1_dev, ACL_FLOAT, sizeof(float), ne_half, nb_half, 4); + int64_t table_ne_broadcast[] = {half_dim, 1, seq_len, 1}; + size_t table_nb_broadcast[4]; + table_nb_broadcast[0] = sizeof(float); + table_nb_broadcast[1] = 0; + table_nb_broadcast[2] = sizeof(float) * n_dims; + table_nb_broadcast[3] = 0; + + aclTensor* acl_cos0 = ggml_cann_create_tensor(cos_dev, ACL_FLOAT, sizeof(float), + table_ne_broadcast, table_nb_broadcast, 4); + aclTensor* acl_sin0 = ggml_cann_create_tensor(sin_dev, ACL_FLOAT, sizeof(float), + table_ne_broadcast, table_nb_broadcast, 4); + + // Output halves + float* y0_dev = y_dev; + float* y1_dev = y_dev + half_dim; + aclTensor* acl_y0 = ggml_cann_create_tensor(y0_dev, ACL_FLOAT, sizeof(float), ne_half, x_nb, 4); + aclTensor* acl_y1 = ggml_cann_create_tensor(y1_dev, ACL_FLOAT, sizeof(float), ne_half, x_nb, 4); + + float y64; + ACL_CHECK(aclrtMemcpy(&y64, sizeof(float), y_dev + 64, sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST)); + + // Temp tensor + aclTensor* acl_tmp = ggml_cann_create_tensor(tmp_dev, ACL_FLOAT, sizeof(float), ne_half, x_nb, 4); + + // y0 = x0 * cos0 - x1 * sin0 + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x0, acl_cos0, acl_y0); + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x1, acl_sin0, acl_tmp); + float alpha_neg = -1.0f; + aclScalar* s_alpha_neg = aclCreateScalar(&alpha_neg, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_y0, acl_tmp, s_alpha_neg, acl_y0); + aclDestroyScalar(s_alpha_neg); + + // y1 = x0 * sin0 + x1 * cos0 + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x0, acl_sin0, acl_y1); + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x1, acl_cos0, acl_tmp); + float alpha_pos = 1.0f; + aclScalar* s_alpha_pos = aclCreateScalar(&alpha_pos, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_y1, acl_tmp, s_alpha_pos, acl_y1); + aclDestroyScalar(s_alpha_pos); + + // Full output tensor + aclTensor* acl_y_full = ggml_cann_create_tensor(y_dev, ACL_FLOAT, sizeof(float), x_ne, x_nb, 4); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, acl_y_full, acl_dst, ggml_cann_type_mapping(dst->type)); - int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; - size_t minus_one_nb[GGML_MAX_DIMS]; - minus_one_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; - } - acl_minus_one_tensor = aclnn_values( - ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); - int64_t dim = 3; - int64_t* index = new int64_t[src0->ne[0]]; - for (int i = 0; i < src0->ne[0]; i++) { - index[i] = i / 2 * 2; - } - int64_t index_num = src0->ne[0]; - float value = -1; - aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index, - index_num, value); - } else { - // roll input: [q0,q1,q2,...] -> - // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1] - input_roll_buffer = roll_allocator.get(); - aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); - aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0); - - int64_t shifts[] = {src0->ne[0] / 2}; - int64_t dims[] = {3}; - aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - - ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); - // init [-1, -1, -1, 1, 1,1,...] - minus_one_scale_buffer = minus_one_scale_allocator.get(); - int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; - size_t minus_one_nb[GGML_MAX_DIMS]; - minus_one_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; - } - acl_minus_one_tensor = aclnn_values( - ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); - // -1 * first half - int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1}; - size_t first_half_nb[GGML_MAX_DIMS]; - first_half_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1]; - } - aclTensor* acl_first_half_tensor = ggml_cann_create_tensor( - minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne, - first_half_nb, GGML_MAX_DIMS); - bool inplace = true; - float scale = -1; - aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace); - ggml_cann_release_resources(ctx, acl_first_half_tensor); - } + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); - // TODO: n_dims < ne0 - GGML_ASSERT(n_dims == src0->ne[0]); + std::vector dst_host(ggml_nelements(dst)); + ACL_CHECK(aclrtMemcpy(dst_host.data(), ggml_nbytes(dst), + dst->data, ggml_nbytes(dst), + ACL_MEMCPY_DEVICE_TO_HOST)); - // input * scale - ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(), - ggml_nbytes(src0)); - void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get(); - size_t input_nb[GGML_MAX_DIMS]; - input_nb[0] = ggml_type_size(src0->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - input_nb[i] = input_nb[i - 1] * src0->ne[i - 1]; + // Clean up + ggml_cann_release_resources(ctx, + acl_src, acl_x, acl_x0, acl_x1, + acl_cos0, acl_sin0, + acl_y0, acl_y1, acl_y_full, + acl_tmp, acl_dst + ); } - aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor( - input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); - aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); + else{ + GGML_TENSOR_UNARY_OP_LOCALS - aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, - acl_input_roll_mul_scale_tensor); + memcpy(&freq_base, (int32_t*)dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t*)dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t*)dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t*)dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float)); - // output - void* output_fp32_buffer; - if (src0->type == GGML_TYPE_F32) { - aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor); - aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, - acl_sin_reshape_tensor); - aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst); - // TODO: ne0 != n_dims in mode2 - } else if (src0->type == GGML_TYPE_F16) { - size_t input_fp32_nb[GGML_MAX_DIMS]; - input_fp32_nb[0] = sizeof(float_t); + // TODO: n_dims <= ne0 + GGML_ASSERT(n_dims == ne0); + GGML_ASSERT(n_dims % 2 == 0); + // TODO: ext_factor != 0 + // GGML_ASSERT(ext_factor == 0); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, + beta_slow, corr_dims); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + + // init ctx.rope_cos/rope_sin cache + aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox); + + int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; + size_t sin_reshape_nb[GGML_MAX_DIMS]; + sin_reshape_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { - input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1]; + sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } - ggml_cann_pool_alloc fp32_allocator1( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); - void* input_fp32_buffer1 = fp32_allocator1.get(); - aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor( - input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); - ggml_cann_pool_alloc fp32_allocator2( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); - void* input_fp32_buffer2 = fp32_allocator2.get(); - aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor( - input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); - - ggml_cann_pool_alloc fp32_allocator( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); - output_fp32_buffer = fp32_allocator.get(); - aclTensor* output_fp32_tensor = ggml_cann_create_tensor( - output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1); - aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, - input_fp32_tensor2); - aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, - output_fp32_tensor); - aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16); - - ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2, - output_fp32_tensor, acl_sin_reshape_tensor, - acl_minus_one_tensor, acl_input_roll_mul_scale_tensor, - acl_input_roll_reshape_tensor, acl_src); - } - return; -#endif + aclTensor* acl_sin_reshape_tensor = + ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); + aclTensor* acl_cos_reshape_tensor = + ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), + sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); - // ggml_mode = 0 --> aclnn_model = 1 - int64_t acl_mode = mode == 0 ? 1 : mode; + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); - switch (src0->type) { - case GGML_TYPE_F32: { - GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src, - acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst); - break; +// #define ASCEND_310P +#ifdef ASCEND_310P + // Special ROPE operation for 310P + + // roll input + void* input_roll_buffer; + aclTensor* acl_minus_one_tensor; + void* minus_one_scale_buffer = nullptr; + ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0)); + ggml_cann_pool_alloc minus_one_scale_allocator( + ctx.pool(), sizeof(float_t) * src0->ne[0]); + if (!is_neox) { + // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...] + input_roll_buffer = roll_allocator.get(); + int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2), + src0->ne[2], src0->ne[3]}; + size_t input_roll_nb[GGML_MAX_DIMS]; + input_roll_nb[0] = ggml_type_size(src0->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1]; + } + aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( + input_roll_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), input_roll_ne, input_roll_nb, + GGML_MAX_DIMS); + aclTensor* acl_input_tensor = ggml_cann_create_tensor( + src0->data, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), input_roll_ne, input_roll_nb, + GGML_MAX_DIMS); + + int64_t shifts[] = {1}; + int64_t dims[] = {3}; + aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); + ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); + + // init [-1, 1, -1, 1, ...] + minus_one_scale_buffer = minus_one_scale_allocator.get(); + + int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; + size_t minus_one_nb[GGML_MAX_DIMS]; + minus_one_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; + } + acl_minus_one_tensor = aclnn_values( + ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], + minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); + int64_t dim = 3; + int64_t* index = new int64_t[src0->ne[0]]; + for (int i = 0; i < src0->ne[0]; i++) { + index[i] = i / 2 * 2; + } + int64_t index_num = src0->ne[0]; + float value = -1; + aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index, + index_num, value); + } else { + // roll input: [q0,q1,q2,...] -> + // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1] + input_roll_buffer = roll_allocator.get(); + aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( + input_roll_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); + aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0); + + int64_t shifts[] = {src0->ne[0] / 2}; + int64_t dims[] = {3}; + aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); + + ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); + // init [-1, -1, -1, 1, 1,1,...] + minus_one_scale_buffer = minus_one_scale_allocator.get(); + int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; + size_t minus_one_nb[GGML_MAX_DIMS]; + minus_one_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; + } + acl_minus_one_tensor = aclnn_values( + ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], + minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); + // -1 * first half + int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1}; + size_t first_half_nb[GGML_MAX_DIMS]; + first_half_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1]; + } + aclTensor* acl_first_half_tensor = ggml_cann_create_tensor( + minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne, + first_half_nb, GGML_MAX_DIMS); + bool inplace = true; + float scale = -1; + aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace); + ggml_cann_release_resources(ctx, acl_first_half_tensor); } - case GGML_TYPE_F16: { - ggml_cann_pool_alloc src_trans_allocator( - ctx.pool(), ggml_nelements(src0) * sizeof(float)); - void* src_trans_buffer = src_trans_allocator.get(); - ggml_cann_pool_alloc dst_trans_allocator( - ctx.pool(), ggml_nelements(dst) * sizeof(float)); - void* dst_trans_buffer = dst_trans_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(float); + // TODO: n_dims < ne0 + GGML_ASSERT(n_dims == src0->ne[0]); + + // input * scale + ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(), + ggml_nbytes(src0)); + void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get(); + size_t input_nb[GGML_MAX_DIMS]; + input_nb[0] = ggml_type_size(src0->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + input_nb[i] = input_nb[i - 1] * src0->ne[i - 1]; + } + aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor( + input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); + aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor( + input_roll_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); + + aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, + acl_input_roll_mul_scale_tensor); + + // output + void* output_fp32_buffer; + if (src0->type == GGML_TYPE_F32) { + aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, + acl_sin_reshape_tensor); + aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst); + // TODO: ne0 != n_dims in mode2 + } else if (src0->type == GGML_TYPE_F16) { + size_t input_fp32_nb[GGML_MAX_DIMS]; + input_fp32_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1]; } + ggml_cann_pool_alloc fp32_allocator1( + ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); + void* input_fp32_buffer1 = fp32_allocator1.get(); + aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor( + input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne, + input_fp32_nb, GGML_MAX_DIMS); + ggml_cann_pool_alloc fp32_allocator2( + ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); + void* input_fp32_buffer2 = fp32_allocator2.get(); + aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor( + input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne, + input_fp32_nb, GGML_MAX_DIMS); + + ggml_cann_pool_alloc fp32_allocator( + ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); + output_fp32_buffer = fp32_allocator.get(); + aclTensor* output_fp32_tensor = ggml_cann_create_tensor( + output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne, + input_fp32_nb, GGML_MAX_DIMS); + aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, + input_fp32_tensor2); + aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, + output_fp32_tensor); + aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16); + + ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2, + output_fp32_tensor, acl_sin_reshape_tensor, + acl_minus_one_tensor, acl_input_roll_mul_scale_tensor, + acl_input_roll_reshape_tensor, acl_src); + } + return; +#endif - aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, - GGML_MAX_DIMS); - aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor( - dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, - GGML_MAX_DIMS); + // ggml_mode = 0 --> aclnn_model = 1 + int64_t acl_mode = mode == 0 ? 1 : mode; + + switch (src0->type) { + case GGML_TYPE_F32: { + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src, + acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst); + break; + } + case GGML_TYPE_F16: { + ggml_cann_pool_alloc src_trans_allocator( + ctx.pool(), ggml_nelements(src0) * sizeof(float)); + void* src_trans_buffer = src_trans_allocator.get(); + ggml_cann_pool_alloc dst_trans_allocator( + ctx.pool(), ggml_nelements(dst) * sizeof(float)); + void* dst_trans_buffer = dst_trans_allocator.get(); + + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } - aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT); + aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, + GGML_MAX_DIMS); + aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor( + dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, + GGML_MAX_DIMS); - GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor, - acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, - acl_dst_trans_tensor); + aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT); - aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16); + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor, + acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, + acl_dst_trans_tensor); - ggml_cann_release_resources(ctx, acl_src_trans_tensor, - acl_dst_trans_tensor); - break; + aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16); + + ggml_cann_release_resources(ctx, acl_src_trans_tensor, + acl_dst_trans_tensor); + break; + } + default: + GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE"); + break; } - default: - GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE"); - break; + ggml_cann_release_resources(ctx, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_src, acl_dst); } - ggml_cann_release_resources(ctx, acl_cos_reshape_tensor, - acl_sin_reshape_tensor, acl_src, acl_dst); + } @@ -3424,3 +3692,181 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ GGML_ABORT("Function is not implemented."); } } + +void ggml_cann_rope_multi(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src0 = dst->src[0]; // [head_dim, n_heads, seq_len, batch] + ggml_tensor* src1 = dst->src[1]; // pos_ids: (size = 4 * seq_len), eg. src1->ne[0]: 8, src1->ne[1]: 1, seq_len: 2 + ggml_tensor* src2 = dst->src[2]; // freq_factors (optional) + + const int32_t* params = (const int32_t*)dst->op_params; + const int n_dims = params[1]; + const int mode = params[2]; + const int n_ctx_orig = params[4]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, params + 5, sizeof(float)); + memcpy(&freq_scale, params + 6, sizeof(float)); + memcpy(&ext_factor, params + 7, sizeof(float)); + memcpy(&attn_factor, params + 8, sizeof(float)); + memcpy(&beta_fast, params + 9, sizeof(float)); + memcpy(&beta_slow, params + 10, sizeof(float)); + + std::array sections; + memcpy(sections.data(), params + 11, 4 * sizeof(int32_t)); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; + + if (mrope_used) { + // MROPE always uses Neox-style rotation + (void)is_neox; // suppress unused warning + } + + // YaRN + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + // seq_len = src0->ne[2] for layout [head_dim, n_heads, seq_len, batch] + int64_t seq_len = src0->ne[2]; + + // Read pos_ids (size = 4 * seq_len) + GGML_ASSERT(src1->ne[0] == 4 * seq_len); + + std::vector pos_ids_host(4 * seq_len); + + ACL_CHECK(aclrtMemcpy(pos_ids_host.data(), 4 * seq_len * sizeof(int32_t), + src1->data, 4 * seq_len * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST)); + + // Read freq_factors + std::vector freq_factors_host; + float* freq_factors_ptr = nullptr; + if (src2) { + freq_factors_host.resize(ggml_nelements(src2)); + ACL_CHECK(aclrtMemcpy(freq_factors_host.data(), ggml_nbytes(src2), + src2->data, ggml_nbytes(src2), + ACL_MEMCPY_DEVICE_TO_HOST)); + freq_factors_ptr = freq_factors_host.data(); + } + + // Compute cos/sin table on host: [seq_len, n_dims] + size_t table_nelements = seq_len * n_dims; + size_t table_bytes = table_nelements * sizeof(float); + std::vector cos_host(table_nelements); + std::vector sin_host(table_nelements); + + aclnn_compute_mrope_tables_host( + cos_host.data(), sin_host.data(), + pos_ids_host.data(), freq_factors_ptr, + seq_len, n_dims, sections, + freq_base, freq_scale, + ext_factor, attn_factor, + is_neox, + corr_dims + ); + + // Allocate device memory + size_t f32_nelements = ggml_nelements(dst); + size_t f32_nbytes = f32_nelements * sizeof(float); + + ggml_cann_pool_alloc workspace(ctx.pool()); + char* dev_ptr = (char*)workspace.alloc(table_bytes * 2 + f32_nbytes * 3); + + float* cos_dev = (float*)dev_ptr; + float* sin_dev = (float*)(dev_ptr + table_bytes); + float* x_dev = (float*)(dev_ptr + table_bytes * 2); + float* y_dev = x_dev + f32_nelements; + float* tmp_dev = y_dev + f32_nelements; + + ACL_CHECK(aclrtMemcpyAsync(cos_dev, table_bytes, cos_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + ACL_CHECK(aclrtMemcpyAsync(sin_dev, table_bytes, sin_host.data(), table_bytes, + ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); + + int64_t x_ne[] = {src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]}; + size_t x_nb[] = { + sizeof(float), + sizeof(float) * x_ne[0], + sizeof(float) * x_ne[0] * x_ne[1], + sizeof(float) * x_ne[0] * x_ne[1] * x_ne[2] + }; + aclTensor* acl_x = ggml_cann_create_tensor(x_dev, ACL_FLOAT, sizeof(float), x_ne, x_nb, 4); + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclnn_cast(ctx, acl_src, acl_x, ACL_FLOAT); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + aclTensor* acl_y_full_init = ggml_cann_create_tensor(y_dev, ACL_FLOAT, sizeof(float), x_ne, x_nb, 4); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_y_full_init, acl_x); + + ggml_cann_release_resources(ctx, acl_y_full_init); + int64_t half_dim = n_dims / 2; + int64_t ne_half[] = {half_dim, src0->ne[1], src0->ne[2], src0->ne[3]}; + size_t* nb_half = x_nb; + + float* x0_dev = x_dev; + float* x1_dev = x_dev + half_dim; + + aclTensor* acl_x0 = ggml_cann_create_tensor(x0_dev, ACL_FLOAT, sizeof(float), ne_half, nb_half, 4); + aclTensor* acl_x1 = ggml_cann_create_tensor(x1_dev, ACL_FLOAT, sizeof(float), ne_half, nb_half, 4); + int64_t table_ne_broadcast[] = {half_dim, 1, seq_len, 1}; + size_t table_nb_broadcast[4]; + + table_nb_broadcast[0] = sizeof(float); + table_nb_broadcast[1] = 0; + table_nb_broadcast[2] = sizeof(float) * n_dims; + table_nb_broadcast[3] = 0; + + aclTensor* acl_cos0 = ggml_cann_create_tensor(cos_dev, ACL_FLOAT, sizeof(float), + table_ne_broadcast, table_nb_broadcast, 4); + aclTensor* acl_sin0 = ggml_cann_create_tensor(sin_dev, ACL_FLOAT, sizeof(float), + table_ne_broadcast, table_nb_broadcast, 4); + + // Output halves + float* y0_dev = y_dev; + float* y1_dev = y_dev + half_dim; + aclTensor* acl_y0 = ggml_cann_create_tensor(y0_dev, ACL_FLOAT, sizeof(float), ne_half, x_nb, 4); + aclTensor* acl_y1 = ggml_cann_create_tensor(y1_dev, ACL_FLOAT, sizeof(float), ne_half, x_nb, 4); + + float y64; + ACL_CHECK(aclrtMemcpy(&y64, sizeof(float), y_dev + 64, sizeof(float), ACL_MEMCPY_DEVICE_TO_HOST)); + + // Temp tensor + aclTensor* acl_tmp = ggml_cann_create_tensor(tmp_dev, ACL_FLOAT, sizeof(float), ne_half, x_nb, 4); + + // y0 = x0 * cos0 - x1 * sin0 + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x0, acl_cos0, acl_y0); + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x1, acl_sin0, acl_tmp); + float alpha_neg = -1.0f; + aclScalar* s_alpha_neg = aclCreateScalar(&alpha_neg, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_y0, acl_tmp, s_alpha_neg, acl_y0); + aclDestroyScalar(s_alpha_neg); + + // y1 = x0 * sin0 + x1 * cos0 + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x0, acl_sin0, acl_y1); + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_x1, acl_cos0, acl_tmp); + float alpha_pos = 1.0f; + aclScalar* s_alpha_pos = aclCreateScalar(&alpha_pos, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_y1, acl_tmp, s_alpha_pos, acl_y1); + aclDestroyScalar(s_alpha_pos); + + // Full output tensor + aclTensor* acl_y_full = ggml_cann_create_tensor(y_dev, ACL_FLOAT, sizeof(float), x_ne, x_nb, 4); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, acl_y_full, acl_dst, ggml_cann_type_mapping(dst->type)); + + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + std::vector dst_host(ggml_nelements(dst)); + ACL_CHECK(aclrtMemcpy(dst_host.data(), ggml_nbytes(dst), + dst->data, ggml_nbytes(dst), + ACL_MEMCPY_DEVICE_TO_HOST)); + + // Clean up + ggml_cann_release_resources(ctx, + acl_src, acl_x, acl_x0, acl_x1, + acl_cos0, acl_sin0, + acl_y0, acl_y1, acl_y_full, + acl_tmp, acl_dst + ); +} \ No newline at end of file diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h old mode 100755 new mode 100644 index 5c510cc9932..cf9c372312e --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -479,6 +479,18 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies Multimodal Rotary Positional Embedding (MROPE). + * + * @details Handles RoPE with multiple position IDs mapped to different sections + * of the head dimension. + * + * @param ctx The backend CANN context. + * @param dst The destination tensor. + */ +void ggml_cann_rope_multi(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Computes the index of the maximum value along the specified dimension * of a ggml tensor using the CANN backend. diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h old mode 100755 new mode 100644 diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp old mode 100755 new mode 100644 index cb8af42ebf9..201edd4b506 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1836,9 +1836,6 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_SOFT_MAX: ggml_cann_softmax(ctx, dst); break; - case GGML_OP_ROPE: - ggml_cann_rope(ctx, dst); - break; case GGML_OP_IM2COL: ggml_cann_im2col(ctx, dst); break; @@ -1881,6 +1878,10 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_FLASH_ATTN_EXT: ggml_cann_flash_attn_ext(ctx, dst); break; + case GGML_OP_ROPE: { + ggml_cann_rope(ctx, dst); + break; + } default: return false; } @@ -2407,25 +2408,34 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // TODO: with ops-test v == 1 float ext_factor = 0.0f; memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float)); - // TODO: n_dims <= ne0 - if (op->src[0]->ne[0] != op->op_params[1]) { - return false; - } // TODO: ext_factor != 0 - if (ext_factor != 0) { - return false; - } + const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; + if (ext_factor != 0) { + if (mode != GGML_ROPE_TYPE_MROPE) { + return false; + } } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; + + // TODO: n_dims <= ne0 + if (op->src[0]->ne[0] != op->op_params[1]) { + if (mode != GGML_ROPE_TYPE_MROPE) { + return false; + } } + //if (mode & GGML_ROPE_TYPE_MROPE) { + // return false; + //} + //if (mode & GGML_ROPE_TYPE_VISION) { + // return false; + //} + if(!ggml_is_contiguous(op->src[0])){ - return false; + if(mode != GGML_ROPE_TYPE_MROPE){ + return false; + } } return true; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 460367cca09..e4a33ad71cd 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5336,24 +5336,6 @@ void ggml_compute_forward_get_rows( GGML_ABORT("fatal error"); } } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} } static void ggml_compute_forward_set_rows_f32( @@ -5512,24 +5494,6 @@ void ggml_compute_forward_get_rows_back( GGML_ABORT("fatal error"); } } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} } // ggml_compute_forward_diag @@ -6098,7 +6062,7 @@ static void ggml_mrope_cache_init( int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; int sec_w = sections[1] + sections[0]; int sec_e = sections[2] + sec_w; - GGML_ASSERT(sect_dims <= ne0); + GGML_ASSERT(sect_dims <= ne0); // ne0 == n_dims for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; @@ -6172,9 +6136,6 @@ static void ggml_compute_forward_rope_f32( GGML_TENSOR_UNARY_OP_LOCALS - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - GGML_ASSERT(nb00 == sizeof(float)); const int ith = params->ith; @@ -6228,7 +6189,6 @@ static void ggml_compute_forward_rope_f32( for (int64_t i3 = 0; i3 < ne3; i3++) { // batch for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; if (!is_mrope) { const int64_t p = pos[i2]; @@ -6342,10 +6302,8 @@ static void ggml_compute_forward_rope_f16( float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; int sections[4]; - //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -6358,9 +6316,6 @@ static void ggml_compute_forward_rope_f16( GGML_TENSOR_UNARY_OP_LOCALS - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); const int ith = params->ith;