Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ export scalar, add!, contract!
# truncation schemes
export notrunc, truncerr, truncdim, truncspace, truncbelow

# cache management
export empty_globalcaches!

# Imports
#---------
using TupleTools
Expand Down Expand Up @@ -134,6 +137,7 @@ using PackageExtensionCompat
# Auxiliary files
#-----------------
include("auxiliary/auxiliary.jl")
include("auxiliary/caches.jl")
include("auxiliary/dicts.jl")
include("auxiliary/iterators.jl")
include("auxiliary/linalg.jl")
Expand Down
148 changes: 148 additions & 0 deletions src/auxiliary/caches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
const GLOBAL_CACHES = Any[]
function empty_globalcaches!()
foreach(empty!, GLOBAL_CACHES)
return nothing
end

abstract type CacheStyle end
struct NoCache <: CacheStyle end
struct TaskLocalCache{D<:AbstractDict} <: CacheStyle end
struct GlobalLRUCache <: CacheStyle end

const DEFAULT_GLOBALCACHE_SIZE = Ref(10^5)

function CacheStyle(args...)
return GlobalLRUCache()
end

macro cached(ex)
Meta.isexpr(ex, :function) ||
error("cached macro can only be used on function definitions")
fcall = ex.args[1]
if Meta.isexpr(fcall, :where)
hasparams = true
params = fcall.args[2:end]
fcall = fcall.args[1]
else
hasparams = false
end
if Meta.isexpr(fcall, :(::))
typed = true
typeex = fcall.args[2]
fcall = fcall.args[1]
else
typed = false

Check warning on line 34 in src/auxiliary/caches.jl

View check run for this annotation

Codecov / codecov/patch

src/auxiliary/caches.jl#L34

Added line #L34 was not covered by tests
end
Meta.isexpr(fcall, :call) ||
error("cached macro can only be used on function definitions")
fname = fcall.args[1]
fargs = fcall.args[2:end]
fargnames = map(fargs) do arg
if Meta.isexpr(arg, :(::))
return arg.args[1]
else
return arg
end
end
_fbody = ex.args[2]

# actual implenetation, with underscore name
_fname = Symbol(:_, fname)
_fcall = Expr(:call, _fname, fargs...)
if hasparams
_fcall = Expr(:where, _fcall, params...)
end
_fex = Expr(:function, _fcall, _fbody)

# implementation that chooses the cache style
newfcall = fcall
if hasparams
newfcall = Expr(:where, newfcall, params...)
end
cachestylevar = gensym(:cachestyle)
cachestyleex = Expr(:(=), cachestylevar,
Expr(:call, :CacheStyle, fname, fargnames...))
newfbody = Expr(:block,
cachestyleex,
Expr(:call, fname, fargnames..., cachestylevar))
newfex = Expr(:function, newfcall, newfbody)

# nocache implementation
fnocachecall = Expr(:call, fname, fargs..., :(::NoCache))
if hasparams
fnocachecall = Expr(:where, fnocachecall, params...)
end
fnocachebody = Expr(:call, _fname, fargnames...)
if typed
T = gensym(:T)
fnocachebody = Expr(:block, Expr(:(=), T, typeex), Expr(:(::), fnocachebody, T))
end
fnocacheex = Expr(:function, fnocachecall, fnocachebody)

# tasklocal cache implementation
Dvar = gensym(:D)
flocalcachecall = Expr(:call, fname, fargs..., :(::TaskLocalCache{$Dvar}))
if hasparams
flocalcachecall = Expr(:where, flocalcachecall, params..., Dvar)
else
flocalcachecall = Expr(:where, flocalcachecall, Dvar)
end
localcachename = Symbol(:_tasklocal_, fname, :_cache)
cachevar = gensym(:cache)
getlocalcacheex = :($cachevar::$Dvar = get!(task_local_storage(), $localcachename) do
return $Dvar()

Check warning on line 93 in src/auxiliary/caches.jl

View check run for this annotation

Codecov / codecov/patch

src/auxiliary/caches.jl#L93

Added line #L93 was not covered by tests
end)
valvar = gensym(:val)
if length(fargnames) == 1
key = fargnames[1]
else
key = Expr(:tuple, fargnames...)
end
getvalex = :(get!($cachevar, $key) do
return $_fname($(fargnames...))
end)
if typed
T = gensym(:T)
flocalcachebody = Expr(:block,
getlocalcacheex,
Expr(:(=), T, typeex),
Expr(:(=), Expr(:(::), valvar, T), getvalex),
Expr(:return, valvar))
else
flocalcachebody = Expr(:block,

Check warning on line 112 in src/auxiliary/caches.jl

View check run for this annotation

Codecov / codecov/patch

src/auxiliary/caches.jl#L112

Added line #L112 was not covered by tests
getlocalcacheex,
Expr(:(=), valvar, getvalex),
Expr(:return, valvar))
end
flocalcacheex = Expr(:function, flocalcachecall, flocalcachebody)

# # global cache implementation
fglobalcachecall = Expr(:call, fname, fargs..., :(::GlobalLRUCache))
if hasparams
fglobalcachecall = Expr(:where, fglobalcachecall, params...)
end
globalcachename = Symbol(:GLOBAL_, uppercase(string(fname)), :_CACHE)
getglobalcachex = Expr(:(=), cachevar, globalcachename)
if typed
T = gensym(:T)
fglobalcachebody = Expr(:block,
getglobalcachex,
Expr(:(=), T, typeex),
Expr(:(=), Expr(:(::), valvar, T), getvalex),
Expr(:return, valvar))
else
fglobalcachebody = Expr(:block,

Check warning on line 134 in src/auxiliary/caches.jl

View check run for this annotation

Codecov / codecov/patch

src/auxiliary/caches.jl#L134

Added line #L134 was not covered by tests
getglobalcachex,
Expr(:(=), valvar, getvalex),
Expr(:return, valvar))
end
fglobalcacheex = Expr(:function, fglobalcachecall, fglobalcachebody)
fglobalcachedef = Expr(:const,
Expr(:(=), globalcachename,
:(LRU{Any,Any}(; maxsize=DEFAULT_GLOBALCACHE_SIZE[]))))
fglobalcacheregister = Expr(:call, :push!, :GLOBAL_CACHES, globalcachename)

# # total expression
return esc(Expr(:block, _fex, newfex, fnocacheex, flocalcacheex,
fglobalcachedef, fglobalcacheregister, fglobalcacheex))
end
97 changes: 37 additions & 60 deletions src/fusiontrees/manipulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,6 @@ function _recursive_repartition(f₁::FusionTree{I,N₁},
end
end

# transpose double fusion tree
const transposecache = LRU{Any,Any}(; maxsize=10^5)
const usetransposecache = Ref{Bool}(true)

"""
transpose(f₁::FusionTree{I}, f₂::FusionTree{I},
p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂}
Expand All @@ -548,28 +544,24 @@ function Base.transpose(f₁::FusionTree{I}, f₂::FusionTree{I},
@assert length(f₁) + length(f₂) == N
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
@assert iscyclicpermutation(p)
if usetransposecache[]
T = sectorscalartype(I)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = fusiontreedict(I){Tuple{F₁,F₂},T}
return _get_transpose(D, (f₁, f₂, p1, p2))
else
return _transpose((f₁, f₂, p1, p2))
end
return fstranspose((f₁, f₂, p1, p2))
end

@noinline function _get_transpose(::Type{D}, @nospecialize(key)) where {D}
d::D = get!(transposecache, key) do
return _transpose(key)
end
return d
end
const FSTransposeKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple{N₁},IndexTuple{N₂}}

const TransposeKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple{N₁},IndexTuple{N₂}}
Base.@pure function _fsdicttype(I, N₁, N₂)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
T = sectorscalartype(I)
return fusiontreedict(I){Tuple{F₁,F₂},T}
end

function _transpose((f₁, f₂, p1, p2)::TransposeKey{I,N₁,N₂}) where {I<:Sector,N₁,N₂}
@cached function fstranspose(key::FSTransposeKey{I,N₁,N₂})::_fsdicttype(I, N₁,
N₂) where {I<:Sector,
N₁,
N₂}
f₁, f₂, p1, p2 = key
N = N₁ + N₂
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
newtrees = repartition(f₁, f₂, N₁)
Expand Down Expand Up @@ -611,6 +603,14 @@ function _transpose((f₁, f₂, p1, p2)::TransposeKey{I,N₁,N₂}) where {I<:S
return newtrees
end

function CacheStyle(::typeof(fstranspose), k::FSTransposeKey{I}) where {I<:Sector}
if FusionStyle(I) isa UniqueFusion
return NoCache()
else
return GlobalLRUCache()
end
end

# COMPOSITE DUALITY MANIPULATIONS PART 2: Planar traces
#-------------------------------------------------------------------
# -> composite manipulations that depend on the duality (rigidity) and pivotal structure
Expand Down Expand Up @@ -1015,10 +1015,6 @@ function permute(f::FusionTree{I,N}, p::NTuple{N,Int}) where {I<:Sector,N}
end

# braid double fusion tree
const braidcache = LRU{Any,Any}(; maxsize=10^5)
const usebraidcache_abelian = Ref{Bool}(false)
const usebraidcache_nonabelian = Ref{Bool}(true)

"""
braid(f₁::FusionTree{I}, f₂::FusionTree{I},
levels1::IndexTuple, levels2::IndexTuple,
Expand All @@ -1043,42 +1039,15 @@ function braid(f₁::FusionTree{I}, f₂::FusionTree{I},
@assert length(f₁) + length(f₂) == N₁ + N₂
@assert length(f₁) == length(levels1) && length(f₂) == length(levels2)
@assert TupleTools.isperm((p1..., p2...))
if FusionStyle(f₁) isa UniqueFusion &&
BraidingStyle(f₁) isa SymmetricBraiding
if usebraidcache_abelian[]
T = Int # do we hardcode this ?
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = SingletonDict{Tuple{F₁,F₂},T}
return _get_braid(D, (f₁, f₂, levels1, levels2, p1, p2))
else
return _braid((f₁, f₂, levels1, levels2, p1, p2))
end
else
if usebraidcache_nonabelian[]
T = sectorscalartype(I)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
D = FusionTreeDict{Tuple{F₁,F₂},T}
return _get_braid(D, (f₁, f₂, levels1, levels2, p1, p2))
else
return _braid((f₁, f₂, levels1, levels2, p1, p2))
end
end
return fsbraid((f₁, f₂, levels1, levels2, p1, p2))
end
const FSBraidKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple,IndexTuple,
IndexTuple{N₁},IndexTuple{N₂}}

@noinline function _get_braid(::Type{D}, @nospecialize(key)) where {D}
d::D = get!(braidcache, key) do
return _braid(key)
end
return d
end

const BraidKey{I<:Sector,N₁,N₂} = Tuple{<:FusionTree{I},<:FusionTree{I},
IndexTuple,IndexTuple,
IndexTuple{N₁},IndexTuple{N₂}}

function _braid((f₁, f₂, l1, l2, p1, p2)::BraidKey{I,N₁,N₂}) where {I<:Sector,N₁,N₂}
@cached function fsbraid(key::FSBraidKey{I,N₁,N₂})::_fsdicttype(I, N₁,
N₂) where {I<:Sector,N₁,N₂}
(f₁, f₂, l1, l2, p1, p2) = key
p = linearizepermutation(p1, p2, length(f₁), length(f₂))
levels = (l1..., reverse(l2)...)
local newtrees
Expand All @@ -1097,6 +1066,14 @@ function _braid((f₁, f₂, l1, l2, p1, p2)::BraidKey{I,N₁,N₂}) where {I<:S
return newtrees
end

function CacheStyle(::typeof(fsbraid), k::FSBraidKey{I}) where {I<:Sector}
if FusionStyle(I) isa UniqueFusion
return NoCache()
else
return GlobalLRUCache()
end
end

"""
permute(f₁::FusionTree{I}, f₂::FusionTree{I},
p1::NTuple{N₁, Int}, p2::NTuple{N₂, Int}) where {I, N₁, N₂}
Expand Down
51 changes: 11 additions & 40 deletions src/spaces/homspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,17 @@ struct FusionBlockStructure{I,N,F₁,F₂}
fusiontreeindices::FusionTreeDict{Tuple{F₁,F₂},Int}
end

abstract type CacheStyle end
struct NoCache <: CacheStyle end
struct TaskLocalCache{D<:AbstractDict} <: CacheStyle end
struct GlobalLRUCache <: CacheStyle end

function CacheStyle(I::Type{<:Sector})
return GlobalLRUCache()
function fusionblockstructuretype(W::HomSpace)
N₁ = length(codomain(W))
N₂ = length(domain(W))
N = N₁ + N₂
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
return FusionBlockStructure{I,N,F₁,F₂}
end

fusionblockstructure(W::HomSpace) = fusionblockstructure(W, CacheStyle(sectortype(W)))

function fusionblockstructure(W::HomSpace, ::NoCache)
@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W)
codom = codomain(W)
dom = domain(W)
N₁ = length(codom)
Expand Down Expand Up @@ -323,36 +322,8 @@ function _subblock_strides(subsz, sz, str)
return Strided.StridedViews._computereshapestrides(subsz, sz_simplify...)
end

function fusionblockstructure(W::HomSpace, ::TaskLocalCache{D}) where {D}
cache::D = get!(task_local_storage(), :_local_tensorstructure_cache) do
return D()
end
N₁ = length(codomain(W))
N₂ = length(domain(W))
N = N₁ + N₂
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
return fusionblockstructure(W, NoCache())
end
return structure
end

const GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE = LRU{Any,Any}(; maxsize=10^4)
# 10^4 different tensor spaces should be enough for most purposes
function fusionblockstructure(W::HomSpace, ::GlobalLRUCache)
cache = GLOBAL_FUSIONBLOCKSTRUCTURE_CACHE
N₁ = length(codomain(W))
N₂ = length(domain(W))
N = N₁ + N₂
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
structure::FusionBlockStructure{I,N,F₁,F₂} = get!(cache, W) do
return fusionblockstructure(W, NoCache())
end
return structure
function CacheStyle(::typeof(fusionblockstructure), W::HomSpace)
return GlobalLRUCache()
end

# Diagonal ranges
Expand Down
Loading
Loading