|
1 | 1 | module TensorOperationsTBLIS |
2 | 2 |
|
3 | 3 | using TensorOperations |
| 4 | +using TensorOperations: StridedView, DefaultAllocator, IndexError |
| 5 | +using TensorOperations: istrivialpermutation, BlasFloat, linearize |
| 6 | +using TensorOperations: argcheck_tensoradd, dimcheck_tensoradd, |
| 7 | + argcheck_tensortrace, dimcheck_tensortrace, |
| 8 | + argcheck_tensorcontract, dimcheck_tensorcontract |
4 | 9 | using TensorOperations: Index2Tuple, IndexTuple, linearize, IndexError |
5 | | -using LinearAlgebra: BlasFloat, rmul! |
| 10 | +using LinearAlgebra: BlasFloat |
6 | 11 | using TupleTools |
7 | 12 |
|
8 | | -include("LibTblis.jl") |
9 | | -using .LibTblis |
| 13 | +include("LibTBLIS.jl") |
| 14 | +using .LibTBLIS |
| 15 | +using .LibTBLIS: LibTBLIS, len_type, stride_type |
10 | 16 |
|
11 | | -export tblis_set_num_threads, tblis_get_num_threads |
12 | | -export tblisBackend |
| 17 | +export TBLIS |
| 18 | +export get_num_tblis_threads, set_num_tblis_threads |
| 19 | + |
| 20 | +get_num_tblis_threads() = convert(Int, LibTBLIS.tblis_get_num_threads()) |
| 21 | +set_num_tblis_threads(n) = LibTBLIS.tblis_set_num_threads(convert(Cuint, n)) |
13 | 22 |
|
14 | 23 | # TensorOperations |
15 | 24 | #------------------ |
16 | 25 |
|
17 | | -struct tblisBackend <: TensorOperations.AbstractBackend end |
18 | | - |
19 | | -function TensorOperations.tensoradd!(C::StridedArray{T}, A::StridedArray{T}, |
20 | | - pA::Index2Tuple, conjA::Bool, |
21 | | - α::Number, β::Number, |
22 | | - ::tblisBackend) where {T<:BlasFloat} |
23 | | - TensorOperations.argcheck_tensoradd(C, A, pA) |
24 | | - TensorOperations.dimcheck_tensoradd(C, A, pA) |
25 | | - |
26 | | - szC = collect(size(C)) |
27 | | - strC = collect(strides(C)) |
28 | | - C_tblis = tblis_tensor(C, szC, strC, β) |
29 | | - |
30 | | - szA = collect(size(A)) |
31 | | - strA = collect(strides(A)) |
32 | | - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) |
33 | | - |
34 | | - einA, einC = TensorOperations.add_labels(pA) |
35 | | - tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) |
36 | | - |
37 | | - return C |
38 | | -end |
39 | | - |
40 | | -function TensorOperations.tensorcontract!(C::StridedArray{T}, |
41 | | - A::StridedArray{T}, pA::Index2Tuple, |
42 | | - conjA::Bool, B::StridedArray{T}, |
43 | | - pB::Index2Tuple, conjB::Bool, pAB::Index2Tuple, |
44 | | - α::Number, β::Number, |
45 | | - ::tblisBackend) where {T<:BlasFloat} |
46 | | - TensorOperations.argcheck_tensorcontract(C, A, pA, B, pB, pAB) |
47 | | - TensorOperations.dimcheck_tensorcontract(C, A, pA, B, pB, pAB) |
48 | | - |
49 | | - rmul!(C, β) # TODO: is it possible to use tblis scaling here? |
50 | | - szC = ndims(C) == 0 ? Int[] : collect(size(C)) |
51 | | - strC = ndims(C) == 0 ? Int[] : collect(strides(C)) |
52 | | - C_tblis = tblis_tensor(C, szC, strC) |
53 | | - |
54 | | - szA = collect(size(A)) |
55 | | - strA = collect(strides(A)) |
56 | | - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) |
57 | | - |
58 | | - szB = collect(size(B)) |
59 | | - strB = collect(strides(B)) |
60 | | - B_tblis = tblis_tensor(conjB ? conj(B) : B, szB, strB, 1) |
61 | | - |
62 | | - einA, einB, einC = TensorOperations.contract_labels(pA, pB, pAB) |
63 | | - tblis_tensor_mult(A_tblis, string(einA...), B_tblis, string(einB...), C_tblis, |
64 | | - string(einC...)) |
65 | | - |
66 | | - return C |
67 | | -end |
68 | | - |
69 | | -function TensorOperations.tensortrace!(C::StridedArray{T}, |
70 | | - A::StridedArray{T}, p::Index2Tuple, q::Index2Tuple, |
71 | | - conjA::Bool, |
72 | | - α::Number, β::Number, |
73 | | - ::tblisBackend) where {T<:BlasFloat} |
74 | | - TensorOperations.argcheck_tensortrace(C, A, p, q) |
75 | | - TensorOperations.dimcheck_tensortrace(C, A, p, q) |
76 | | - |
77 | | - rmul!(C, β) # TODO: is it possible to use tblis scaling here? |
78 | | - szC = ndims(C) == 0 ? Int[] : collect(size(C)) |
79 | | - strC = ndims(C) == 0 ? Int[] : collect(strides(C)) |
80 | | - C_tblis = tblis_tensor(C, szC, strC) |
81 | | - |
82 | | - szA = collect(size(A)) |
83 | | - strA = collect(strides(A)) |
84 | | - A_tblis = tblis_tensor(conjA ? conj(A) : A, szA, strA, α) |
85 | | - |
86 | | - einA, einC = TensorOperations.trace_labels(p, q) |
| 26 | +struct TBLIS <: TensorOperations.AbstractBackend end |
87 | 27 |
|
88 | | - tblis_tensor_add(A_tblis, string(einA...), C_tblis, string(einC...)) |
| 28 | +Base.@deprecate(tblisBackend(), TBLIS()) |
| 29 | +Base.@deprecate(tblis_get_num_threads(), get_num_tblis_threads()) |
| 30 | +Base.@deprecate(tblis_set_num_threads(n), set_num_tblis_threads(n)) |
89 | 31 |
|
90 | | - return C |
91 | | -end |
| 32 | +include("strided.jl") |
92 | 33 |
|
93 | 34 | end # module TensorOperationsTBLIS |
0 commit comments