2020
2121// AMX type_trais
2222namespace ggml ::cpu::amx {
23- class tensor_traits : public ggml ::cpu::tensor_traits {
23+ class tensor_traits : public ggml ::cpu::tensor_traits {
24+ bool work_size (int /* n_threads */ , const struct ggml_tensor * op, size_t & size) override {
25+ size = ggml_backend_amx_desired_wsize (op);
26+ return true ;
27+ }
2428
25- bool work_size (int /* n_threads */ , const struct ggml_tensor * op, size_t & size) override {
26- size = ggml_backend_amx_desired_wsize (op);
29+ bool compute_forward (struct ggml_compute_params * params, struct ggml_tensor * op) override {
30+ if (op->op == GGML_OP_MUL_MAT) {
31+ ggml_backend_amx_mul_mat (params, op);
2732 return true ;
2833 }
29-
30- bool compute_forward (struct ggml_compute_params * params, struct ggml_tensor * op) override {
31- if (op->op == GGML_OP_MUL_MAT) {
32- ggml_backend_amx_mul_mat (params, op);
33- return true ;
34- }
35- return false ;
36- }
37- };
38-
39- static ggml::cpu::tensor_traits* get_tensor_traits (ggml_backend_buffer_t , struct ggml_tensor *) {
40- static tensor_traits traits;
41- return &traits;
34+ return false ;
4235 }
36+ };
37+
38+ static ggml::cpu::tensor_traits * get_tensor_traits (ggml_backend_buffer_t , struct ggml_tensor *) {
39+ static tensor_traits traits;
40+ return &traits;
4341}
42+ } // namespace ggml::cpu::amx
4443
4544// AMX buffer interface
4645static void ggml_backend_amx_buffer_free_buffer (ggml_backend_buffer_t buffer) {
4746 free (buffer->context );
4847}
4948
5049static void * ggml_backend_amx_buffer_get_base (ggml_backend_buffer_t buffer) {
51- return (void *)(buffer->context );
50+ return (void *) (buffer->context );
5251}
5352
5453static void ggml_backend_amx_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
55- tensor->extra = (void *)ggml::cpu::amx::get_tensor_traits (buffer, tensor);
54+ tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits (buffer, tensor);
5655
5756 GGML_UNUSED (buffer);
5857}
5958
60- static void ggml_backend_amx_buffer_memset_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
61- memset ((char *)tensor->data + offset, value, size);
59+ static void ggml_backend_amx_buffer_memset_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
60+ uint8_t value, size_t offset, size_t size) {
61+ memset ((char *) tensor->data + offset, value, size);
6262
6363 GGML_UNUSED (buffer);
6464}
6565
66- static void ggml_backend_amx_buffer_set_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
66+ static void ggml_backend_amx_buffer_set_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
67+ const void * data, size_t offset, size_t size) {
6768 if (qtype_has_amx_kernels (tensor->type )) {
6869 ggml_backend_amx_convert_weight (tensor, data, offset, size);
6970 } else {
70- memcpy ((char *)tensor->data + offset, data, size);
71+ memcpy ((char *) tensor->data + offset, data, size);
7172 }
7273
7374 GGML_UNUSED (buffer);
@@ -136,49 +137,42 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
136137}
137138
138139namespace ggml ::cpu::amx {
139- class extra_buffer_type : ggml::cpu::extra_buffer_type {
140- bool supports_op (ggml_backend_dev_t , const struct ggml_tensor * op) override {
141- // handle only 2d gemm for now
142- auto is_contiguous_2d = [](const struct ggml_tensor * t) {
143- return ggml_is_contiguous (t) && t->ne [3 ] == 1 && t->ne [2 ] == 1 ;
144- };
145-
146- if ( op->op == GGML_OP_MUL_MAT &&
147- is_contiguous_2d (op->src [0 ]) && // src0 must be contiguous
148- is_contiguous_2d (op->src [1 ]) && // src1 must be contiguous
149- op->src [0 ]->buffer &&
150- op->src [0 ]->buffer ->buft == ggml_backend_amx_buffer_type () &&
151- op->ne [0 ] % (TILE_N * 2 ) == 0 && // out_features is 32x
152- (qtype_has_amx_kernels (op->src [0 ]->type ) || (op->src [0 ]->type == GGML_TYPE_F16))
153- )
154- {
155- // src1 must be host buffer
156- if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
157- return false ;
158- }
159- // src1 must be float32
160- if (op->src [1 ]->type == GGML_TYPE_F32) {
161- return true ;
162- }
140+ class extra_buffer_type : ggml::cpu::extra_buffer_type {
141+ bool supports_op (ggml_backend_dev_t , const struct ggml_tensor * op) override {
142+ // handle only 2d gemm for now
143+ auto is_contiguous_2d = [](const struct ggml_tensor * t) {
144+ return ggml_is_contiguous (t) && t->ne [3 ] == 1 && t->ne [2 ] == 1 ;
145+ };
146+
147+ if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d (op->src [0 ]) && // src0 must be contiguous
148+ is_contiguous_2d (op->src [1 ]) && // src1 must be contiguous
149+ op->src [0 ]->buffer && op->src [0 ]->buffer ->buft == ggml_backend_amx_buffer_type () &&
150+ op->ne [0 ] % (TILE_N * 2 ) == 0 && // out_features is 32x
151+ (qtype_has_amx_kernels (op->src [0 ]->type ) || (op->src [0 ]->type == GGML_TYPE_F16))) {
152+ // src1 must be host buffer
153+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
154+ return false ;
163155 }
164- return false ;
165- }
166-
167- ggml::cpu::tensor_traits* get_tensor_traits (const struct ggml_tensor * op) override {
168- if ( op->op == GGML_OP_MUL_MAT &&
169- op->src [0 ]->buffer &&
170- op->src [0 ]->buffer ->buft == ggml_backend_amx_buffer_type ()
171- )
172- {
173- return (ggml::cpu::tensor_traits*) op->src [0 ]->extra ;
156+ // src1 must be float32
157+ if (op->src [1 ]->type == GGML_TYPE_F32) {
158+ return true ;
174159 }
160+ }
161+ return false ;
162+ }
175163
176- return nullptr ;
164+ ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
165+ if (op->op == GGML_OP_MUL_MAT && op->src [0 ]->buffer &&
166+ op->src [0 ]->buffer ->buft == ggml_backend_amx_buffer_type ()) {
167+ return (ggml::cpu::tensor_traits *) op->src [0 ]->extra ;
177168 }
178- };
179- }
180169
181- static size_t ggml_backend_amx_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
170+ return nullptr ;
171+ }
172+ };
173+ } // namespace ggml::cpu::amx
174+
175+ static size_t ggml_backend_amx_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
182176 return ggml_backend_amx_get_alloc_size (tensor);
183177
184178 GGML_UNUSED (buft);
@@ -200,25 +194,26 @@ static bool ggml_amx_init() {
200194 return true ;
201195#endif
202196}
197+
203198ggml_backend_buffer_type_t ggml_backend_amx_buffer_type () {
204199 static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
205200 /* .iface = */ {
206- /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
207- /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
208- /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
209- /* .get_max_size = */ NULL , // defaults to SIZE_MAX
210- /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
211- /* .is_host = */ nullptr ,
212- },
201+ /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
202+ /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
203+ /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
204+ /* .get_max_size = */ nullptr , // defaults to SIZE_MAX
205+ /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
206+ /* .is_host = */ nullptr ,
207+ },
213208 /* .device = */ ggml_backend_reg_dev_get (ggml_backend_cpu_reg (), 0 ),
214209 /* .context = */ new ggml::cpu::amx::extra_buffer_type (),
215210 };
216211
217212 if (!ggml_amx_init ()) {
218- return NULL ;
213+ return nullptr ;
219214 }
220215
221216 return &ggml_backend_buffer_type_amx;
222217}
223218
224- #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
219+ #endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
0 commit comments