|
1 | 1 | module TensorOperationsBumperExt |
2 | 2 |
|
3 | 3 | using TensorOperations |
| 4 | +using TensorOperations: tensoralloc_add, tensoralloc_contract |
| 5 | +using VectorInterface: One, Zero |
| 6 | +using PrecompileTools |
4 | 7 | using Bumper |
5 | 8 |
|
6 | 9 | # Hack to normalize StridedView type to avoid too many specializations |
@@ -50,4 +53,97 @@ function TensorOperations._butensor(src, ex...) |
50 | 53 | return return Base.remove_linenums!(newex) |
51 | 54 | end |
52 | 55 |
|
| 56 | +if PrecompileTools.workload_enabled(@__MODULE__) |
| 57 | + buf = typeof(Bumper.default_buffer()) |
| 58 | + backend = TensorOperations.DefaultBackend |
| 59 | + |
| 60 | + # tensoradd! |
| 61 | + # ---------- |
| 62 | + for T in TensorOperations.PRECOMPILE_ELTYPES |
| 63 | + for N in 0:(TensorOperations.PRECOMPILE_ADD_NDIMS) |
| 64 | + TA = Array{T,N} |
| 65 | + pA = Index2Tuple{N,0} |
| 66 | + TA_buf = Core.Compiler.return_type(tensoralloc_add, |
| 67 | + Tuple{T,TA,pA,Bool,Val{true},buf}) |
| 68 | + for (C, A) in Iterators.product((TA, TA_buf), (TA, TA_buf)) |
| 69 | + C == A == TA && continue |
| 70 | + precompile(tensoradd!, (C, A, pA, Bool, One, Zero)) |
| 71 | + precompile(tensoradd!, (C, A, pA, Bool, T, Zero)) |
| 72 | + precompile(tensoradd!, (C, A, pA, Bool, T, T)) |
| 73 | + end |
| 74 | + |
| 75 | + precompile(tensoralloc_add, (T, TA_buf, pA, Bool, Val{true}, buf)) |
| 76 | + precompile(tensoralloc_add, (T, TA_buf, pA, Bool, Val{false}, buf)) |
| 77 | + end |
| 78 | + end |
| 79 | + |
| 80 | + # tensortrace! |
| 81 | + # ------------ |
| 82 | + for T in TensorOperations.PRECOMPILE_ELTYPES |
| 83 | + for N1 in 0:TensorOperations.PRECOMPILE_TRACE_NDIMS[1], |
| 84 | + N2 in 0:TensorOperations.PRECOMPILE_TRACE_NDIMS[2] |
| 85 | + |
| 86 | + TC = Array{T,N1} |
| 87 | + TA = Array{T,N1 + 2N2} |
| 88 | + p = Index2Tuple{N1,0} |
| 89 | + q = Index2Tuple{N2,N2} |
| 90 | + r = Index2Tuple{N1 + 2N2,0} |
| 91 | + |
| 92 | + TA_buf = Core.Compiler.return_type(tensoralloc_add, |
| 93 | + Tuple{T,TA,r,Bool,Val{true},buf}) |
| 94 | + TC_buf = Core.Compiler.return_type(tensoralloc_add, |
| 95 | + Tuple{T,TA,p,Bool,Val{true},buf}) |
| 96 | + |
| 97 | + for (C, A) in Iterators.product((TC, TC_buf), (TA, TA_buf)) |
| 98 | + C == TC && A == TA && continue |
| 99 | + precompile(tensortrace!, (C, A, p, q, Bool, One, Zero)) |
| 100 | + precompile(tensortrace!, (C, A, p, q, Bool, T, Zero)) |
| 101 | + precompile(tensortrace!, (C, A, p, q, Bool, T, T)) |
| 102 | + end |
| 103 | + |
| 104 | + # allocation re-uses tensoralloc_add |
| 105 | + end |
| 106 | + end |
| 107 | + |
| 108 | + # tensorcontract! |
| 109 | + # --------------- |
| 110 | + for T in TensorOperations.PRECOMPILE_ELTYPES |
| 111 | + for N1 in 0:TensorOperations.PRECOMPILE_CONTRACT_NDIMS[1], |
| 112 | + N2 in 0:TensorOperations.PRECOMPILE_CONTRACT_NDIMS[2], |
| 113 | + N3 in 0:TensorOperations.PRECOMPILE_CONTRACT_NDIMS[1] |
| 114 | + |
| 115 | + NA = N1 + N2 |
| 116 | + NB = N2 + N3 |
| 117 | + NC = N1 + N3 |
| 118 | + TC, TA, TB = Array{T,NC}, Array{T,NA}, Array{T,NB} |
| 119 | + pA = Index2Tuple{N1,N2} |
| 120 | + pB = Index2Tuple{N2,N3} |
| 121 | + pAB = Index2Tuple{NC,0} |
| 122 | + |
| 123 | + TC_buf = Core.Compiler.return_type(tensoralloc_contract, |
| 124 | + Tuple{T,TA,pA,Bool,TB,pB,Bool,pAB,Val{true}, |
| 125 | + buf}) |
| 126 | + TA_buf = Core.Compiler.return_type(tensoralloc_add, |
| 127 | + Tuple{T,TA,pA,Bool,Val{true},buf}) |
| 128 | + TB_buf = Core.Compiler.return_type(tensoralloc_add, |
| 129 | + Tuple{T,TB,pB,Bool,Val{true},buf}) |
| 130 | + for (C, A, B) in Iterators.product((TC, TC_buf), (TA, TA_buf), (TB, TB_buf)) |
| 131 | + precompile(tensorcontract!, |
| 132 | + (C, A, pA, Bool, B, pB, Bool, pAB, One, Zero, backend, buf)) |
| 133 | + precompile(tensorcontract!, |
| 134 | + (C, A, pA, Bool, B, pB, Bool, pAB, T, Zero, backend, buf)) |
| 135 | + precompile(tensorcontract!, |
| 136 | + (C, A, pA, Bool, B, pB, Bool, pAB, T, T, backend, buf)) |
| 137 | + end |
| 138 | + |
| 139 | + for (A, B) in Iterators.product((TA, TA_buf), (TB, TB_buf)) |
| 140 | + precompile(tensoralloc_contract, |
| 141 | + (T, A, pA, Bool, B, pB, Bool, pAB, Val{true}, buf)) |
| 142 | + precompile(tensoralloc_contract, |
| 143 | + (T, A, pA, Bool, B, pB, Bool, pAB, Val{false}, buf)) |
| 144 | + end |
| 145 | + end |
| 146 | + end |
| 147 | +end |
| 148 | + |
53 | 149 | end |
0 commit comments