|
37 | 37 | GGML_METAL_KERNEL_TYPE_DIV_ROW, |
38 | 38 | GGML_METAL_KERNEL_TYPE_SCALE, |
39 | 39 | GGML_METAL_KERNEL_TYPE_SCALE_4, |
| 40 | + GGML_METAL_KERNEL_TYPE_CLAMP, |
40 | 41 | GGML_METAL_KERNEL_TYPE_TANH, |
41 | 42 | GGML_METAL_KERNEL_TYPE_RELU, |
42 | 43 | GGML_METAL_KERNEL_TYPE_GELU, |
@@ -468,6 +469,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ |
468 | 469 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); |
469 | 470 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); |
470 | 471 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); |
| 472 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); |
471 | 473 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); |
472 | 474 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); |
473 | 475 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); |
@@ -713,6 +715,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const |
713 | 715 | case GGML_OP_MUL: |
714 | 716 | case GGML_OP_DIV: |
715 | 717 | case GGML_OP_SCALE: |
| 718 | + case GGML_OP_CLAMP: |
716 | 719 | case GGML_OP_SQR: |
717 | 720 | case GGML_OP_SUM_ROWS: |
718 | 721 | return true; |
@@ -1154,6 +1157,25 @@ static enum ggml_status ggml_metal_graph_compute( |
1154 | 1157 |
|
1155 | 1158 | [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
1156 | 1159 | } break; |
| 1160 | + case GGML_OP_CLAMP: |
| 1161 | + { |
| 1162 | + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; |
| 1163 | + |
| 1164 | + float min; |
| 1165 | + float max; |
| 1166 | + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); |
| 1167 | + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); |
| 1168 | + |
| 1169 | + [encoder setComputePipelineState:pipeline]; |
| 1170 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 1171 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; |
| 1172 | + [encoder setBytes:&min length:sizeof(min) atIndex:2]; |
| 1173 | + [encoder setBytes:&max length:sizeof(max) atIndex:3]; |
| 1174 | + |
| 1175 | + const int64_t n = ggml_nelements(dst); |
| 1176 | + |
| 1177 | + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 1178 | + } break; |
1157 | 1179 | case GGML_OP_UNARY: |
1158 | 1180 | switch (ggml_get_unary_op(gf->nodes[i])) { |
1159 | 1181 | case GGML_UNARY_OP_TANH: |
|
0 commit comments