Skip to content

Commit d526053

Browse files
committed
Add scheduler support in mul!
1 parent aa1e6d1 commit d526053

File tree

3 files changed

+60
-8
lines changed

3 files changed

+60
-8
lines changed

src/tensors/blockiterator.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ Base.IteratorEltype(::BlockIterator) = Base.HasEltype()
1313
Base.eltype(::Type{<:BlockIterator{T}}) where {T} = blocktype(T)
1414
Base.length(iter::BlockIterator) = length(iter.structure)
1515
Base.isdone(iter::BlockIterator, state...) = Base.isdone(iter.structure, state...)
16+
17+
Base.haskey(iter::BlockIterator, c) = haskey(iter.structure, c)

src/tensors/linalg.jl

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
309309
end
310310

311311
# TensorMap multiplication
312-
function LinearAlgebra.mul!(tC::AbstractTensorMap,
313-
tA::AbstractTensorMap,
314-
tB::AbstractTensorMap, α=true, β=false)
312+
function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
313+
tB::AbstractTensorMap,
314+
α::Number, β::Number,
315+
backend::AbstractBackend=TO.DefaultBackend())
316+
if backend isa TO.DefaultBackend
317+
newbackend = TO.select_backend(mul!, tC, tA, tB)
318+
return mul!(tC, tA, tB, α, β, newbackend)
319+
elseif backend isa TO.NoBackend # error for missing backend
320+
TC = typeof(tC)
321+
TA = typeof(tA)
322+
TB = typeof(tB)
323+
throw(ArgumentError("No suitable backend found for `mul!` and tensor types $TC, $TA and $TB"))
324+
else # error for unknown backend
325+
TC = typeof(tC)
326+
TA = typeof(tA)
327+
TB = typeof(tB)
328+
throw(ArgumentError("Unknown backend for `mul!` and tensor types $TC, $TA and $TB"))
329+
end
330+
end
331+
332+
function TO.select_backend(::typeof(mul!), C::AbstractTensorMap, A::AbstractTensorMap,
333+
B::AbstractTensorMap)
334+
return TensorKitBackend()
335+
end
336+
337+
function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
338+
tB::AbstractTensorMap, α::Number, β::Number,
339+
backend::TensorKitBackend)
315340
compose(space(tA), space(tB)) == space(tC) ||
316341
throw(SpaceMismatch(lazy"$(space(tC)) ≠ $(space(tA)) * $(space(tB))"))
317342

343+
scheduler = backend.blockscheduler
344+
if isnothing(scheduler)
345+
return sequential_mul!(tC, tA, tB, α, β)
346+
else
347+
return threaded_mul!(tC, tA, tB, α, β, scheduler)
348+
end
349+
end
350+
351+
function sequential_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap,
352+
tB::AbstractTensorMap, α::Number, β::Number)
318353
iterC = blocks(tC)
319354
iterA = blocks(tA)
320355
iterB = blocks(tB)
@@ -336,13 +371,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
336371
elseif cB < cC
337372
nextB = iterate(iterB, stateB)
338373
else
339-
if β != one(β)
374+
if !isone(β)
340375
rmul!(C, β)
341376
end
342377
nextC = iterate(iterC, stateC)
343378
end
344379
else
345-
if β != one(β)
380+
if !isone(β)
346381
rmul!(C, β)
347382
end
348383
nextC = iterate(iterC, stateC)
@@ -351,7 +386,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
351386
return tC
352387
end
353388

354-
# TODO: consider spawning threads for different blocks, support backends
389+
function threaded_mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, tB::AbstractTensorMap,
390+
α::Number, β::Number, scheduler::Scheduler)
391+
# obtain cached data before multithreading
392+
bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB)
393+
394+
tforeach(blocksectors(tC); scheduler) do c
395+
if haskey(bAs, c) # then also bBs should have it
396+
mul!(bCs[c], bAs[c], bBs[c], α, β)
397+
elseif !isone(β)
398+
scale!(bCs[c], β)
399+
end
400+
end
401+
402+
return tC
403+
end
355404

356405
# TensorMap inverse
357406
function Base.inv(t::AbstractTensorMap)

src/tensors/tensoroperations.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,10 @@ TO.tensorcost(t::AbstractTensorMap, i::Int) = dim(space(t, i))
149149
# TODO: figure out a name
150150
# TODO: what should be the default scheduler?
151151
# TODO: should we allow a separate scheduler for "blocks" and "subblocks"
152-
@kwdef struct TensorKitBackend{B<:AbstractBackend,S<:Scheduler} <: AbstractBackend
152+
@kwdef struct TensorKitBackend{B<:AbstractBackend,BS,SBS} <: AbstractBackend
153153
arraybackend::B = TO.DefaultBackend()
154-
scheduler::S = SerialScheduler()
154+
blockscheduler::BS = SerialScheduler()
155+
subblockscheduler::SBS = SerialScheduler()
155156
end
156157

157158
function TO.select_backend(::typeof(TO.tensoradd!), C::AbstractTensorMap,

0 commit comments

Comments
 (0)