Skip to content

Commit df68663

Browse files
committed
some modification after review
1 parent d3c57c1 commit df68663

File tree

2 files changed

+22
-207
lines changed

2 files changed

+22
-207
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 2 additions & 195 deletions
Original file line numberDiff line numberDiff line change
@@ -2531,47 +2531,6 @@ static void aclnn_mat_mul_3d(ggml_backend_cann_context& ctx, aclTensor* acl_inpu
25312531
* multiplication will be stored.
25322532
*/
25332533
static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
2534-
ggml_tensor* dst) {
2535-
ggml_tensor* weight = dst->src[0]; // weight
2536-
ggml_tensor* input = dst->src[1]; // input
2537-
2538-
// when weight ne2 or ne3 is 1, aclnnMatmulGetWorkspaceSize will auto
2539-
// broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
2540-
BCAST_MUL_MAT_SHAPE(input, weight, dst);
2541-
2542-
// transpose weight: [1,2,3,4] -> [1,2,4,3]
2543-
int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0],
2544-
bcast_weight_ne[2], bcast_weight_ne[3],
2545-
bcast_weight_ne[4], bcast_weight_ne[5]};
2546-
size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
2547-
bcast_weight_nb[2], bcast_weight_nb[3],
2548-
bcast_weight_nb[4], bcast_weight_nb[5]};
2549-
2550-
aclTensor* acl_weight_tensor =
2551-
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, bcast_dims);
2552-
aclTensor* acl_input_tensor =
2553-
ggml_cann_create_tensor(input, BCAST_MUL_MAT_PARAM(input));
2554-
aclTensor* acl_dst = ggml_cann_create_tensor(dst, BCAST_MUL_MAT_PARAM(dst));
2555-
aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2556-
2557-
ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2558-
ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2559-
ACL_CHECK(aclDestroyTensor(acl_dst));
2560-
}
2561-
2562-
/**
2563-
* @brief Performs matrix multiplication with floating-point precision on
2564-
* tensors using the CANN backend.
2565-
*
2566-
* This function performs matrix multiplication of the input tensor and the
2567-
* weight tensor, handling broadcasting and transposing as needed, and stores
2568-
* the result in the destination tensor `dst`.
2569-
*
2570-
* @param ctx The context for the CANN backend operations.
2571-
* @param dst The destination tensor where the result of the matrix
2572-
* multiplication will be stored.
2573-
*/
2574-
static void ggml_cann_mat_mul_fp2(ggml_backend_cann_context& ctx,
25752534
ggml_tensor* dst) {
25762535
ggml_tensor* weight = dst->src[0]; // weight
25772536
ggml_tensor* input = dst->src[1]; // input
@@ -2637,158 +2596,6 @@ static void ggml_cann_mat_mul_fp2(ggml_backend_cann_context& ctx,
26372596
* multiplication will be stored.
26382597
*/
26392598
static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2640-
ggml_tensor* dst,
2641-
const enum ggml_type type) {
2642-
ggml_tensor* src0 = dst->src[0]; // weight
2643-
ggml_tensor* src1 = dst->src[1]; // input
2644-
2645-
// The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
2646-
// is regarded as batch. weight need transpose.
2647-
int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
2648-
float weight_elem_size;
2649-
if (type == GGML_TYPE_Q4_0) {
2650-
weight_elem_size = float(sizeof(uint8_t)) / 2;
2651-
}
2652-
else if (type == GGML_TYPE_Q8_0) {
2653-
weight_elem_size = float(sizeof(uint8_t));
2654-
}
2655-
else {
2656-
GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT");
2657-
}
2658-
float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
2659-
2660-
// size of one matrix is element_size * height * width.
2661-
size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
2662-
size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
2663-
2664-
// scale stored at the end of weight. Also need transpose.
2665-
GGML_ASSERT(QK4_0 == QK8_0);
2666-
int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
2667-
size_t scale_elem_size = sizeof(uint16_t);
2668-
size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
2669-
scale_elem_size};
2670-
size_t scale_stride = scale_elem_size * src0->ne[0] * src0->ne[1] / QK8_0;
2671-
char* scale_offset = (char*)src0->data + weight_size;
2672-
2673-
// input
2674-
void* input_buffer;
2675-
size_t input_elem_size = sizeof(uint16_t);
2676-
int64_t input_ne[] = {src1->ne[0], src1->ne[1]};
2677-
size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]};
2678-
size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1];
2679-
2680-
ggml_cann_pool_alloc input_alloctor(ctx.pool());
2681-
if (src1->type != GGML_TYPE_F16) {
2682-
aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1);
2683-
input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
2684-
input_buffer = input_alloctor.get();
2685-
2686-
int64_t* input_cast_ne = src1->ne;
2687-
size_t input_cast_nb[GGML_MAX_DIMS];
2688-
input_cast_nb[0] = sizeof(uint16_t);
2689-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2690-
input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1];
2691-
}
2692-
2693-
aclTensor* acl_input_tensor = ggml_cann_create_tensor(
2694-
input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
2695-
input_cast_nb, GGML_MAX_DIMS);
2696-
aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
2697-
ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2698-
ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
2699-
} else {
2700-
input_buffer = src1->data;
2701-
}
2702-
2703-
// output
2704-
size_t output_elem_size = sizeof(uint16_t);
2705-
int64_t output_ne[] = {dst->ne[0], dst->ne[1]};
2706-
size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne[0]};
2707-
ggml_cann_pool_alloc output_alloctor(
2708-
ctx.pool(), ggml_nelements(dst) * output_elem_size);
2709-
void* output_buffer = output_alloctor.get();
2710-
size_t output_stride = output_elem_size * dst->ne[0] * dst->ne[1];
2711-
2712-
// aclnn
2713-
uint64_t workspaceSize = 0;
2714-
aclOpExecutor* executor;
2715-
void* workspaceAddr = nullptr;
2716-
2717-
for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {
2718-
for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {
2719-
int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);
2720-
int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]);
2721-
2722-
int64_t batch1 = n1 * src1->ne[2] + c1;
2723-
int64_t batch0 = n0 * src0->ne[2] + c0;
2724-
2725-
aclTensor* acl_input_tensor = ggml_cann_create_tensor(
2726-
(char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
2727-
input_elem_size, input_ne, input_nb, 2);
2728-
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
2729-
(char*)src0->data + batch0 * weight_stride,
2730-
ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
2731-
weight_nb, 2);
2732-
aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
2733-
scale_offset + batch0 * scale_stride, ACL_FLOAT16,
2734-
scale_elem_size, scale_ne, scale_nb, 2);
2735-
aclTensor* acl_output_tensor = ggml_cann_create_tensor(
2736-
(char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2737-
output_elem_size, output_ne, output_nb, 2);
2738-
2739-
ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
2740-
acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
2741-
nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
2742-
&workspaceSize, &executor));
2743-
2744-
if (workspaceSize > 0 && workspaceAddr == nullptr) {
2745-
ggml_cann_pool_alloc workspace_allocator(ctx.pool(),
2746-
workspaceSize);
2747-
workspaceAddr = workspace_allocator.get();
2748-
}
2749-
2750-
ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
2751-
workspaceAddr, workspaceSize, executor, ctx.stream()));
2752-
2753-
ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2754-
ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2755-
ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
2756-
ACL_CHECK(aclDestroyTensor(acl_output_tensor));
2757-
}
2758-
}
2759-
2760-
// cast out
2761-
int64_t* output_cast_ne = dst->ne;
2762-
size_t output_cast_nb[GGML_MAX_DIMS];
2763-
output_cast_nb[0] = sizeof(uint16_t);
2764-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2765-
output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
2766-
}
2767-
2768-
aclTensor* acl_output_tensor =
2769-
ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size,
2770-
output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
2771-
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
2772-
aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
2773-
2774-
ACL_CHECK(aclDestroyTensor(acl_output_tensor));
2775-
ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
2776-
}
2777-
2778-
/**
2779-
* @brief Performs matrix multiplication with quantized weights and
2780-
* floating-point inputs using the CANN backend.
2781-
*
2782-
* This function performs matrix multiplication of the input tensor `src1` and
2783-
* the weight tensor `src0`, handling broadcasting, transposing, and
2784-
* quantization as needed, and stores the result in the destination tensor
2785-
* `dst`.
2786-
*
2787-
* @param ctx The context for the CANN backend operations.
2788-
* @param dst The destination tensor where the result of the matrix
2789-
* multiplication will be stored.
2790-
*/
2791-
static void ggml_cann_mul_mat_quant2(ggml_backend_cann_context& ctx,
27922599
ggml_tensor* dst,
27932600
const enum ggml_type type) {
27942601
ggml_tensor* src0 = dst->src[0]; // weight
@@ -2979,11 +2786,11 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
29792786
switch (type) {
29802787
case GGML_TYPE_F32:
29812788
case GGML_TYPE_F16:
2982-
ggml_cann_mat_mul_fp2(ctx, dst);
2789+
ggml_cann_mat_mul_fp(ctx, dst);
29832790
break;
29842791
case GGML_TYPE_Q4_0:
29852792
case GGML_TYPE_Q8_0:
2986-
ggml_cann_mul_mat_quant2(ctx, dst, type);
2793+
ggml_cann_mul_mat_quant(ctx, dst, type);
29872794
break;
29882795
default:
29892796
GGML_ABORT("fatal error");

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,6 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
341341
std::vector<void*> map_offsets;
342342

343343
/**
344-
* @brief Constructor to initialize the buffer pool with virtual memory for
345344
* @brief Constructor to initialize the buffer pool with virtual memory for
346345
* a specific device.
347346
*
@@ -1872,17 +1871,17 @@ struct ggml_backend_cann_device_context {
18721871
};
18731872

18741873
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
1875-
ggml_backend_cann_context * ctx = (ggml_backend_cann_context *)dev->context;
1874+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
18761875
return ctx->name.c_str();
18771876
}
18781877

18791878
static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
1880-
ggml_backend_cann_context * ctx = (ggml_backend_cann_context *)dev->context;
1879+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
18811880
return ctx->description.c_str();
18821881
}
18831882

18841883
static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1885-
ggml_backend_cann_context * ctx = (ggml_backend_cann_context *)dev->context;
1884+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
18861885
ggml_backend_cann_get_device_memory(ctx->device, free, total);
18871886
}
18881887

@@ -1909,7 +1908,7 @@ static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_back
19091908

19101909
static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
19111910
GGML_UNUSED(params);
1912-
ggml_backend_cann_context * ctx = (ggml_backend_cann_context *)dev->context;
1911+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
19131912
return ggml_backend_cann_init(ctx->device);
19141913
}
19151914

@@ -1929,7 +1928,7 @@ static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, cons
19291928
static bool ggml_backend_cann_supports_buft(
19301929
ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
19311930
if (ggml_backend_buft_is_cann(buft)) {
1932-
ggml_backend_cann_context * dev_ctx = (ggml_backend_cann_context *)dev->context;
1931+
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
19331932
ggml_backend_cann_buffer_type_context * buft_ctx =
19341933
(ggml_backend_cann_buffer_type_context *)buft->context;
19351934
return buft_ctx->device == dev_ctx->device;
@@ -1938,7 +1937,7 @@ static bool ggml_backend_cann_supports_buft(
19381937
}
19391938

19401939
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
1941-
ggml_backend_cann_context * ctx = (ggml_backend_cann_context*)dev->context;
1940+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
19421941
return ggml_backend_cann_buffer_type(ctx->device);
19431942
}
19441943

@@ -1959,7 +1958,7 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(
19591958
*/
19601959
static ggml_backend_event_t ggml_backend_cann_device_event_new(
19611960
ggml_backend_dev_t dev) {
1962-
ggml_backend_cann_context * dev_ctx = (ggml_backend_cann_context *)dev->context;
1961+
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
19631962

19641963
ggml_cann_set_device(dev_ctx->device);
19651964

@@ -2067,7 +2066,11 @@ ggml_backend_reg_t ggml_backend_cann_reg() {
20672066
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
20682067

20692068
for (int i = 0; i < ggml_cann_info().device_count; i++) {
2070-
ggml_backend_cann_context* dev_ctx = new ggml_backend_cann_context(i);
2069+
ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
2070+
dev_ctx->description = aclrtGetSocName();
2071+
dev_ctx->device = i;
2072+
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2073+
ggml_cann_set_device(i);
20712074
ggml_backend_dev_t dev = new ggml_backend_device {
20722075
/* .interface = */ ggml_backend_cann_device_interface,
20732076
/* .reg = */ &reg,
@@ -2095,12 +2098,17 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
20952098
return nullptr;
20962099
}
20972100

2098-
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device);
2101+
ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
2102+
if (ctx == nullptr) {
2103+
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
2104+
return nullptr;
2105+
}
2106+
ggml_cann_set_device(ctx->device);
20992107
ggml_backend_t cann_backend =
21002108
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
21012109
/* .interface = */ ggml_backend_cann_interface,
2102-
/* .device = */ dev,
2103-
/* .context = */ dev->context};
2110+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2111+
/* .context = */ ctx};
21042112

21052113
return cann_backend;
21062114
}

0 commit comments

Comments
 (0)