@@ -309,12 +309,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
309309end
310310
311311# TensorMap multiplication
312- function LinearAlgebra. mul! (tC:: AbstractTensorMap ,
313- tA:: AbstractTensorMap ,
314- tB:: AbstractTensorMap , α= true , β= false )
312+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
313+ tB:: AbstractTensorMap ,
314+ α:: Number , β:: Number ,
315+ backend:: AbstractBackend = TO. DefaultBackend ())
316+ if backend isa TO. DefaultBackend
317+ newbackend = TO. select_backend (mul!, tC, tA, tB)
318+ return mul! (tC, tA, tB, α, β, newbackend)
319+ elseif backend isa TO. NoBackend # error for missing backend
320+ TC = typeof (tC)
321+ TA = typeof (tA)
322+ TB = typeof (tB)
323+ throw (ArgumentError (" No suitable backend found for `mul!` and tensor types $TC , $TA and $TB " ))
324+ else # error for unknown backend
325+ TC = typeof (tC)
326+ TA = typeof (tA)
327+ TB = typeof (tB)
328+ throw (ArgumentError (" Unknown backend for `mul!` and tensor types $TC , $TA and $TB " ))
329+ end
330+ end
331+
332+ function TO. select_backend (:: typeof (mul!), C:: AbstractTensorMap , A:: AbstractTensorMap ,
333+ B:: AbstractTensorMap )
334+ return TensorKitBackend ()
335+ end
336+
337+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
338+ tB:: AbstractTensorMap , α:: Number , β:: Number ,
339+ backend:: TensorKitBackend )
315340 compose (space (tA), space (tB)) == space (tC) ||
316341 throw (SpaceMismatch (lazy " $(space(tC)) ≠ $(space(tA)) * $(space(tB))" ))
317342
343+ scheduler = backend. blockscheduler
344+ if isnothing (scheduler)
345+ return sequential_mul! (tC, tA, tB, α, β)
346+ else
347+ return threaded_mul! (tC, tA, tB, α, β, scheduler)
348+ end
349+ end
350+
351+ function sequential_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
352+ tB:: AbstractTensorMap , α:: Number , β:: Number )
318353 iterC = blocks (tC)
319354 iterA = blocks (tA)
320355 iterB = blocks (tB)
@@ -336,13 +371,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
336371 elseif cB < cC
337372 nextB = iterate (iterB, stateB)
338373 else
339- if β != one (β)
374+ if ! isone (β)
340375 rmul! (C, β)
341376 end
342377 nextC = iterate (iterC, stateC)
343378 end
344379 else
345- if β != one (β)
380+ if ! isone (β)
346381 rmul! (C, β)
347382 end
348383 nextC = iterate (iterC, stateC)
@@ -351,7 +386,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
351386 return tC
352387end
353388
354- # TODO : consider spawning threads for different blocks, support backends
389+ function threaded_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap , tB:: AbstractTensorMap ,
390+ α:: Number , β:: Number , scheduler:: Scheduler )
391+ # obtain cached data before multithreading
392+ bCs, bAs, bBs = blocks (tC), blocks (tA), blocks (tB)
393+
394+ tforeach (blocksectors (tC); scheduler) do c
395+ if haskey (bAs, c) # then also bBs should have it
396+ mul! (bCs[c], bAs[c], bBs[c], α, β)
397+ elseif ! isone (β)
398+ scale! (bCs[c], β)
399+ end
400+ end
401+
402+ return tC
403+ end
355404
356405# TensorMap inverse
357406function Base. inv (t:: AbstractTensorMap )
0 commit comments