Skip to content

Commit a6d3cfe

Browse files
authored
CANN: optimize rope operator (ggml-org#15335)
* optimize rope ops * amendment * delete trailing whitespace * change the variable name
1 parent 67f09a3 commit a6d3cfe

File tree

2 files changed

+119
-64
lines changed

2 files changed

+119
-64
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 106 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,86 +2154,129 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
21542154

21552155
GGML_TENSOR_BINARY_OP_LOCALS
21562156

2157-
// theta_scale arange, [0,1,...,ne00/2 - 1]
21582157
int64_t theta_scale_length = ne00 / 2;
2159-
ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
2160-
theta_scale_length * sizeof(float_t));
2161-
void* theta_scale_buffer = theta_scale_allocator.get();
21622158
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
21632159
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
21642160
theta_scale_length * sizeof(float_t)};
21652161

2166-
aclTensor* acl_theta_scale_tensor =
2167-
ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t),
2168-
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2169-
float start = 0;
2170-
float step = 1;
2171-
float stop = ne00 / 2;
2172-
float n_elements = ne00 / 2;
2173-
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2174-
2175-
// power
2176-
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2177-
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2178-
acl_theta_scale_tensor);
2179-
2180-
// freq_scale
2181-
if (freq_scale != 1) {
2182-
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
2183-
}
2184-
2185-
// freq_factors
2186-
if (src2) {
2187-
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2188-
src2->data, ggml_cann_type_mapping(src2->type),
2189-
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2190-
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2191-
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
2192-
}
2193-
2194-
// position
21952162
GGML_ASSERT(src1->type == GGML_TYPE_I32);
21962163
int64_t position_length = src1->ne[0];
21972164
int64_t position_ne[] = {1, 1, position_length, 1};
21982165
size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t),
21992166
sizeof(int32_t) * position_length};
2200-
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
2201-
src1->data, ggml_cann_type_mapping(src1->type),
2202-
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
2203-
2204-
// power * position
2205-
int64_t theta_length = theta_scale_length * position_length;
2206-
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
2207-
theta_length * sizeof(float_t));
2208-
void* theta_buffer = theta_allocator.get();
2167+
22092168
int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
22102169
size_t theta_nb[GGML_MAX_DIMS];
22112170
theta_nb[0] = sizeof(float_t);
22122171
for (int i = 1; i < GGML_MAX_DIMS; i++) {
22132172
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
22142173
}
2215-
aclTensor* acl_theta_tensor =
2216-
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
2217-
theta_ne, theta_nb, GGML_MAX_DIMS);
2218-
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
2219-
acl_theta_tensor);
2220-
2221-
// sin/cos
2222-
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
2223-
theta_length * sizeof(float_t));
2224-
void* sin_buffer = sin_allocator.get();
2225-
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2226-
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2227-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2228-
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
22292174

2230-
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
2231-
theta_length * sizeof(float_t));
2232-
void* cos_buffer = cos_allocator.get();
2175+
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
2176+
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
2177+
2178+
// used for accuracy testing
2179+
bool is_attention = is_q || is_k;
2180+
2181+
if(ctx.init_ptr == nullptr || !is_attention) {
2182+
// theta_scale arange, [0,1,...,ne00/2 - 1]
2183+
if(ctx.init_ptr != nullptr){
2184+
ACL_CHECK(aclrtFree(ctx.init_ptr));
2185+
}
2186+
ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2187+
2188+
aclTensor* acl_theta_scale_tensor =
2189+
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
2190+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2191+
float start = 0;
2192+
float step = 1;
2193+
float stop = ne00 / 2;
2194+
float n_elements = ne00 / 2;
2195+
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2196+
2197+
// power
2198+
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2199+
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2200+
acl_theta_scale_tensor);
2201+
2202+
// freq_scale
2203+
if (freq_scale != 1) {
2204+
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
2205+
}
2206+
2207+
// freq_factors
2208+
if (src2) {
2209+
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2210+
src2->data, ggml_cann_type_mapping(src2->type),
2211+
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2212+
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2213+
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
2214+
}
2215+
// release
2216+
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
2217+
}
2218+
2219+
if(ctx.sin_ptr == nullptr) {
2220+
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
2221+
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2222+
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2223+
}
2224+
if(position_length > ctx.max_prompt_length) {
2225+
ctx.max_prompt_length = position_length;
2226+
int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
2227+
ACL_CHECK(aclrtFree(ctx.sin_ptr));
2228+
ACL_CHECK(aclrtFree(ctx.cos_ptr));
2229+
ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2230+
ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
2231+
}
2232+
2233+
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
2234+
2235+
if(is_fisrt_layer || !is_attention) {
2236+
2237+
aclTensor* acl_theta_scale_tensor =
2238+
ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
2239+
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2240+
2241+
// position
2242+
aclTensor* acl_position_tensor = ggml_cann_create_tensor(
2243+
src1->data, ggml_cann_type_mapping(src1->type),
2244+
ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
2245+
2246+
// power * position
2247+
int64_t theta_length = theta_scale_length * position_length;
2248+
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
2249+
theta_length * sizeof(float_t));
2250+
void* theta_buffer = theta_allocator.get();
2251+
2252+
aclTensor* acl_theta_tensor =
2253+
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
2254+
theta_ne, theta_nb, GGML_MAX_DIMS);
2255+
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
2256+
acl_theta_tensor);
2257+
2258+
// sin/cos
2259+
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2260+
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2261+
GGML_MAX_DIMS, ACL_FORMAT_ND);
2262+
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
2263+
2264+
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
2265+
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2266+
GGML_MAX_DIMS, ACL_FORMAT_ND);
2267+
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
2268+
2269+
// release
2270+
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2271+
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
2272+
}
2273+
2274+
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2275+
ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2276+
GGML_MAX_DIMS, ACL_FORMAT_ND);
22332277
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
2234-
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2235-
GGML_MAX_DIMS, ACL_FORMAT_ND);
2236-
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
2278+
ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2279+
GGML_MAX_DIMS, ACL_FORMAT_ND);
22372280

22382281
// attn_factor
22392282
if (attn_factor != 1) {
@@ -2257,8 +2300,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22572300
}
22582301

22592302
// release
2260-
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2261-
acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
2303+
ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
22622304
}
22632305

22642306
#ifdef __cplusplus

ggml/src/ggml-cann/common.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,10 @@ struct ggml_backend_cann_context {
368368
std::string name; /**< Name of the device. */
369369
std::string description; /**< Description of the device. */
370370
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
371+
void* init_ptr = nullptr;
372+
void* sin_ptr = nullptr;
373+
void* cos_ptr = nullptr;
374+
int64_t max_prompt_length = 65536;
371375
#ifdef USE_ACL_GRAPH
372376
/// Cached CANN ACL graph used for executing the current ggml computation graph.
373377
std::unique_ptr<ggml_cann_graph> cann_graph;
@@ -414,6 +418,15 @@ struct ggml_backend_cann_context {
414418
ACL_CHECK(aclrtDestroyStream(streams[i]));
415419
}
416420
}
421+
if(init_ptr != nullptr) {
422+
ACL_CHECK(aclrtFree(init_ptr));
423+
}
424+
if(sin_ptr != nullptr) {
425+
ACL_CHECK(aclrtFree(sin_ptr));
426+
}
427+
if(cos_ptr != nullptr) {
428+
ACL_CHECK(aclrtFree(cos_ptr));
429+
}
417430
}
418431

419432
/**

0 commit comments

Comments
 (0)