@@ -1098,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
10981098 * @param dst The destination tensor. Its src[0] is treated as the input tensor.
10991099 */
11001100template <void unary_op (ggml_backend_cann_context&, aclTensor*, aclTensor*)>
1101- void ggml_cann_unary_op (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1101+ void ggml_cann_op_unary (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
11021102 ggml_tensor* src = dst->src [0 ];
11031103
11041104 aclTensor* acl_src = ggml_cann_create_tensor (src);
@@ -1109,49 +1109,125 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
11091109}
11101110
11111111/* *
1112- * @brief Applies a unary operation to a ggml tensor using the CANN backend.
1112+ * @brief Applies a unary operation to a ggml tensor using the CANN backend.
11131113 *
1114- * @details This function performs a unary operation on the input tensor using
1115- * a user-provided lambda or callable object `unary_op`, which accepts the CANN
1116- * context and two ACL tensors (source and destination). Internally, this function
1117- * creates ACL representations of the ggml tensors and invokes the unary operation.
1118- * The result is stored in the destination tensor `dst`. This utility abstracts the
1119- * common boilerplate of tensor conversion and cleanup when implementing unary ops.
1114+ * @details This function applies a unary operation to the input tensor using
1115+ * a user-provided lambda or callable `unary_op`. The lambda receives the
1116+ * CANN backend context and two ACL tensors: the source and the destination.
11201117 *
1121- * @param unary_op A callable that performs the unary operation using CANN APIs.
1122- * @param ctx The CANN context used for operations.
1123- * @param dst The destination tensor where the result will be stored.
1124- * The source tensor is retrieved from `dst->src[0]`.
1118+ * Internally, this function handles the conversion from GGML tensors to ACL tensors,
1119+ * calls the provided unary op, and manages resource cleanup. The input is assumed
1120+ * to be `dst->src[0]`, and the result is written to `dst`.
1121+ *
1122+ * This utility simplifies writing unary op wrappers by abstracting tensor preparation.
1123+ *
1124+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1125+ * @param ctx The CANN context for operation execution.
1126+ * @param dst The destination ggml_tensor where the result will be stored.
1127+ * The input tensor is assumed to be `dst->src[0]`.
1128+ *
1129+ * @see GGML_CANN_CALL_OP_UNARY
11251130 */
1126- void ggml_cann_unary_op (
1131+ void ggml_cann_op_unary (
11271132 std::function<void (ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
11281133 ggml_backend_cann_context& ctx, ggml_tensor* dst);
11291134
11301135/* *
1131- * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
1136+ * @brief Applies a gated (GLU-style) unary operation using the CANN backend.
1137+ *
1138+ * @details This function performs a gated activation such as GEGLU or ReGLU.
1139+ * It supports two input modes:
1140+ *
1141+ * 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
1142+ * These are used directly as the value and gate tensors.
1143+ *
1144+ * 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
1145+ * contain a concatenation of value and gate along the first dimension. This tensor
1146+ * will be split into two equal halves to form the value and gate inputs.
1147+ *
1148+ * The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
1149+ * then multiplies the result in-place with the gate tensor:
1150+ *
1151+ * @code
1152+ * dst = unary_op(value) * gate;
1153+ * @endcode
1154+ *
1155+ * The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
1156+ * order of value/gate in the packed input case.
1157+ *
1158+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
1159+ * It receives (ctx, acl_value_tensor, acl_output_tensor).
1160+ * @param ctx The CANN context used for execution.
1161+ * @param dst The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
1162+ *
1163+ * @see GGML_CANN_CALL_OP_UNARY_GATED
1164+ */
1165+ void ggml_cann_op_unary_gated (
1166+ std::function<void (ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
1167+ ggml_backend_cann_context& ctx, ggml_tensor* dst);
1168+
1169+ /* *
1170+ * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
1171+ *
1172+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
1173+ * and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
1174+ * unary ops in the CANN backend.
1175+ *
1176+ * Internally, this macro expands to a lambda like:
1177+ * @code
1178+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1179+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1180+ * };
1181+ * @endcode
1182+ *
1183+ * This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
1184+ *
1185+ * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
1186+ *
1187+ * @see ggml_cann_op_unary
1188+ * @see GGML_CANN_CALL_ACLNN_OP
1189+ */
1190+ #define GGML_CANN_CALL_OP_UNARY (OP_NAME ) \
1191+ do { \
1192+ auto lambda = [](ggml_backend_cann_context& ctx, \
1193+ aclTensor* acl_src, \
1194+ aclTensor* acl_dst) { \
1195+ GGML_CANN_CALL_ACLNN_OP (ctx, OP_NAME, acl_src, acl_dst); \
1196+ }; \
1197+ ggml_cann_op_unary (lambda, ctx, dst); \
1198+ } \
1199+ while (0 )
1200+
1201+ /* *
1202+ * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
11321203 *
1133- * This macro defines an inline lambda wrapping a specific ACL operation name ,
1134- * and passes it to the templated ggml_cann_unary_op function. It simplifies
1135- * calling unary ops by hiding the lambda boilerplate .
1204+ * This macro wraps the specified ACLNN unary operator name into a lambda expression ,
1205+ * and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
1206+ * executing gated unary ops in the CANN backend .
11361207 *
1137- * Internally, the lambda will call :
1208+ * Internally, this macro expands to a lambda like :
11381209 * @code
1139- * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1210+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
1211+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
1212+ * };
11401213 * @endcode
11411214 *
1215+ * This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
1216+ *
11421217 * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
11431218 *
1144- * @see ggml_cann_unary_op
1219+ * @see ggml_cann_op_unary_gated
11451220 * @see GGML_CANN_CALL_ACLNN_OP
11461221 */
1147- #define GGML_CANN_CALL_UNARY_OP (OP_NAME ) \
1222+ #define GGML_CANN_CALL_OP_UNARY_GATED (OP_NAME ) \
11481223 do { \
11491224 auto lambda = [](ggml_backend_cann_context& ctx, \
11501225 aclTensor* acl_src, \
11511226 aclTensor* acl_dst) { \
11521227 GGML_CANN_CALL_ACLNN_OP (ctx, OP_NAME, acl_src, acl_dst); \
11531228 }; \
1154- ggml_cann_unary_op (lambda, ctx, dst); \
1229+ ggml_cann_op_unary_gated (lambda, ctx, dst); \
11551230 } \
11561231 while (0 )
1232+
11571233#endif // CANN_ACLNN_OPS
0 commit comments