Skip to content

Commit 8ea9103

Browse files
committed
clean up code
1 parent eca22fd commit 8ea9103

File tree

2 files changed

+99
-43
lines changed

2 files changed

+99
-43
lines changed

src/algorithms/derivatives/mpo_derivatives.jl

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -173,47 +173,7 @@ function (H::PrecomputedDerivative)(x::AbstractTensorMap)
173173
TC = TensorOperations.promote_contract(scalartype(x_fused), scalartype(R_fused))
174174
xR = TensorOperations.tensoralloc_contract(TC, x_fused, ((1,), (2,)), false, R_fused, ((1,), (2, 3)), false, ((1, 2), (3,)), Val(true), H.allocator)
175175

176-
matrix_contract!(xR, R_fused, x_fused, 1, One(), Zero(), H.backend, H.allocator; transpose = true)
177-
178-
# structure_xR = TensorKit.fusionblockstructure(space(xR))
179-
# structure_R = TensorKit.fusionblockstructure(space(R_fused))
180-
181-
# xblocks = blocks(x_fused)
182-
# for ((f₁, f₂), i1) in structure_xR.fusiontreeindices
183-
# sz, str, offset = structure_xR.fusiontreestructure[i1]
184-
# xr = TensorKit.Strided.StridedView(xR.data, sz, str, offset)
185-
186-
# u = first(f₁.uncoupled)
187-
# x = TensorKit.Strided.StridedView(xblocks[u])
188-
# isempty(x) && (zerovector!(xr); continue)
189-
190-
# if haskey(structure_R.fusiontreeindices, (f₁, f₂))
191-
# @inbounds i = structure_R.fusiontreeindices[(f₁, f₂)]
192-
# @inbounds sz, str, offset = structure_R.fusiontreestructure[i]
193-
# r = TensorKit.Strided.StridedView(R_fused.data, sz, str, offset)
194-
195-
# if TensorOperations.isblascontractable(r, ((1,), (2, 3))) &&
196-
# TensorOperations.isblasdestination(xr, ((1,), (2, 3)))
197-
# C = TensorKit.Strided.sreshape(xr, size(xr, 1), size(xr, 2) * size(xr, 3))
198-
# B = TensorKit.Strided.sreshape(r, size(r, 1), size(r, 2) * size(r, 3))
199-
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
200-
# elseif sz[2] < sz[3]
201-
# for k in axes(r, 2)
202-
# C = xr[:, k, :]
203-
# B = r[:, k, :]
204-
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
205-
# end
206-
# else
207-
# for k in axes(r, 3)
208-
# C = xr[:, :, k]
209-
# B = r[:, :, k]
210-
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
211-
# end
212-
# end
213-
# else
214-
# zerovector!(xr)
215-
# end
216-
# end
176+
mul_front!(xR, x_fused, R_fused, One(), Zero(), H.backend, H.allocator)
217177

218178
LxR = H.leftenv * xR
219179
TensorOperations.tensorfree!(xR, H.allocator)

src/utility/utility.jl

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ function _transpose_as(t1::AbstractTensorMap, t2::AbstractTensorMap; copy::Bool
88
return repartition(t1, numout(t2), numin(t2); copy)
99
end
1010

11-
_mul_front(C, A) = matrix_contract(A, C, 1; transpose = true) # _transpose_front(C * _transpose_tail(A))
12-
_mul_tail(A, C) = matrix_contract(A, C, numind(A)) # A * C
11+
_mul_front(C, A) = mul_front(C, A) # _transpose_front(C * _transpose_tail(A))
12+
_mul_tail(A, C) = mul_tail(A, C) # A * C
1313

1414
function _similar_tail(A::AbstractTensorMap)
1515
cod = _firstspace(A)
@@ -231,6 +231,102 @@ function matrix_contract!(
231231
return C
232232
end
233233

234+
function mul_front(
235+
A::AbstractTensorMap, B::AbstractTensorMap,
236+
α::Number,
237+
backend::AbstractBackend = DefaultBackend(), allocator = DefaultAllocator()
238+
)
239+
cod = prod(i -> i == 1 ? space(A, 1) : space(B, i), 1:numout(B))
240+
dom = domain(B)
241+
T = TensorOperations.promote_contract(scalartype(A), scalartype(B), scalartype(α))
242+
C = similar(A, T, cod dom)
243+
return mul_front!(C, A, B, α, Zero(), backend, allocator)
244+
end
245+
246+
function mul_front!(
247+
C::AbstractTensorMap, A::AbstractTensorMap, B::AbstractTensorMap,
248+
α::Number, β::Number,
249+
backend::AbstractBackend = DefaultBackend(), allocator = DefaultAllocator()
250+
)
251+
(numin(C) == numin(B) && numout(C) == numout(B) && numin(A) == numout(A) == 1) ||
252+
throw(SpaceMismatch())
253+
254+
numout(B) == 1 && return mul!(C, A, B, α, β)
255+
256+
cp = checkpoint(allocator)
257+
258+
Ablocks = blocks(A)
259+
Bstructure = TensorKit.fusionblockstructure(space(B))
260+
for ((f₁, f₂), c) in subblocks(C)
261+
# fetch A block
262+
u = first(f₁.uncoupled)
263+
a = Ablocks[u]
264+
isempty(a) && (scale!(c, β); continue)
265+
266+
# fetch B block
267+
haskey(Bstructure.fusiontreeindices, (f₁, f₂)) || (scale!(c, β); continue)
268+
b = B[f₁, f₂]
269+
270+
tensorcontract!(
271+
c,
272+
a, ((1,), (2,)), false,
273+
b, ((1,), ntuple(i -> i + 1, numind(B) - 1)), false,
274+
(ntuple(identity, numout(C)), ntuple(i -> i + numout(C), numin(C))),
275+
α, β, backend, allocator
276+
)
277+
end
278+
279+
return reset!(allocator, cp)
280+
end
281+
282+
function mul_tail(
283+
A::AbstractTensorMap, B::AbstractTensorMap,
284+
α::Number,
285+
backend::AbstractBackend = DefaultBackend(), allocator = DefaultAllocator()
286+
)
287+
cod = codomain(A)
288+
dom = prod(i -> i == 1 ? domain(B)[1] : domain(A)[i], 1:numin(A))
289+
T = TensorOperations.promote_contract(scalartype(A), scalartype(B), scalartype(α))
290+
C = similar(A, T, cod dom)
291+
return mul_tail!(C, A, B, α, Zero(), backend, allocator)
292+
end
293+
294+
function mul_tail!(
295+
C::AbstractTensorMap, A::AbstractTensorMap, B::AbstractTensorMap,
296+
α::Number, β::Number,
297+
backend::AbstractBackend = DefaultBackend(), allocator = DefaultAllocator()
298+
)
299+
(numin(C) == numin(A) && numout(C) == numout(A) && numin(B) == numout(B) == 1) ||
300+
throw(SpaceMismatch())
301+
302+
numin(A) == 1 && return mul!(C, A, B, α, β)
303+
304+
cp = checkpoint(allocator)
305+
306+
Astructure = TensorKit.fusionblockstructure(space(A))
307+
Bblocks = blocks(B)
308+
for ((f₁, f₂), c) in subblocks(C)
309+
# fetch B block
310+
u = first(f₂.uncoupled)
311+
b = Bblocks[u]
312+
isempty(b) && (scale!(c, β); continue)
313+
314+
# fetch A block
315+
haskey(Astructure.fusiontreeindices, (f₁, f₂)) || (scale!(c, β); continue)
316+
a = A[f₁, f₂]
317+
318+
tensorcontract!(
319+
c,
320+
a, (ntuple(identity, numind(A) - 1), (1,)), false,
321+
b, ((1,), (2,)), false,
322+
(ntuple(identity, numout(C)), ntuple(i -> i + numout(C), numin(C))),
323+
α, β, backend, allocator
324+
)
325+
end
326+
327+
return reset!(allocator, cp)
328+
end
329+
234330
@inline fuse_legs(x::TensorMap, N₁::Int, N₂::Int) = fuse_legs(x, Val(N₁), Val(N₂))
235331
function fuse_legs(x::TensorMap, ::Val{N₁}, ::Val{N₂}) where {N₁, N₂}
236332
((0 <= N₁ <= numout(x)) && (0 <= N₂ <= numin(x))) || throw(ArgumentError("invalid fusing scheme"))

0 commit comments

Comments
 (0)