|
19 | 19 | #define COMBINE2(a,b) COMBINE(a,b) |
20 | 20 | #define SME1_PREPROCESS_BASE sgemm_direct_sme1_preprocess |
21 | 21 | #define SME1_PREPROCESS COMBINE2(SME1_PREPROCESS_BASE,TS) |
| 22 | +#define SME1_KERNEL2X2_BASE sgemm_direct_alpha_beta_sme1_2VLx2VL |
| 23 | +#define SME1_KERNEL2X2 COMBINE2(SME1_KERNEL2X2_BASE,TS) |
22 | 24 | #else |
23 | 25 | #define SME1_PREPROCESS sgemm_direct_sme1_preprocess |
| 26 | +#define SME1_KERNEL2X2 sgemm_direct_alpha_beta_sme1_2VLx2VL |
24 | 27 | #endif |
25 | 28 | /* Function prototypes */ |
26 | 29 | extern void SME1_PREPROCESS(uint64_t nbr, uint64_t nbc,\ |
@@ -111,7 +114,7 @@ return; |
111 | 114 | } |
112 | 115 |
|
113 | 116 | __arm_new("za") __arm_locally_streaming |
114 | | -static void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ |
| 117 | +void SME1_KERNEL2X2(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ |
115 | 118 | const float *ba, const float *restrict bb, const float* beta,\ |
116 | 119 | float *restrict C) { |
117 | 120 |
|
@@ -151,7 +154,7 @@ static void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_ |
151 | 154 | } |
152 | 155 |
|
153 | 156 | #else |
154 | | -void sgemm_direct_alpha_beta_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ |
| 157 | +void SME1_KERNEL2X2(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ |
155 | 158 | const float *ba, const float *restrict bb, const float* beta,\ |
156 | 159 | float *restrict C){fprintf(stderr,"empty sgemm_alpha_beta2x2 should never get called!!!\n");} |
157 | 160 | #endif |
@@ -197,7 +200,7 @@ void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict |
197 | 200 |
|
198 | 201 | /* Calculate C = alpha*A*B + beta*C */ |
199 | 202 |
|
200 | | - sgemm_direct_alpha_beta_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B, &beta, R); |
| 203 | + SME1_KERNEL2X2(M, K, N, &alpha, A_mod, B, &beta, R); |
201 | 204 |
|
202 | 205 | free(A_mod); |
203 | 206 | } |
|
0 commit comments