Skip to content

Commit bace949

Browse files
authored
Check output size for FunctionMaps (#161)
1 parent 744800a commit bace949

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LinearMaps"
22
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
3-
version = "3.4.0"
3+
version = "3.4.1"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/functionmap.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@ const TransposeFunctionMap = TransposeMap{<:Any, <:FunctionMap}
3737
const AdjointFunctionMap = AdjointMap{<:Any, <:FunctionMap}
3838

3939
function Base.:(*)(A::FunctionMap, x::AbstractVector)
40-
length(x) == A.N || throw(DimensionMismatch())
40+
length(x) == size(A, 2) || throw(DimensionMismatch())
4141
if ismutating(A)
42-
y = similar(x, promote_type(eltype(A), eltype(x)), A.M)
42+
y = similar(x, promote_type(eltype(A), eltype(x)), size(A, 1))
4343
A.f(y, x)
4444
else
4545
y = A.f(x)
46+
length(y) == size(A, 1) || throw(DimensionMismatch())
4647
end
4748
return y
4849
end
@@ -56,6 +57,7 @@ function Base.:(*)(A::AdjointFunctionMap, x::AbstractVector)
5657
Afun.fc(y, x)
5758
else
5859
y = Afun.fc(x)
60+
length(y) == size(A, 1) || throw(DimensionMismatch())
5961
end
6062
return y
6163
elseif issymmetric(Afun) # but !isreal(A), Afun.f can be used
@@ -65,6 +67,7 @@ function Base.:(*)(A::AdjointFunctionMap, x::AbstractVector)
6567
Afun.f(y, x)
6668
else
6769
y = Afun.f(x)
70+
length(y) == size(A, 1) || throw(DimensionMismatch())
6871
end
6972
conj!(y)
7073
return y
@@ -85,6 +88,7 @@ function Base.:(*)(A::TransposeFunctionMap, x::AbstractVector)
8588
Afun.fc(y, x)
8689
else
8790
y = Afun.fc(x)
91+
length(y) == size(A, 1) || throw(DimensionMismatch())
8892
end
8993
if !isreal(A)
9094
conj!(y)
@@ -97,6 +101,7 @@ function Base.:(*)(A::TransposeFunctionMap, x::AbstractVector)
97101
Afun.f(y, x)
98102
else
99103
y = Afun.f(x)
104+
length(y) == size(A, 1) || throw(DimensionMismatch())
100105
end
101106
conj!(y)
102107
return y

0 commit comments

Comments
 (0)