@@ -8,8 +8,8 @@ function _transpose_as(t1::AbstractTensorMap, t2::AbstractTensorMap; copy::Bool
88 return repartition (t1, numout (t2), numin (t2); copy)
99end
1010
11- _mul_front (C, A) = matrix_contract (A, C, 1 ; transpose = true ) # _transpose_front(C * _transpose_tail(A))
12- _mul_tail (A, C) = matrix_contract (A, C, numind (A) ) # A * C
11+ _mul_front (C, A) = mul_front ( C, A ) # _transpose_front(C * _transpose_tail(A))
12+ _mul_tail (A, C) = mul_tail (A, C) # A * C
1313
1414function _similar_tail (A:: AbstractTensorMap )
1515 cod = _firstspace (A)
@@ -231,6 +231,102 @@ function matrix_contract!(
231231 return C
232232end
233233
234+ function mul_front (
235+ A:: AbstractTensorMap , B:: AbstractTensorMap ,
236+ α:: Number ,
237+ backend:: AbstractBackend = DefaultBackend (), allocator = DefaultAllocator ()
238+ )
239+ cod = prod (i -> i == 1 ? space (A, 1 ) : space (B, i), 1 : numout (B))
240+ dom = domain (B)
241+ T = TensorOperations. promote_contract (scalartype (A), scalartype (B), scalartype (α))
242+ C = similar (A, T, cod ← dom)
243+ return mul_front! (C, A, B, α, Zero (), backend, allocator)
244+ end
245+
246+ function mul_front! (
247+ C:: AbstractTensorMap , A:: AbstractTensorMap , B:: AbstractTensorMap ,
248+ α:: Number , β:: Number ,
249+ backend:: AbstractBackend = DefaultBackend (), allocator = DefaultAllocator ()
250+ )
251+ (numin (C) == numin (B) && numout (C) == numout (B) && numin (A) == numout (A) == 1 ) ||
252+ throw (SpaceMismatch ())
253+
254+ numout (B) == 1 && return mul! (C, A, B, α, β)
255+
256+ cp = checkpoint (allocator)
257+
258+ Ablocks = blocks (A)
259+ Bstructure = TensorKit. fusionblockstructure (space (B))
260+ for ((f₁, f₂), c) in subblocks (C)
261+ # fetch A block
262+ u = first (f₁. uncoupled)
263+ a = Ablocks[u]
264+ isempty (a) && (scale! (c, β); continue )
265+
266+ # fetch B block
267+ haskey (Bstructure. fusiontreeindices, (f₁, f₂)) || (scale! (c, β); continue )
268+ b = B[f₁, f₂]
269+
270+ tensorcontract! (
271+ c,
272+ a, ((1 ,), (2 ,)), false ,
273+ b, ((1 ,), ntuple (i -> i + 1 , numind (B) - 1 )), false ,
274+ (ntuple (identity, numout (C)), ntuple (i -> i + numout (C), numin (C))),
275+ α, β, backend, allocator
276+ )
277+ end
278+
279+ return reset! (allocator, cp)
280+ end
281+
282+ function mul_tail (
283+ A:: AbstractTensorMap , B:: AbstractTensorMap ,
284+ α:: Number ,
285+ backend:: AbstractBackend = DefaultBackend (), allocator = DefaultAllocator ()
286+ )
287+ cod = codomain (A)
288+ dom = prod (i -> i == 1 ? domain (B)[1 ] : domain (A)[i], 1 : numin (A))
289+ T = TensorOperations. promote_contract (scalartype (A), scalartype (B), scalartype (α))
290+ C = similar (A, T, cod ← dom)
291+ return mul_tail! (C, A, B, α, Zero (), backend, allocator)
292+ end
293+
294+ function mul_tail! (
295+ C:: AbstractTensorMap , A:: AbstractTensorMap , B:: AbstractTensorMap ,
296+ α:: Number , β:: Number ,
297+ backend:: AbstractBackend = DefaultBackend (), allocator = DefaultAllocator ()
298+ )
299+ (numin (C) == numin (A) && numout (C) == numout (A) && numin (B) == numout (B) == 1 ) ||
300+ throw (SpaceMismatch ())
301+
302+ numin (A) == 1 && return mul! (C, A, B, α, β)
303+
304+ cp = checkpoint (allocator)
305+
306+ Astructure = TensorKit. fusionblockstructure (space (A))
307+ Bblocks = blocks (B)
308+ for ((f₁, f₂), c) in subblocks (C)
309+ # fetch B block
310+ u = first (f₂. uncoupled)
311+ b = Bblocks[u]
312+ isempty (b) && (scale! (c, β); continue )
313+
314+ # fetch A block
315+ haskey (Astructure. fusiontreeindices, (f₁, f₂)) || (scale! (c, β); continue )
316+ a = A[f₁, f₂]
317+
318+ tensorcontract! (
319+ c,
320+ a, (ntuple (identity, numind (A) - 1 ), (1 ,)), false ,
321+ b, ((1 ,), (2 ,)), false ,
322+ (ntuple (identity, numout (C)), ntuple (i -> i + numout (C), numin (C))),
323+ α, β, backend, allocator
324+ )
325+ end
326+
327+ return reset! (allocator, cp)
328+ end
329+
234330@inline fuse_legs (x:: TensorMap , N₁:: Int , N₂:: Int ) = fuse_legs (x, Val (N₁), Val (N₂))
235331function fuse_legs (x:: TensorMap , :: Val{N₁} , :: Val{N₂} ) where {N₁, N₂}
236332 ((0 <= N₁ <= numout (x)) && (0 <= N₂ <= numin (x))) || throw (ArgumentError (" invalid fusing scheme" ))
0 commit comments