@@ -183,24 +183,24 @@ function (H::PrecomputedDerivative)(x::AbstractTensorMap)
183183 @inbounds sz, str, offset = structure_R. fusiontreestructure[i]
184184 r = TensorKit. Strided. StridedView (R_fused. data, sz, str, offset)
185185
186- # if sz[2] < sz[3]
187- # for k in axes(r, 2 )
188- # C = xr[:, k, :]
189- # B = r[:, k, :]
190- # LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
191- # end
192- # else
193- # for k in axes(r, 3)
194- # C = xr [:, :, k ]
195- # B = r[:, :, k]
196- # LinearAlgebra.BLAS.gemm!('N', 'N', one(TC), x, B, zero(TC), C)
197- # end
198- # end
199-
200- TensorOperations . tensorcontract! (
201- xr, x, (( 1 , ), ( 2 ,)), false ,
202- r, (( 1 ,), ( 2 , 3 )), false , (( 1 , 2 ), ( 3 ,)), One (), Zero (), H . backend, H . allocator
203- )
186+ if TensorOperations . isblascontractable (r, (( 1 ,), ( 2 , 3 ))) &&
187+ TensorOperations . isblasdestination (xr, (( 1 ,), ( 2 , 3 )) )
188+ C = TensorKit . Strided . sreshape (xr, size (xr, 1 ), size (xr, 2 ) * size (xr, 3 ))
189+ B = TensorKit . Strided . sreshape (r, size (r, 1 ), size (r, 2 ) * size (r, 3 ))
190+ LinearAlgebra. BLAS. gemm! (' N' , ' N' , one (TC), x, B, zero (TC), C)
191+ elseif sz[ 2 ] < sz[ 3 ]
192+ for k in axes (r, 2 )
193+ C = xr[:, k, :]
194+ B = r [:, k, : ]
195+ LinearAlgebra . BLAS . gemm! ( ' N ' , ' N ' , one (TC), x, B, zero (TC), C)
196+ end
197+ else
198+ for k in axes (r, 3 )
199+ C = xr[:, :, k]
200+ B = r[:, :, k]
201+ LinearAlgebra . BLAS . gemm! ( ' N ' , ' N ' , one (TC ), x, B, zero (TC), C)
202+ end
203+ end
204204 else
205205 zerovector! (xr)
206206 end
0 commit comments