@@ -1458,11 +1458,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
14581458 ACL_CHECK (aclrtSynchronizeDevice ());
14591459 ACL_CHECK (aclrtResetDevice (cann_ctx->device ));
14601460
1461- // finalize when last backend freed.
1462- if (cann_ctx->device == ggml_backend_cann_get_device_count () - 1 ) {
1463- ACL_CHECK (aclFinalize ());
1464- }
1465-
14661461 delete cann_ctx;
14671462 delete backend;
14681463}
@@ -1688,11 +1683,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
16881683 }
16891684 case GGML_OP_MUL_MAT: {
16901685 switch (op->src [0 ]->type ) {
1691- case GGML_TYPE_Q8_0:
16921686 case GGML_TYPE_F16:
16931687 case GGML_TYPE_F32:
1694- case GGML_TYPE_Q4_0:
16951688 return true ;
1689+ case GGML_TYPE_Q8_0:
1690+ case GGML_TYPE_Q4_0:
1691+ // only support contiguous for quantized types.
1692+ return ggml_is_contiguous (op->src [0 ]) &&
1693+ ggml_is_contiguous (op->src [1 ]);
16961694 default :
16971695 return false ;
16981696 }
@@ -1738,13 +1736,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17381736 }
17391737 case GGML_OP_ROPE: {
17401738 // TODO: with ops-test v == 1
1741- float * ext_factor = (float *)((int32_t *)op->op_params + 7 );
1739+ float ext_factor = 0 .0f ;
1740+ memcpy (&ext_factor, (const float *) op->op_params + 7 , sizeof (float ));
17421741 // TODO: n_dims <= ne0
17431742 if (op->src [0 ]->ne [0 ] != op->op_params [1 ]) {
17441743 return false ;
17451744 }
17461745 // TODO: ext_factor != 0
1747- if (* ext_factor != 0 ) {
1746+ if (ext_factor != 0 ) {
17481747 return false ;
17491748 }
17501749
@@ -1766,6 +1765,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17661765 }
17671766 return true ;
17681767 }
1768+ case GGML_OP_POOL_2D: {
1769+ const int32_t * opts = (const int32_t *) op->op_params ;
1770+ const int k0 = opts[1 ];
1771+ const int k1 = opts[2 ];
1772+ const int p0 = opts[5 ];
1773+ const int p1 = opts[6 ];
1774+ // value of paddingH should be at most half of kernelH
1775+ // value of paddingW should be at most half of kernelW
1776+ return (p0 <= (k0 / 2 )) && (p1 <= (k1 / 2 ));
1777+ }
17691778 case GGML_OP_DUP:
17701779 case GGML_OP_IM2COL:
17711780 case GGML_OP_CONCAT:
@@ -1785,7 +1794,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17851794 case GGML_OP_CLAMP:
17861795 case GGML_OP_DIAG_MASK_INF:
17871796 case GGML_OP_SOFT_MAX:
1788- case GGML_OP_POOL_2D:
17891797 case GGML_OP_SUM_ROWS:
17901798 case GGML_OP_ARGSORT:
17911799 case GGML_OP_ACC:
0 commit comments