Skip to content

Commit 9bd2a52

Browse files
committed
Q4 tiled Gemm Implementation
This patch does tiled gemm approach similar to SGEMM. But, this degrades performance that current qgemm implementation. Signed-off-by: Shalini Salomi Bodapati <[email protected]>
1 parent ee3a9fc commit 9bd2a52

File tree

1 file changed

+200
-50
lines changed

1 file changed

+200
-50
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 200 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,13 +1582,79 @@ class tinyBLAS_Q0_PPC {
15821582
float *C, int64_t ldc,
15831583
int ith, int nth)
15841584
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1585+
kc = 8;
15851586
}
15861587

15871588
void matmul(int64_t m, int64_t n) {
1589+
int mc = 8; int nc = 8;
1590+
if (m%mc == 0 && n%nc == 0 && k%kc == 0) {
1591+
//debug_print_q4_0((const block_q4_0 *)A, lda, m);
1592+
//debug_print_q8_0((const block_q8_0 *)B, ldb, n);
1593+
matmul_tiled(m, n, mc, nc, kc);
1594+
}
1595+
else {
1596+
//debug_print_q4_0((const block_q4_0 *)A, lda, m);
1597+
//debug_print_q8_0((const block_q8_0 *)B, ldb, n);
15881598
mnpack(0, m, 0, n);
1599+
}
15891600
}
15901601

15911602
private:
1603+
void debug_print_q4_0(const block_q4_0 *A, int lda, int m) {
1604+
printf("\n===== Matrix A (Q4_0) =====\n");
1605+
for (int i = 0; i < m; i++) {
1606+
// each block holds QK4_0 values (usually 32)
1607+
for (int blk = 0; blk < lda; blk++) {
1608+
const block_q4_0* bb = A + i*lda + blk;
1609+
float d = GGML_FP16_TO_FP32(bb->d);
1610+
printf("Row %d: d = %f, qs = ", i, d);
1611+
for ( int x = 0; x< QK4_0/2; x++) {
1612+
uint8_t q = bb->qs[x];
1613+
int8_t q0 = (q & 0x0F) - 8; // lower nibble
1614+
int8_t q1 = ((q >> 4) & 0x0F) - 8; // upper nibble
1615+
printf("%d %d ", q0, q1);
1616+
}
1617+
printf("\n");
1618+
}
1619+
}
1620+
}
1621+
1622+
1623+
void debug_print_q8_0(const block_q8_0 *B, int ldb, int n) {
1624+
printf("\n===== Matrix B (Q8_0) =====\n");
1625+
for (int j = 0; j < n; j++) {
1626+
printf("Col %d : ", j);
1627+
for (int blk = 0; blk < k; blk++) {
1628+
const block_q8_0 *bb = B + j*ldb + blk;
1629+
float d = GGML_FP16_TO_FP32(bb->d);
1630+
printf(" [d=%f, qs=", d);
1631+
for (int x = 0; x < QK8_0; x++) {
1632+
printf("%d ", bb->qs[x]);
1633+
}
1634+
printf("]\n");
1635+
}
1636+
printf("\n");
1637+
}
1638+
}
1639+
void print_vec_q4(const char* name, vec_t vec) {
1640+
printf("%s:\t", name);
1641+
for (int i = 0; i < 16; i++) {
1642+
uint8_t byte = (uint8_t) vec[i]; // take the raw 8-bit value
1643+
1644+
int8_t lo = (byte & 0x0F) - 8; // lower nibble (0–15) → shift to signed (-8..7)
1645+
int8_t hi = ((byte >> 4) & 0x0F) - 8; // upper nibble
1646+
1647+
printf("(%2d,%2d) ", lo, hi);
1648+
}
1649+
printf("\n");
1650+
}
1651+
1652+
void print_vec_q8(vec_t vec){
1653+
for (int i = 0; i<16; i++) {
1654+
printf("%-5d ", *((int8_t*)&vec[i]));
1655+
}
1656+
printf("\n");
1657+
}
15921658

15931659
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
15941660
for (int I = 0; I < RM; I++) {
@@ -1598,8 +1664,17 @@ class tinyBLAS_Q0_PPC {
15981664
}
15991665
}
16001666

1667+
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
1668+
for (int I = 0; I < RM; I++) {
1669+
for (int J = 0; J < RN; J++) {
1670+
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
1671+
*c_ptr += *((float*)&fin_res[idx+I]+J);
1672+
}
1673+
}
1674+
}
1675+
16011676
template<int size>
1602-
inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array<int, size>& comparray, vector float* vs, vector float* fin_res) {
1677+
inline void compute(acc_t* ACC, int c_idx, int s_idx, int* comparray, vector float* vs, vector float* fin_res) {
16031678
vector signed int vec_C[4];
16041679
vector float CA[4] = {0};
16051680
vector float res[4] = {0};
@@ -1630,6 +1705,27 @@ class tinyBLAS_Q0_PPC {
16301705
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
16311706
}
16321707

1708+
inline void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
1709+
float a_scales[8];
1710+
for (int I = 0; I < 8; ++I) {
1711+
a_scales[I] = unhalf((A + ((ii + I) * lda) + blk)->d);
1712+
}
1713+
1714+
float tmp_bl[4], tmp_br[4];
1715+
for (int J = 0; J < 4; ++J) {
1716+
tmp_bl[J] = unhalf((B + ((jj + J) * ldb) + blk)->d);
1717+
tmp_br[J] = unhalf((B + ((jj + J + 4) * ldb) + blk)->d);
1718+
}
1719+
vector float vec_bl = vec_xl(0, tmp_bl); // or vec_xl(0, tmp_bl)
1720+
vector float vec_br = vec_xl(0, tmp_br);
1721+
1722+
for (int I = 0; I < 8; ++I) {
1723+
vector float a_vec = vec_splats(a_scales[I]);
1724+
vs[I] = vec_mul(a_vec, vec_bl); // left half
1725+
vs[I + 8] = vec_mul(a_vec, vec_br); // right half
1726+
}
1727+
}
1728+
16331729
template <typename V1, typename V2>
16341730
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
16351731
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
@@ -1661,7 +1757,7 @@ class tinyBLAS_Q0_PPC {
16611757
}
16621758

16631759
template<int size>
1664-
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
1760+
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int* comparray) {
16651761
int64_t i, j;
16661762
TA *aoffset = NULL;
16671763
int8_t *vecOffset = NULL;
@@ -1670,7 +1766,9 @@ class tinyBLAS_Q0_PPC {
16701766
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
16711767
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
16721768
aoffset = const_cast<TA*>(a);
1769+
int index = 0;
16731770
vecOffset = vec;
1771+
//int kc = 1;
16741772
j = (rows >> 3);
16751773
if (j > 0) {
16761774
do {
@@ -1683,43 +1781,36 @@ class tinyBLAS_Q0_PPC {
16831781
aoffset7 = aoffset6 + lda;
16841782
aoffset8 = aoffset7 + lda;
16851783
aoffset += 8 * lda;
1686-
i = (cols >> 2);
1687-
if (i > 0) {
1688-
do {
1689-
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
1690-
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs));
1691-
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
1692-
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
1693-
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs));
1694-
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs));
1695-
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
1696-
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
1697-
1698-
process_q4_elements(c1, &comparray[0]);
1699-
process_q4_elements(c2, &comparray[1]);
1700-
process_q4_elements(c3, &comparray[2]);
1701-
process_q4_elements(c4, &comparray[3]);
1702-
process_q4_elements(c5, &comparray[4]);
1703-
process_q4_elements(c6, &comparray[5]);
1704-
process_q4_elements(c7, &comparray[6]);
1705-
process_q4_elements(c8, &comparray[7]);
1784+
for (int blk = 0; blk < kc; blk++) {
1785+
//float scale = GGML_FP16_TO_FP32((aoffset1+blk)->d);
1786+
//printf("packed block0 with scale=%f\n", scale);
1787+
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
1788+
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
1789+
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
1790+
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
1791+
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
1792+
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
1793+
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
1794+
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
1795+
//scale = GGML_FP16_TO_FP32((aoffset8+blk)->d);
1796+
//printf("packed block8 with scale=%f\n", scale);
1797+
1798+
process_q4_elements(c1, &comparray[index + 8*blk+0]);
1799+
process_q4_elements(c2, &comparray[index + 8*blk+1]);
1800+
process_q4_elements(c3, &comparray[index + 8*blk+2]);
1801+
process_q4_elements(c4, &comparray[index + 8*blk+3]);
1802+
process_q4_elements(c5, &comparray[index + 8*blk+4]);
1803+
process_q4_elements(c6, &comparray[index + 8*blk+5]);
1804+
process_q4_elements(c7, &comparray[index + 8*blk+6]);
1805+
process_q4_elements(c8, &comparray[index + 8*blk+7]);
17061806
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
17071807
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
17081808
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
17091809
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
1710-
aoffset1 += lda;
1711-
aoffset2 += lda;
1712-
aoffset3 += lda;
1713-
aoffset4 += lda;
1714-
aoffset5 += lda;
1715-
aoffset6 += lda;
1716-
aoffset7 += lda;
1717-
aoffset8 += lda;
17181810
vecOffset += 256;
1719-
i--;
1720-
} while (i > 0);
1721-
}
1811+
}
17221812
j--;
1813+
index += 8*kc;
17231814
} while (j > 0);
17241815
}
17251816

@@ -1792,19 +1883,16 @@ class tinyBLAS_Q0_PPC {
17921883
VB c1[8] = {0}; VB c2[8] = {0};
17931884
aoffset = const_cast<block_q8_0*>(a);
17941885
vecOffset = vec;
1886+
//int kc = 1;
17951887
j = (rows >> 3);
17961888
if (j > 0) {
17971889
do {
1798-
aoffsets[0] = aoffset;
1799-
for (int it = 1; it < 8; it++)
1800-
aoffsets[it] = aoffsets[it-1] + lda;
1890+
for (int it = 0; it < 8; it++)
1891+
aoffsets[it] = aoffset + it*lda;
18011892
aoffset += 8 * lda;
1802-
1803-
i = (cols >> 3);
1804-
if (i > 0) {
1805-
do {
1893+
for (int blk = 0; blk < kc; blk++) {
18061894
for (int it = 0; it < 8; it++) {
1807-
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
1895+
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
18081896
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
18091897
c1[it] = c[it][0];
18101898
c2[it] = c[it][1];
@@ -1813,12 +1901,8 @@ class tinyBLAS_Q0_PPC {
18131901
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
18141902
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
18151903
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
1816-
for (int it = 0; it < 8; it++)
1817-
aoffsets[it] += lda;
18181904
vecOffset += 256;
1819-
i--;
1820-
} while(i > 0);
1821-
}
1905+
}
18221906
j--;
18231907
} while(j > 0);
18241908
}
@@ -1918,7 +2002,8 @@ class tinyBLAS_Q0_PPC {
19182002
void KERNEL_4x8(int64_t ii, int64_t jj) {
19192003
vec_t vec_A[8], vec_B[16] = {0};
19202004
acc_t acc_0, acc_1;
1921-
std::array<int, 4> comparray {};
2005+
//std::array<int, 4> comparray {};
2006+
int comparray[8] = {0};
19222007
vector float fin_res[8] = {0};
19232008
vector float vs[8] = {0};
19242009
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
@@ -1963,7 +2048,8 @@ class tinyBLAS_Q0_PPC {
19632048
void KERNEL_8x4(int64_t ii, int64_t jj) {
19642049
vec_t vec_A[16], vec_B[8] = {0};
19652050
acc_t acc_0, acc_1;
1966-
std::array<int, 8> comparray {};
2051+
//std::array<int, 8> comparray {};
2052+
int comparray[8] = {0};
19672053
vector float fin_res[8] = {0};
19682054
vector float vs[8] = {0};
19692055
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
@@ -2005,9 +2091,11 @@ class tinyBLAS_Q0_PPC {
20052091
}
20062092

20072093
void KERNEL_8x8(int64_t ii, int64_t jj) {
2094+
printf("In kernel 8x8 with ii = %ld jj = %ld\n", ii, jj);
20082095
vec_t vec_A[16], vec_B[16] = {0};
20092096
acc_t acc_0, acc_1, acc_2, acc_3;
2010-
std::array<int, 8> comparray {};
2097+
//std::array<int, 8> comparray {};
2098+
int comparray[8] = {0};
20112099
vector float fin_res[16] = {0};
20122100
vector float vs[16] = {0};
20132101
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
@@ -2017,10 +2105,12 @@ class tinyBLAS_Q0_PPC {
20172105
__builtin_mma_xxsetaccz(&acc_2);
20182106
__builtin_mma_xxsetaccz(&acc_3);
20192107
if (std::is_same_v<TA, block_q4_0>) {
2108+
printf("calling packNormal for A matrix l = %d\n", l);
20202109
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
20212110
} else {
20222111
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
20232112
}
2113+
printf("calling packNormal for B matrix l = %d\n", l);
20242114
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
20252115
for(int x = 0; x < 8; x++) {
20262116
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
@@ -2057,6 +2147,64 @@ class tinyBLAS_Q0_PPC {
20572147
save_res(ii+4, jj+4, 12, fin_res);
20582148
}
20592149

2150+
void KERNEL_Q4(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, vec_t *vec_A, vec_t *vec_B, int *comparray) {
2151+
acc_t acc[4];
2152+
for (int i = 0; i < mc ; i += 8) {
2153+
for (int j = 0; j < nc; j += 8) {
2154+
vector float fin_res[16] = {0};
2155+
vector float vs[16] = {0};
2156+
for (int64_t kk = 0; kk < kc; kk++) {
2157+
for (int x = 0; x < 4; x++) {
2158+
__builtin_mma_xxsetaccz(&acc[x]);
2159+
}
2160+
int A_block_idx = (i/8)*(16*kc) + kk*16;
2161+
int B_block_idx = (j/8)*(16*kc)+ kk*16;
2162+
vec_t *A_block = &vec_A[A_block_idx];
2163+
vec_t *B_block = &vec_B[B_block_idx];
2164+
2165+
for (int x = 0; x < 8; x++) {
2166+
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
2167+
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
2168+
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
2169+
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
2170+
}
2171+
compute_scale(ii+i, jj+j, kk, vs);
2172+
int c_index = (i/8)*(8*kc)+ kk*8;
2173+
int* c_block = &comparray[c_index];
2174+
compute<8>(&acc[0], 0, 0, c_block, vs, fin_res);
2175+
compute<8>(&acc[1], 4, 4, c_block, vs, fin_res);
2176+
compute<8>(&acc[2], 0, 8, c_block, vs, fin_res);
2177+
compute<8>(&acc[3], 4, 12, c_block, vs, fin_res);
2178+
}
2179+
add_save_res(ii+i, jj+j, 0, fin_res);
2180+
add_save_res(ii+i+4, jj+j, 4, fin_res);
2181+
add_save_res(ii+i, jj+j+4, 8, fin_res);
2182+
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
2183+
}
2184+
2185+
}
2186+
}
2187+
2188+
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
2189+
int64_t ytiles = m / mc;
2190+
int64_t xtiles = n / nc;
2191+
int64_t tiles = xtiles * ytiles;
2192+
2193+
for (int64_t job = 0; job < tiles; job++) {
2194+
int64_t ii = (job / xtiles) * mc;
2195+
int64_t jj = (job % xtiles) * nc;
2196+
2197+
for (int64_t kk = 0; kk < k; kk += kc) {
2198+
vec_t A_pack[mc*kc*2]; // int4 → int8_t storage
2199+
vec_t B_pack[nc*kc*2];
2200+
int comparray[mc*kc]; // scales for A
2201+
packNormalInt4<8>(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
2202+
packNormal<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
2203+
KERNEL_Q4(ii, jj, mc, nc, kc, A_pack, B_pack, comparray);
2204+
}
2205+
}
2206+
}
2207+
20602208
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
20612209
int64_t ytiles = (m - m0) / RM;
20622210
int64_t xtiles = (n - n0) / RN;
@@ -2074,7 +2222,8 @@ class tinyBLAS_Q0_PPC {
20742222
for (int64_t job = start; job < end; ++job) {
20752223
int64_t ii = m0 + job / xtiles * RM;
20762224
int64_t jj = n0 + job % xtiles * RN;
2077-
std::array<int, 4> comparray{};
2225+
//std::array<int, 4> comparray{};
2226+
int comparray[4] = {0};
20782227
vector float res[4] = {0};
20792228
vector float fin_res[4] = {0};
20802229
vector float vs[4] = {0};
@@ -2159,6 +2308,7 @@ class tinyBLAS_Q0_PPC {
21592308
const block_q8_0 *const B;
21602309
float *C;
21612310
const int64_t k;
2311+
int64_t kc;
21622312
const int64_t lda;
21632313
const int64_t ldb;
21642314
const int64_t ldc;

0 commit comments

Comments
 (0)