Skip to content

Commit 9bfc361

Browse files
authored
Merge branch 'OpenMathLib:develop' into issue5414
2 parents 47a66ae + b6d5057 commit 9bfc361

File tree

15 files changed

+384
-16
lines changed

15 files changed

+384
-16
lines changed

.github/workflows/riscv64_vector.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
opts: TARGET=RISCV64_ZVL128B BINARY=64 ARCH=riscv64
2727
qemu_cpu: rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=128,elen=64
2828
- target: RISCV64_ZVL256B
29-
opts: TARGET=RISCV64_ZVL256B BINARY=64 ARCH=riscv64
29+
opts: TARGET=RISCV64_ZVL256B BINARY=64 ARCH=riscv64 BUILD_BFLOAT16=1 BUILD_HFLOAT16=1
3030
qemu_cpu: rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=256,elen=64
3131
- target: DYNAMIC_ARCH=1
3232
opts: TARGET=RISCV64_GENERIC BINARY=64 ARCH=riscv64 DYNAMIC_ARCH=1
@@ -40,7 +40,7 @@ jobs:
4040
run: |
4141
sudo apt-get update
4242
sudo apt-get install autoconf automake autotools-dev ninja-build make \
43-
libgomp1-riscv64-cross ccache
43+
libgomp1-riscv64-cross ccache qemu-kvm
4444
wget ${riscv_gnu_toolchain}/${riscv_gnu_toolchain_nightly_download_path}
4545
tar -xvf $(basename ${riscv_gnu_toolchain_nightly_download_path}) -C /opt
4646

cmake/cc.cmake

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ endif ()
213213

214214
if (${CORE} STREQUAL A64FX)
215215
if (NOT DYNAMIC_ARCH)
216-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
216+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
217217
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=a64fx")
218218
elseif (${GCC_VERSION} VERSION_GREATER 11.0 OR ${GCC_VERSION} VERSION_EQUAL 11.0)
219219
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8.2-a+sve -mtune=a64fx")
@@ -227,7 +227,7 @@ if (${CORE} STREQUAL NEOVERSEV2)
227227
if (NOT DYNAMIC_ARCH)
228228
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
229229
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8.5-a+sve+sve2+bf16 -mtune=neoverse-v2")
230-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
230+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
231231
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-v2")
232232
else ()
233233
if (${GCC_VERSION} VERSION_GREATER 13.0 OR ${GCC_VERSION} VERSION_EQUAL 13.0)
@@ -245,7 +245,7 @@ if (${CORE} STREQUAL NEOVERSEN2)
245245
if (NOT DYNAMIC_ARCH)
246246
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
247247
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8.5-a+sve+sve2+bf16 -mtune=neoverse-n2")
248-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
248+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
249249
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-v2")
250250
else ()
251251
if (${GCC_VERSION} VERSION_GREATER 11.1 OR ${GCC_VERSION} VERSION_EQUAL 11.1)
@@ -261,7 +261,7 @@ if (${CORE} STREQUAL NEOVERSEV1)
261261
if (NOT DYNAMIC_ARCH)
262262
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
263263
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8.4-a+sve+bf16 -mtune=neoverse-v1")
264-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
264+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
265265
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-v1")
266266
else ()
267267
if (${GCC_VERSION} VERSION_GREATER 10.4 OR ${GCC_VERSION} VERSION_EQUAL 10.4)
@@ -275,7 +275,7 @@ endif ()
275275

276276
if (${CORE} STREQUAL NEOVERSEN1)
277277
if (NOT DYNAMIC_ARCH)
278-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
278+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
279279
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-n1")
280280
elseif (${GCC_VERSION} VERSION_GREATER 9.4 OR ${GCC_VERSION} VERSION_EQUAL 9.4)
281281
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8.2-a -mtune=neoverse-n1")
@@ -287,7 +287,7 @@ endif ()
287287

288288
if (${CORE} STREQUAL AMPEREONE)
289289
if (NOT DYNAMIC_ARCH)
290-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC")
290+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC")
291291
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-n1")
292292
elseif (${GCC_VERSION} VERSION_GREATER 12.1)
293293
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8.6-a+crypto+crc+fp16+sha3+rng -mtune=ampereone")
@@ -301,7 +301,7 @@ if (${CORE} STREQUAL ARMV8SVE)
301301
if (NOT DYNAMIC_ARCH)
302302
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
303303
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8-a+sve")
304-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
304+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
305305
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=host")
306306
else ()
307307
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8-a+sve")
@@ -311,7 +311,7 @@ endif ()
311311

312312
if (${CORE} STREQUAL ARMV9SME)
313313
if (NOT DYNAMIC_ARCH)
314-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
314+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
315315
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=host")
316316
else ()
317317
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv9-a+sme")

common_level3.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N,
7272
float beta,
7373
float * R, BLASLONG strideR);
7474

75+
void strmm_direct_LNUN(BLASLONG M, BLASLONG N,
76+
float alpha,
77+
float * A, BLASLONG strideA,
78+
float * B, BLASLONG strideB);
79+
void strmm_direct_LNLN(BLASLONG M, BLASLONG N,
80+
float alpha,
81+
float * A, BLASLONG strideA,
82+
float * B, BLASLONG strideB);
83+
void strmm_direct_LTUN(BLASLONG M, BLASLONG N,
84+
float alpha,
85+
float * A, BLASLONG strideA,
86+
float * B, BLASLONG strideB);
87+
void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
88+
float alpha,
89+
float * A, BLASLONG strideA,
90+
float * B, BLASLONG strideB);
91+
7592
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
7693

7794
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BL
262262

263263
void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
264264
void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
265+
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
266+
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
267+
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
268+
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
265269
#endif
266270

267271

common_s.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
5353
#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU
5454
#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL
55+
#define STRMM_DIRECT_LNUN strmm_direct_LNUN
56+
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
57+
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
58+
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
5559

5660
#define SGEMM_ONCOPY sgemm_oncopy
5761
#define SGEMM_OTCOPY sgemm_otcopy
@@ -224,6 +228,10 @@
224228
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
225229
#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU
226230
#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL
231+
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
232+
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
233+
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
234+
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
227235
#endif
228236

229237
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

getarch.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2060,10 +2060,9 @@ int main(int argc, char *argv[]){
20602060
#endif
20612061

20622062

2063-
#ifdef INTEL_AMD
2064-
#ifndef FORCE
2063+
#if defined(INTEL_AMD) && !defined(FORCE)
20652064
get_sse();
2066-
#else
2065+
#elif defined(FORCE_INTEL)
20672066

20682067
sprintf(buffer, "%s", ARCHCONFIG);
20692068

@@ -2093,7 +2092,6 @@ int main(int argc, char *argv[]){
20932092
} else p ++;
20942093
}
20952094
#endif
2096-
#endif
20972095

20982096
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
20992097
printf("__BYTE_ORDER__=__ORDER_BIG_ENDIAN__\n");

interface/trsm.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,23 @@ void CNAME(enum CBLAS_ORDER order,
355355
return;
356356
}
357357

358+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
359+
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
360+
#if defined(DYNAMIC_ARCH)
361+
if (support_sme1())
362+
#endif
363+
if (args.m == 0 || args.n == 0) return;
364+
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft && m == lda && n == ldb) {
365+
if (Trans == CblasNoTrans) {
366+
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, alpha, a, lda, b, ldb);
367+
} else if (Trans == CblasTrans) {
368+
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, alpha, a, lda, b, ldb);
369+
}
370+
return;
371+
}
372+
#endif
373+
#endif
374+
358375
#endif
359376

360377
if ((args.m == 0) || (args.n == 0)) return;

kernel/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
241241
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
242242
set(USE_TRMM true)
243243
endif ()
244+
set(USE_DIRECT_STRMM false)
245+
if (ARM64)
246+
set(USE_DIRECT_STRMM true)
247+
endif()
244248
set(USE_DIRECT_SGEMM false)
245249
if (X86_64 OR ARM64)
246250
set(USE_DIRECT_SGEMM true)
@@ -285,6 +289,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
285289
endif ()
286290
endif()
287291

292+
if (USE_DIRECT_STRMM)
293+
if (ARM64)
294+
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
295+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
296+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
297+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
298+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
299+
endif ()
300+
endif ()
301+
288302
foreach (float_type SINGLE DOUBLE)
289303
string(SUBSTRING ${float_type} 0 1 float_char)
290304
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
@@ -460,6 +474,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
460474
set(TRMM_KERNEL "${${float_char}GEMMKERNEL}")
461475
endif ()
462476

477+
463478
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
464479

465480
# just enumerate all these. there is an extra define for these indicating which side is a conjugate (e.g. CN NC NN) that I don't really want to work into GenerateCombinationObjects

kernel/Makefile.L3

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ ifeq ($(ARCH), arm64)
5353
USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
5555
USE_DIRECT_SSYMM = 1
56+
USE_DIRECT_STRMM = 1
5657
endif
5758

5859
ifeq ($(ARCH), riscv64)
@@ -153,6 +154,18 @@ endif
153154
endif
154155
endif
155156

157+
ifdef USE_DIRECT_STRMM
158+
ifndef STRMMDIRECTKERNEL
159+
ifeq ($(ARCH), arm64)
160+
ifeq ($(TARGET_CORE), ARMV9SME)
161+
HAVE_SME = 1
162+
endif
163+
STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c
164+
endif
165+
endif
166+
endif
167+
168+
156169
ifeq ($(BUILD_BFLOAT16), 1)
157170
ifndef BGEMMKERNEL
158171
BGEMM_BETA = ../generic/gemm_beta.c
@@ -245,6 +258,14 @@ SKERNELOBJS += \
245258
endif
246259
endif
247260

261+
ifdef USE_DIRECT_STRMM
262+
ifeq ($(ARCH), arm64)
263+
SKERNELOBJS += \
264+
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
265+
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
266+
endif
267+
endif
268+
248269
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
249270
DKERNELOBJS += \
250271
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -1186,6 +1207,23 @@ else
11861207
$(CC) $(CFLAGS) -c -DTRMMKERNEL -UDOUBLE -UCOMPLEX -ULEFT -DTRANSA $< -o $@
11871208
endif
11881209

1210+
1211+
ifdef USE_DIRECT_STRMM
1212+
ifeq ($(ARCH), arm64)
1213+
$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1214+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@
1215+
1216+
$(KDIR)strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1217+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -UUPPER $< -o $@
1218+
1219+
$(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1220+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -DUPPER $< -o $@
1221+
1222+
$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1223+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@
1224+
endif
1225+
endif
1226+
11891227
$(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL)
11901228
ifeq ($(OS), AIX)
11911229
$(CC) $(CFLAGS) -S -DTRMMKERNEL -DDOUBLE -UCOMPLEX -DLEFT -UTRANSA $< -o - > dtrmm_kernel_ln.s

kernel/arm64/ssymm_direct_alpha_beta_arm64_sme1.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ static void ssymm_direct_sme1_preprocessLL(uint64_t nbr, uint64_t nbc,
201201
}
202202
}
203203
}
204-
204+
#else
205+
static void ssymm_direct_sme1_preprocessLU(uint64_t nbr, uint64_t nbc,
206+
const float *restrict a, float *restrict a_mod){}
207+
static void ssymm_direct_sme1_preprocessLL(uint64_t nbr, uint64_t nbc,
208+
const float *restrict a, float *restrict a_mod){}
205209
#endif
206210

207211
//

0 commit comments

Comments
 (0)