Skip to content

Commit 1ef97c4

Browse files
authored
Merge pull request #3550 from guowangy/smatrix-mask-fix
Small Matrix: use proper inline asm input constraint for AVX512 mask
2 parents 10b0428 + 2256832 commit 1ef97c4

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

kernel/x86_64/dgemm_small_kernel_nn_skylakex.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4848
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
4949
#define MASK_STORE_512(M, N) \
5050
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
51-
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
51+
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
5252
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
5353
#endif
5454

@@ -266,7 +266,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
266266
int mm = M - i;
267267
if (!mm) return 0;
268268
if (mm > 4 || K < 16) {
269-
register __mmask8 mask asm("k1") = (1UL << mm) - 1;
269+
register __mmask8 mask = (1UL << mm) - 1;
270270
for (j = 0; j < n6; j += 6) {
271271
DECLARE_RESULT_512(0, 0);
272272
DECLARE_RESULT_512(0, 1);

kernel/x86_64/dgemm_small_kernel_nt_skylakex.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5555
_mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N)
5656
#define MASK_STORE_512(M, N) \
5757
result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
58-
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \
58+
asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "Yk"(mask)); \
5959
_mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N)
6060
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \
6161
__m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \
@@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
303303
}
304304
int mm = M - i;
305305
if (mm >= 6) {
306-
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
306+
register __mmask16 mask = (1UL << mm) - 1;
307307
for (j = 0; j < n8; j += 8) {
308308
DECLARE_RESULT_512(0, 0);
309309
DECLARE_RESULT_512(0, 1);

kernel/x86_64/sgemm_small_kernel_nn_skylakex.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4848
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
4949
#define MASK_STORE_512(M, N) \
5050
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
51-
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
51+
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
5252
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
5353
#endif
5454

@@ -267,7 +267,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
267267
int mm = M - i;
268268
if (!mm) return 0;
269269
if (mm > 8 || K < 32) {
270-
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
270+
register __mmask16 mask = (1UL << mm) - 1;
271271
for (j = 0; j < n6; j += 6) {
272272
DECLARE_RESULT_512(0, 0);
273273
DECLARE_RESULT_512(0, 1);

kernel/x86_64/sgemm_small_kernel_nt_skylakex.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
5555
_mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N)
5656
#define MASK_STORE_512(M, N) \
5757
result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
58-
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \
58+
asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "Yk"(mask)); \
5959
_mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N)
6060
#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
6161
__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
@@ -303,7 +303,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alp
303303
}
304304
int mm = M - i;
305305
if (mm >= 12) {
306-
register __mmask16 mask asm("k1") = (1UL << mm) - 1;
306+
register __mmask16 mask = (1UL << mm) - 1;
307307
for (j = 0; j < n8; j += 8) {
308308
DECLARE_RESULT_512(0, 0);
309309
DECLARE_RESULT_512(0, 1);

0 commit comments

Comments
 (0)