Skip to content

Commit 5d5c95b

Browse files
lkdvosJutho
authored andcommitted
Apply code suggestions treetransformers
1 parent cae170a commit 5d5c95b

File tree

1 file changed

+33
-83
lines changed

1 file changed

+33
-83
lines changed

src/tensors/treetransformers.jl

Lines changed: 33 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ function TreeTransformer(transform::Function, Vsrc::HomSpace{S},
5050
cols = Int[]
5151
vals = sectorscalartype(sectortype(Vdst))[]
5252

53-
for (f1, f2) in structure_src.fusiontreelist
54-
row = structure_src.fusiontreeindices[(f1, f2)]
53+
for (row, (f1, f2)) in enumerate(structure_src.fusiontreelist)
5554
for ((f3, f4), coeff) in transform(f1, f2)
5655
col = structure_dst.fusiontreeindices[(f3, f4)]
5756
push!(rows, row)
@@ -70,88 +69,39 @@ function TreeTransformer(transform::Function, Vsrc::HomSpace{S},
7069
end
7170
end
7271

73-
# Transpose
74-
# ---------
75-
const treetransposercache = LRU{Any,Any}(; maxsize=10^5)
76-
const usetreetransposercache = Ref{Bool}(true)
72+
for (transform, transformer) in
73+
((:permute, :permuter), (:braid, :braider), (:transpose, :transposer))
74+
treetransformcache = Symbol("tree", transformer, "cache")
75+
usetreetransformcache = Symbol("usetree", transformer, "cache")
76+
treetransformer = Symbol("tree", transformer)
77+
_get_treetransformer = Symbol("_get_", treetransformer)
78+
_treetransformer = Symbol("_", treetransformer)
7779

78-
function treetransposer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple)
79-
return fusiontreetransform(f1, f2) = transpose(f1, f2, p...)
80-
end
81-
function treetransposer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple)
82-
if usetreetransposercache[]
83-
key = (space(tdst), space(tsrc), p)
84-
A = treetransformertype(space(tdst), space(tsrc))
85-
return _get_treetransposer(A, key)
86-
else
87-
return _treetransposer((space(tdst), space(tsrc), p))
88-
end
89-
end
90-
@noinline function _get_treetransposer(A, key)
91-
d::A = get!(treetransposercache, key) do
92-
return _treetransposer(key)
93-
end
94-
return d
95-
end
96-
function _treetransposer((Vdst, Vsrc, p))
97-
fusiontreetransform(f1, f2) = transpose(f1, f2, p...)
98-
return TreeTransformer(fusiontreetransform, Vsrc, Vdst)
99-
end
100-
101-
# Braid
102-
# -----
103-
const treebraidercache = LRU{Any,Any}(; maxsize=10^5)
104-
const usetreebraidercache = Ref{Bool}(true)
105-
106-
function treebraider(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple,
107-
l::Index2Tuple)
108-
return fusiontreetransform(f1, f2) = braid(f1, f2, p..., l...)
109-
end
110-
function treebraider(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple,
111-
l::Index2Tuple)
112-
if usetreebraidercache[]
113-
key = (space(tdst), space(tsrc), p, l)
114-
A = treetransformertype(space(tdst), space(tsrc))
115-
return _get_treebraider(A, key)
116-
else
117-
return _treebraider((space(tdst), space(tsrc), p, l))
118-
end
119-
end
120-
@noinline function _get_treebraider(A, key)
121-
d::A = get!(treebraidercache, key) do
122-
return _treebraider(key)
123-
end
124-
return d
125-
end
126-
function _treebraider((Vdst, Vsrc, p, l))
127-
fusiontreetransform(f1, f2) = braid(f1, f2, p..., l...)
128-
return TreeTransformer(fusiontreetransform, Vsrc, Vdst)
129-
end
130-
131-
# Permute
132-
# -------
133-
const treepermutercache = LRU{Any,Any}(; maxsize=10^5)
134-
const usetreepermutercache = Ref{Bool}(true)
80+
@eval begin
81+
const $treetransformcache = LRU{Any,Any}(; maxsize=10^5)
82+
const $usetreetransformcache = Ref{Bool}(true)
13583

136-
function treepermuter(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple)
137-
return fusiontreetransform(f1, f2) = permute(f1, f2, p...)
138-
end
139-
function treepermuter(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple)
140-
if usetreepermutercache[]
141-
key = (space(tdst), space(tsrc), p)
142-
A = treetransformertype(space(tdst), space(tsrc))
143-
return _get_treepermuter(A, key)
144-
else
145-
return _treepermuter((space(tdst), space(tsrc), p))
146-
end
147-
end
148-
@noinline function _get_treepermuter(A, key)
149-
d::A = get!(treepermutercache, key) do
150-
return _treepermuter(key)
84+
function $treetransformer(::AbstractTensorMap, ::AbstractTensorMap, p::Index2Tuple)
85+
return fusiontreetransform(f1, f2) = $transform(f1, f2, p...)
86+
end
87+
function $treetransformer(tdst::TensorMap, tsrc::TensorMap, p::Index2Tuple)
88+
if $usetreetransformcache[]
89+
key = (space(tdst), space(tsrc), p)
90+
A = treetransformertype(space(tdst), space(tsrc))
91+
return $_get_treetransformer(A, key)
92+
else
93+
return $_treetransformer((space(tdst), space(tsrc), p))
94+
end
95+
end
96+
@noinline function $_get_treetransformer(A, key)
97+
d::A = get!($treetransformcache, key) do
98+
return $_treetransformer(key)
99+
end
100+
return d
101+
end
102+
function $_treetransformer((Vdst, Vsrc, p))
103+
fusiontreetransform(f1, f2) = $transform(f1, f2, p...)
104+
return TreeTransformer(fusiontreetransform, Vsrc, Vdst)
105+
end
151106
end
152-
return d
153-
end
154-
function _treepermuter((Vdst, Vsrc, p))
155-
fusiontreetransform(f1, f2) = permute(f1, f2, p...)
156-
return TreeTransformer(fusiontreetransform, Vsrc, Vdst)
157107
end

0 commit comments

Comments
 (0)