Skip to content

Commit 89a12fa

Browse files
author
Chip Kerchner
committed
MMA BF16 GEMV code.
1 parent 7947970 commit 89a12fa

File tree

9 files changed

+1011
-69
lines changed

9 files changed

+1011
-69
lines changed

kernel/power/gemm_common.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include <altivec.h>
66

7+
#define NBMAX 4096
8+
79
#define FORCEINLINE inline __attribute__((always_inline))
810

911
#ifdef __clang__

kernel/power/sbgemv_common.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, v
111111
return (v_inp0 * v_in00);
112112
}
113113

114-
FORCEINLINE vec_f32 vec_loadNHi_multi2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero)
114+
FORCEINLINE vec_f32 vec_loadNHi_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero)
115115
{
116116
vec_f32 v_in00 = vec_loadNHi(in, n, zero);
117117

kernel/power/sbgemv_common_power10.c

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
/***************************************************************************
2+
Copyright (c) 2024, The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
28+
#ifndef SBGEMV_COMMON_MMA_C
29+
#define SBGEMV_COMMON_MMA_C
30+
#include "sbgemv_common.c"
31+
32+
FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
33+
{
34+
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
35+
36+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
37+
}
38+
39+
FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp)
40+
{
41+
vec_bf16 in0[2];
42+
43+
vec_load_pair((vec_f32 *)in0, (vec_f32 *)in);
44+
45+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]);
46+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]);
47+
}
48+
49+
FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
50+
{
51+
vec_bf16 in0 = vec_loadN(in, n);
52+
53+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp);
54+
}
55+
56+
FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
57+
{
58+
vec_bf16 in00 = vec_mergeh(in0, in0);
59+
60+
__builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00);
61+
}
62+
63+
FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp)
64+
{
65+
vec_bf16 in01 = vec_mergel(in0, in0);
66+
67+
vec_mult1_mma(&out[0], in0, inp);
68+
69+
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
70+
}
71+
72+
FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp)
73+
{
74+
vec_mult2_mma(out + 0, in0[0], inp);
75+
vec_mult2_mma(out + 2, in0[1], inp);
76+
}
77+
78+
FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
79+
{
80+
vec_bf16 in0 = vec_loadN(in, n);
81+
82+
vec_mult1_mma(out, in0, inp);
83+
}
84+
85+
FORCEINLINE void vec_loadN_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n)
86+
{
87+
vec_bf16 in0 = vec_loadN(in, n);
88+
89+
vec_mult2_mma(out, in0, inp);
90+
}
91+
92+
FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
93+
{
94+
vec_bf16 in0 = (vec_bf16)vec_load_vec(in);
95+
96+
vec_mult2_mma(out, in0, inp);
97+
}
98+
99+
FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp)
100+
{
101+
vec_bf16 in0[4];
102+
103+
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0));
104+
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2));
105+
106+
vec_mult4_mma(&out[0], in0 + 0, inp);
107+
vec_mult4_mma(&out[4], in0 + 2, inp);
108+
}
109+
110+
FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
111+
{
112+
__builtin_mma_disassemble_acc((void*)temp, &out[0]);
113+
114+
vy0[0] += (temp[0] * v_alpha);
115+
}
116+
117+
FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
118+
{
119+
vec_reduce1_mma(&out[0], &temp[0], v_alpha, &vy0[0]);
120+
vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]);
121+
}
122+
123+
FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0)
124+
{
125+
vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0);
126+
vec_reduce2_mma(&out[2], &temp[8], v_alpha, vy0 + 2);
127+
vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4);
128+
vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6);
129+
}
130+
131+
FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
132+
{
133+
vec_bf16 in00 = vec_mergeh(in0, in1);
134+
135+
__builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00);
136+
}
137+
138+
FORCEINLINE void vec_mult2a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
139+
{
140+
vec_bf16 in01 = vec_mergel(in0, in1);
141+
142+
vec_mult11a_mma(&out[0], in0, in1, inp);
143+
144+
__builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01);
145+
}
146+
147+
FORCEINLINE void vec_mult4a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
148+
{
149+
vec_mult2a_mma(out + 0, in0[0], in1[0], inp);
150+
vec_mult2a_mma(out + 2, in0[1], in1[1], inp);
151+
}
152+
153+
FORCEINLINE void vec_loadN_mult11a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
154+
{
155+
vec_bf16 in0 = vec_loadN(ina, n);
156+
vec_bf16 in1 = vec_loadN(inb, n);
157+
158+
vec_mult11a_mma(out, in0, in1, inp);
159+
}
160+
161+
FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
162+
{
163+
vec_bf16 in0 = (vec_bf16)vec_load_vec(ina);
164+
vec_bf16 in1 = (vec_bf16)vec_load_vec(inb);
165+
166+
vec_mult2a_mma(out, in0, in1, inp);
167+
}
168+
169+
FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
170+
{
171+
vec_bf16 in0[4], in1[4];
172+
173+
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
174+
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
175+
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
176+
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
177+
178+
vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp);
179+
vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp);
180+
}
181+
182+
FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
183+
{
184+
vec_bf16 in0 = vec_loadN(ina, n);
185+
vec_bf16 in1 = vec_loadN(inb, n);
186+
187+
vec_mult2a_mma(out, in0, in1, inp);
188+
}
189+
190+
FORCEINLINE void vec_mult11b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
191+
{
192+
vec_bf16 in00 = vec_mergeh(in0, in1);
193+
194+
__builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00);
195+
}
196+
197+
FORCEINLINE void vec_mult2b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp)
198+
{
199+
vec_bf16 in01 = vec_mergel(in0, in1);
200+
201+
vec_mult11b_mma(&out[0], in0, in1, inp);
202+
203+
__builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01);
204+
}
205+
206+
FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp)
207+
{
208+
vec_mult2b_mma(out + 0, in0[0], in1[0], inp);
209+
vec_mult2b_mma(out + 2, in0[1], in1[1], inp);
210+
}
211+
212+
FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
213+
{
214+
vec_bf16 in0 = vec_loadN(ina, n);
215+
vec_bf16 in1 = vec_loadN(inb, n);
216+
217+
vec_mult11b_mma(out, in0, in1, inp);
218+
}
219+
220+
FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
221+
{
222+
vec_bf16 in0 = (vec_bf16)vec_load_vec(ina);
223+
vec_bf16 in1 = (vec_bf16)vec_load_vec(inb);
224+
225+
vec_mult2b_mma(out, in0, in1, inp);
226+
}
227+
228+
FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp)
229+
{
230+
vec_bf16 in0[4], in1[4];
231+
232+
vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0));
233+
vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0));
234+
vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2));
235+
vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2));
236+
237+
vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp);
238+
vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp);
239+
}
240+
241+
FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n)
242+
{
243+
vec_bf16 in0 = vec_loadN(ina, n);
244+
vec_bf16 in1 = vec_loadN(inb, n);
245+
246+
vec_mult2b_mma(out, in0, in1, inp);
247+
}
248+
249+
FORCEINLINE void vec_load4_pair(vec_f32 *vy0, vec_f32 *v_y)
250+
{
251+
vec_load_pair(vy0 + 0, v_y + 0);
252+
vec_load_pair(vy0 + 2, v_y + 2);
253+
vec_load_pair(vy0 + 4, v_y + 4);
254+
vec_load_pair(vy0 + 6, v_y + 6);
255+
}
256+
257+
FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0)
258+
{
259+
vec_store_pair(v_y + 0, vy0 + 0);
260+
vec_store_pair(v_y + 2, vy0 + 2);
261+
vec_store_pair(v_y + 4, vy0 + 4);
262+
vec_store_pair(v_y + 6, vy0 + 6);
263+
}
264+
265+
#endif

kernel/power/sbgemv_n.c

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto
8787
}
8888
}
8989

90+
#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX))
91+
#define USE_N_8
92+
#endif
93+
9094
int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y)
9195
{
9296
IFLOAT *x_ptr, *ap[4];
@@ -100,7 +104,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
100104
y_ptr = y;
101105

102106
BLASLONG lda4 = lda << 2;
107+
#ifdef USE_N_8
103108
BLASLONG lda8 = lda << 3;
109+
#endif
104110
BLASLONG NB = NBMAX;
105111
BLASLONG m2 = (m & (NBMAX - 1));
106112

@@ -126,6 +132,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
126132
ap[3] = ap[2] + lda;
127133

128134
if (inc_x == 1) {
135+
#ifdef USE_N_8
129136
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
130137
BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha);
131138
ap[0] += lda8;
@@ -135,9 +142,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
135142
x_ptr += 8;
136143
}
137144
if (n & 4) {
145+
#else
146+
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
147+
#endif
138148
BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha);
139149
ap[0] += lda4;
140150
ap[1] += lda4;
151+
#ifndef USE_N_8
152+
ap[2] += lda4;
153+
ap[3] += lda4;
154+
#endif
141155
x_ptr += 4;
142156
}
143157
if (n & 2) {
@@ -149,6 +163,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
149163
BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha);
150164
}
151165
} else {
166+
#ifdef USE_N_8
152167
for (BLASLONG j = 0; j + 8 <= n; j += 8) {
153168
copy_x(8, x_ptr, xbuffer, inc_x);
154169
BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha);
@@ -159,10 +174,17 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
159174
x_ptr += 8 * inc_x;
160175
}
161176
if (n & 4) {
177+
#else
178+
for (BLASLONG j = 0; j + 4 <= n; j += 4) {
179+
#endif
162180
copy_x(4, x_ptr, xbuffer, inc_x);
163181
BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha);
164182
ap[0] += lda4;
165183
ap[1] += lda4;
184+
#ifndef USE_N_8
185+
ap[2] += lda4;
186+
ap[3] += lda4;
187+
#endif
166188
x_ptr += 4 * inc_x;
167189
}
168190
if (n & 2) {

0 commit comments

Comments
 (0)