Skip to content

Commit 88f0cbc

Browse files
committed
Workaround for promote_type for GPU arrays
1 parent ea4cb6c commit 88f0cbc

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

src/LinearOperatorCollection.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,35 @@ include("SamplingOp.jl")
108108
include("NormalOp.jl")
109109
include("DiagOp.jl")
110110

111+
function promote_storage_types(A, B)
112+
A_type = storage_type(A)
113+
B_type = storage_type(B)
114+
S = promote_type(A_type, B_type)
115+
if !isconcretetype(S)
116+
# Find common eltype
117+
elType = promote_type(eltype(A), eltype(B))
118+
if !isconcretetype(elType)
119+
throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
120+
end
121+
122+
# Same base type
123+
A_base = Base.typename(A_type).wrapper
124+
B_base = Base.typename(B_type).wrapper
125+
if A_base != B_base
126+
throw(LinearOperatorException("Storage types cannot be promoted to a common base type"))
127+
end
128+
129+
# LinearOperators only accepts DataTypes, so we cant just do A_base{elType}, since that might be a UnionAll
130+
# Check if either A_type or B_type have the fitting eltype
131+
if eltype(A_type) == elType
132+
S = A_type
133+
elseif eltype(B_type) == elType
134+
S = B_type
135+
else
136+
throw(LinearOperatorException("Storage types cannot be promoted to a common eltype"))
137+
end
138+
end
139+
return S
140+
end
141+
111142
end

src/NormalOp.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ end
5252
LinearOperators.storage_type(op::NormalOpImpl) = typeof(op.Mv5)
5353

5454
function NormalOpImpl(parent, weights)
55-
S = promote_type(storage_type(parent), storage_type(weights))
56-
isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
55+
S = promote_storage_types(parent, weights)
5756
tmp = S(undef, size(parent, 1))
5857
return NormalOpImpl(parent, weights, tmp)
5958
end

src/ProdOp.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ composition/product of two Operators. Differs with * since it can handle normal
3636
function ProdOp(A, B)
3737
nrow = size(A, 1)
3838
ncol = size(B, 2)
39-
S = promote_type(LinearOperators.storage_type(A), LinearOperators.storage_type(B))
40-
isconcretetype(S) || throw(LinearOperatorException("Storage types cannot be promoted to a concrete type"))
39+
S = promote_storage_types(A, B)
4140
tmp_ = S(undef, size(B, 1))
4241

4342
function produ!(res, x::AbstractVector{T}, tmp) where T<:Union{Real,Complex}

0 commit comments

Comments
 (0)