-
Notifications
You must be signed in to change notification settings - Fork 1
Refactor and add StridedView support
#5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
d837949
complete overhaul
Jutho 118e5d3
improve tests and fix other comments
Jutho d2f747c
add TBLIS docstring
Jutho 0fbff61
Update author email
lkdvos a079d1f
`get_num_tblis_threads` to unexported `get_num_threads`
lkdvos e3ef895
Update README
lkdvos efc9fec
Add codecov ignore for `lib`
lkdvos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| name = "TensorOperationsTBLIS" | ||
| uuid = "1e289f0c-8058-4c3e-8acf-f8ef036bd865" | ||
| authors = ["lkdvos <[email protected]>"] | ||
| version = "0.2.0" | ||
| authors = ["Lukas Devos <lukas.devos@ugent.be>", "Jutho Haegeman <jutho.haegeman@ugent.be>"] | ||
| version = "0.3.0" | ||
|
|
||
| [deps] | ||
| Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" | ||
|
|
@@ -11,14 +11,17 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" | |
| tblis_jll = "9c7f617c-f299-5d18-afb6-044c7798b3d0" | ||
|
|
||
| [compat] | ||
| Random = "1" | ||
| TensorOperations = "5" | ||
| TupleTools = "1" | ||
| Test = "1" | ||
| julia = "1.8" | ||
| tblis_jll = "1.2" | ||
|
|
||
| [extras] | ||
| LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
| Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
| Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
|
||
| [targets] | ||
| test = ["Test", "LinearAlgebra"] | ||
| test = ["Test", "LinearAlgebra", "Random"] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| module LibTBLIS | ||
|
|
||
| using tblis_jll | ||
| using LinearAlgebra: BlasFloat | ||
|
|
||
| export tblis_scalar, tblis_tensor | ||
| export tblis_tensor_add, tblis_tensor_mult, tblis_tensor_dot | ||
| export tblis_set_num_threads, tblis_get_num_threads | ||
|
|
||
| const ptrdiff_t = Cptrdiff_t | ||
|
|
||
| const scomplex = ComplexF32 | ||
| const dcomplex = ComplexF64 | ||
|
|
||
| const IS_LIBC_MUSL = occursin("musl", Base.BUILD_TRIPLET) | ||
| if Sys.isapple() && Sys.ARCH === :aarch64 | ||
| include("lib/aarch64-apple-darwin20.jl") | ||
| elseif Sys.islinux() && Sys.ARCH === :aarch64 && !IS_LIBC_MUSL | ||
| include("lib/aarch64-linux-gnu.jl") | ||
| elseif Sys.islinux() && Sys.ARCH === :aarch64 && IS_LIBC_MUSL | ||
| include("lib/aarch64-linux-musl.jl") | ||
| elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && !IS_LIBC_MUSL | ||
| include("lib/armv7l-linux-gnueabihf.jl") | ||
| elseif Sys.islinux() && startswith(string(Sys.ARCH), "arm") && IS_LIBC_MUSL | ||
| include("lib/armv7l-linux-musleabihf.jl") | ||
| elseif Sys.islinux() && Sys.ARCH === :i686 && !IS_LIBC_MUSL | ||
| include("lib/i686-linux-gnu.jl") | ||
| elseif Sys.islinux() && Sys.ARCH === :i686 && IS_LIBC_MUSL | ||
| include("lib/i686-linux-musl.jl") | ||
| elseif Sys.iswindows() && Sys.ARCH === :i686 | ||
| include("lib/i686-w64-mingw32.jl") | ||
| elseif Sys.islinux() && Sys.ARCH === :powerpc64le | ||
| include("lib/powerpc64le-linux-gnu.jl") | ||
| elseif Sys.isapple() && Sys.ARCH === :x86_64 | ||
| include("lib/x86_64-apple-darwin14.jl") | ||
| elseif Sys.islinux() && Sys.ARCH === :x86_64 && !IS_LIBC_MUSL | ||
| include("lib/x86_64-linux-gnu.jl") | ||
| elseif Sys.islinux() && Sys.ARCH === :x86_64 && IS_LIBC_MUSL | ||
| include("lib/x86_64-linux-musl.jl") | ||
| elseif Sys.isbsd() && !Sys.isapple() | ||
| include("lib/x86_64-unknown-freebsd.jl") | ||
| elseif Sys.iswindows() && Sys.ARCH === :x86_64 | ||
| include("lib/x86_64-w64-mingw32.jl") | ||
| else | ||
| error("Unknown platform: $(Base.BUILD_TRIPLET)") | ||
| end | ||
|
|
||
| # tblis_scalar and tblis_tensor | ||
| # ----------------------------- | ||
| """ | ||
| tblis_scalar(s::Number) | ||
|
|
||
| Initializes a tblis scalar from a number. | ||
| """ | ||
| function tblis_scalar end | ||
|
|
||
| """ | ||
| tblis_tensor(A::AbstractArray{T<:BlasFloat}, [szA::Vector{Int}, strA::Vector{Int}, scalar::Number]) | ||
|
|
||
| Initializes a tblis tensor from an array that should be strided and admit a pointer to its | ||
| data. This operation is deemed unsafe, in the sense that the user is responsible for ensuring | ||
| that the reference to the array and the sizes and strides stays alive during the lifetime of | ||
| the tensor. | ||
| """ | ||
| function tblis_tensor end | ||
|
|
||
| for (T, tblis_init_scalar, tblis_init_tensor, tblis_init_tensor_scaled) in | ||
| ((:Float32, :tblis_init_scalar_s, :tblis_init_tensor_s, :tblis_init_tensor_scaled_s), | ||
| (:Float64, :tblis_init_scalar_d, :tblis_init_tensor_d, :tblis_init_tensor_scaled_d), | ||
| (:ComplexF32, :tblis_init_scalar_c, :tblis_init_tensor_c, :tblis_init_tensor_scaled_c), | ||
| (:ComplexF64, :tblis_init_scalar_z, :tblis_init_tensor_z, :tblis_init_tensor_scaled_z)) | ||
| @eval begin | ||
| function tblis_scalar(s::$T) | ||
| t = Ref{tblis_scalar}() | ||
| $tblis_init_scalar(t, s) | ||
| return t[] | ||
| end | ||
| function tblis_tensor(A::AbstractArray{$T,N}, | ||
| s::Number=one(T), | ||
| szA::Vector{len_type}=collect(len_type, size(A)), | ||
| strA::Vector{stride_type}=collect(stride_type, strides(A))) where {N} | ||
| t = Ref{tblis_tensor}() | ||
| if isone(s) | ||
| $tblis_init_tensor(t, N, pointer(szA), pointer(A), pointer(strA)) | ||
| else | ||
| $tblis_init_tensor_scaled(t, $T(s), N, pointer(szA), pointer(A), | ||
| pointer(strA)) | ||
| end | ||
| return t[] | ||
| end | ||
| end | ||
| end | ||
|
|
||
| # tensor operations | ||
| # ------------------ | ||
| function tblis_tensor_add(A::tblis_tensor, idxA, B::tblis_tensor, idxB) | ||
| return tblis_tensor_add(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB) | ||
| end | ||
|
|
||
| function tblis_tensor_mult(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_tensor, | ||
| idxC) | ||
| return tblis_tensor_mult(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C), idxC) | ||
| end | ||
|
|
||
| function tblis_tensor_dot(A::tblis_tensor, idxA, B::tblis_tensor, idxB, C::tblis_scalar) | ||
| return tblis_tensor_dot(C_NULL, C_NULL, Ref(A), idxA, Ref(B), idxB, Ref(C)) | ||
| end | ||
|
|
||
| end |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,93 +1,34 @@ | ||
| module TensorOperationsTBLIS | ||
|
|
||
| using TensorOperations | ||
| using TensorOperations: StridedView, DefaultAllocator, IndexError | ||
| using TensorOperations: istrivialpermutation, BlasFloat, linearize | ||
| using TensorOperations: argcheck_tensoradd, dimcheck_tensoradd, | ||
| argcheck_tensortrace, dimcheck_tensortrace, | ||
| argcheck_tensorcontract, dimcheck_tensorcontract | ||
| using TensorOperations: Index2Tuple, IndexTuple, linearize, IndexError | ||
| using LinearAlgebra: BlasFloat, rmul! | ||
| using LinearAlgebra: BlasFloat | ||
| using TupleTools | ||
|
|
||
| include("LibTblis.jl") | ||
| using .LibTblis | ||
| include("LibTBLIS.jl") | ||
| using .LibTBLIS | ||
| using .LibTBLIS: LibTBLIS, len_type, stride_type | ||
|
|
||
| export tblis_set_num_threads, tblis_get_num_threads | ||
| export tblisBackend | ||
| export TBLIS | ||
| export get_num_tblis_threads, set_num_tblis_threads | ||
|
|
||
| get_num_tblis_threads() = convert(Int, LibTBLIS.tblis_get_num_threads()) | ||
| set_num_tblis_threads(n) = LibTBLIS.tblis_set_num_threads(convert(Cuint, n)) | ||
|
|
||
| # TensorOperations | ||
| #------------------ | ||
|
|
||
| struct tblisBackend <: TensorOperations.AbstractBackend end | ||
|
|
||
| function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T}, | ||
| pA::Index2Tuple, conjA::Bool, | ||
| α::Number, β::Number, | ||
| ::tblisBackend) where {T<:BlasFloat} | ||
| TensorOperations.argcheck_tensoradd(C, A, pA) | ||
| TensorOperations.dimcheck_tensoradd(C, A, pA) | ||
|
|
||
| szC = collect(size(C)) | ||
| strC = collect(strides(C)) | ||
| C_tblis = tblis_tensor(C, szC, strC, β) | ||
|
|
||
| szA = collect(size(A)) | ||
| strA = collect(strides(A)) | ||
| A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) | ||
|
|
||
| einA, einC = TensorOperations.add_labels(pA) | ||
| tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) | ||
|
|
||
| return C | ||
| end | ||
|
|
||
| function TensorOperations.tensorcontract!(C::StridedArray{T}, | ||
| A::StridedArray{T}, pA::Index2Tuple, | ||
| conjA::Bool, B::StridedArray{T}, | ||
| pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, | ||
| α::Number, β::Number, | ||
| ::tblisBackend) where {T<:BlasFloat} | ||
| TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB) | ||
| TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB) | ||
|
|
||
| rmul!(C, β) # TODO: is it possible to use tblis scaling here? | ||
| szC = ndims(C) == 0 ? Int[] : collect(size(C)) | ||
| strC = ndims(C) == 0 ? Int[] : collect(strides(C)) | ||
| C_tblis = tblis_tensor(C, szC, strC) | ||
|
|
||
| szA = collect(size(A)) | ||
| strA = collect(strides(A)) | ||
| A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) | ||
|
|
||
| szB = collect(size(B)) | ||
| strB = collect(strides(B)) | ||
| B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1) | ||
|
|
||
| einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB) | ||
| tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis, | ||
| string(einC...)) | ||
|
|
||
| return C | ||
| end | ||
|
|
||
| function TensorOperations.tensortrace!(C::StridedArray{T}, | ||
| A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple, | ||
| conjA::Bool, | ||
| α::Number, β::Number, | ||
| ::tblisBackend) where {T<:BlasFloat} | ||
| TensorOperations.argcheck_tensortrace(C, A, p, q) | ||
| TensorOperations.dimcheck_tensortrace(C, A, p, q) | ||
|
|
||
| rmul!(C, β) # TODO: is it possible to use tblis scaling here? | ||
| szC = ndims(C) == 0 ? Int[] : collect(size(C)) | ||
| strC = ndims(C) == 0 ? Int[] : collect(strides(C)) | ||
| C_tblis = tblis_tensor(C, szC, strC) | ||
|
|
||
| szA = collect(size(A)) | ||
| strA = collect(strides(A)) | ||
| A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) | ||
|
|
||
| einA, einC = TensorOperations.trace_labels(p, q) | ||
| struct TBLIS <: TensorOperations.AbstractBackend end | ||
lkdvos marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) | ||
| Base.@deprecate(tblisBackend(), TBLIS()) | ||
| Base.@deprecate(tblis_get_num_threads(), get_num_tblis_threads()) | ||
| Base.@deprecate(tblis_set_num_threads(n), set_num_tblis_threads(n)) | ||
|
|
||
| return C | ||
| end | ||
| include("strided.jl") | ||
|
|
||
| end # module TensorOperationsTBLIS | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.