Skip to content

Commit 02e91b2

Browse files
kshyattlkdvos
andauthored
Update ext/TensorKitCUDAExt/cutensormap.jl
Co-authored-by: Lukas Devos <[email protected]>
1 parent 622b707 commit 02e91b2

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,9 @@ end
153153

154154
# Scalar implementation
155155
#-----------------------
156-
function TensorKit.scalar(t::CuTensorMap)
157-
# TODO: should scalar only work if N₁ == N₂ == 0?
158-
return @allowscalar dim(codomain(t)) == dim(domain(t)) == 1 ?
159-
first(blocks(t))[2][1, 1] : throw(DimensionMismatch())
156+
function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S}
157+
inds = findall(!iszero, t.data)
158+
return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)
160159
end
161160

162161
function Base.convert(

0 commit comments

Comments
 (0)