diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 5ec3b91ae2b39..34a2992c769b9 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 3d93f4a8acdf2..a28fc8b7fc893 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float kqmax_new_j = kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c247b50c9e690..093ae9000ab37 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -306,6 +306,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, GGML_METAL_KERNEL_TYPE_ARANGE_F32, @@ -390,6 +392,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, + GGML_METAL_KERNEL_TYPE_ARGMAX, GGML_METAL_KERNEL_TYPE_COUNT }; @@ -870,6 +873,8 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); @@ -952,6 +957,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } @@ -1069,6 +1075,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_REPEAT: case GGML_OP_SCALE: case GGML_OP_CLAMP: + case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_SQR: case GGML_OP_SQRT: @@ -1081,6 +1088,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return has_simdgroup_reduction; case GGML_OP_RMS_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0); + case GGML_OP_ARGMAX: case GGML_OP_NORM: case GGML_OP_ROPE: return true; @@ -3138,6 +3146,49 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; } } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + + const int32_t IC = src1->ne[1]; + const int32_t IL = src1->ne[0]; + + const int32_t K = src0->ne[0]; + + const int32_t OL = dst->ne[0]; + const int32_t OC = dst->ne[1]; + + id pipeline; + + switch (src0->type) { + case GGML_TYPE_F32: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; + } break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&K length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8]; + + [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_UPSCALE: { GGML_ASSERT(src0->type == GGML_TYPE_F32); @@ -3797,6 +3848,31 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; } break; + case GGML_OP_ARGMAX: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + + const int64_t nrows = ggml_nrows(src0); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 7567f326200fc..5caa0846a8b53 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1366,6 +1366,63 @@ kernel void kernel_ssm_scan_f32( } } +kernel void kernel_argmax( + device const void * x, + device int32_t * dst, + constant int64_t & ncols, + constant uint64_t & nb01, + threadgroup float * shared_maxval [[threadgroup(0)]], + threadgroup int32_t * shared_argmax [[threadgroup(1)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + + float lmax = -INFINITY; + int32_t larg = -1; + + for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + if (x_row[i00] > lmax) { + lmax = x_row[i00]; + larg = i00; + } + } + + // find the argmax value in the block + float max_val = simd_max(lmax); + int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + shared_maxval[tiisg] = -INFINITY; + shared_argmax[tiisg] = -1; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shared_maxval[sgitg] = max_val; + shared_argmax[sgitg] = arg_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = shared_maxval[tiisg]; + arg_val = shared_argmax[tiisg]; + + float max_val_reduced = simd_max(max_val); + int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); + + dst[tgpig] = arg_val_reduced; + + return; + } + + dst[tgpig] = arg_val; +} + kernel void kernel_norm( constant ggml_metal_kargs_norm & args, device const char * src0, @@ -2671,6 +2728,79 @@ kernel void kernel_im2col_ext( template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +typedef void (conv_transpose_1d_t)( + device const float * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_1d( + device const T * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]) { + + float v = 0.0f; + + for (int64_t c = 0; c < IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; + + for (int64_t i = 0; i < IL; i++) { + if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) { + v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i]; + } + } + } + + device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1); + + dst_ptr[0] = v; +} + +template [[host_name("kernel_conv_transpose_1d_f32_f32")]] +kernel void kernel_conv_transpose_1d( + device const float * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template [[host_name("kernel_conv_transpose_1d_f16_f32")]] +kernel void kernel_conv_transpose_1d( + device const half * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + kernel void kernel_upscale_f32( device const char * src0, device char * dst, diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index d0815cf89e2f5..8cf25b77f98ea 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -73,7 +73,6 @@ while read c; do src/ggml*.h \ src/ggml*.c \ src/ggml*.cpp \ - src/ggml-amx/* \ src/ggml-blas/* \ src/ggml-cann/* \ src/ggml-cpu/* \ @@ -124,7 +123,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # src/ggml*.c -> ggml/src/ggml*.c # src/ggml*.cpp -> ggml/src/ggml*.cpp # src/ggml*.h -> ggml/src/ggml*.h - # src/ggml-amx/* -> ggml/src/ggml-amx/* # src/ggml-blas/* -> ggml/src/ggml-blas/* # src/ggml-cann/* -> ggml/src/ggml-cann/* # src/ggml-cpu/* -> ggml/src/ggml-cpu/* @@ -151,7 +149,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-amx\//\1ggml\/src\/ggml-amx\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index fd71a0a46fa8e..27769c93b512d 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c598cbe30621251e80acbcf3b601589a37c17f4d +b903ffe79daf18c0aaacbebe44a7b93a6b8d0982 diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index 000270afbfd1e..f81615bb6b76d 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -7,7 +7,6 @@ cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake cp -rpv ../ggml/src/ggml*.c ./ggml/src/ cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ cp -rpv ../ggml/src/ggml*.h ./ggml/src/ -cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/ cp -rpv ../ggml/src/ggml-blas/* ./ggml/src/ggml-blas/ cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/ diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c786da4c35d7c..87c92dadd9bcc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3460,13 +3460,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - test_cases.emplace_back(new test_argmax()); - test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1})); - test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1})); + test_cases.emplace_back(new test_count_equal()); + + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1})); - - test_cases.emplace_back(new test_count_equal()); + test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1})); for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1 test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));