Skip to content

Commit 7164560

Browse files
committed
Include Bumper precompilation
1 parent 7721be6 commit 7164560

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

ext/TensorOperationsBumperExt.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
module TensorOperationsBumperExt
22

33
using TensorOperations
4+
using TensorOperations: tensoralloc_add, tensoralloc_contract
5+
using VectorInterface: One, Zero
6+
using PrecompileTools
47
using Bumper
58

69
# Hack to normalize StridedView type to avoid too many specializations
@@ -50,4 +53,97 @@ function TensorOperations._butensor(src, ex...)
5053
return return Base.remove_linenums!(newex)
5154
end
5255

56+
if PrecompileTools.workload_enabled(@__MODULE__)
57+
buf = typeof(Bumper.default_buffer())
58+
backend = TensorOperations.DefaultBackend
59+
60+
# tensoradd!
61+
# ----------
62+
for T in TensorOperations.PRECOMPILE_ELTYPES
63+
for N in 0:(TensorOperations.PRECOMPILE_ADD_NDIMS)
64+
TA = Array{T,N}
65+
pA = Index2Tuple{N,0}
66+
TA_buf = Core.Compiler.return_type(tensoralloc_add,
67+
Tuple{T,TA,pA,Bool,Val{true},buf})
68+
for (C, A) in Iterators.product((TA, TA_buf), (TA, TA_buf))
69+
C == A == TA && continue
70+
precompile(tensoradd!, (C, A, pA, Bool, One, Zero))
71+
precompile(tensoradd!, (C, A, pA, Bool, T, Zero))
72+
precompile(tensoradd!, (C, A, pA, Bool, T, T))
73+
end
74+
75+
precompile(tensoralloc_add, (T, TA_buf, pA, Bool, Val{true}, buf))
76+
precompile(tensoralloc_add, (T, TA_buf, pA, Bool, Val{false}, buf))
77+
end
78+
end
79+
80+
# tensortrace!
81+
# ------------
82+
for T in TensorOperations.PRECOMPILE_ELTYPES
83+
for N1 in 0:TensorOperations.PRECOMPILE_TRACE_NDIMS[1],
84+
N2 in 0:TensorOperations.PRECOMPILE_TRACE_NDIMS[2]
85+
86+
TC = Array{T,N1}
87+
TA = Array{T,N1 + 2N2}
88+
p = Index2Tuple{N1,0}
89+
q = Index2Tuple{N2,N2}
90+
r = Index2Tuple{N1 + 2N2,0}
91+
92+
TA_buf = Core.Compiler.return_type(tensoralloc_add,
93+
Tuple{T,TA,r,Bool,Val{true},buf})
94+
TC_buf = Core.Compiler.return_type(tensoralloc_add,
95+
Tuple{T,TA,p,Bool,Val{true},buf})
96+
97+
for (C, A) in Iterators.product((TC, TC_buf), (TA, TA_buf))
98+
C == TC && A == TA && continue
99+
precompile(tensortrace!, (C, A, p, q, Bool, One, Zero))
100+
precompile(tensortrace!, (C, A, p, q, Bool, T, Zero))
101+
precompile(tensortrace!, (C, A, p, q, Bool, T, T))
102+
end
103+
104+
# allocation re-uses tensoralloc_add
105+
end
106+
end
107+
108+
# tensorcontract!
109+
# ---------------
110+
for T in TensorOperations.PRECOMPILE_ELTYPES
111+
for N1 in 0:TensorOperations.PRECOMPILE_CONTRACT_NDIMS[1],
112+
N2 in 0:TensorOperations.PRECOMPILE_CONTRACT_NDIMS[2],
113+
N3 in 0:TensorOperations.PRECOMPILE_CONTRACT_NDIMS[1]
114+
115+
NA = N1 + N2
116+
NB = N2 + N3
117+
NC = N1 + N3
118+
TC, TA, TB = Array{T,NC}, Array{T,NA}, Array{T,NB}
119+
pA = Index2Tuple{N1,N2}
120+
pB = Index2Tuple{N2,N3}
121+
pAB = Index2Tuple{NC,0}
122+
123+
TC_buf = Core.Compiler.return_type(tensoralloc_contract,
124+
Tuple{T,TA,pA,Bool,TB,pB,Bool,pAB,Val{true},
125+
buf})
126+
TA_buf = Core.Compiler.return_type(tensoralloc_add,
127+
Tuple{T,TA,pA,Bool,Val{true},buf})
128+
TB_buf = Core.Compiler.return_type(tensoralloc_add,
129+
Tuple{T,TB,pB,Bool,Val{true},buf})
130+
for (C, A, B) in Iterators.product((TC, TC_buf), (TA, TA_buf), (TB, TB_buf))
131+
precompile(tensorcontract!,
132+
(C, A, pA, Bool, B, pB, Bool, pAB, One, Zero, backend, buf))
133+
precompile(tensorcontract!,
134+
(C, A, pA, Bool, B, pB, Bool, pAB, T, Zero, backend, buf))
135+
precompile(tensorcontract!,
136+
(C, A, pA, Bool, B, pB, Bool, pAB, T, T, backend, buf))
137+
end
138+
139+
for (A, B) in Iterators.product((TA, TA_buf), (TB, TB_buf))
140+
precompile(tensoralloc_contract,
141+
(T, A, pA, Bool, B, pB, Bool, pAB, Val{true}, buf))
142+
precompile(tensoralloc_contract,
143+
(T, A, pA, Bool, B, pB, Bool, pAB, Val{false}, buf))
144+
end
145+
end
146+
end
147+
end
148+
53149
end

0 commit comments

Comments
 (0)