diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 2a5cb8abfa137..8d7ea8fce1816 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2216,16 +2216,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale); } - if(ctx.sin_ptr == nullptr) { - int64_t theta_length = theta_scale_length * ctx.max_prompt_length; - ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); - ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); - } if(position_length > ctx.max_prompt_length) { ctx.max_prompt_length = position_length; int64_t theta_length = theta_scale_length * ctx.max_prompt_length; - ACL_CHECK(aclrtFree(ctx.sin_ptr)); - ACL_CHECK(aclrtFree(ctx.cos_ptr)); + + if (ctx.sin_ptr != nullptr) + ACL_CHECK(aclrtFree(ctx.sin_ptr)); + if (ctx.cos_ptr != nullptr) + ACL_CHECK(aclrtFree(ctx.cos_ptr)); ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); } diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 2c2033bfba857..29e0a65627c5f 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -371,7 +371,7 @@ struct ggml_backend_cann_context { void* init_ptr = nullptr; void* sin_ptr = nullptr; void* cos_ptr = nullptr; - int64_t max_prompt_length = 65536; + int64_t max_prompt_length = 0; #ifdef USE_ACL_GRAPH /// Cached CANN ACL graph used for executing the current ggml computation graph. std::unique_ptr cann_graph;