Skip to content

Commit 5d5af38

Browse files
committed
Avoid copy in matrix contract
This reverts commit c9db523.
1 parent c9db523 commit 5d5af38

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

src/algorithms/derivatives/hamiltonian_derivatives.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,8 @@ function prepare_operator!!(
113113
H.E
114114
end
115115

116-
# O3′ = Core.Compiler.return_type(prepare_operator!!, Tuple{O3, typeof(backend), typeof(allocator)})
117-
# A = ismissing(H.A) ? H.A : prepare_operator!!(H.A, backend, allocator)
118-
O3′ = O3
119-
A = H.A
116+
O3′ = Core.Compiler.return_type(prepare_operator!!, Tuple{O3, typeof(backend), typeof(allocator)})
117+
A = ismissing(H.A) ? H.A : prepare_operator!!(H.A, backend, allocator)
120118

121119
return JordanMPO_AC_Hamiltonian{O1, O2, O3′}(D, I, E, C, B, A)
122120
end
@@ -330,10 +328,8 @@ function prepare_operator!!(
330328
H.EE
331329
end
332330

333-
# O4′ = Core.Compiler.return_type(prepare_operator!!, Tuple{O4, typeof(backend), typeof(allocator)})
334-
# AA = ismissing(H.AA) ? H.AA : prepare_operator!!(H.AA, backend, allocator)
335-
O4′ = O4
336-
AA = H.AA
331+
O4′ = Core.Compiler.return_type(prepare_operator!!, Tuple{O4, typeof(backend), typeof(allocator)})
332+
AA = prepare_operator!!(H.AA, backend, allocator)
337333

338334
return JordanMPO_AC2_Hamiltonian{O1, O2, O3, O4′}(II, IC, ID, CB, CA, AB, AA, BE, DE, EE)
339335
end

src/algorithms/derivatives/mpo_derivatives.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,24 @@ function (H::PrecomputedDerivative)(x::AbstractTensorMap)
183183
@inbounds sz, str, offset = structure_R.fusiontreestructure[i]
184184
r = TensorKit.Strided.StridedView(R_fused.data, sz, str, offset)
185185

186-
# if sz[2] < sz[3]
187-
# for k in axes(r, 2)
188-
# C = xr[:, k, :]
189-
# B = r[:, k, :]
190-
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
191-
# end
192-
# else
193-
# for k in axes(r, 3)
194-
# C = xr[:, :, k]
195-
# B = r[:, :, k]
196-
# LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
197-
# end
198-
# end
199-
200-
TensorOperations.tensorcontract!(
201-
xr, x, ((1,), (2,)), false,
202-
r, ((1,), (2, 3)), false, ((1, 2), (3,)), One(), Zero(), H.backend, H.allocator
203-
)
186+
if TensorOperations.isblascontractable(r, ((1,), (2, 3))) &&
187+
TensorOperations.isblasdestination(xr, ((1,), (2, 3)))
188+
C = TensorKit.Strided.sreshape(xr, size(xr, 1), size(xr, 2) * size(xr, 3))
189+
B = TensorKit.Strided.sreshape(r, size(r, 1), size(r, 2) * size(r, 3))
190+
LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
191+
elseif sz[2] < sz[3]
192+
for k in axes(r, 2)
193+
C = xr[:, k, :]
194+
B = r[:, k, :]
195+
LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
196+
end
197+
else
198+
for k in axes(r, 3)
199+
C = xr[:, :, k]
200+
B = r[:, :, k]
201+
LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
202+
end
203+
end
204204
else
205205
zerovector!(xr)
206206
end

0 commit comments

Comments
 (0)