1- #include " zdnn.h"
2- #include " ggml-zdnn.h"
31#include " ggml-zdnn-impl.h"
4-
52#include " ggml-impl.h"
63#include " ggml-backend-impl.h"
74
5+ #include " ggml-zdnn/common.hpp"
6+ #include " ggml-zdnn/mmf.hpp"
7+ #include " ggml.h"
8+
89#include < vector>
910#include < memory>
1011#include < csignal>
@@ -88,80 +89,6 @@ inline void ggml_zdnn_init_tensor(ggml_backend_zdnn_buffer * buffer, const ggml_
8889 ZDNN_CHECK (zdnn_init_ztensor_with_malloc (&buffer->pre_tfm_desc , &buffer->tfm_desc , &buffer->ztensor ));
8990}
9091
91- static void ggml_zdnn_mul_mat_op (ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
92- GGML_TENSOR_BINARY_OP_LOCALS;
93-
94- const enum ggml_type type = src0->type ;
95-
96- GGML_ASSERT (ne0 == ne01);
97- GGML_ASSERT (ne1 == ne11);
98- GGML_ASSERT (ne2 == ne12);
99- GGML_ASSERT (ne3 == ne13);
100-
101- // we don't support permuted src0 or src1
102- GGML_ASSERT (nb00 == ggml_type_size (type));
103- GGML_ASSERT (nb10 == ggml_type_size (src1->type ));
104-
105- // dst cannot be transposed or permuted
106- GGML_ASSERT (nb0 == sizeof (float ));
107- GGML_ASSERT (nb0 <= nb1);
108- GGML_ASSERT (nb1 <= nb2);
109- GGML_ASSERT (nb2 <= nb3);
110-
111- const ggml_tensor * weights = src0;
112- const ggml_tensor * inputs = src1;
113- ggml_tensor * output = dst;
114-
115- ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra ;
116- ggml_backend_zdnn_buffer * inputs_extra = (ggml_backend_zdnn_buffer *)inputs->extra ;
117- ggml_backend_zdnn_buffer * output_extra = (ggml_backend_zdnn_buffer *)output->extra ;
118- ggml_backend_zdnn_buffer * bias_extra = (ggml_backend_zdnn_buffer *)output_extra->extra ;
119-
120- const int64_t weights_rows = ne01;
121- const int64_t weights_cols = ne00;
122- const int64_t inputs_rows = ne11;
123- const int64_t inputs_cols = ne10;
124-
125- assert (inputs_cols == weights_cols);
126-
127- const int64_t output_rows = ne1;
128- const int64_t output_cols = ne0;
129-
130- // GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n",
131- // __func__, weights_extra->name,
132- // weights->ne[3], weights->ne[2], weights->ne[1], weights->ne[0],
133- // weights_extra->pre_tfm_desc.dim1,
134- // weights_extra->pre_tfm_desc.dim2,
135- // weights_extra->pre_tfm_desc.dim3,
136- // weights_extra->pre_tfm_desc.dim4);
137-
138- // GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n",
139- // __func__, inputs_extra->name,
140- // inputs->ne[3], inputs->ne[2], inputs->ne[1], inputs->ne[0],
141- // inputs_extra->pre_tfm_desc.dim1,
142- // inputs_extra->pre_tfm_desc.dim2,
143- // inputs_extra->pre_tfm_desc.dim3,
144- // inputs_extra->pre_tfm_desc.dim4);
145-
146- GGML_ASSERT (weights_extra->pre_tfm_desc .dim1 == weights->ne [0 ] && " weights_extra->pre_tfm_desc.dim1 must match weights->ne[0]" );
147- GGML_ASSERT (weights_extra->pre_tfm_desc .dim2 == weights->ne [1 ] && " weights_extra->pre_tfm_desc.dim2 must match weights->ne[1]" );
148- GGML_ASSERT (inputs_extra->pre_tfm_desc .dim1 == inputs->ne [0 ] && " inputs_extra->pre_tfm_desc.dim1 must match inputs->ne[0]" );
149- GGML_ASSERT (inputs_extra->pre_tfm_desc .dim2 == inputs->ne [1 ] && " inputs_extra->pre_tfm_desc.dim2 must match inputs->ne[1]" );
150-
151- ZDNN_CHECK (zdnn_matmul_transpose_op (&inputs_extra->ztensor , &weights_extra->ztensor , &bias_extra->ztensor ,
152- false , true , MATMUL_OP_ADDITION, &output_extra->ztensor ));
153- // TODO: Remove in the future as we are currently DLF16 -> FP32 then in the next op, FP32 -> DLF16 again. Inefficient.
154- ZDNN_CHECK (zdnn_transform_origtensor (&output_extra->ztensor , output->data ));
155-
156- GGML_UNUSED (ctx);
157- GGML_UNUSED (weights_rows);
158- GGML_UNUSED (weights_cols);
159- GGML_UNUSED (inputs_rows);
160- GGML_UNUSED (inputs_cols);
161- GGML_UNUSED (output_rows);
162- GGML_UNUSED (output_cols);
163- }
164-
16592static void ggml_zdnn_mul_mat_dispatch (ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16693 // debug helpers
16794 // GGML_LOG_INFO("%s: use_mul_mat_vec = %d\n", __func__, use_mul_mat_vec);
@@ -174,7 +101,7 @@ static void ggml_zdnn_mul_mat_dispatch(ggml_backend_zdnn_context * ctx, const gg
174101 // GGML_LOG_INFO("%s: src0 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
175102 // GGML_LOG_INFO("%s: src1 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
176103
177- ggml_zdnn_mul_mat_op (ctx, src0, src1, dst);
104+ ggml_zdnn_mul_mat_f (ctx, src0, src1, dst);
178105}
179106
180107static bool ggml_zdnn_compute_forward (ggml_backend_zdnn_context * ctx, ggml_tensor * dst) {
0 commit comments