Skip to content

Commit cefd1b7

Browse files
sw: add unrolled kernels for dp-fdotp and dp-faxpy
1 parent c60a7b2 commit cefd1b7

File tree

7 files changed

+139
-0
lines changed

7 files changed

+139
-0
lines changed

sw/spatzBenchmarks/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ include_directories(${SNRUNTIME_INCLUDE_DIRS})
2424

2525
add_compile_options(-O3 -g -ffunction-sections)
2626

27+
# Use unrolled spatzBenchmarks
28+
if (UNROLL)
29+
add_definitions(-DUNROLL)
30+
endif()
2731

2832
# Macro to regenerate the golden values and compile a module
2933
macro(add_spatz_test_oneParam name file param1)

sw/spatzBenchmarks/dp-faxpy/kernel/faxpy.c

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,54 @@ void faxpy_v64b(const double a, const double *x, const double *y,
4545
} while (avl > 0);
4646
}
4747

48+
// Unrolled 64-bit AXPY: y = a * x + y
49+
void faxpy_v64b_unrl(const double a, const double *x, const double *y, unsigned int avl) {
50+
51+
unsigned int vl;
52+
double *y2;
53+
54+
// Stripmine and accumulate a partial vector
55+
do {
56+
// Set the vl
57+
asm volatile("vsetvli %0, %1, e64, m8, ta, ma" : "=r"(vl) : "r"(avl));
58+
59+
// Load vectors
60+
asm volatile("vle64.v v0, (%0)" ::"r"(x));
61+
asm volatile("vle64.v v8, (%0)" ::"r"(y));
62+
63+
// Multiply-accumulate
64+
asm volatile("vfmacc.vf v8, %0, v0" ::"f"(a));
65+
avl -= vl;
66+
if (avl > 0) {
67+
// Set the vl
68+
asm volatile("vsetvli %0, %1, e64, m8, ta, ma" : "=r"(vl) : "r"(avl));
69+
70+
// Load vectors
71+
x += vl;
72+
asm volatile("vle64.v v16, (%0)" ::"r"(x));
73+
y2 = y + vl;
74+
asm volatile("vle64.v v24, (%0)" ::"r"(y2));
75+
76+
// Multiply-accumulate
77+
asm volatile("vfmacc.vf v24, %0, v16" ::"f"(a));
78+
}
79+
80+
// Store results
81+
asm volatile("vse64.v v8, (%0)" ::"r"(y));
82+
if (avl > 0) {
83+
// Store results
84+
y += vl;
85+
asm volatile("vse64.v v24, (%0)" ::"r"(y));
86+
avl -= vl;
87+
}
88+
89+
// Bump pointers
90+
x += vl;
91+
y += vl;
92+
93+
} while (avl > 0);
94+
}
95+
4896
// 32-bit AXPY: y = a * x + y
4997
void faxpy_v32b(const float a, const float *x, const float *y,
5098
unsigned int avl) {

sw/spatzBenchmarks/dp-faxpy/kernel/faxpy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
inline void faxpy_v64b(const double a, const double *x, const double *y,
2323
unsigned int avl) __attribute__((always_inline));
24+
inline void faxpy_v64b_unrl(const double a, const double *x, const double *y,
25+
unsigned int avl) __attribute__((always_inline));
2426
inline void faxpy_v32b(const float a, const float *x, const float *y,
2527
unsigned int avl) __attribute__((always_inline));
2628
inline void faxpy_v16b(const _Float16 a, const _Float16 *x, const _Float16 *y,

sw/spatzBenchmarks/dp-faxpy/main.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ int main() {
6262

6363
snrt_dma_start_1d(x, axpy_X_dram, dim * sizeof(double));
6464
snrt_dma_start_1d(y, axpy_Y_dram, dim * sizeof(double));
65+
snrt_dma_wait_all();
6566
}
6667

6768
// Wait for all cores to finish
@@ -83,7 +84,11 @@ int main() {
8384
timer = benchmark_get_cycle();
8485

8586
// Call AXPY
87+
#ifdef UNROLL
88+
faxpy_v64b_unrl(*a, x_int, y_int, dim_core);
89+
#else
8690
faxpy_v64b(*a, x_int, y_int, dim_core);
91+
#endif
8792

8893
// Wait for all cores to finish
8994
snrt_cluster_hw_barrier();

sw/spatzBenchmarks/dp-fdotp/kernel/fdotp.c

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,80 @@ double fdotp_v64b(const double *a, const double *b, unsigned int avl) {
5959
return red;
6060
}
6161

62+
// 64-bit dot-product: a * b
63+
// m8 allows only for partial register re-allocation with factor-2 unrolling
64+
double fdotp_v64b_m8_unrl(const double *a, const double *b, unsigned int avl) {
65+
const unsigned int orig_avl = avl;
66+
unsigned int vl;
67+
68+
double red;
69+
70+
// Stripmine and accumulate a partial reduced vector
71+
do {
72+
// Set the vl
73+
asm volatile("vsetvli %0, %1, e64, m8, ta, ma" : "=r"(vl) : "r"(avl));
74+
75+
// Load chunk a and b
76+
asm volatile("vle64.v v8, (%0)" ::"r"(a));
77+
asm volatile("vle64.v v16, (%0)" ::"r"(b));
78+
79+
// Multiply and accumulate
80+
if (avl == orig_avl) {
81+
asm volatile("vfmul.vv v24, v8, v16");
82+
} else {
83+
asm volatile("vfmacc.vv v24, v8, v16");
84+
}
85+
86+
// Bump pointers
87+
a += vl;
88+
b += vl;
89+
avl -= vl;
90+
91+
if (avl <= 0) break;
92+
93+
// Set the vl
94+
asm volatile("vsetvli %0, %1, e64, m8, ta, ma" : "=r"(vl) : "r"(avl));
95+
96+
// Load chunk a and b
97+
asm volatile("vle64.v v0, (%0)" ::"r"(a));
98+
asm volatile("vle64.v v8, (%0)" ::"r"(b));
99+
100+
// Multiply and accumulate
101+
asm volatile("vfmacc.vv v24, v0, v8");
102+
103+
// Bump pointers
104+
a += vl;
105+
b += vl;
106+
avl -= vl;
107+
108+
if (avl <= 0) break;
109+
110+
// Set the vl
111+
asm volatile("vsetvli %0, %1, e64, m8, ta, ma" : "=r"(vl) : "r"(avl));
112+
113+
// Load chunk a and b
114+
asm volatile("vle64.v v16, (%0)" ::"r"(a));
115+
asm volatile("vle64.v v0, (%0)" ::"r"(b));
116+
117+
// Multiply and accumulate
118+
asm volatile("vfmacc.vv v24, v0, v16");
119+
120+
// Bump pointers
121+
a += vl;
122+
b += vl;
123+
avl -= vl;
124+
} while (avl > 0);
125+
126+
// Clean the accumulator
127+
asm volatile("vmv.s.x v0, zero");
128+
129+
// Reduce and return
130+
asm volatile("vfredusum.vs v0, v24, v0");
131+
asm volatile("vfmv.f.s %0, v0" : "=f"(red));
132+
133+
return red;
134+
}
135+
62136
// 32-bit dot-product: a * b
63137
float fdotp_v32b(const float *a, const float *b, unsigned int avl) {
64138
const unsigned int orig_avl = avl;

sw/spatzBenchmarks/dp-fdotp/kernel/fdotp.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
inline double fdotp_v64b(const double *a, const double *b, unsigned int avl)
2323
__attribute__((always_inline));
24+
inline double fdotp_v64b_m8_unrl(const double *a, const double *b, unsigned int avl)
25+
__attribute__((always_inline));
2426
inline float fdotp_v32b(const float *a, const float *b, unsigned int avl)
2527
__attribute__((always_inline));
2628
inline _Float16 fdotp_v16b(const _Float16 *a, const _Float16 *b,

sw/spatzBenchmarks/dp-fdotp/main.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ int main() {
8282

8383
// Calculate dotp
8484
double acc;
85+
#ifdef UNROLL
86+
acc = fdotp_v64b_m8_unrl(a_int, b_int, dim);
87+
#else
8588
acc = fdotp_v64b(a_int, b_int, dim);
89+
#endif
8690
result[cid] = acc;
8791

8892
// Wait for all cores to finish

0 commit comments

Comments
 (0)