@@ -309,12 +309,47 @@ function LinearAlgebra.tr(t::AbstractTensorMap)
309
309
end
310
310
311
311
# TensorMap multiplication
312
- function LinearAlgebra. mul! (tC:: AbstractTensorMap ,
313
- tA:: AbstractTensorMap ,
314
- tB:: AbstractTensorMap , α= true , β= false )
312
+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
313
+ tB:: AbstractTensorMap ,
314
+ α:: Number , β:: Number ,
315
+ backend:: AbstractBackend = TO. DefaultBackend ())
316
+ if backend isa TO. DefaultBackend
317
+ newbackend = TO. select_backend (mul!, tC, tA, tB)
318
+ return mul! (tC, tA, tB, α, β, newbackend)
319
+ elseif backend isa TO. NoBackend # error for missing backend
320
+ TC = typeof (tC)
321
+ TA = typeof (tA)
322
+ TB = typeof (tB)
323
+ throw (ArgumentError (" No suitable backend found for `mul!` and tensor types $TC , $TA and $TB " ))
324
+ else # error for unknown backend
325
+ TC = typeof (tC)
326
+ TA = typeof (tA)
327
+ TB = typeof (tB)
328
+ throw (ArgumentError (" Unknown backend for `mul!` and tensor types $TC , $TA and $TB " ))
329
+ end
330
+ end
331
+
332
+ function TO. select_backend (:: typeof (mul!), C:: AbstractTensorMap , A:: AbstractTensorMap ,
333
+ B:: AbstractTensorMap )
334
+ return TensorKitBackend ()
335
+ end
336
+
337
+ function LinearAlgebra. mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
338
+ tB:: AbstractTensorMap , α:: Number , β:: Number ,
339
+ backend:: TensorKitBackend )
315
340
compose (space (tA), space (tB)) == space (tC) ||
316
341
throw (SpaceMismatch (lazy " $(space(tC)) ≠ $(space(tA)) * $(space(tB))" ))
317
342
343
+ scheduler = backend. blockscheduler
344
+ if isnothing (scheduler)
345
+ return sequential_mul! (tC, tA, tB, α, β)
346
+ else
347
+ return threaded_mul! (tC, tA, tB, α, β, scheduler)
348
+ end
349
+ end
350
+
351
+ function sequential_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap ,
352
+ tB:: AbstractTensorMap , α:: Number , β:: Number )
318
353
iterC = blocks (tC)
319
354
iterA = blocks (tA)
320
355
iterB = blocks (tB)
@@ -336,13 +371,13 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
336
371
elseif cB < cC
337
372
nextB = iterate (iterB, stateB)
338
373
else
339
- if β != one (β)
374
+ if ! isone (β)
340
375
rmul! (C, β)
341
376
end
342
377
nextC = iterate (iterC, stateC)
343
378
end
344
379
else
345
- if β != one (β)
380
+ if ! isone (β)
346
381
rmul! (C, β)
347
382
end
348
383
nextC = iterate (iterC, stateC)
@@ -351,7 +386,21 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap,
351
386
return tC
352
387
end
353
388
354
- # TODO : consider spawning threads for different blocks, support backends
389
+ function threaded_mul! (tC:: AbstractTensorMap , tA:: AbstractTensorMap , tB:: AbstractTensorMap ,
390
+ α:: Number , β:: Number , scheduler:: Scheduler )
391
+ # obtain cached data before multithreading
392
+ bCs, bAs, bBs = blocks (tC), blocks (tA), blocks (tB)
393
+
394
+ tforeach (blocksectors (tC); scheduler) do c
395
+ if haskey (bAs, c) # then also bBs should have it
396
+ mul! (bCs[c], bAs[c], bBs[c], α, β)
397
+ elseif ! isone (β)
398
+ scale! (bCs[c], β)
399
+ end
400
+ end
401
+
402
+ return tC
403
+ end
355
404
356
405
# TensorMap inverse
357
406
function Base. inv (t:: AbstractTensorMap )
0 commit comments