diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 78232f0..7e14a01 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,12 +27,12 @@ jobs: # version: 'nightly' # os: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v5 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: @@ -50,4 +50,3 @@ jobs: file: lcov.info env: COVERALLS_TOKEN: ${{ secrets.COVERALLS_TOKEN }} - diff --git a/Project.toml b/Project.toml index ca4fb7b..e98b524 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.1" [deps] DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/BLIS.jl b/src/BLIS.jl index 08fc700..c30e2ca 100644 --- a/src/BLIS.jl +++ b/src/BLIS.jl @@ -5,6 +5,7 @@ using Libdl using blis_jll using LinearAlgebra +global blis_path = "" global libblis = C_NULL __init__() = begin @@ -13,13 +14,21 @@ __init__() = begin @info "Using custom defined BLIS installation instead of blis_jll." global libblis = dlopen(joinpath(get(ENV, "BLISDIR", ""), "lib/libblis")) else - blis_path = blis_jll.blis_path + global blis_path = blis_jll.blis_path # Use BinaryBuilder provided BLIS library. @info "blis_jll yields BLIS installation: $blis_path." global libblis = dlopen(blis_path) end end +export switch_blas +switch_blas(; clear=false, verbose=false) = begin + if libblis == C_NULL + throw(ErrorException("BLIS library not found under $blis_path.")) + end + BLAS.lbt_forward(blis_path, clear=clear, verbose=verbose) +end + # Data types. module Types include("types.jl") diff --git a/src/interface_linalg/level1.jl b/src/interface_linalg/level1.jl index c278afb..d364e30 100644 --- a/src/interface_linalg/level1.jl +++ b/src/interface_linalg/level1.jl @@ -1,6 +1,7 @@ # Level-1 LinearAlgebra.BLAS interface. # +#= # NOTE: scal! and blascopy! have incx in its parmeter. # No further StridedVector interface is provided. macro blis_interface_linalg_lv1_scal(T1, targetfunc, bliname) @@ -63,6 +64,7 @@ end @blis_interface_linalg_lv1_copy Float64 blascopy! copyv! @blis_interface_linalg_lv1_copy ComplexF32 blascopy! copyv! @blis_interface_linalg_lv1_copy ComplexF64 blascopy! copyv! +=# macro blis_interface_linalg_lv1_axpy(Tc1, T1, T2, targetfunc, bliname) diff --git a/src/interface_linalg/level2.jl b/src/interface_linalg/level2.jl index c26bf01..f1027d5 100644 --- a/src/interface_linalg/level2.jl +++ b/src/interface_linalg/level2.jl @@ -53,14 +53,8 @@ end @doc """ gemv!(tA, α, A, x, β, y) - BLIS-based GEMV with strides support & mixed-precision. + BLIS-based GEMV with strides support. """ gemv! -@blis_interface_linalg_lv2_gemv(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - gemv!, gemv!) @blis_interface_linalg_lv2_gemv Float32 Float32 Float32 Float32 Float32 gemv! gemv! @blis_interface_linalg_lv2_gemv Float64 Float64 Float64 Float64 Float64 gemv! gemv! @blis_interface_linalg_lv2_gemv ComplexF32 ComplexF32 ComplexF32 ComplexF32 ComplexF32 gemv! gemv! @@ -94,7 +88,7 @@ macro blis_interface_linalg_lv2_hemv(Tc1, T1, T2, Tc2, T3, targetfunc, bliname, oβ = BliObj(β) oy = BliObj(y) - ObjectBackend.bli_obj_set_uplo!(bli_ul, oA) + ObjectBackend.bli_obj_set_uplo!(bli_ul, oA.obj) ObjectBackend.bli_obj_set_struc!($bli_struc, oA.obj) $blifunc(oα, oA, ox, oβ, oy) y @@ -105,15 +99,8 @@ end @doc """ hemv!(ul, α, A, x, β, y) - BLIS-based HEMV with strides support & mixed-precision. + BLIS-based HEMV with strides support. """ hemv! -@blis_interface_linalg_lv2_hemv(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - hemv!, hemv!, - BLIS_HERMITIAN) @blis_interface_linalg_lv2_hemv Float32 Float32 Float32 Float32 Float32 hemv! hemv! BLIS_HERMITIAN @blis_interface_linalg_lv2_hemv Float64 Float64 Float64 Float64 Float64 hemv! hemv! BLIS_HERMITIAN @blis_interface_linalg_lv2_hemv ComplexF32 ComplexF32 ComplexF32 ComplexF32 ComplexF32 hemv! hemv! BLIS_HERMITIAN @@ -121,15 +108,8 @@ end @doc """ symv!(ul, α, A, x, β, y) - BLIS-based SYMV with strides support & mixed-precision. + BLIS-based SYMV with strides support. """ symv! -@blis_interface_linalg_lv2_hemv(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - symv!, symv!, - BLIS_SYMMETRIC) @blis_interface_linalg_lv2_hemv Float32 Float32 Float32 Float32 Float32 symv! symv! BLIS_SYMMETRIC @blis_interface_linalg_lv2_hemv Float64 Float64 Float64 Float64 Float64 symv! symv! BLIS_SYMMETRIC @blis_interface_linalg_lv2_hemv ComplexF32 ComplexF32 ComplexF32 ComplexF32 ComplexF32 symv! symv! BLIS_SYMMETRIC @@ -162,9 +142,10 @@ macro blis_interface_linalg_lv2_trmv(T1, T2, targetfunc, bliname) oA = BliObj(A) ob = BliObj(b) - ObjectBackend.bli_obj_set_uplo!(bli_ul, oA) - ObjectBackend.bli_obj_set_diag!(bli_dA, oA) - ObjectBackend.bli_obj_set_onlytrans!(bli_tA, oA) + ObjectBackend.bli_obj_set_uplo!(bli_ul, oA.obj) + ObjectBackend.bli_obj_set_diag!(bli_dA, oA.obj) + ObjectBackend.bli_obj_set_onlytrans!(bli_tA, oA.obj) + ObjectBackend.bli_obj_set_struc!(BLIS_TRIANGULAR, oA.obj) $blifunc(oα, oA, ob) b @@ -174,11 +155,8 @@ end @doc """ trmv!(ul, tA, dA, A, b) - BLIS-based TRMV with strides support & mixed-precision. + BLIS-based TRMV with strides support. """ -@blis_interface_linalg_lv2_trmv(BliCompatibleType, - BliCompatibleType, - trmv!, trmv!) @blis_interface_linalg_lv2_trmv Float32 Float32 trmv! trmv! @blis_interface_linalg_lv2_trmv Float64 Float64 trmv! trmv! @blis_interface_linalg_lv2_trmv ComplexF32 ComplexF32 trmv! trmv! @@ -186,11 +164,8 @@ end @doc """ trsv!(ul, tA, dA, A, b) - BLIS-based TRSV with strides support & mixed-precision. + BLIS-based TRSV with strides support. """ -@blis_interface_linalg_lv2_trmv(BliCompatibleType, - BliCompatibleType, - trsv!, trsv!) @blis_interface_linalg_lv2_trmv Float32 Float32 trsv! trsv! @blis_interface_linalg_lv2_trmv Float64 Float64 trsv! trsv! @blis_interface_linalg_lv2_trmv ComplexF32 ComplexF32 trsv! trsv! @@ -228,13 +203,8 @@ end @doc """ ger!(α, x, y, A) - BLIS-based GER with strides support & mixed-precision. + BLIS-based GER with strides support. """ ger! -@blis_interface_linalg_lv2_ger(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - ger!, ger!) @blis_interface_linalg_lv2_ger Float32 Float32 Float32 Float32 ger! ger! @blis_interface_linalg_lv2_ger Float64 Float64 Float64 Float64 ger! ger! @blis_interface_linalg_lv2_ger ComplexF32 ComplexF32 ComplexF32 ComplexF32 ger! ger! @@ -264,7 +234,7 @@ macro blis_interface_linalg_lv2_her(Tc1, T1, T2, targetfunc, bliname, bli_struc) ox = BliObj(x) oA = BliObj(A) - ObjectBackend.bli_obj_set_uplo!(bli_ul, oA) + ObjectBackend.bli_obj_set_uplo!(bli_ul, oA.obj) ObjectBackend.bli_obj_set_struc!($bli_struc, oA.obj) $blifunc(oα, ox, oA) A @@ -275,13 +245,8 @@ end @doc """ her!(uplo, α, x, A) - BLIS-based HER with strides support & mixed-precision. + BLIS-based HER with strides support. """ her! -@blis_interface_linalg_lv2_her(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - her!, her!, - BLIS_HERMITIAN) @blis_interface_linalg_lv2_her Float32 Float32 Float32 her! her! BLIS_HERMITIAN @blis_interface_linalg_lv2_her Float64 Float64 Float64 her! her! BLIS_HERMITIAN @blis_interface_linalg_lv2_her Float32 ComplexF32 ComplexF32 her! her! BLIS_HERMITIAN @@ -289,13 +254,8 @@ end @doc """ syr!(uplo, α, x, A) - BLIS-based SYR with strides support & mixed-precision. + BLIS-based SYR with strides support. """ syr! -@blis_interface_linalg_lv2_her(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - syr!, syr!, - BLIS_SYMMETRIC) @blis_interface_linalg_lv2_her Float32 Float32 Float32 syr! syr! BLIS_SYMMETRIC @blis_interface_linalg_lv2_her Float64 Float64 Float64 syr! syr! BLIS_SYMMETRIC @blis_interface_linalg_lv2_her ComplexF32 ComplexF32 ComplexF32 syr! syr! BLIS_SYMMETRIC diff --git a/src/interface_linalg/level3.jl b/src/interface_linalg/level3.jl index 4668b25..a7460b7 100644 --- a/src/interface_linalg/level3.jl +++ b/src/interface_linalg/level3.jl @@ -165,16 +165,9 @@ end if "hemm" ∉ blacklist @doc """ hemm!(side, uplo, α, A, B, β, C) - BLIS-based HEMM with generic strides & mixed precision directly supported. + BLIS-based HEMM with generic strides directly supported. `A` expresses a `Hermitian` matrix with its `uplo` triangle. """ hemm! -@blis_interface_linalg_lv3_hemm(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - hemm!, hemm!, - BLIS_HERMITIAN) @blis_interface_linalg_lv3_hemm Float32 Float32 Float32 Float32 Float32 hemm! hemm! BLIS_HERMITIAN @blis_interface_linalg_lv3_hemm Float64 Float64 Float64 Float64 Float64 hemm! hemm! BLIS_HERMITIAN @blis_interface_linalg_lv3_hemm ComplexF32 ComplexF32 ComplexF32 ComplexF32 ComplexF32 hemm! hemm! BLIS_HERMITIAN @@ -184,16 +177,9 @@ end if "symm" ∉ blacklist @doc """ symm!(side, uplo, α, A, B, β, C) - BLIS-based SYMM with generic strides & mixed precision directly supported. + BLIS-based SYMM with generic strides directly supported. `A` expresses a `Symmetric` matrix with its `uplo` triangle. """ symm! -@blis_interface_linalg_lv3_hemm(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - symm!, symm!, - BLIS_SYMMETRIC) @blis_interface_linalg_lv3_hemm Float32 Float32 Float32 Float32 Float32 symm! symm! BLIS_SYMMETRIC @blis_interface_linalg_lv3_hemm Float64 Float64 Float64 Float64 Float64 symm! symm! BLIS_SYMMETRIC @blis_interface_linalg_lv3_hemm ComplexF32 ComplexF32 ComplexF32 ComplexF32 ComplexF32 symm! symm! BLIS_SYMMETRIC @@ -244,19 +230,12 @@ end if "her2k" ∉ blacklist @doc """ her2k!(uplo, tAB, α, A, B, β, C) - BLIS-based HER2K with generic strides & mixed precision directly supported. + BLIS-based HER2K with generic strides directly supported. Performs rank-2k update on `Hermitian` matrix `C` (expressed by `uplo`-triangle): ```math C = β C + (α A B^† + \\bar α B A^†)^{tAB} ``` """ her2k! -@blis_interface_linalg_lv3_her2k(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - her2k!, her2k!, - BLIS_HERMITIAN) @blis_interface_linalg_lv3_her2k Float32 Float32 Float32 Float32 Float32 her2k! her2k! BLIS_HERMITIAN @blis_interface_linalg_lv3_her2k Float64 Float64 Float64 Float64 Float64 her2k! her2k! BLIS_HERMITIAN @blis_interface_linalg_lv3_her2k ComplexF32 ComplexF32 ComplexF32 Float32 ComplexF32 her2k! her2k! BLIS_HERMITIAN @@ -266,19 +245,12 @@ end if "syr2k" ∉ blacklist @doc """ syr2k!(uplo, tAB, α, A, B, β, C) - BLIS-based SYR2K with generic strides & mixed precision directly supported. + BLIS-based SYR2K with generic strides directly supported. Performs rank-2k update on `Symmetric` matrix `C` (expressed by `uplo`-triangle): ```math C = β C + (α A B^T + \\bar α B A^T)^{tAB} ``` """ syr2k! -@blis_interface_linalg_lv3_her2k(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - syr2k!, syr2k!, - BLIS_SYMMETRIC) @blis_interface_linalg_lv3_her2k Float32 Float32 Float32 Float32 Float32 syr2k! syr2k! BLIS_SYMMETRIC @blis_interface_linalg_lv3_her2k Float64 Float64 Float64 Float64 Float64 syr2k! syr2k! BLIS_SYMMETRIC @blis_interface_linalg_lv3_her2k ComplexF32 ComplexF32 ComplexF32 ComplexF32 ComplexF32 syr2k! syr2k! BLIS_SYMMETRIC @@ -326,15 +298,9 @@ end if "herk" ∉ blacklist @doc """ herk!(uplo, tA, α, A, β, C) - BLIS-based HERK with generic strides & mixed precision directly supported. + BLIS-based HERK with generic strides directly supported. Performs rank-k update on `Hermitian` matrix `C` (expressed by `uplo`-triangle). """ herk! -@blis_interface_linalg_lv3_herk(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - herk!, herk!, - BLIS_HERMITIAN) @blis_interface_linalg_lv3_herk Float32 Float32 Float32 Float32 herk! herk! BLIS_HERMITIAN @blis_interface_linalg_lv3_herk Float64 Float64 Float64 Float64 herk! herk! BLIS_HERMITIAN @blis_interface_linalg_lv3_herk Float32 ComplexF32 Float32 ComplexF32 herk! herk! BLIS_HERMITIAN @@ -344,15 +310,9 @@ end if "syrk" ∉ blacklist @doc """ syrk!(uplo, tA, α, A, β, C) - BLIS-based SYRK with generic strides & mixed precision directly supported. + BLIS-based SYRK with generic strides directly supported. Performs rank-k update on `Symmetric` matrix `C` (expressed by `uplo`-triangle). """ syrk! -@blis_interface_linalg_lv3_herk(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - syrk!, syrk!, - BLIS_SYMMETRIC) @blis_interface_linalg_lv3_herk Float32 Float32 Float32 Float32 syrk! syrk! BLIS_SYMMETRIC @blis_interface_linalg_lv3_herk Float64 Float64 Float64 Float64 syrk! syrk! BLIS_SYMMETRIC @blis_interface_linalg_lv3_herk ComplexF32 ComplexF32 ComplexF32 ComplexF32 syrk! syrk! BLIS_SYMMETRIC @@ -410,12 +370,8 @@ end if "trmm" ∉ blacklist @doc """ trmm!(side, uplo, tA, dA, α, A, B) - BLIS-based TRMM with generic strides & mixed precision directly supported. + BLIS-based TRMM with generic strides directly supported. """ trmm! -@blis_interface_linalg_lv3_trmm(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - trmm!, trmm!) @blis_interface_linalg_lv3_trmm Float32 Float32 Float32 trmm! trmm! @blis_interface_linalg_lv3_trmm Float64 Float64 Float64 trmm! trmm! @blis_interface_linalg_lv3_trmm ComplexF32 ComplexF32 ComplexF32 trmm! trmm! @@ -425,12 +381,8 @@ end if "trsm" ∉ blacklist @doc """ trsm!(side, uplo, tA, dA, α, A, B) - BLIS-based TRSM with generic strides & mixed precision directly supported. + BLIS-based TRSM with generic strides directly supported. """ trsm! -@blis_interface_linalg_lv3_trmm(BliCompatibleType, - BliCompatibleType, - BliCompatibleType, - trsm!, trsm!) @blis_interface_linalg_lv3_trmm Float32 Float32 Float32 trsm! trsm! @blis_interface_linalg_lv3_trmm Float64 Float64 Float64 trsm! trsm! @blis_interface_linalg_lv3_trmm ComplexF32 ComplexF32 ComplexF32 trsm! trsm! diff --git a/src/switch_blas.jl b/src/switch_blas.jl deleted file mode 100644 index e69de29..0000000 diff --git a/test/init_test_mmul.jl b/test/init_test_mmul.jl deleted file mode 100644 index 0614369..0000000 --- a/test/init_test_mmul.jl +++ /dev/null @@ -1,269 +0,0 @@ -# Simple test on matrix multiplication. -# -# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\vvvvvv/!!!!!!!!!!!!!!!!!!!! -# !CAUTION: THIS TEST MUST OCCUR >BEFORE< BLIS WAS IMPORTED.! -# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!/^^^^^^\!!!!!!!!!!!!!!!!!!!! -# -using Test -using Random -using LinearAlgebra -using LinearAlgebra: BLAS -using DelimitedFiles -using Statistics -Random.seed!(1234) -rtype(::Type{Complex{T}}) where {T} = T -zrtest(val, atol, label) = begin - iszr = ≈(val, 0.0, atol=atol) - if !iszr - @info "`$label` test failed. Consider adding it to ~/.blis_jlbla_blacklist." - end - return iszr -end - -@testset "BLAS level-3 LinearAlgebra interface" begin -αr = 1.1 -βr = 1.1 -αc = 1.1 + 0.2im -βc = 1.1 + 0.3im - -χlarge = 500 -χsmall = 20 - -Alarge_base = rand(ComplexF64, χlarge, χlarge) -Blarge_base = rand(ComplexF64, χlarge, χlarge) -Clarge_base = rand(ComplexF64, χlarge, χlarge) - -Asmall_base = rand(ComplexF64, χsmall, χsmall) -Bsmall_base = rand(ComplexF64, χsmall, χsmall) -Csmall_base = rand(ComplexF64, χsmall, χsmall) - -global Clarge_gemm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Clarge_hemm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Clarge_symm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Clarge_her2k= [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Clarge_syr2k= [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Clarge_herk = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Clarge_syrk = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Clarge_trmm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] -# global Clarge_trsm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] TRSM is unstable on random A. -global Csmall_gemm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_hemm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_symm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_her2k= [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_syr2k= [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_herk = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_syrk = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_trmm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Csmall_trsm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)] - -global Cst_lg_gemm = [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Cst_lg_hemm = [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Cst_lg_symm = [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -# global Cst_lg_her2k= [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -# global Cst_lg_syr2k= [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -# global Cst_lg_herk = [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -# global Cst_lg_syrk = [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Cst_sm_gemm = [zeros(T, χsmall÷2, χsmall÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Cst_sm_hemm = [zeros(T, χsmall÷2, χsmall÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] -global Cst_sm_symm = [zeros(T, χsmall÷2, χsmall÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)] - -for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64]) - Alarge = zeros(T, χlarge, χlarge) - Blarge = zeros(T, χlarge, χlarge) - Asmall = zeros(T, χsmall, χsmall) - Bsmall = zeros(T, χsmall, χsmall) - - local elconv, αu, βu, αR, βR - local locl_hemm!, locl_her2k!, locl_herk! - if eltype(Alarge)<:Complex - elconv = x -> x - αu = T(αc) - βu = T(βc) - αR = rtype(T)(αr) - βR = rtype(T)(βr) - locl_hemm! = BLAS.hemm! - locl_herk! = BLAS.herk! - locl_her2k! = BLAS.her2k! - else - elconv = x -> real(x) - αu = T(αr) - βu = T(βr) - αR = αu - βR = βu - locl_hemm! = BLAS.symm! - locl_herk! = BLAS.syrk! - locl_her2k! = BLAS.syr2k! - end - Alarge .= elconv.(Alarge_base) - Blarge .= elconv.(Blarge_base) - Asmall .= elconv.(Asmall_base) - Bsmall .= elconv.(Bsmall_base) - Clarge_gemm[i] .= elconv.(Clarge_base) - Clarge_hemm[i] .= elconv.(Clarge_base) - Clarge_symm[i] .= elconv.(Clarge_base) - Clarge_her2k[i] .= elconv.(Clarge_base) - Clarge_syr2k[i] .= elconv.(Clarge_base) - Clarge_herk[i] .= elconv.(Clarge_base) - Clarge_syrk[i] .= elconv.(Clarge_base) - Clarge_trmm[i] .= elconv.(Clarge_base) - Csmall_gemm[i] .= elconv.(Csmall_base) - Csmall_hemm[i] .= elconv.(Csmall_base) - Csmall_symm[i] .= elconv.(Csmall_base) - Csmall_her2k[i] .= elconv.(Csmall_base) - Csmall_syr2k[i] .= elconv.(Csmall_base) - Csmall_herk[i] .= elconv.(Csmall_base) - Csmall_syrk[i] .= elconv.(Csmall_base) - Csmall_trmm[i] .= elconv.(Csmall_base) - Csmall_trsm[i] .= elconv.(Csmall_base) - - # Strided. - Ast_lg = view(Alarge, 1:2:χlarge, 1:2:χlarge) - Bst_lg = view(Blarge, 1:2:χlarge, 1:2:χlarge) - Ast_sm = view(Asmall, 1:2:χsmall, 1:2:χsmall) - Bst_sm = view(Bsmall, 1:2:χsmall, 1:2:χsmall) - - # Execute: column-major. - BLAS.gemm!('N', 'N', αu, Alarge, Blarge, βu, Clarge_gemm[i]) - BLAS.gemm!('N', 'N', αu, Asmall, Bsmall, βu, Csmall_gemm[i]) - locl_hemm!('L', 'U', αu, Alarge, Blarge, βu, Clarge_hemm[i]) - locl_hemm!('R', 'U', αu, Asmall, Bsmall, βu, Csmall_hemm[i]) - BLAS.symm!('L', 'L', αu, Alarge, Blarge, βu, Clarge_symm[i]) - BLAS.symm!('L', 'L', αu, Asmall, Bsmall, βu, Csmall_symm[i]) - locl_her2k!('U', 'N', αu, Alarge, Blarge, βR, Clarge_her2k[i]) - locl_her2k!('U', 'N', αu, Asmall, Bsmall, βR, Csmall_her2k[i]) - BLAS.syr2k!('U', 'N', αu, Alarge, Blarge, βu, Clarge_syr2k[i]) - BLAS.syr2k!('U', 'N', αu, Asmall, Bsmall, βu, Csmall_syr2k[i]) - locl_herk!('U', 'N', αR, Alarge, βR, Clarge_herk[i]) - locl_herk!('U', 'N', αR, Asmall, βR, Csmall_herk[i]) - BLAS.syrk!('U', 'N', αu, Alarge, βu, Clarge_syrk[i]) - BLAS.syrk!('U', 'N', αu, Asmall, βu, Csmall_syrk[i]) - BLAS.trmm!('L', 'U', 'N', 'N', αu, Alarge, Clarge_trmm[i]) - BLAS.trmm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trmm[i]) - BLAS.trsm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trsm[i]) - - # Execute: generic-strided. - Cst_lg_gemm[i] .= Ast_lg * Bst_lg - Cst_sm_gemm[i] .= Ast_sm * Bst_sm - Cst_lg_hemm[i] .= Hermitian(Array(Ast_lg)) * Array(Bst_lg) - Cst_sm_hemm[i] .= Hermitian(Array(Ast_sm)) * Array(Bst_sm) - Cst_lg_symm[i] .= Symmetric(Array(Ast_lg)) * Array(Bst_lg) - Cst_sm_symm[i] .= Symmetric(Array(Ast_sm)) * Array(Bst_sm) -end - -# Import BLIS for testing. -using BLIS -if length(BLIS.BLASInterface.blacklist) > 0 - @info "Blacklisted methods: $(BLIS.BLASInterface.blacklist)." -end - -for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64]) - Alarge = zeros(T, χlarge, χlarge) - Blarge = zeros(T, χlarge, χlarge) - Asmall = zeros(T, χsmall, χsmall) - Bsmall = zeros(T, χsmall, χsmall) - - local elconv, αu, βu, αR, βR - local locl_hemm!, locl_her2k!, locl_herk! - if eltype(Alarge)<:Complex - elconv = x -> x - αu = T(αc) - βu = T(βc) - αR = rtype(T)(αr) - βR = rtype(T)(βr) - locl_hemm! = BLAS.hemm! - locl_herk! = BLAS.herk! - locl_her2k! = BLAS.her2k! - else - elconv = x -> real(x) - αu = T(αr) - βu = T(βr) - αR = αu - βR = βu - locl_hemm! = BLAS.symm! - locl_herk! = BLAS.syrk! - locl_her2k! = BLAS.syr2k! - end - Alarge .= elconv.(Alarge_base) - Blarge .= elconv.(Blarge_base) - Asmall .= elconv.(Asmall_base) - Bsmall .= elconv.(Bsmall_base) - Clarge_gemm_cur = T.(elconv.(Clarge_base)) - Clarge_hemm_cur = T.(elconv.(Clarge_base)) - Clarge_symm_cur = T.(elconv.(Clarge_base)) - Clarge_her2k_cur = T.(elconv.(Clarge_base)) - Clarge_syr2k_cur = T.(elconv.(Clarge_base)) - Clarge_herk_cur = T.(elconv.(Clarge_base)) - Clarge_syrk_cur = T.(elconv.(Clarge_base)) - Clarge_trmm_cur = T.(elconv.(Clarge_base)) - Csmall_gemm_cur = T.(elconv.(Csmall_base)) - Csmall_hemm_cur = T.(elconv.(Csmall_base)) - Csmall_symm_cur = T.(elconv.(Csmall_base)) - Csmall_her2k_cur = T.(elconv.(Csmall_base)) - Csmall_syr2k_cur = T.(elconv.(Csmall_base)) - Csmall_herk_cur = T.(elconv.(Csmall_base)) - Csmall_syrk_cur = T.(elconv.(Csmall_base)) - Csmall_trmm_cur = T.(elconv.(Csmall_base)) - Csmall_trsm_cur = T.(elconv.(Csmall_base)) - - # Strided. - Ast_lg = view(Alarge, 1:2:χlarge, 1:2:χlarge) - Bst_lg = view(Blarge, 1:2:χlarge, 1:2:χlarge) - Ast_sm = view(Asmall, 1:2:χsmall, 1:2:χsmall) - Bst_sm = view(Bsmall, 1:2:χsmall, 1:2:χsmall) - - # Execute: column-major. - BLAS.gemm!('N', 'N', αu, Alarge, Blarge, βu, Clarge_gemm_cur) - BLAS.gemm!('N', 'N', αu, Asmall, Bsmall, βu, Csmall_gemm_cur) - locl_hemm!('L', 'U', αu, Alarge, Blarge, βu, Clarge_hemm_cur) - locl_hemm!('R', 'U', αu, Asmall, Bsmall, βu, Csmall_hemm_cur) - BLAS.symm!('L', 'L', αu, Alarge, Blarge, βu, Clarge_symm_cur) - BLAS.symm!('L', 'L', αu, Asmall, Bsmall, βu, Csmall_symm_cur) - locl_her2k!('U', 'N', αu, Alarge, Blarge, βR, Clarge_her2k_cur) - locl_her2k!('U', 'N', αu, Asmall, Bsmall, βR, Csmall_her2k_cur) - BLAS.syr2k!('U', 'N', αu, Alarge, Blarge, βu, Clarge_syr2k_cur) - BLAS.syr2k!('U', 'N', αu, Asmall, Bsmall, βu, Csmall_syr2k_cur) - locl_herk!('U', 'N', αR, Alarge, βR, Clarge_herk_cur) - locl_herk!('U', 'N', αR, Asmall, βR, Csmall_herk_cur) - BLAS.syrk!('U', 'N', αu, Alarge, βu, Clarge_syrk_cur) - BLAS.syrk!('U', 'N', αu, Asmall, βu, Csmall_syrk_cur) - BLAS.trmm!('L', 'U', 'N', 'N', αu, Alarge, Clarge_trmm_cur) - BLAS.trmm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trmm_cur) - BLAS.trsm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trsm_cur) - - # Execute: generic-strided. - Cst_lg_gemm_cur = Ast_lg * Bst_lg - Cst_sm_gemm_cur = Ast_sm * Bst_sm - Cst_lg_hemm_cur = Hermitian(Ast_lg) * Bst_lg - Cst_sm_hemm_cur = Hermitian(Ast_sm) * Bst_sm - Cst_lg_symm_cur = Symmetric(Ast_lg) * Bst_lg - Cst_sm_symm_cur = Symmetric(Ast_sm) * Bst_sm - - # Check. - @test zrtest(mean(abs.(Clarge_gemm_cur - Clarge_gemm[i] )), 1e-6*χlarge^1.2, "500_gemm_$T") - @test zrtest(mean(abs.(Clarge_hemm_cur - Clarge_hemm[i] )), 1e-6*χlarge^1.2, "500_hemm_$T") - @test zrtest(mean(abs.(Clarge_symm_cur - Clarge_symm[i] )), 1e-6*χlarge^1.2, "500_symm_$T") - @test zrtest(mean(abs.(Clarge_her2k_cur - Clarge_her2k[i])), 1e-6*χlarge^1.2, "500_her2k_$T") - @test zrtest(mean(abs.(Clarge_syr2k_cur - Clarge_syr2k[i])), 1e-6*χlarge^1.2, "500_syr2k_$T") - @test zrtest(mean(abs.(Clarge_herk_cur - Clarge_herk[i] )), 1e-6*χlarge^1.2, "500_herk_$T") - @test zrtest(mean(abs.(Clarge_syrk_cur - Clarge_syrk[i] )), 1e-6*χlarge^1.2, "500_syrk_$T") - @test zrtest(mean(abs.(Clarge_trmm_cur - Clarge_trmm[i] )), 1e-6*χlarge^1.2, "500_trmm_$T") - @test zrtest(mean(abs.(Csmall_gemm_cur - Csmall_gemm[i] )), 1e-6*χsmall^1.2, "20_gemm_$T") - @test zrtest(mean(abs.(Csmall_hemm_cur - Csmall_hemm[i] )), 1e-6*χsmall^1.2, "20_hemm_$T") - @test zrtest(mean(abs.(Csmall_symm_cur - Csmall_symm[i] )), 1e-6*χsmall^1.2, "20_symm_$T") - @test zrtest(mean(abs.(Csmall_her2k_cur - Csmall_her2k[i])), 1e-6*χsmall^1.2, "20_her2k_$T") - @test zrtest(mean(abs.(Csmall_syr2k_cur - Csmall_syr2k[i])), 1e-6*χsmall^1.2, "20_syr2k_$T") - @test zrtest(mean(abs.(Csmall_herk_cur - Csmall_herk[i] )), 1e-6*χsmall^1.2, "20_herk_$T") - @test zrtest(mean(abs.(Csmall_syrk_cur - Csmall_syrk[i] )), 1e-6*χsmall^1.2, "20_syrk_$T") - @test zrtest(mean(abs.(Csmall_trmm_cur - Csmall_trmm[i] )), 1e-6*χsmall^1.2, "20_trmm_$T") - @test zrtest(mean(abs.(Csmall_trsm_cur - Csmall_trsm[i] )), 1e-2*χsmall^1.2, "20_trsm_$T") # Large TRSM err on random A. - - # Check - strided. - @test zrtest(mean(abs.(Cst_lg_gemm_cur - Cst_lg_gemm[i])), 1e-6*χlarge^1.2, "250_rs2_gemm_$T") - @test zrtest(mean(abs.(Cst_sm_gemm_cur - Cst_sm_gemm[i])), 1e-6*χsmall^1.2, "250_rs2_gemm_$T") - @test zrtest(mean(abs.(Cst_lg_hemm_cur - Cst_lg_hemm[i])), 1e-6*χlarge^1.2, "250_rs2_hemm_$T") - @test zrtest(mean(abs.(Cst_sm_hemm_cur - Cst_sm_hemm[i])), 1e-6*χsmall^1.2, "250_rs2_hemm_$T") - @test zrtest(mean(abs.(Cst_lg_symm_cur - Cst_lg_symm[i])), 1e-6*χlarge^1.2, "250_rs2_symm_$T") - @test zrtest(mean(abs.(Cst_sm_symm_cur - Cst_sm_symm[i])), 1e-6*χsmall^1.2, "250_rs2_symm_$T") -end -end - diff --git a/test/runtests.jl b/test/runtests.jl index e44812c..8697c77 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,37 @@ -# Enter test. -# -# Make sure this file occurs first BEFORE LOADING BLIS. -include("init_test_mmul.jl") +using Test +using Random +using LinearAlgebra +using LinearAlgebra: BLAS +using InteractiveUtils: @which +using DelimitedFiles +using Statistics +using BLIS + +Random.seed!(1234) + +rtype(::Type{Complex{T}}) where {T} = T + +zrtest(val, atol, label) = begin + iszr = ≈(val, 0.0, atol=atol) + if !iszr + @info "`$label` test failed (err=$val). Consider adding it to ~/.blis_jlbla_blacklist." + end + return iszr +end + +"Run & evict this method." +macro run_evict(func, io, largs) + return quote + # Invokation. + $func($(esc(largs))..., $(esc(io))) + + # Evict. + method_blis = @which $func($(esc(largs))..., $(esc(io))) + Base.delete_method(method_blis) + end +end + +include("test_level1.jl") +include("test_level2.jl") +include("test_level3.jl") diff --git a/test/test_level1.jl b/test/test_level1.jl new file mode 100644 index 0000000..3d9946d --- /dev/null +++ b/test/test_level1.jl @@ -0,0 +1,49 @@ +# Level-1 BLAS: Test & evict. +# + +macro l1_test_evict(func, χ, type, ctype, largs) + return quote + local χ = $χ + local αr = 1.1 + local βr = 1.1 + local αc = 1.1 + 0.2im + local βc = 1.1 + 0.3im + αr = $type(αr) + βr = $type(βr) + if $type <: Complex + αc = $type(αc) + end + if $type <: Complex + βc = $type(βc) + end + + local x = rand($type, χ) + local y = rand($type, χ) + local y_= copy(y) + + @run_evict $func y $(largs) + + x = $type.(x) + + # Check. + $func($(largs)..., y_) + abs.(y - y_) + end +end + + +@testset "BLAS level-1 LinearAlgebra interface" begin + + @test zrtest(reduce(max, @l1_test_evict BLAS.axpy! 100 Float32 Float64 (αr, x) ), 1e-6, "xaxpy_100") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpy! 100 Float32 Float32 (αr, x) ), 1e-6, "saxpy_100") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpy! 30 Float64 Float64 (αr, x) ), 1e-6, "daxpy_30") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpy! 1100 ComplexF32 ComplexF32 (αc, x) ), 4e-3, "caxpy_1100") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpy! 300 ComplexF64 ComplexF64 (αc, x) ), 4e-7, "zaxpy_300") + + @test zrtest(reduce(max, @l1_test_evict BLAS.axpby! 100 Float32 Float64 (αr, x, βr) ), 1e-6, "xaxpby_100") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpby! 100 Float32 Float32 (αr, x, βr) ), 1e-6, "saxpby_100") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpby! 30 Float64 Float64 (αr, x, βr) ), 1e-6, "daxpby_30") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpby! 1100 ComplexF32 ComplexF32 (αc, x, βc) ), 4e-3, "caxpby_1100") + @test zrtest(reduce(max, @l1_test_evict BLAS.axpby! 300 ComplexF64 ComplexF64 (αc, x, βc) ), 4e-7, "zaxpby_300") + +end diff --git a/test/test_level2.jl b/test/test_level2.jl new file mode 100644 index 0000000..0f101db --- /dev/null +++ b/test/test_level2.jl @@ -0,0 +1,106 @@ +# Level-2 BLAS: Test & evict. +# + +macro l2_test_evict(func, χ, type, largs) + return quote + local χ = $χ + local αr = 1.1 + local βr = 1.1 + local αc = 1.1 + 0.2im + local βc = 1.1 + 0.3im + αr = $type(αr) + βr = $type(βr) + if $type <: Complex + αc = $type(αc) + end + if $type <: Complex + βc = $type(βc) + end + + local A = rand($type, χ, χ) + local x = rand($type, χ) + local y = rand($type, χ) + local y_= copy(y) + + @run_evict $func y $(largs) + + A = $type.(A) + x = $type.(x) + + # Check. + $func($(largs)..., y_) + abs.(y - y_) + end +end + +macro l2r_test_evict(func, χ, αtype, type, largs) + return quote + local χ = $χ + local αr = 1.1 + local βr = 1.1 + local αc = 1.1 + 0.2im + local βc = 1.1 + 0.3im + αr = $αtype(αr) + βr = $type(βr) + if $type <: Complex + αc = $type(αc) + end + if $type <: Complex + βc = $type(βc) + end + + local x = rand($type, χ) + local y = rand($type, χ) + local C = rand($type, χ, χ) + local C_= copy(C) + + @run_evict $func C $(largs) + + x = $type.(x) + y = $type.(y) + + # Check. + $func($(largs)..., C_) + abs.(C - C_) + end +end + +@testset "BLAS level-2 LinearAlgebra interface" begin + + @test zrtest(reduce(max, @l2_test_evict BLAS.gemv! 100 Float32 ('N', αr, A, x, βr) ), 1e-4, "sgemv_n_100") + @test zrtest(reduce(max, @l2_test_evict BLAS.gemv! 30 Float64 ('T', αr, A, x, βr) ), 1e-6, "dgemv_t_30") + @test zrtest(reduce(max, @l2_test_evict BLAS.gemv! 1100 ComplexF32 ('N', αc, A, x, βc) ), 4e-3, "cgemv_n_1100") + @test zrtest(reduce(max, @l2_test_evict BLAS.gemv! 300 ComplexF64 ('N', αc, A, x, βc) ), 4e-7, "zgemv_n_300") + + @test zrtest(reduce(max, @l2_test_evict BLAS.hemv! 1100 ComplexF32 ('U', αc, A, x, βc) ), 4e-3, "chemv_u_1100") + @test zrtest(reduce(max, @l2_test_evict BLAS.hemv! 300 ComplexF64 ('L', αc, A, x, βc) ), 4e-7, "zhemv_l_300") + + @test zrtest(reduce(max, @l2_test_evict BLAS.symv! 32 Float32 ('U', αr, A, x, βr) ), 1e-4, "ssymv_u_32") + @test zrtest(reduce(max, @l2_test_evict BLAS.symv! 2000 Float64 ('U', αr, A, x, βr) ), 1e-3, "dsymv_u_2000") + @test zrtest(reduce(max, @l2_test_evict BLAS.symv! 24 ComplexF32 ('L', αc, A, x, βc) ), 4e-6, "csymv_l_24") + @test zrtest(reduce(max, @l2_test_evict BLAS.symv! 1200 ComplexF64 ('L', αc, A, x, βc) ), 4e-7, "zsymv_l_1200") + + @test zrtest(reduce(max, @l2_test_evict BLAS.trmv! 32 Float32 ('U', 'N', 'U', A) ), 1e-4, "strmv_u_32") + @test zrtest(reduce(max, @l2_test_evict BLAS.trmv! 2000 Float64 ('U', 'T', 'N', A) ), 1e-3, "dtrmv_u_2000") + @test zrtest(reduce(max, @l2_test_evict BLAS.trmv! 24 ComplexF32 ('L', 'N', 'U', A) ), 4e-3, "ctrmv_l_24") + @test zrtest(reduce(max, @l2_test_evict BLAS.trmv! 1200 ComplexF64 ('L', 'T', 'N', A) ), 4e-7, "ztrmv_l_1200") + + @test zrtest(reduce(max, @l2_test_evict BLAS.trsv! 100 Float32 ('U', 'N', 'U', A) ), 1e-2, "strsv_u_100") + @test zrtest(reduce(max, @l2_test_evict BLAS.trsv! 30 Float64 ('U', 'T', 'N', A) ), 1e-4, "dtrsv_u_30") + @test zrtest(reduce(max, @l2_test_evict BLAS.trsv! 100 ComplexF32 ('L', 'N', 'U', A) ), 1e-1, "ctrsv_l_100") + @test zrtest(reduce(max, @l2_test_evict BLAS.trsv! 100 ComplexF64 ('L', 'T', 'N', A) ), 1e-2, "ztrsv_l_100") + + @test zrtest(reduce(max, @l2r_test_evict BLAS.ger! 100 Float32 Float32 (αr, x, y) ), 1e-4, "sger_100") + @test zrtest(reduce(max, @l2r_test_evict BLAS.ger! 30 Float64 Float64 (αr, x, y) ), 1e-6, "dger_30") + # @test zrtest(reduce(max, @l2r_test_evict BLAS.ger! 1100 ComplexF32 ComplexF32 (αc, x, y) ), 4e-3, "cger_1100") + # @test zrtest(reduce(max, @l2r_test_evict BLAS.ger! 300 ComplexF64 ComplexF64 (αc, x, y) ), 4e-7, "zger_300") + + @test zrtest(reduce(max, @l2r_test_evict BLAS.her! 200 Float32 ComplexF32 ('U', αr, x) ), 4e-4, "csyr_u_200") + @test zrtest(reduce(max, @l2r_test_evict BLAS.her! 800 Float64 ComplexF64 ('L', αr, x) ), 4e-5, "zsyr_l_800") + + @test zrtest(reduce(max, @l2r_test_evict BLAS.syr! 100 Float32 Float32 ('U', αr, x) ), 1e-4, "ssyr_u_100") + @test zrtest(reduce(max, @l2r_test_evict BLAS.syr! 30 Float64 Float64 ('L', αr, x) ), 1e-6, "dsyr_l_30") + @test zrtest(reduce(max, @l2r_test_evict BLAS.syr! 1100 ComplexF32 ComplexF32 ('U', αc, x) ), 4e-3, "csyr_u_1100") + @test zrtest(reduce(max, @l2r_test_evict BLAS.syr! 300 ComplexF64 ComplexF64 ('L', αc, x) ), 4e-7, "zsyr_l_300") + +end diff --git a/test/test_level3.jl b/test/test_level3.jl new file mode 100644 index 0000000..61b7a7f --- /dev/null +++ b/test/test_level3.jl @@ -0,0 +1,84 @@ +# Level-3 BLAS: Test & evict. +# + +macro l3_test_evict_(func, χ, αtype, βtype, atype, btype, ctype, largs) + return quote + local χ = $χ + local αr = 1.1 + local βr = 1.1 + local αc = 1.1 + 0.2im + local βc = 1.1 + 0.3im + αr = $αtype(αr) + βr = $βtype(βr) + if $αtype <: Complex + αc = $αtype(αc) + end + if $βtype <: Complex + βc = $βtype(βc) + end + + local A = rand($atype, χ, χ) + local B = rand($btype, χ, χ) + local C = rand($ctype, χ, χ) + local C_= copy(C) + + @run_evict $func C $(largs) + + A = $ctype.(A) + B = $ctype.(B) + + # Check. + $func($(largs)..., C_) + abs.(C - C_) + end +end + +macro l3_test_evict(func, χ, atype, btype, ctype, largs) + return quote + @l3_test_evict_ $func $χ $ctype $ctype $atype $btype $ctype $largs + end +end + +@testset "BLAS level-3 LinearAlgebra interface" begin + + @test zrtest(reduce(max, @l3_test_evict BLAS.gemm! 100 Float32 ComplexF32 ComplexF64 ('N', 'N', αr, A, B, βr) ), 1e-6, "xgemm_nn_100") + @test zrtest(reduce(max, @l3_test_evict BLAS.gemm! 20 Float32 Float32 Float32 ('N', 'N', αr, A, B, βr) ), 1e-3, "sgemm_nn_20") + @test zrtest(reduce(max, @l3_test_evict BLAS.gemm! 1200 Float64 Float64 Float64 ('T', 'N', αr, A, B, βr) ), 1e-6, "dgemm_tn_1200") + @test zrtest(reduce(max, @l3_test_evict BLAS.gemm! 240 ComplexF32 ComplexF32 ComplexF32 ('T', 'T', αc, A, B, βc) ), 1e-3, "cgemm_tt_240") + @test zrtest(reduce(max, @l3_test_evict BLAS.gemm! 1200 ComplexF64 ComplexF64 ComplexF64 ('N', 'T', αc, A, B, βc) ), 1e-6, "zgemm_nt_1200") + + @test zrtest(reduce(max, @l3_test_evict BLAS.hemm! 240 ComplexF32 ComplexF32 ComplexF32 ('R', 'L', αc, A, B, βc) ), 1e-3, "chemm_rl_240") + @test zrtest(reduce(max, @l3_test_evict BLAS.hemm! 1200 ComplexF64 ComplexF64 ComplexF64 ('R', 'L', αc, A, B, βc) ), 1e-6, "zhemm_rl_1200") + + @test zrtest(reduce(max, @l3_test_evict BLAS.symm! 2000 Float32 Float32 Float32 ('L', 'U', αr, A, B, βr) ), 1e-3, "ssymm_lu_2000") + @test zrtest(reduce(max, @l3_test_evict BLAS.symm! 32 Float64 Float64 Float64 ('R', 'U', αr, A, B, βr) ), 1e-6, "dsymm_ru_32") + @test zrtest(reduce(max, @l3_test_evict BLAS.symm! 2400 ComplexF32 ComplexF32 ComplexF32 ('R', 'L', αc, A, B, βc) ), 1e-2, "csymm_rl_2400") + @test zrtest(reduce(max, @l3_test_evict BLAS.symm! 50 ComplexF64 ComplexF64 ComplexF64 ('R', 'L', αc, A, B, βc) ), 1e-6, "zsymm_rl_50") + + @test zrtest(reduce(max, @l3_test_evict_ BLAS.her2k! 240 ComplexF32 Float32 ComplexF32 ComplexF32 ComplexF32 ('U', 'N', αr, A, B, βr) ), 1e-3, "cher2k_un_240") + @test zrtest(reduce(max, @l3_test_evict_ BLAS.her2k! 1200 ComplexF64 Float64 ComplexF64 ComplexF64 ComplexF64 ('L', 'N', αr, A, B, βr) ), 1e-6, "zher2k_lt_1200") + + @test zrtest(reduce(max, @l3_test_evict BLAS.syr2k! 2000 Float32 Float32 Float32 ('U', 'N', αr, A, B, βr) ), 1e-3, "ssyr2k_un_2000") + @test zrtest(reduce(max, @l3_test_evict BLAS.syr2k! 32 Float64 Float64 Float64 ('U', 'T', αr, A, B, βr) ), 1e-6, "dsyr2k_ut_32") + @test zrtest(reduce(max, @l3_test_evict BLAS.syr2k! 2400 ComplexF32 ComplexF32 ComplexF32 ('L', 'N', αc, A, B, βc) ), 1e-2, "csyr2k_ln_2400") + @test zrtest(reduce(max, @l3_test_evict BLAS.syr2k! 50 ComplexF64 ComplexF64 ComplexF64 ('L', 'T', αc, A, B, βc) ), 1e-6, "zsyr2k_lt_50") + + @test zrtest(reduce(max, @l3_test_evict_ BLAS.herk! 240 Float32 Float32 ComplexF32 ComplexF32 ComplexF32 ('U', 'N', αr, A, βr) ), 1e-3, "cherk_un_240") + @test zrtest(reduce(max, @l3_test_evict_ BLAS.herk! 1200 Float64 Float64 ComplexF64 ComplexF64 ComplexF64 ('L', 'N', αr, A, βr) ), 1e-6, "zherk_lt_1200") + + @test zrtest(reduce(max, @l3_test_evict BLAS.syrk! 2000 Float32 Float32 Float32 ('U', 'N', αr, A, βr) ), 1e-3, "ssyrk_un_2000") + @test zrtest(reduce(max, @l3_test_evict BLAS.syrk! 32 Float64 Float64 Float64 ('U', 'T', αr, A, βr) ), 1e-6, "dsyrk_ut_32") + @test zrtest(reduce(max, @l3_test_evict BLAS.syrk! 2400 ComplexF32 ComplexF32 ComplexF32 ('L', 'N', αc, A, βc) ), 1e-2, "csyrk_ln_2400") + @test zrtest(reduce(max, @l3_test_evict BLAS.syrk! 50 ComplexF64 ComplexF64 ComplexF64 ('L', 'T', αc, A, βc) ), 1e-6, "zsyrk_lt_50") + + @test zrtest(reduce(max, @l3_test_evict BLAS.trmm! 20 Float32 Float32 Float32 ('L', 'U', 'N', 'U', αr, A) ), 1e-3, "strmm_lunu_20") + @test zrtest(reduce(max, @l3_test_evict BLAS.trmm! 1200 Float64 Float64 Float64 ('L', 'L', 'N', 'U', αr, A) ), 1e-6, "dtrmm_llnu_1200") + @test zrtest(reduce(max, @l3_test_evict BLAS.trmm! 240 ComplexF32 ComplexF32 ComplexF32 ('R', 'U', 'T', 'N', αc, A) ), 1e-3, "ctrmm_rutn_240") + @test zrtest(reduce(max, @l3_test_evict BLAS.trmm! 1200 ComplexF64 ComplexF64 ComplexF64 ('R', 'L', 'T', 'N', αc, A) ), 1e-6, "ztrmm_rltn_1200") + + @test zrtest(reduce(max, @l3_test_evict BLAS.trsm! 100 Float32 Float32 Float32 ('L', 'U', 'N', 'U', αr, A) ), 1e-3, "strsm_un_100") + @test zrtest(reduce(max, @l3_test_evict BLAS.trsm! 100 Float64 Float64 Float64 ('L', 'L', 'N', 'U', αr, A) ), 1e-4, "dtrsm_ut_100") + @test zrtest(reduce(max, @l3_test_evict BLAS.trsm! 30 ComplexF32 ComplexF32 ComplexF32 ('R', 'U', 'T', 'N', αc, A) ), 1e-2, "ctrsm_ln_30") + @test zrtest(reduce(max, @l3_test_evict BLAS.trsm! 30 ComplexF64 ComplexF64 ComplexF64 ('R', 'L', 'T', 'N', αc, A) ), 1e-4, "ztrsm_lt_30") +end +