diff --git a/CMakeLists.txt b/CMakeLists.txt index c5add8239c2bd..d5d61a7d25ca2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ option(LLAMA_AVX512 "llama: enable AVX512" option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF) +option(LLAMA_AMX "llama: enable AMX" OFF) option(LLAMA_FMA "llama: enable FMA" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512 if (NOT MSVC) @@ -1072,6 +1073,14 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW add_compile_definitions($<$:__AVX512BF16__>) add_compile_definitions($<$:__AVX512BF16__>) endif() + if (LLAMA_AMX) + add_compile_definitions($<$:__AMX_TILE__>) + add_compile_definitions($<$:__AMX_TILE__>) + add_compile_definitions($<$:__AMX_INT8__>) + add_compile_definitions($<$:__AMX_INT8__>) + add_compile_definitions($<$:__AMX_BF16__>) + add_compile_definitions($<$:__AMX_BF16__>) + endif() elseif (LLAMA_AVX2) list(APPEND ARCH_FLAGS /arch:AVX2) elseif (LLAMA_AVX) @@ -1106,6 +1115,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW if (LLAMA_AVX512_BF16) list(APPEND ARCH_FLAGS -mavx512bf16) endif() + if (LLAMA_AMX) + list(APPEND ARCH_FLAGS -mavx512vl -mavx512dq) + list(APPEND ARCH_FLAGS -mamx-tile -mamx-int8 -mamx-bf16) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") message(STATUS "PowerPC detected") diff --git a/ggml.c b/ggml.c index 5145ceec9f4b2..8e58cd08e15e2 100644 --- a/ggml.c +++ b/ggml.c @@ -26,6 +26,10 @@ #include #if defined(__gnu_linux__) #include +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +#define ARCH_REQ_XCOMP_PERM 0x1023 +#define XFEATURE_XTILEDATA 18 +#endif #endif #ifdef GGML_USE_METAL @@ -36,6 +40,13 @@ #undef GGML_USE_LLAMAFILE #endif +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +#undef GGML_USE_LLAMAFILE +#define AMX_TILE_MN 16 +#define AMX_TILE_K 16 +#define AMX_BLCK_SIZE 64 +#endif + #ifdef GGML_USE_LLAMAFILE #include "sgemm.h" #endif @@ -1834,7 +1845,100 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * *s = sumf; } +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +static inline void ggml_transpose_8x8xpack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs) { + __m256 row0 = _mm256_loadu_ps((const float *)((const int8_t *)s + 0*bs)); + __m256 row1 = _mm256_loadu_ps((const float *)((const int8_t *)s + 1*bs)); + __m256 row2 = _mm256_loadu_ps((const float *)((const int8_t *)s + 2*bs)); + __m256 row3 = _mm256_loadu_ps((const float *)((const int8_t *)s + 3*bs)); + __m256 row4 = _mm256_loadu_ps((const float *)((const int8_t *)s + 4*bs)); + __m256 row5 = _mm256_loadu_ps((const float *)((const int8_t *)s + 5*bs)); + __m256 row6 = _mm256_loadu_ps((const float *)((const int8_t *)s + 6*bs)); + __m256 row7 = _mm256_loadu_ps((const float *)((const int8_t *)s + 7*bs)); + + __m256 tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7; + __m256 tr8, tr9, tr10, tr11, tr12, tr13, tr14, tr15; + tr0 = _mm256_unpacklo_ps(row0, row1); + tr1 = _mm256_unpackhi_ps(row0, row1); + tr2 = _mm256_unpacklo_ps(row2, row3); + tr3 = _mm256_unpackhi_ps(row2, row3); + tr4 = _mm256_unpacklo_ps(row4, row5); + tr5 = _mm256_unpackhi_ps(row4, row5); + tr6 = _mm256_unpacklo_ps(row6, row7); + tr7 = _mm256_unpackhi_ps(row6, row7); + tr8 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(1, 0, 1, 0)); + tr9 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(3, 2, 3, 2)); + tr10 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(1, 0, 1, 0)); + tr11 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(3, 2, 3, 2)); + tr12 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(1, 0, 1, 0)); + tr13 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(3, 2, 3, 2)); + tr14 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(1, 0, 1, 0)); + tr15 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(3, 2, 3, 2)); + row0 = _mm256_permute2f128_ps(tr8, tr12, 0x20); + row1 = _mm256_permute2f128_ps(tr9, tr13, 0x20); + row2 = _mm256_permute2f128_ps(tr10, tr14, 0x20); + row3 = _mm256_permute2f128_ps(tr11, tr15, 0x20); + row4 = _mm256_permute2f128_ps(tr8, tr12, 0x31); + row5 = _mm256_permute2f128_ps(tr9, tr13, 0x31); + row6 = _mm256_permute2f128_ps(tr10, tr14, 0x31); + row7 = _mm256_permute2f128_ps(tr11, tr15, 0x31); + + _mm256_storeu_ps((float *)((int8_t *)d + 0*bd), row0); + _mm256_storeu_ps((float *)((int8_t *)d + 1*bd), row1); + _mm256_storeu_ps((float *)((int8_t *)d + 2*bd), row2); + _mm256_storeu_ps((float *)((int8_t *)d + 3*bd), row3); + _mm256_storeu_ps((float *)((int8_t *)d + 4*bd), row4); + _mm256_storeu_ps((float *)((int8_t *)d + 5*bd), row5); + _mm256_storeu_ps((float *)((int8_t *)d + 6*bd), row6); + _mm256_storeu_ps((float *)((int8_t *)d + 7*bd), row7); +} + +static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs, const size_t nr, const size_t nc) { + assert(nr % 8 == 0); + assert(nc % 8 == 0); + for (size_t bi = 0; bi < nr; bi += 8) { + for (size_t bj = 0; bj < nc; bj += 8) { + ggml_transpose_8x8xpack4((int8_t *)d + bj*bd + bi*4, bd, (const int8_t *)s + bi*bs + bj*4, bs); + } + } +} + +typedef struct __tile_config +{ + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[8]; + uint8_t reserved_1[16]; + uint8_t rows[8]; + uint8_t reserved_2[8]; +} __tile_config_t; +#endif + static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (nrc == AMX_TILE_MN) { + assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0); + // 0: zt, 1: yt, 2: xt + __tile_config_t cfg = { + .palette_id = 1, + .start_row = 0, + .colsb = {AMX_TILE_MN*sizeof(float), AMX_TILE_K*4, AMX_TILE_MN*4, 0,}, + .rows = {AMX_TILE_MN, AMX_TILE_K, AMX_TILE_MN, 0,}, + }; + _tile_loadconfig(&cfg); + _tile_zero(0); + for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) { + ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)]; + ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K); + _tile_loadd(1, y + i, by); + _tile_loadd(2, axt, AMX_TILE_MN*4); + _tile_dpbf16ps(0, 1, 2); + } + _tile_stored(0, s, bs*sizeof(float)); + return; + } +#endif assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -12269,17 +12373,62 @@ static void ggml_compute_forward_mul_mat_one_chunk( assert(ne13 % ne03 == 0); // block-tiling attempt +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + const int64_t blck_0 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16; + const int64_t blck_1 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16; +#else const int64_t blck_0 = 16; const int64_t blck_1 = 16; +#endif const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; // attempt to reduce false-sharing (does not seem to make a difference) +#ifdef __ARM_FEATURE_MATMUL_INT8 // 16 * 2, accounting for mmla kernels float tmp[32]; +#elif defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (num_rows_per_vec_dot == AMX_TILE_MN) { + assert(AMX_TILE_MN <= blck_0 && AMX_TILE_MN <= blck_1); + assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0); + assert(src1->type == GGML_TYPE_F32); + } + // AMX_BLCK_SIZE * AMX_TILE_MN, accounting for amx kernels + float tmp[AMX_BLCK_SIZE*AMX_TILE_MN]; + uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096); + ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase); + ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)); + float * xf32 = (float *) (wbase + 2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)); + xf32 = (float *) (((size_t)xf32 + 4095) & ~4095); +#else + float tmp[16]; +#endif for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (num_rows_per_vec_dot == AMX_TILE_MN) { + const int64_t ii13 = (iir1 / (ne12 * ne1)); + const int64_t ii12 = (iir1 - ii13 * ne12 * ne1) / ne1; + const int64_t ii11 = (iir1 - ii13 * ne12 * ne1 - ii12 * ne1); + + // broadcast src0 into src1 + const int64_t ii03 = ii13 / r3; + const int64_t ii02 = ii12 / r2; + + ggml_to_float_t const to_float = type_traits[type].to_float; + + const char * src0_row = (const char*)src0->data + (0 + ii02 * nb02 + ii03 * nb03); + const uint8_t * src1_col = (const uint8_t *)src1->data + ii11 * nb11 + ii12 * nb12 + ii13 * nb13; + for (int i = 0; i < blck_0 && iir0 + i < ir0_end; ++i) { + to_float((const uint8_t *)src0_row + iir0*nb01 + i*nb01, xf32, ne00); + ggml_fp32_to_bf16_row(xf32, xbf16 + i*ne00, ne00); + } + for (int i = 0; i < blck_1 && iir1 + i < ir1_end; ++i) { + ggml_fp32_to_bf16_row((const float *)(src1_col + i*nb11), ybf16 + i*ne00, ne00); + } + } +#endif for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { const int64_t i13 = (ir1 / (ne12 * ne1)); const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; @@ -12293,28 +12442,37 @@ static void ggml_compute_forward_mul_mat_one_chunk( const int64_t i2 = i12; const int64_t i3 = i13; - const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); - - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char*)wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); + float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if (num_rows_per_vec_dot == AMX_TILE_MN) { + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], blck_0, xbf16 + (ir0-iir0)*ne00, ne00*sizeof(ggml_bf16_t), ybf16 + (ir1-iir1)*ne00, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN); + } + } + else +#endif + if (true) { + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char*)wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? blck_0 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + } } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { - memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); + memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); } } } @@ -12473,6 +12631,11 @@ UseGgmlGemm1:; } // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. atomic_store(&state->shared->current_chunk, nth); +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (ne01 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) { + return; + } +#endif if (src1->type != vec_dot_type) { char * wdata = params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); @@ -12540,9 +12703,18 @@ UseGgmlGemm2:; if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { num_rows_per_vec_dot = 1; } +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (nr0 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) { + num_rows_per_vec_dot = AMX_TILE_MN; + } +#endif // Now select a reasonable chunk size. +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + int chunk_size = AMX_BLCK_SIZE; +#else int chunk_size = 16; +#endif // We need to step up the size if it's small if (nr0 == 1 || nr1 == 1) { @@ -19328,6 +19500,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { set_numa_thread_affinity(state->ith); +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) && defined(__gnu_linux__) + // refer to https://www.kernel.org/doc/Documentation/x86/xstate.rst + syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); +#endif + int node_n = -1; int task_phase = GGML_TASK_TYPE_FINALIZE; @@ -19525,6 +19702,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa * node->src[1]->ne[2]*node->src[1]->ne[3]; } } else +#endif +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) { + cur = n_threads*(2*AMX_BLCK_SIZE*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096); + } else #endif if (node->src[1]->type != vec_dot_type) { cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); @@ -22877,4 +23059,12 @@ int ggml_cpu_has_matmul_int8(void) { #endif } +int ggml_cpu_has_amx(void) { +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + return 1; +#else + return 0; +#endif +} + //////////////////////////////////////////////////////////////////////////////// diff --git a/ggml.h b/ggml.h index f803ba7241fe1..72517142c6245 100644 --- a/ggml.h +++ b/ggml.h @@ -2421,6 +2421,7 @@ extern "C" { GGML_API int ggml_cpu_has_sycl (void); GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_matmul_int8(void); + GGML_API int ggml_cpu_has_amx (void); // // Internal types and functions exposed for tests and benchmarks diff --git a/llama.cpp b/llama.cpp index f67cb7e232945..d3ef5238ec3a0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -18355,6 +18355,7 @@ const char * llama_print_system_info(void) { s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | "; + s += "AMX = " + std::to_string(ggml_cpu_has_amx()) + " | "; #ifdef GGML_USE_LLAMAFILE s += "LLAMAFILE = 1 | "; #else