Skip to content

Commit c889558

Browse files
authored
Rework for DYNAMIC_ARCH use and use of SGEMM functions by SSYMM
1 parent 4ae3e37 commit c889558

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

kernel/arm64/sgemm_direct_alpha_beta_arm64_sme1.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
#define COMBINE2(a,b) COMBINE(a,b)
2020
#define SME1_PREPROCESS_BASE sgemm_direct_sme1_preprocess
2121
#define SME1_PREPROCESS COMBINE2(SME1_PREPROCESS_BASE,TS)
22+
#define SME1_KERNEL2X2_BASE sgemm_direct_alpha_beta_sme1_2VLx2VL
23+
#define SME1_KERNEL2X2 COMBINE2(SME1_KERNEL2X2_BASE,TS)
2224
#else
2325
#define SME1_PREPROCESS sgemm_direct_sme1_preprocess
26+
#define SME1_KERNEL2X2 sgemm_direct_alpha_beta_sme1_2VLx2VL
2427
#endif
2528
/* Function prototypes */
2629
extern void SME1_PREPROCESS(uint64_t nbr, uint64_t nbc,\
@@ -111,7 +114,7 @@ return;
111114
}
112115

113116
__arm_new("za") __arm_locally_streaming
114-
static void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
117+
void SME1_KERNEL2X2(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
115118
const float *ba, const float *restrict bb, const float* beta,\
116119
float *restrict C) {
117120

@@ -151,7 +154,7 @@ static void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_
151154
}
152155

153156
#else
154-
void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
157+
void SME1_KERNEL2X2(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
155158
const float *ba, const float *restrict bb, const float* beta,\
156159
float *restrict C){fprintf(stderr,"empty sgemm_alpha_beta2x2 should never get called!!!\n");}
157160
#endif
@@ -197,7 +200,7 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict
197200

198201
/* Calculate C = alpha*A*B + beta*C */
199202

200-
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R);
203+
SME1_KERNEL2X2(M, K, N, &alpha, A_mod, B, &beta, R);
201204

202205
free(A_mod);
203206
}

kernel/arm64/ssymm_direct_alpha_beta_arm64_sme1.c

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,23 @@
1414
#include <arm_sme.h>
1515
#endif
1616

17+
#if defined(DYNAMIC_ARCH)
18+
#define COMBINE(a,b) a ## b
19+
#define COMBINE2(a,b) COMBINE(a,b)
20+
#define SGEMM_PREPROCESS_BASE sgemm_direct_sme1_preprocess
21+
#define SGEMM_PREPROCESS COMBINE2(SGEMM_PREPROCESS_BASE,TS)
22+
#define SGEMM_DIRECT2X2_BASE sgemm_direct_alpha_beta_sme1_2VLx2VL
23+
#define SGEMM_DIRECT2X2 COMBINE2(SGEMM_DIRECT2X2_BASE,TS)
24+
#else
25+
#define SGEMM_PREPROCESS sgemm_direct_sme1_preprocess
26+
#define SGEMM_DIRECT2X2 sgemm_direct_alpha_beta_sme1_2VLx2VL
27+
#endif
28+
1729
/* Function prototypes */
18-
extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\
19-
const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess");
30+
extern void SGEMM_PREPROCESS(uint64_t nbr, uint64_t nbc,\
31+
const float * restrict a, float * a_mod);
2032

21-
extern void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
33+
extern void SGEMM_DIRECT2X2(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
2234
const float *ba, const float *restrict bb, const float* beta,\
2335
float *restrict C);
2436
/* Function Definitions */
@@ -212,7 +224,7 @@ void CNAME(BLASLONG M, BLASLONG N, float alpha, float *__restrict A,
212224
#endif
213225

214226
/* Calculate C = alpha*A*B + beta*C */
215-
sgemm_direct_alpha_beta_sme1_2VLx2VL(M, M, N, &alpha, A_mod, B, &beta, R);
227+
SGEMM_DIRECT2X2(M, M, N, &alpha, A_mod, B, &beta, R);
216228
free(A_mod);
217229
}
218230

0 commit comments

Comments
 (0)