Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ version = "0.14.4"
[deps]
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
TensorKitSectors = "13a9c161-d5da-41f0-bcbd-e1a08ae0647f"
Expand All @@ -31,8 +33,10 @@ Combinatorics = "1"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
OhMyThreads = "0.7.0"
PackageExtensionCompat = "1"
Random = "1"
ScopedValues = "1.3.0"
SparseArrays = "1"
Strided = "2"
TensorKitSectors = "0.1"
Expand Down
1 change: 0 additions & 1 deletion docs/src/lib/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ TensorKit.add_transpose!
```@docs
compose(::AbstractTensorMap, ::AbstractTensorMap)
trace_permute!
contract!
⊗(::AbstractTensorMap, ::AbstractTensorMap)
```

Expand Down
5 changes: 4 additions & 1 deletion src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export OrthogonalFactorizationAlgorithm, QR, QRpos, QL, QLpos, LQ, LQpos, RQ, RQ

# tensor operations
export @tensor, @tensoropt, @ncon, ncon, @planar, @plansor
export scalar, add!, contract!
export scalar, add!

# truncation schemes
export notrunc, truncerr, truncdim, truncspace, truncbelow
Expand All @@ -101,6 +101,8 @@ using TensorOperations: IndexTuple, Index2Tuple, linearize, AbstractBackend
const TO = TensorOperations

using LRUCache
using OhMyThreads
using ScopedValues

using TensorKitSectors
import TensorKitSectors: dim, BraidingStyle, FusionStyle, ⊠, ⊗
Expand Down Expand Up @@ -184,6 +186,7 @@ include("spaces/vectorspaces.jl")
#-------------------------------------
# general definitions
include("tensors/abstracttensor.jl")
include("tensors/backends.jl")
include("tensors/blockiterator.jl")
include("tensors/tensor.jl")
include("tensors/adjoint.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/planar/planaroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ function planarcontract!(C::AbstractTensorMap,
α::Number, β::Number,
backend, allocator)
if BraidingStyle(sectortype(C)) == Bosonic()
return contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator)
return TO.tensorcontract!(C, A, pA, false, B, pB, false, pAB,
α, β, backend, allocator)
end

codA, domA = codomainind(A), domainind(A)
Expand Down
107 changes: 107 additions & 0 deletions src/tensors/backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Scheduler implementation
# ------------------------
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

The default scheduler used when looping over different blocks in the matrix representation of a
tensor.

For controlling this value, see also [`set_blockscheduler`](@ref) and [`with_blockscheduler`](@ref).
"""
const blockscheduler = ScopedValue{Scheduler}(SerialScheduler())

"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())

The default scheduler used when looping over different subblocks in a tensor.

For controlling this value, see also [`set_subblockscheduler`](@ref) and [`with_subblockscheduler`](@ref).
"""
const subblockscheduler = ScopedValue{Scheduler}(SerialScheduler())

function select_scheduler(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
return if scheduler == OhMyThreads.Implementation.NotGiven() && isempty(kwargs)
Threads.nthreads() == 1 ? SerialScheduler() : DynamicScheduler()

Check warning on line 24 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L22-L24

Added lines #L22 - L24 were not covered by tests
else
OhMyThreads.Implementation._scheduler_from_userinput(scheduler; kwargs...)

Check warning on line 26 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L26

Added line #L26 was not covered by tests
end
end

"""
set_blockscheduler!([scheduler]; kwargs...) -> previuos

Set the default scheduler used in looping over the different blocks in the matrix representation
of a tensor.
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
set of keywords arguments. For a detailed description, consult the
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).

See also [`with_blockscheduler`](@ref).
"""
function set_blockscheduler!(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
previous = blockscheduler[]
blockscheduler[] = select_scheduler(scheduler; kwargs...)
return previous

Check warning on line 44 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L41-L44

Added lines #L41 - L44 were not covered by tests
end

"""
with_blockscheduler(f, [scheduler]; kwargs...)

Run `f` in a scope where the `blockscheduler` is determined by `scheduler` and `kwargs...`.

See also [`set_blockscheduler!`](@ref).
"""
function with_blockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
@with blockscheduler => select_scheduler(scheduler; kwargs...) f()

Check warning on line 55 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L54-L55

Added lines #L54 - L55 were not covered by tests
end

"""
set_subblockscheduler!([scheduler]; kwargs...) -> previous

Set the default scheduler used in looping over the different subblocks in a tensor.
The arguments to this function are either an `OhMyThreads.Scheduler` or a `Symbol` with optional
set of keywords arguments. For a detailed description, consult the
[`OhMyThreads` documentation](https://juliafolds2.github.io/OhMyThreads.jl/stable/refs/api/#Schedulers).

See also [`with_subblockscheduler`](@ref).
"""
function set_subblockscheduler!(scheduler=OhMyThreads.Implementation.NotGiven(); kwargs...)
previous = subblockscheduler[]
subblockscheduler[] = select_scheduler(scheduler; kwargs...)
return previous

Check warning on line 71 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L68-L71

Added lines #L68 - L71 were not covered by tests
end

"""
with_subblockscheduler(f, [scheduler]; kwargs...)

Run `f` in a scope where the [`subblockscheduler`](@ref) is determined by `scheduler` and `kwargs...`.

See also [`set_subblockscheduler!`](@ref).
"""
function with_subblockscheduler(f, scheduler=OhMyThreads.Implementation.NotGiven();

Check warning on line 81 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L81

Added line #L81 was not covered by tests
kwargs...)
@with subblockscheduler => select_scheduler(scheduler; kwargs...) f()

Check warning on line 83 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L83

Added line #L83 was not covered by tests
end

# Backend implementation
# ----------------------
# TODO: figure out a name
# TODO: what should be the default scheduler?
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
arraybackend::B = TO.DefaultBackend()
blockscheduler::BS = blockscheduler[]
subblockscheduler::SBS = subblockscheduler[]
end

function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,

Check warning on line 96 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L96

Added line #L96 was not covered by tests
A::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 98 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L98

Added line #L98 was not covered by tests
end
function TO.select_backend(::typeof(TO.tensortrace!), C::AbstractTensorMap,

Check warning on line 100 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L100

Added line #L100 was not covered by tests
A::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 102 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L102

Added line #L102 was not covered by tests
end
function TO.select_backend(::typeof(TO.tensorcontract!), C::AbstractTensorMap,

Check warning on line 104 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L104

Added line #L104 was not covered by tests
A::AbstractTensorMap, B::AbstractTensorMap)
return TensorKitBackend()

Check warning on line 106 in src/tensors/backends.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/backends.jl#L106

Added line #L106 was not covered by tests
end
2 changes: 2 additions & 0 deletions src/tensors/blockiterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
Base.length(iter::BlockIterator) = length(iter.structure)
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)

Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)
8 changes: 4 additions & 4 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ function add_transform!(tdst::AbstractTensorMap,
fusiontreetransform,
α::Number,
β::Number,
backend::AbstractBackend...)
backend::TensorKitBackend, allocator)
return add_transform!(tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β,
backend...)
backend, allocator)
end

# VectorInterface
Expand All @@ -173,8 +173,8 @@ end

function TO.tensoradd!(C::AbstractTensorMap,
A::BraidingTensor, pA::Index2Tuple, conjA::Symbol,
α::Number, β::Number, backend=TO.DefaultBackend(),
allocator=TO.DefaultAllocator())
α::Number, β::Number, backend::AbstractBackend,
allocator)
return TO.tensoradd!(C, TensorMap(A), pA, conjA, α, β, backend, allocator)
end

Expand Down
Loading
Loading