Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ option(LLAMA_AVX512 "llama: enable AVX512"
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_AMX "llama: enable AMX" OFF)
option(LLAMA_FMA "llama: enable FMA" ${INS_ENB})
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
Expand Down Expand Up @@ -1072,6 +1073,14 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
endif()
if (LLAMA_AMX)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
endif()
elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX)
Expand Down Expand Up @@ -1106,6 +1115,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16)
endif()
if (LLAMA_AMX)
list(APPEND ARCH_FLAGS -mavx512vl -mavx512dq)
list(APPEND ARCH_FLAGS -mamx-tile -mamx-int8 -mamx-bf16)
endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected")
Expand Down
224 changes: 207 additions & 17 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#include <signal.h>
#if defined(__gnu_linux__)
#include <syscall.h>
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILEDATA 18
#endif
#endif

#ifdef GGML_USE_METAL
Expand All @@ -36,6 +40,13 @@
#undef GGML_USE_LLAMAFILE
#endif

#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
#undef GGML_USE_LLAMAFILE
#define AMX_TILE_MN 16
#define AMX_TILE_K 16
#define AMX_BLCK_SIZE 64
#endif

#ifdef GGML_USE_LLAMAFILE
#include "sgemm.h"
#endif
Expand Down Expand Up @@ -1834,7 +1845,100 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
*s = sumf;
}

#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
static inline void ggml_transpose_8x8xpack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs) {
__m256 row0 = _mm256_loadu_ps((const float *)((const int8_t *)s + 0*bs));
__m256 row1 = _mm256_loadu_ps((const float *)((const int8_t *)s + 1*bs));
__m256 row2 = _mm256_loadu_ps((const float *)((const int8_t *)s + 2*bs));
__m256 row3 = _mm256_loadu_ps((const float *)((const int8_t *)s + 3*bs));
__m256 row4 = _mm256_loadu_ps((const float *)((const int8_t *)s + 4*bs));
__m256 row5 = _mm256_loadu_ps((const float *)((const int8_t *)s + 5*bs));
__m256 row6 = _mm256_loadu_ps((const float *)((const int8_t *)s + 6*bs));
__m256 row7 = _mm256_loadu_ps((const float *)((const int8_t *)s + 7*bs));

__m256 tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7;
__m256 tr8, tr9, tr10, tr11, tr12, tr13, tr14, tr15;
tr0 = _mm256_unpacklo_ps(row0, row1);
tr1 = _mm256_unpackhi_ps(row0, row1);
tr2 = _mm256_unpacklo_ps(row2, row3);
tr3 = _mm256_unpackhi_ps(row2, row3);
tr4 = _mm256_unpacklo_ps(row4, row5);
tr5 = _mm256_unpackhi_ps(row4, row5);
tr6 = _mm256_unpacklo_ps(row6, row7);
tr7 = _mm256_unpackhi_ps(row6, row7);
tr8 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(1, 0, 1, 0));
tr9 = _mm256_shuffle_ps(tr0, tr2, _MM_SHUFFLE(3, 2, 3, 2));
tr10 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(1, 0, 1, 0));
tr11 = _mm256_shuffle_ps(tr1, tr3, _MM_SHUFFLE(3, 2, 3, 2));
tr12 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(1, 0, 1, 0));
tr13 = _mm256_shuffle_ps(tr4, tr6, _MM_SHUFFLE(3, 2, 3, 2));
tr14 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(1, 0, 1, 0));
tr15 = _mm256_shuffle_ps(tr5, tr7, _MM_SHUFFLE(3, 2, 3, 2));
row0 = _mm256_permute2f128_ps(tr8, tr12, 0x20);
row1 = _mm256_permute2f128_ps(tr9, tr13, 0x20);
row2 = _mm256_permute2f128_ps(tr10, tr14, 0x20);
row3 = _mm256_permute2f128_ps(tr11, tr15, 0x20);
row4 = _mm256_permute2f128_ps(tr8, tr12, 0x31);
row5 = _mm256_permute2f128_ps(tr9, tr13, 0x31);
row6 = _mm256_permute2f128_ps(tr10, tr14, 0x31);
row7 = _mm256_permute2f128_ps(tr11, tr15, 0x31);

_mm256_storeu_ps((float *)((int8_t *)d + 0*bd), row0);
_mm256_storeu_ps((float *)((int8_t *)d + 1*bd), row1);
_mm256_storeu_ps((float *)((int8_t *)d + 2*bd), row2);
_mm256_storeu_ps((float *)((int8_t *)d + 3*bd), row3);
_mm256_storeu_ps((float *)((int8_t *)d + 4*bd), row4);
_mm256_storeu_ps((float *)((int8_t *)d + 5*bd), row5);
_mm256_storeu_ps((float *)((int8_t *)d + 6*bd), row6);
_mm256_storeu_ps((float *)((int8_t *)d + 7*bd), row7);
}

static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void * restrict s, const size_t bs, const size_t nr, const size_t nc) {
assert(nr % 8 == 0);
assert(nc % 8 == 0);
for (size_t bi = 0; bi < nr; bi += 8) {
for (size_t bj = 0; bj < nc; bj += 8) {
ggml_transpose_8x8xpack4((int8_t *)d + bj*bd + bi*4, bd, (const int8_t *)s + bi*bs + bj*4, bs);
}
}
}

typedef struct __tile_config
{
uint8_t palette_id;
uint8_t start_row;
uint8_t reserved_0[14];
uint16_t colsb[8];
uint8_t reserved_1[16];
uint8_t rows[8];
uint8_t reserved_2[8];
} __tile_config_t;
#endif

static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (nrc == AMX_TILE_MN) {
assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0);
// 0: zt, 1: yt, 2: xt
__tile_config_t cfg = {
.palette_id = 1,
.start_row = 0,
.colsb = {AMX_TILE_MN*sizeof(float), AMX_TILE_K*4, AMX_TILE_MN*4, 0,},
.rows = {AMX_TILE_MN, AMX_TILE_K, AMX_TILE_MN, 0,},
};
_tile_loadconfig(&cfg);
_tile_zero(0);
for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) {
ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)];
ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K);
_tile_loadd(1, y + i, by);
_tile_loadd(2, axt, AMX_TILE_MN*4);
_tile_dpbf16ps(0, 1, 2);
}
_tile_stored(0, s, bs*sizeof(float));
return;
}
#endif
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
Expand Down Expand Up @@ -12269,17 +12373,62 @@ static void ggml_compute_forward_mul_mat_one_chunk(
assert(ne13 % ne03 == 0);

// block-tiling attempt
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
const int64_t blck_0 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
const int64_t blck_1 = num_rows_per_vec_dot == AMX_TILE_MN ? AMX_BLCK_SIZE : 16;
#else
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
#endif

const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;

// attempt to reduce false-sharing (does not seem to make a difference)
#ifdef __ARM_FEATURE_MATMUL_INT8
// 16 * 2, accounting for mmla kernels
float tmp[32];
#elif defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
assert(AMX_TILE_MN <= blck_0 && AMX_TILE_MN <= blck_1);
assert(blck_0 % AMX_TILE_MN == 0 && blck_1 % AMX_TILE_MN == 0);
assert(src1->type == GGML_TYPE_F32);
}
// AMX_BLCK_SIZE * AMX_TILE_MN, accounting for amx kernels
float tmp[AMX_BLCK_SIZE*AMX_TILE_MN];
uint8_t * wbase = (uint8_t *) (params->wdata) + params->ith*(2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t)+ne00*sizeof(float)+4096);
ggml_bf16_t * xbf16 = (ggml_bf16_t *)(wbase);
ggml_bf16_t * ybf16 = (ggml_bf16_t *)(wbase + 1*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
float * xf32 = (float *) (wbase + 2*AMX_BLCK_SIZE*ne00*sizeof(ggml_bf16_t));
xf32 = (float *) (((size_t)xf32 + 4095) & ~4095);
#else
float tmp[16];
#endif

for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
const int64_t ii13 = (iir1 / (ne12 * ne1));
const int64_t ii12 = (iir1 - ii13 * ne12 * ne1) / ne1;
const int64_t ii11 = (iir1 - ii13 * ne12 * ne1 - ii12 * ne1);

// broadcast src0 into src1
const int64_t ii03 = ii13 / r3;
const int64_t ii02 = ii12 / r2;

ggml_to_float_t const to_float = type_traits[type].to_float;

const char * src0_row = (const char*)src0->data + (0 + ii02 * nb02 + ii03 * nb03);
const uint8_t * src1_col = (const uint8_t *)src1->data + ii11 * nb11 + ii12 * nb12 + ii13 * nb13;
for (int i = 0; i < blck_0 && iir0 + i < ir0_end; ++i) {
to_float((const uint8_t *)src0_row + iir0*nb01 + i*nb01, xf32, ne00);
ggml_fp32_to_bf16_row(xf32, xbf16 + i*ne00, ne00);
}
for (int i = 0; i < blck_1 && iir1 + i < ir1_end; ++i) {
ggml_fp32_to_bf16_row((const float *)(src1_col + i*nb11), ybf16 + i*ne00, ne00);
}
}
#endif
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
const int64_t i13 = (ir1 / (ne12 * ne1));
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
Expand All @@ -12293,28 +12442,37 @@ static void ggml_compute_forward_mul_mat_one_chunk(
const int64_t i2 = i12;
const int64_t i3 = i13;

const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);

// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char*)wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (num_rows_per_vec_dot == AMX_TILE_MN) {
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
ggml_vec_dot_bf16(ne00, &tmp[ir0 - iir0], blck_0, xbf16 + (ir0-iir0)*ne00, ne00*sizeof(ggml_bf16_t), ybf16 + (ir1-iir1)*ne00, ne00*sizeof(ggml_bf16_t), AMX_TILE_MN);
}
}
else
#endif
if (true) {
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char*)wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12 + i13 * nb13));

//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}

for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? blck_0 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
}

for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * blck_0), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
}
}
}
Expand Down Expand Up @@ -12473,6 +12631,11 @@ UseGgmlGemm1:;
}
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
atomic_store(&state->shared->current_chunk, nth);
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (ne01 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) {
return;
}
#endif
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
Expand Down Expand Up @@ -12540,9 +12703,18 @@ UseGgmlGemm2:;
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
num_rows_per_vec_dot = 1;
}
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((ne00 % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (nr0 % AMX_TILE_MN == 0) && (ne11 % AMX_TILE_MN == 0)) {
num_rows_per_vec_dot = AMX_TILE_MN;
}
#endif

// Now select a reasonable chunk size.
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
int chunk_size = AMX_BLCK_SIZE;
#else
int chunk_size = 16;
#endif

// We need to step up the size if it's small
if (nr0 == 1 || nr1 == 1) {
Expand Down Expand Up @@ -19328,6 +19500,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {

set_numa_thread_affinity(state->ith);

#if defined(__AMX_TILE__) && defined(__AMX_BF16__) && defined(__gnu_linux__)
// refer to https://www.kernel.org/doc/Documentation/x86/xstate.rst
syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
#endif

int node_n = -1;
int task_phase = GGML_TASK_TYPE_FINALIZE;

Expand Down Expand Up @@ -19525,6 +19702,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
* node->src[1]->ne[2]*node->src[1]->ne[3];
}
} else
#endif
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if ((node->src[0]->ne[0] % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0) && (node->src[0]->ne[1] % AMX_TILE_MN == 0) && (node->src[1]->ne[1] % AMX_TILE_MN == 0)) {
cur = n_threads*(2*AMX_BLCK_SIZE*node->src[0]->ne[0]*sizeof(ggml_bf16_t)+node->src[0]->ne[0]*sizeof(float)+4096);
} else
#endif
if (node->src[1]->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
Expand Down Expand Up @@ -22877,4 +23059,12 @@ int ggml_cpu_has_matmul_int8(void) {
#endif
}

int ggml_cpu_has_amx(void) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
return 1;
#else
return 0;
#endif
}

////////////////////////////////////////////////////////////////////////////////
1 change: 1 addition & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,7 @@ extern "C" {
GGML_API int ggml_cpu_has_sycl (void);
GGML_API int ggml_cpu_has_vsx (void);
GGML_API int ggml_cpu_has_matmul_int8(void);
GGML_API int ggml_cpu_has_amx (void);

//
// Internal types and functions exposed for tests and benchmarks
Expand Down
1 change: 1 addition & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18355,6 +18355,7 @@ const char * llama_print_system_info(void) {
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
s += "AMX = " + std::to_string(ggml_cpu_has_amx()) + " | ";
#ifdef GGML_USE_LLAMAFILE
s += "LLAMAFILE = 1 | ";
#else
Expand Down