Skip to content

Commit bfd9f06

Browse files
committed
added remove_displacement
1 parent b73b479 commit bfd9f06

File tree

12 files changed

+48
-24
lines changed

12 files changed

+48
-24
lines changed

src/calculus/AffineAdd.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,5 @@ end
9494

9595
displacement(A::AffineAdd{L,D,true}) where {L,D} = A.d+displacement(A.A)
9696
displacement(A::AffineAdd{L,D,false}) where {L,D} = -A.d+displacement(A.A)
97+
98+
remove_displacement(A::AffineAdd) = remove_displacement(A.A)

src/calculus/BroadCast.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,4 @@ is_linear( R::BroadCast) = is_linear(R.A)
116116
is_null( R::BroadCast) = is_null(R.A)
117117

118118
fun_name(R::BroadCast) = "."fun_name(R.A)
119+
remove_displacement(B::BroadCast) = BroadCast(remove_displacement(B.A), B.dim_out, B.bufC, B.bufD)

src/calculus/Compose.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,5 @@ function permute(C::Compose, p::AbstractVector{Int})
121121
Compose(AA,C.buf)
122122

123123
end
124+
125+
remove_displacement(C::Compose) = Compose(remove_displacement.(C.A),C.buf)

src/calculus/DCAT.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ function permute(H::DCAT{N,L,P1,P2}, p::AbstractVector{Int}) where {N,L,P1,P2}
210210
DCAT(H.A,new_part,H.idxC)
211211
end
212212

213+
remove_displacement(D::DCAT) = DCAT(remove_displacement.(D.A), D.idxD, D.idxC)
214+
213215
# special cases
214216
# Eye constructor
215217
Eye(x::A) where {N, A <: NTuple{N,AbstractArray}} = DCAT(Eye.(x)...)

src/calculus/HCAT.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,5 @@ function permute(H::HCAT{M,N,L,P,C}, p::AbstractVector{Int}) where {M,N,L,P,C}
320320

321321
HCAT{M,N,L,P,C}(H.A,new_part,H.buf)
322322
end
323+
324+
remove_displacement(H::HCAT{M}) where {M} = HCAT(remove_displacement.(H.A), H.idxs, H.buf, M)

src/calculus/Hadamard.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,22 @@ true
3030

3131
struct Hadamard{M, C, V <: VCAT{M}} <: NonLinearOperator
3232
A::V
33-
mid::C
34-
mid2::C
35-
function Hadamard(A::V, mid::C, mid2::C) where {M, C, V <: VCAT{M}}
33+
buf::C
34+
buf2::C
35+
function Hadamard(A::V, buf::C, buf2::C) where {M, C, V <: VCAT{M}}
3636
any([ai != size(A,1)[1] for ai in size(A,1)]) &&
3737
throw(DimensionMismatch("cannot compose operators"))
3838

39-
new{M, C, V}(A,mid,mid2)
39+
new{M, C, V}(A,buf,buf2)
4040
end
4141
end
4242

4343
struct HadamardJacobian{M, C, V <: VCAT{M}} <: LinearOperator
4444
A::V
45-
mid::C
46-
mid2::C
47-
function HadamardJacobian(A::V,mid::C,mid2::C) where {M, C, V <: VCAT{M}}
48-
new{M, C, V}(A,mid,mid2)
45+
buf::C
46+
buf2::C
47+
function HadamardJacobian(A::V,buf::C,buf2::C) where {M, C, V <: VCAT{M}}
48+
new{M, C, V}(A,buf,buf2)
4949
end
5050
end
5151

@@ -57,32 +57,32 @@ function Hadamard(L1::AbstractOperator,L2::AbstractOperator)
5757

5858
V = VCAT(A,B)
5959

60-
mid = zeros.(codomainType(V), size(V,1))
61-
mid2 = zeros.(codomainType(V), size(V,1))
60+
buf = zeros.(codomainType(V), size(V,1))
61+
buf2 = zeros.(codomainType(V), size(V,1))
6262

63-
Hadamard(V,mid,mid2)
63+
Hadamard(V,buf,buf2)
6464
end
6565

6666
# Mappings
6767
function A_mul_B!(y, H::Hadamard{M,C,V}, b) where {M,C,V}
68-
A_mul_B!(H.mid,H.A,b)
68+
A_mul_B!(H.buf,H.A,b)
6969

70-
y .= H.mid[1]
71-
for i = 2:length(H.mid)
72-
y .*= H.mid[i]
70+
y .= H.buf[1]
71+
for i = 2:length(H.buf)
72+
y .*= H.buf[i]
7373
end
7474
end
7575

7676
# Jacobian
7777
Jacobian(A::H, x::D) where {M, D<:Tuple, C, V, H <: Hadamard{M,C,V}} =
78-
HadamardJacobian(Jacobian(A.A,x),A.mid,A.mid2)
78+
HadamardJacobian(Jacobian(A.A,x),A.buf,A.buf2)
7979

8080
function Ac_mul_B!(y, J::HadamardJacobian{M,C,V}, b) where {M,C,V}
81-
for i = 1:length(J.mid)
82-
c = (J.mid[1:i-1]...,J.mid[i+1:end]...,b)
83-
J.mid2[i] .= (.*)(c...)
81+
for i = 1:length(J.buf)
82+
c = (J.buf[1:i-1]...,J.buf[i+1:end]...,b)
83+
J.buf2[i] .= (.*)(c...)
8484
end
85-
Ac_mul_B!(y, J.A, J.mid2)
85+
Ac_mul_B!(y, J.A, J.buf2)
8686

8787
end
8888

@@ -104,5 +104,7 @@ import Base: permute
104104

105105
function permute(H::Hadamard, p::AbstractVector{Int})
106106
A = VCAT([permute(a,p) for a in H.A.A]...)
107-
Hadamard(A,H.mid,H.mid2)
107+
Hadamard(A,H.buf,H.buf2)
108108
end
109+
110+
remove_displacement(N::Hadamard) = Hadamard(remove_displacement(N.A), N.buf, N.buf2)

src/calculus/NonLinearCompose.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,4 @@ function permute(P::NonLinearCompose{N,L,C,D}, p::AbstractVector{Int}) where {N,
143143
NonLinearCompose(permute(P.A,p),permute(P.B,p),P.buf,P.bufx)
144144
end
145145

146-
147-
146+
remove_displacement(N::NonLinearCompose) = NonLinearCompose(remove_displacement(N.A), remove_displacement(N.B), N.buf, N.bufx)

src/calculus/Reshape.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,4 @@ is_full_row_rank( R::Reshape) = is_full_row_rank( R.A)
6868
is_full_column_rank( R::Reshape) = is_full_column_rank( R.A)
6969

7070
fun_name(R::Reshape) = ""*fun_name(R.A)
71+
remove_displacement(R::Reshape) = Reshape(remove_displacement(R.A), R.dim_out)

src/calculus/Scale.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,4 @@ fun_type(L::Scale) = fun_type(L.A)
100100
diag(L::Scale) = L.coeff*diag(L.A)
101101
diag_AcA(L::Scale) = (L.coeff)^2*diag_AcA(L.A)
102102
diag_AAc(L::Scale) = (L.coeff)^2*diag_AAc(L.A)
103+
remove_displacement(S::Scale) = Scale(S.coeff, S.coeff_conj, remove_displacement(S.A) )

src/calculus/Sum.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,4 @@ function permute(S::Sum{M,N}, p::AbstractVector{Int}) where {M,N}
117117
return Sum(AA,S.bufC,S.bufD[p],M,N)
118118
end
119119

120+
remove_displacement(S::Sum{M,N}) where {M,N} = Sum(remove_displacement.(S.A), S.bufC, S.bufD, M, N)

0 commit comments

Comments
 (0)