Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorOperationsTBLIS"
uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865"
authors = ["lkdvos <[email protected]>"]
version = "0.2.0"
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
version = "0.3.0"

[deps]
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand All @@ -11,14 +11,17 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0"

[compat]
Random = "1"
TensorOperations = "5"
TupleTools = "1"
Test = "1"
julia = "1.8"
tblis_jll = "1.2"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LinearAlgebra"]
test = ["Test", "LinearAlgebra", "Random"]
109 changes: 109 additions & 0 deletions src/LibTBLIS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
module LibTBLIS

using tblis_jll
using LinearAlgebra: BlasFloat

export tblis_scalar, tblis_tensor
export tblis_tensor_add, tblis_tensor_mult, tblis_tensor_dot
export tblis_set_num_threads, tblis_get_num_threads

const ptrdiff_t = Cptrdiff_t

const scomplex = ComplexF32
const dcomplex = ComplexF64

const IS_LIBC_MUSL = occursin("musl", Base.BUILD_TRIPLET)
if Sys.isapple() && Sys.ARCH === :aarch64
include("lib/aarch64-apple-darwin20.jl")
elseif Sys.islinux() && Sys.ARCH === :aarch64 && !IS_LIBC_MUSL
include("lib/aarch64-linux-gnu.jl")
elseif Sys.islinux() && Sys.ARCH === :aarch64 && IS_LIBC_MUSL
include("lib/aarch64-linux-musl.jl")
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && !IS_LIBC_MUSL
include("lib/armv7l-linux-gnueabihf.jl")
elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && IS_LIBC_MUSL
include("lib/armv7l-linux-musleabihf.jl")
elseif Sys.islinux() && Sys.ARCH === :i686 && !IS_LIBC_MUSL
include("lib/i686-linux-gnu.jl")
elseif Sys.islinux() && Sys.ARCH === :i686 && IS_LIBC_MUSL
include("lib/i686-linux-musl.jl")
elseif Sys.iswindows() && Sys.ARCH === :i686
include("lib/i686-w64-mingw32.jl")
elseif Sys.islinux() && Sys.ARCH === :powerpc64le
include("lib/powerpc64le-linux-gnu.jl")
elseif Sys.isapple() && Sys.ARCH === :x86_64
include("lib/x86_64-apple-darwin14.jl")
elseif Sys.islinux() && Sys.ARCH === :x86_64 && !IS_LIBC_MUSL
include("lib/x86_64-linux-gnu.jl")
elseif Sys.islinux() && Sys.ARCH === :x86_64 && IS_LIBC_MUSL
include("lib/x86_64-linux-musl.jl")
elseif Sys.isbsd() && !Sys.isapple()
include("lib/x86_64-unknown-freebsd.jl")
elseif Sys.iswindows() && Sys.ARCH === :x86_64
include("lib/x86_64-w64-mingw32.jl")
else
error("Unknown platform: $(Base.BUILD_TRIPLET)")
end

# tblis_scalar and tblis_tensor
# -----------------------------
"""
tblis_scalar(s::Number)

Initializes a tblis scalar from a number.
"""
function tblis_scalar end

"""
tblis_tensor(A::AbstractArray{T<:BlasFloat}, [szA::Vector{Int}, strA::Vector{Int}, scalar::Number])

Initializes a tblis tensor from an array that should be strided and admit a pointer to its
data. This operation is deemed unsafe, in the sense that the user is responsible for ensuring
that the reference to the array and the sizes and strides stays alive during the lifetime of
the tensor.
"""
function tblis_tensor end

for (T, tblis_init_scalar, tblis_init_tensor, tblis_init_tensor_scaled) in
((:Float32, :tblis_init_scalar_s, :tblis_init_tensor_s, :tblis_init_tensor_scaled_s),
(:Float64, :tblis_init_scalar_d, :tblis_init_tensor_d, :tblis_init_tensor_scaled_d),
(:ComplexF32, :tblis_init_scalar_c, :tblis_init_tensor_c, :tblis_init_tensor_scaled_c),
(:ComplexF64, :tblis_init_scalar_z, :tblis_init_tensor_z, :tblis_init_tensor_scaled_z))
@eval begin
function tblis_scalar(s::$T)
t = Ref{tblis_scalar}()
$tblis_init_scalar(t, s)
return t[]
end
function tblis_tensor(A::AbstractArray{$T,N},
s::Number=one(T),
szA::Vector{len_type}=collect(len_type, size(A)),
strA::Vector{stride_type}=collect(stride_type, strides(A))) where {N}
t = Ref{tblis_tensor}()
if isone(s)
$tblis_init_tensor(t, N, pointer(szA), pointer(A), pointer(strA))
else
$tblis_init_tensor_scaled(t, $T(s), N, pointer(szA), pointer(A),
pointer(strA))
end
return t[]
end
end
end

# tensor operations
# ------------------
function tblis_tensor_add(A::tblis_tensor, idxA, B::tblis_tensor, idxB)
return tblis_tensor_add(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB)
end

function tblis_tensor_mult(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_tensor,
idxC)
return tblis_tensor_mult(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C), idxC)
end

function tblis_tensor_dot(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_scalar)
return tblis_tensor_dot(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C))
end

end
146 changes: 0 additions & 146 deletions src/LibTblis.jl

This file was deleted.

97 changes: 19 additions & 78 deletions src/TensorOperationsTBLIS.jl
Original file line number Diff line number Diff line change
@@ -1,93 +1,34 @@
module TensorOperationsTBLIS

using TensorOperations
using TensorOperations: StridedView, DefaultAllocator, IndexError
using TensorOperations: istrivialpermutation, BlasFloat, linearize
using TensorOperations: argcheck_tensoradd, dimcheck_tensoradd,
argcheck_tensortrace, dimcheck_tensortrace,
argcheck_tensorcontract, dimcheck_tensorcontract
using TensorOperations: Index2Tuple, IndexTuple, linearize, IndexError
using LinearAlgebra: BlasFloat, rmul!
using LinearAlgebra: BlasFloat
using TupleTools

include("LibTblis.jl")
using .LibTblis
include("LibTBLIS.jl")
using .LibTBLIS
using .LibTBLIS: LibTBLIS, len_type, stride_type

export tblis_set_num_threads, tblis_get_num_threads
export tblisBackend
export TBLIS
export get_num_tblis_threads, set_num_tblis_threads

get_num_tblis_threads() = convert(Int, LibTBLIS.tblis_get_num_threads())
set_num_tblis_threads(n) = LibTBLIS.tblis_set_num_threads(convert(Cuint, n))

# TensorOperations
#------------------

struct tblisBackend <: TensorOperations.AbstractBackend end

function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T},
pA::Index2Tuple, conjA::Bool,
α::Number, β::Number,
::tblisBackend) where {T<:BlasFloat}
TensorOperations.argcheck_tensoradd(C, A, pA)
TensorOperations.dimcheck_tensoradd(C, A, pA)

szC = collect(size(C))
strC = collect(strides(C))
C_tblis = tblis_tensor(C, szC, strC, β)

szA = collect(size(A))
strA = collect(strides(A))
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)

einA, einC = TensorOperations.add_labels(pA)
tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))

return C
end

function TensorOperations.tensorcontract!(C::StridedArray{T},
A::StridedArray{T}, pA::Index2Tuple,
conjA::Bool, B::StridedArray{T},
pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple,
α::Number, β::Number,
::tblisBackend) where {T<:BlasFloat}
TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB)
TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB)

rmul!(C, β) # TODO: is it possible to use tblis scaling here?
szC = ndims(C) == 0 ? Int[] : collect(size(C))
strC = ndims(C) == 0 ? Int[] : collect(strides(C))
C_tblis = tblis_tensor(C, szC, strC)

szA = collect(size(A))
strA = collect(strides(A))
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)

szB = collect(size(B))
strB = collect(strides(B))
B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1)

einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB)
tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis,
string(einC...))

return C
end

function TensorOperations.tensortrace!(C::StridedArray{T},
A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple,
conjA::Bool,
α::Number, β::Number,
::tblisBackend) where {T<:BlasFloat}
TensorOperations.argcheck_tensortrace(C, A, p, q)
TensorOperations.dimcheck_tensortrace(C, A, p, q)

rmul!(C, β) # TODO: is it possible to use tblis scaling here?
szC = ndims(C) == 0 ? Int[] : collect(size(C))
strC = ndims(C) == 0 ? Int[] : collect(strides(C))
C_tblis = tblis_tensor(C, szC, strC)

szA = collect(size(A))
strA = collect(strides(A))
A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α)

einA, einC = TensorOperations.trace_labels(p, q)
struct TBLIS <: TensorOperations.AbstractBackend end

tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...))
Base.@deprecate(tblisBackend(), TBLIS())
Base.@deprecate(tblis_get_num_threads(), get_num_tblis_threads())
Base.@deprecate(tblis_set_num_threads(n), set_num_tblis_threads(n))

return C
end
include("strided.jl")

end # module TensorOperationsTBLIS
Loading