@@ -283,12 +283,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
283283end
284284
285285# TensorMap multiplication
286- function LinearAlgebra. mul!(tC:: AbstractTensorMap ,
287- tA:: AbstractTensorMap ,
288- tB:: AbstractTensorMap , α= true , β= false )
286+ function LinearAlgebra. mul!(tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
287+ tB:: AbstractTensorMap ,
288+ α:: Number , β:: Number ,
289+ backend:: AbstractBackend = TO. DefaultBackend())
290+ if backend isa TO. DefaultBackend
291+ newbackend = TO. select_backend(mul!, tC, tA, tB)
292+ return mul!(tC, tA, tB, α, β, newbackend)
293+ elseif backend isa TO. NoBackend # error for missing backend
294+ TC = typeof(tC)
295+ TA = typeof(tA)
296+ TB = typeof(tB)
297+ throw(ArgumentError(" No suitable backend found for `mul!` and tensor types $TC , $TA and $TB " ))
298+ else # error for unknown backend
299+ TC = typeof(tC)
300+ TA = typeof(tA)
301+ TB = typeof(tB)
302+ throw(ArgumentError(" Unknown backend for `mul!` and tensor types $TC , $TA and $TB " ))
303+ end
304+ end
305+
306+ function TO. select_backend(:: typeof (mul!), C:: AbstractTensorMap , A:: AbstractTensorMap ,
307+ B:: AbstractTensorMap )
308+ return TensorKitBackend()
309+ end
310+
311+ function LinearAlgebra. mul!(tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
312+ tB:: AbstractTensorMap , α:: Number , β:: Number ,
313+ backend:: TensorKitBackend )
289314 compose(space(tA), space(tB)) == space(tC) ||
290315 throw(SpaceMismatch(lazy" $(space(tC)) ≠ $(space(tA)) * $(space(tB)) " ))
291316
317+ scheduler = backend. blockscheduler
318+ if isnothing(scheduler)
319+ return sequential_mul!(tC, tA, tB, α, β)
320+ else
321+ return threaded_mul!(tC, tA, tB, α, β, scheduler)
322+ end
323+ end
324+
325+ function sequential_mul!(tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
326+ tB:: AbstractTensorMap , α:: Number , β:: Number )
292327 iterC = blocks(tC)
293328 iterA = blocks(tA)
294329 iterB = blocks(tB)
@@ -310,13 +345,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
310345 elseif cB < cC
311346 nextB = iterate(iterB, stateB)
312347 else
313- if β != one (β)
348+ if ! isone (β)
314349 rmul!(C, β)
315350 end
316351 nextC = iterate(iterC, stateC)
317352 end
318353 else
319- if β != one (β)
354+ if ! isone (β)
320355 rmul!(C, β)
321356 end
322357 nextC = iterate(iterC, stateC)
@@ -325,7 +360,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
325360 return tC
326361end
327362
328- # TODO : consider spawning threads for different blocks, support backends
363+ function threaded_mul!(tC:: AbstractTensorMap , tA:: AbstractTensorMap , tB:: AbstractTensorMap ,
364+ α:: Number , β:: Number , scheduler:: Scheduler )
365+ # obtain cached data before multithreading
366+ bCs, bAs, bBs = blocks(tC), blocks(tA), blocks(tB)
367+
368+ tforeach(blocksectors(tC); scheduler) do c
369+ if haskey(bAs, c) # then also bBs should have it
370+ mul!(bCs[c], bAs[c], bBs[c], α, β)
371+ elseif ! isone(β)
372+ scale!(bCs[c], β)
373+ end
374+ end
375+
376+ return tC
377+ end
329378
330379# TensorMap inverse
331380function Base. inv(t:: AbstractTensorMap )
0 commit comments