Skip to content

Commit 3f110c8

Browse files
committed
Improve bgemm and sbgemm testing
- Fixes wrong return type for `is_close` - Adds stricter compiler flags for test files so we don't see the above issue again - Re-uses test helper functions between compare_sgemm_sbgemm/bgemm.c
1 parent 81b30d4 commit 3f110c8

File tree

4 files changed

+84
-105
lines changed

4 files changed

+84
-105
lines changed

test/Makefile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# 3. Neither the name of the OpenBLAS project nor the names of
1414
# its contributors may be used to endorse or promote products
1515
# derived from this software without specific prior written permission.
16-
#
16+
#
1717
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
1818
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
1919
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
@@ -34,6 +34,7 @@ ifneq (, $(filter $(CORE),LOONGSON3R3 LOONGSON3R4))
3434
endif
3535
override FFLAGS += -fno-tree-vectorize
3636
endif
37+
override CFLAGS += -std=c11 -Wall -Werror
3738

3839
SUPPORT_GEMM3M = 0
3940

@@ -402,10 +403,10 @@ zblat3 : zblat3.$(SUFFIX) ../$(LIBNAME)
402403
endif
403404

404405
ifeq ($(BUILD_BFLOAT16),1)
405-
test_bgemm : compare_sgemm_bgemm.c ../$(LIBNAME)
406+
test_bgemm : compare_sgemm_bgemm.c test_helpers.h ../$(LIBNAME)
406407
$(CC) $(CLDFLAGS) -o test_bgemm compare_sgemm_bgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
407408

408-
test_sbgemm : compare_sgemm_sbgemm.c ../$(LIBNAME)
409+
test_sbgemm : compare_sgemm_sbgemm.c test_helpers.h ../$(LIBNAME)
409410
$(CC) $(CLDFLAGS) -o test_sbgemm compare_sgemm_sbgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB)
410411
endif
411412

test/compare_sgemm_bgemm.c

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,13 @@ THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828
#include <stdint.h>
2929
#include <stdio.h>
3030

31+
#include "test_helpers.h"
3132

3233
#define SGEMM BLASFUNC(sgemm)
3334
#define BGEMM BLASFUNC(bgemm)
3435
#define BGEMM_LARGEST 256
3536

36-
static float float16to32(bfloat16 value)
37-
{
38-
blasint one = 1;
39-
float result;
40-
sbf16tos_(&one, &value, &one, &result, &one);
41-
return result;
42-
}
43-
44-
static float truncate_float(float value) {
37+
static float truncate_float32_to_bfloat16(float value) {
4538
blasint one = 1;
4639
bfloat16 tmp;
4740
float result;
@@ -50,17 +43,6 @@ static float truncate_float(float value) {
5043
return result;
5144
}
5245

53-
static void *malloc_safe(size_t size) {
54-
if (size == 0)
55-
return malloc(1);
56-
else
57-
return malloc(size);
58-
}
59-
60-
static float is_close(float a, float b, float rtol, float atol) {
61-
return fabs(a - b) <= (atol + rtol*fabs(b));
62-
}
63-
6446
int
6547
main (int argc, char *argv[])
6648
{
@@ -151,15 +133,15 @@ main (int argc, char *argv[])
151133
DD[i * m + j] +=
152134
float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]);
153135
}
154-
if (!is_close(float16to32(CC[i * m + j]), truncate_float(C[i * m + j]), 0.01, 0.001)) {
136+
if (!is_close(float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(C[i * m + j]), 0.01, 0.001)) {
155137
printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, C=%.6f\n",
156-
i, j, k, float16to32(CC[i * m + j]), truncate_float(C[i * m + j]));
138+
i, j, k, float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(C[i * m + j]));
157139
ret++;
158140
}
159141

160-
if (!is_close(float16to32(CC[i * m + j]), truncate_float(DD[i * m + j]), 0.0001, 0.00001)) {
142+
if (!is_close(float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(DD[i * m + j]), 0.0001, 0.00001)) {
161143
printf("Mismatch at i=%d, j=%d, k=%d: CC=%.6f, DD=%.6f\n",
162-
i, j, k, float16to32(CC[i * m + j]), truncate_float(DD[i * m + j]));
144+
i, j, k, float16to32(CC[i * m + j]), truncate_float32_to_bfloat16(DD[i * m + j]));
163145
ret++;
164146
}
165147

test/compare_sgemm_sbgemm.c

Lines changed: 19 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -27,72 +27,15 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727
#include <stdio.h>
2828
#include <stdint.h>
2929
#include "../common.h"
30+
31+
#include "test_helpers.h"
32+
3033
#define SGEMM BLASFUNC(sgemm)
3134
#define SBGEMM BLASFUNC(sbgemm)
3235
#define SGEMV BLASFUNC(sgemv)
3336
#define SBGEMV BLASFUNC(sbgemv)
34-
typedef union
35-
{
36-
unsigned short v;
37-
#if defined(_AIX)
38-
struct __attribute__((packed))
39-
#else
40-
struct
41-
#endif
42-
{
43-
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
44-
unsigned short s:1;
45-
unsigned short e:8;
46-
unsigned short m:7;
47-
#else
48-
unsigned short m:7;
49-
unsigned short e:8;
50-
unsigned short s:1;
51-
#endif
52-
} bits;
53-
} bfloat16_bits;
54-
55-
typedef union
56-
{
57-
float v;
58-
#if defined(_AIX)
59-
struct __attribute__((packed))
60-
#else
61-
struct
62-
#endif
63-
{
64-
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
65-
uint32_t s:1;
66-
uint32_t e:8;
67-
uint32_t m:23;
68-
#else
69-
uint32_t m:23;
70-
uint32_t e:8;
71-
uint32_t s:1;
72-
#endif
73-
} bits;
74-
} float32_bits;
75-
76-
float
77-
float16to32 (bfloat16_bits f16)
78-
{
79-
float32_bits f32;
80-
f32.bits.s = f16.bits.s;
81-
f32.bits.e = f16.bits.e;
82-
f32.bits.m = (uint32_t) f16.bits.m << 16;
83-
return f32.v;
84-
}
85-
8637
#define SBGEMM_LARGEST 256
8738

88-
void *malloc_safe(size_t size)
89-
{
90-
if (size == 0)
91-
return malloc(1);
92-
else
93-
return malloc(size);
94-
}
95-
9639
int
9740
main (int argc, char *argv[])
9841
{
@@ -111,32 +54,29 @@ main (int argc, char *argv[])
11154
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
11255
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
11356
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
114-
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits));
115-
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits));
57+
bfloat16 *AA = (bfloat16 *)malloc_safe(m * k * sizeof(bfloat16));
58+
bfloat16 *BB = (bfloat16 *)malloc_safe(k * n * sizeof(bfloat16));
11659
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
11760
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
11861
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
11962
(DD == NULL) || (CC == NULL))
12063
return 1;
121-
bfloat16 atmp,btmp;
12264
blasint one=1;
12365

12466
for (j = 0; j < m; j++)
12567
{
12668
for (i = 0; i < k; i++)
12769
{
12870
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
129-
sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
130-
AA[j * k + i].v = atmp;
71+
sbstobf16_(&one, &A[j*k+i], &one, &AA[j * k + i], &one);
13172
}
13273
}
13374
for (j = 0; j < n; j++)
13475
{
13576
for (i = 0; i < k; i++)
13677
{
13778
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
138-
sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
139-
BB[j * k + i].v = btmp;
79+
sbstobf16_(&one, &B[j*k+i], &one, &BB[j * k + i], &one);
14080
}
14181
}
14282
for (y = 0; y < 4; y++)
@@ -182,10 +122,12 @@ main (int argc, char *argv[])
182122
DD[i * m + j] +=
183123
float16to32 (AA[k * j + l]) * float16to32 (BB[i + l * n]);
184124
}
185-
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
125+
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
186126
ret++;
187-
if (fabs (CC[i * m + j] - DD[i * m + j]) > 1.0)
127+
}
128+
if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) {
188129
ret++;
130+
}
189131
}
190132
}
191133
free(A);
@@ -211,27 +153,24 @@ main (int argc, char *argv[])
211153
float *A = (float *)malloc_safe(x * x * sizeof(FLOAT));
212154
float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l);
213155
float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l);
214-
bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits));
215-
bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l);
156+
bfloat16 *AA = (bfloat16 *)malloc_safe(x * x * sizeof(bfloat16));
157+
bfloat16 *BB = (bfloat16 *)malloc_safe(x * sizeof(bfloat16) << l);
216158
float *DD = (float *)malloc_safe(x * sizeof(FLOAT));
217159
float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l);
218160
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
219161
(DD == NULL) || (CC == NULL))
220162
return 1;
221-
bfloat16 atmp, btmp;
222163
blasint one = 1;
223164

224165
for (j = 0; j < x; j++)
225166
{
226167
for (i = 0; i < x; i++)
227168
{
228169
A[j * x + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
229-
sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one);
230-
AA[j * x + i].v = atmp;
170+
sbstobf16_(&one, &A[j*x+i], &one, &AA[j * x + i], &one);
231171
}
232172
B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
233-
sbstobf16_(&one, &B[j << l], &one, &btmp, &one);
234-
BB[j << l].v = btmp;
173+
sbstobf16_(&one, &B[j << l], &one, &BB[j << l], &one);
235174

236175
CC[j << l] = C[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
237176
}
@@ -262,10 +201,12 @@ main (int argc, char *argv[])
262201
}
263202

264203
for (j = 0; j < x; j++) {
265-
if (fabs (CC[j << l] - C[j << l]) > 1.0)
204+
if (!is_close(CC[j << l], C[j << l], 0.01, 0.001)) {
266205
ret++;
267-
if (fabs (CC[j << l] - DD[j]) > 1.0)
206+
}
207+
if (!is_close(CC[j << l], DD[j], 0.001, 0.0001)) {
268208
ret++;
209+
}
269210
}
270211
}
271212
free(A);

test/test_helpers.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/***************************************************************************
2+
Copyright (c) 2025 The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
22+
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
25+
THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
28+
#ifndef TEST_HELPERS_H
29+
#define TEST_HELPERS_H
30+
#include <stdbool.h>
31+
32+
#include "../common.h"
33+
34+
#if IFLOAT == bfloat16
35+
static float float16to32(bfloat16 value)
36+
{
37+
blasint one = 1;
38+
float result;
39+
sbf16tos_(&one, &value, &one, &result, &one);
40+
return result;
41+
}
42+
#endif
43+
44+
static void *malloc_safe(size_t size) {
45+
if (size == 0)
46+
return malloc(1);
47+
else
48+
return malloc(size);
49+
}
50+
51+
static bool is_close(float a, float b, float rtol, float atol) {
52+
return fabs(a - b) <= (atol + rtol*fabs(b));
53+
}
54+
55+
#endif

0 commit comments

Comments
 (0)