@@ -129,11 +129,7 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten
129129 const int64_t output_rows = ne1;
130130 const int64_t output_cols = ne0;
131131
132- const int64_t weights_dim[GGML_MAX_DIMS] = { 1 , 1 , weights_cols, weights_rows };
133- const int64_t inputs_dim[GGML_MAX_DIMS] = { 1 , 1 , inputs_cols, inputs_rows };
134132 const int64_t bias_dim [GGML_MAX_DIMS] = { 1 , 1 , 1 , output_cols };
135- const int64_t output_dim[GGML_MAX_DIMS] = { 1 , 1 , output_cols, output_rows };
136-
137133 ggml_zdnn_create_tensor (ptd_bias, td_bias, zt_bias, output, bias_dim, ZDNN_1D);
138134
139135 void * bias_data = (void *)calloc (ne0, ggml_element_size (output));
@@ -277,13 +273,12 @@ static bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_d
277273 const int64_t ne0 = op->ne [0 ];
278274 const int64_t ne1 = op->ne [1 ];
279275
280- const int64_t max_batch = zdnn_get_nnpa_max_dim_idx_size () ;
276+ const int64_t max_batch = ctx_dev-> max_size ;
281277
282278 return ggml_is_contiguous (src0) &&
283279 ggml_is_contiguous (src1) &&
284- src1->type == GGML_TYPE_F32 &&
285- (ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch) &&
286- (src0->type == GGML_TYPE_F32 || ggml_get_type_traits (src0->type )->to_float != NULL );
280+ src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
281+ (ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch);
287282 } break ;
288283
289284 default :
0 commit comments