Skip to content

Commit 09a016f

Browse files
committed
Split sbgemv test from sbgemm test
1 parent 3f110c8 commit 09a016f

File tree

5 files changed

+153
-84
lines changed

5 files changed

+153
-84
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ test/ZBLAT2.SUMM
8181
test/ZBLAT3.SUMM
8282
test/ZBLAT3_3M.SUMM
8383
test/SHBLAT3.SUMM
84+
test/SBBLAT2.SUMM
8485
test/SBBLAT3.SUMM
8586
test/BBLAT3.SUMM
8687
test/cblat1
@@ -97,6 +98,7 @@ test/sblat3
9798
test/sblat3_3m
9899
test/test_shgemm
99100
test/test_sbgemm
101+
test/test_sbgemv
100102
test/test_bgemm
101103
test/zblat1
102104
test/zblat2

test/Makefile

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ endif
119119
endif
120120
endif
121121

122+
ifeq ($(BUILD_BFLOAT16), 1)
123+
B2 = test_sbgemv
124+
endif
122125
ifeq ($(BUILD_SINGLE),1)
123126
S2=sblat2
124127
endif
@@ -132,11 +135,15 @@ ifeq ($(BUILD_COMPLEX16),1)
132135
Z2=zblat2
133136
endif
134137

135-
level2: $(S2) $(D2) $(C2) $(Z2)
138+
level2: $(B2) $(S2) $(D2) $(C2) $(Z2)
136139

137140

138141
ifneq ($(CROSS), 1)
139142
rm -f ?BLAT2.SUMM
143+
ifeq ($(BUILD_BFLOAT16),1)
144+
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./test_sbgemv > SBBLAT2.SUMM
145+
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
146+
endif
140147
ifeq ($(BUILD_SINGLE),1)
141148
OPENBLAS_NUM_THREADS=1 OMP_NUM_THREADS=1 ./sblat2 < ./sblat2.dat
142149
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
@@ -156,6 +163,10 @@ endif
156163
ifdef SMP
157164
rm -f ?BLAT2.SUMM
158165
ifeq ($(USE_OPENMP), 1)
166+
ifeq ($(BUILD_BFLOAT16),1)
167+
OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM
168+
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
169+
endif
159170
ifeq ($(BUILD_SINGLE),1)
160171
OMP_NUM_THREADS=2 ./sblat2 < ./sblat2.dat
161172
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
@@ -173,6 +184,10 @@ ifeq ($(BUILD_COMPLEX16),1)
173184
@$(GREP) -q FATAL ZBLAT2.SUMM && cat ZBLAT2.SUMM || exit 0
174185
endif
175186
else
187+
ifeq ($(BUILD_BFLOAT16),1)
188+
OMP_NUM_THREADS=2 ./test_sbgemv > SBBLAT2.SUMM
189+
@$(GREP) -q FATAL SBBLAT2.SUMM && cat SBBLAT2.SUMM || exit 0
190+
endif
176191
ifeq ($(BUILD_SINGLE),1)
177192
OPENBLAS_NUM_THREADS=2 ./sblat2 < ./sblat2.dat
178193
@$(GREP) -q FATAL SBLAT2.SUMM && cat SBLAT2.SUMM || exit 0
@@ -195,7 +210,7 @@ endif
195210

196211
ifeq ($(BUILD_BFLOAT16),1)
197212
BF3= test_bgemm
198-
B3= test_sbgemm
213+
B3 = test_sbgemm
199214
endif
200215
ifeq ($(BUILD_SINGLE),1)
201216
S3=sblat3
@@ -408,6 +423,9 @@ test_bgemm : compare_sgemm_bgemm.c test_helpers.h ../$(LIBNAME)
408423

409424
test_sbgemm : compare_sgemm_sbgemm.c test_helpers.h ../$(LIBNAME)
410425
$(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
426+
427+
test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME)
428+
$(CC) $(CLDFLAGS) -o test_sbgemv compare_sgemv_sbgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
411429
endif
412430

413431
ifeq ($(BUILD_COMPLEX),1)
@@ -426,7 +444,7 @@ clean:
426444
@rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \
427445
sblat1 dblat1 cblat1 zblat1 \
428446
sblat2 dblat2 cblat2 zblat2 \
429-
test_bgemm test_sbgemm sblat3 dblat3 cblat3 zblat3 \
447+
test_bgemm test_sbgemm test_sbgemv sblat3 dblat3 cblat3 zblat3 \
430448
sblat1p dblat1p cblat1p zblat1p \
431449
sblat2p dblat2p cblat2p zblat2p \
432450
sblat3p dblat3p cblat3p zblat3p \

test/compare_sgemm_bgemm.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ main (int argc, char *argv[])
158158

159159
if (ret != 0) {
160160
fprintf (stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret);
161-
return ret;
162161
}
162+
163+
return ret;
163164
}

test/compare_sgemm_sbgemm.c

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -141,87 +141,7 @@ main (int argc, char *argv[])
141141

142142
if (ret != 0) {
143143
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
144-
return ret;
145144
}
146145

147-
for (beta = 0; beta < 3; beta += 1) {
148-
for (alpha = 0; alpha < 3; alpha += 1) {
149-
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
150-
for (x = 1; x <= loop; x++)
151-
{
152-
k = (x == 0) ? 0 : l + 1;
153-
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
154-
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
155-
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
156-
bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16));
157-
bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l);
158-
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
159-
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
160-
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
161-
(DD == NULL) || (CC == NULL))
162-
return 1;
163-
blasint one = 1;
164-
165-
for (j = 0; j < x; j++)
166-
{
167-
for (i = 0; i < x; i++)
168-
{
169-
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
170-
sbstobf16_(&one, &A[j*x+i], &one, &AA[j * x + i], &one);
171-
}
172-
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
173-
sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one);
174-
175-
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
176-
}
177-
178-
for (y = 0; y < 2; y++)
179-
{
180-
if (y == 0) {
181-
transA = 'N';
182-
} else {
183-
transA = 'T';
184-
}
185-
186-
memset(CC, 0, x * sizeof(FLOAT) << l);
187-
memset(DD, 0, x * sizeof(FLOAT));
188-
memset(C, 0, x * sizeof(FLOAT) << l);
189-
190-
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
191-
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
192-
193-
for (int i = 0; i < x; i ++) DD[i] *= beta;
194-
195-
for (j = 0; j < x; j++)
196-
for (i = 0; i < x; i++)
197-
if (transA == 'N') {
198-
DD[i] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
199-
} else if (transA == 'T') {
200-
DD[j] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
201-
}
202-
203-
for (j = 0; j < x; j++) {
204-
if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) {
205-
ret++;
206-
}
207-
if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) {
208-
ret++;
209-
}
210-
}
211-
}
212-
free(A);
213-
free(B);
214-
free(C);
215-
free(AA);
216-
free(BB);
217-
free(DD);
218-
free(CC);
219-
} // x
220-
} // l
221-
} // alpha
222-
} // beta
223-
224-
if (ret != 0)
225-
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
226146
return ret;
227147
}

test/compare_sgemv_sbgemv.c

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/***************************************************************************
2+
Copyright (c) 2020,2025 The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
#include <stdio.h>
28+
#include <stdint.h>
29+
#include "../common.h"
30+
31+
#include "test_helpers.h"
32+
33+
#define SGEMV BLASFUNC(sgemv)
34+
#define SBGEMV BLASFUNC(sbgemv)
35+
#define SBGEMV_LARGEST 256
36+
37+
int
38+
main (int argc, char *argv[])
39+
{
40+
blasint k;
41+
int i, j, l;
42+
blasint x, y;
43+
int ret = 0;
44+
int loop = SBGEMV_LARGEST;
45+
char transA = 'N';
46+
float alpha = 1.0, beta = 0.0;
47+
48+
for (beta = 0; beta < 3; beta += 1) {
49+
for (alpha = 0; alpha < 3; alpha += 1) {
50+
for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one.
51+
for (x = 1; x <= loop; x++)
52+
{
53+
k = (x == 0) ? 0 : l + 1;
54+
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
55+
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
56+
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
57+
bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16));
58+
bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l);
59+
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
60+
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
61+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
62+
(DD == NULL) || (CC == NULL))
63+
return 1;
64+
blasint one = 1;
65+
66+
for (j = 0; j < x; j++)
67+
{
68+
for (i = 0; i < x; i++)
69+
{
70+
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
71+
sbstobf16_(&one, &A[j*x+i], &one, &AA[j * x + i], &one);
72+
}
73+
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
74+
sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one);
75+
76+
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
77+
}
78+
79+
for (y = 0; y < 2; y++)
80+
{
81+
if (y == 0) {
82+
transA = 'N';
83+
} else {
84+
transA = 'T';
85+
}
86+
87+
memset(CC, 0, x * sizeof(FLOAT) << l);
88+
memset(DD, 0, x * sizeof(FLOAT));
89+
memset(C, 0, x * sizeof(FLOAT) << l);
90+
91+
SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k);
92+
SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k);
93+
94+
for (int i = 0; i < x; i ++) DD[i] *= beta;
95+
96+
for (j = 0; j < x; j++)
97+
for (i = 0; i < x; i++)
98+
if (transA == 'N') {
99+
DD[i] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]);
100+
} else if (transA == 'T') {
101+
DD[j] += alpha * float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]);
102+
}
103+
104+
for (j = 0; j < x; j++) {
105+
if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) {
106+
ret++;
107+
}
108+
if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) {
109+
ret++;
110+
}
111+
}
112+
}
113+
free(A);
114+
free(B);
115+
free(C);
116+
free(AA);
117+
free(BB);
118+
free(DD);
119+
free(CC);
120+
} // x
121+
} // l
122+
} // alpha
123+
} // beta
124+
125+
if (ret != 0)
126+
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
127+
return ret;
128+
}

0 commit comments

Comments
 (0)