Skip to content

Commit c2a6d90

Browse files
ptiedesbrantq
authored andcommitted
Move ReactantPlan to extension
1 parent fcdd20f commit c2a6d90

File tree

2 files changed

+0
-449
lines changed

2 files changed

+0
-449
lines changed

test/integration/Comrade/comimager.jl

Lines changed: 0 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ using Distributions
3535
using VLBIImagePriors
3636
using LogExpFunctions
3737
import TransformVariables as TV
38-
include("reactant_nfft.jl")
3938

4039
using Downloads
4140
using Distributions
@@ -49,38 +48,6 @@ using Test
4948
const dataf = joinpath("/home/ptiede/.julia/dev/Comrade/examples/Data/SR1_M87_2017_096_lo_hops_netcal_StokesI.uvfits")
5049

5150
# TODO upstream to VLBISkyModels
52-
struct ReactantAlg <: VLBISkyModels.NUFT end
53-
54-
function VLBISkyModels.plan_nuft_spatial(
55-
::ReactantAlg,
56-
imgdomain::ComradeBase.AbstractRectiGrid,
57-
visdomain::ComradeBase.UnstructuredDomain,
58-
)
59-
visp = domainpoints(visdomain)
60-
uv2 = similar(visp.U, (2, length(visdomain)))
61-
dpx = pixelsizes(imgdomain)
62-
dx = dpx.X
63-
dy = dpx.Y
64-
rm = ComradeBase.rotmat(imgdomain)'
65-
# Here we flip the sign because the NFFT uses the -2pi convention
66-
uv2[1, :] .= -VLBISkyModels._rotatex.(visp.U, visp.V, Ref(rm)) .* dx
67-
uv2[2, :] .= -VLBISkyModels._rotatey.(visp.U, visp.V, Ref(rm)) .* dy
68-
return ReactantNFFTPlan(uv2, size(imgdomain))
69-
end
70-
71-
function VLBISkyModels.make_phases(
72-
::ReactantAlg,
73-
imgdomain::ComradeBase.AbstractRectiGrid,
74-
visdomain::ComradeBase.UnstructuredDomain,
75-
)
76-
return VLBISkyModels.make_phases(NFFTAlg(), imgdomain, visdomain)
77-
end
78-
79-
function VLBISkyModels._jlnuft!(out, A::ReactantNFFTPlan, inp::Reactant.AnyTracedRArray)
80-
LinearAlgebra.mul!(out, A, inp)
81-
return nothing
82-
end
83-
8451

8552

8653
# TODO Make ReactantLogExpFunctionsExt.
@@ -89,102 +56,6 @@ function LogExpFunctions.logistic(@nospecialize x::Reactant.TracedRNumber)
8956
end
9057
LogExpFunctions.log1pexp(x::Reactant.TracedRNumber) = log(1 + exp(x))
9158

92-
#!!!! TODO Everything in this block needs to be upstreamed to TransformVariables.jl
93-
# The major problem is that most of them require @allowscalar which is not very nice enforce in their
94-
# codebase since it is expensive for CPU code. The other temporary solution is to make a ReactantTransformVariablesExt.jl package.
95-
# function TV.transform_with(
96-
# flag::TV.NoLogJac, t::TV.ScalarTransform, x::Reactant.AnyTracedRVector, index
97-
# )
98-
# return transform(t, @allowscalar x[index]), flag, index + 1
99-
# end
100-
101-
# TODO Upstream to TransformVariables.jl
102-
# function TV.transform_with(
103-
# ::TV.LogJac, t::TV.ScalarTransform, x::Reactant.AnyTracedRVector, index
104-
# )
105-
# return TV.transform_and_logjac(t, @allowscalar x[index])..., index + 1
106-
# end
107-
108-
# TODO This is needed for TransformVariables but @allowscalar is rather annoying here.
109-
# function TV._transform_tuple(flag::TV.LogJacFlag, x::Reactant.AnyTracedRVector, index, ts)
110-
# tfirst = first(ts)
111-
# out = TV.transform_with(flag, tfirst, x, index)
112-
# @allowscalar yfirst = out[1]
113-
# @allowscalar ℓfirst = out[2]
114-
# @allowscalar index′ = out[3]
115-
# # yrest, ℓrest, index′′
116-
# trest = Base.tail(ts)
117-
# outrest = TV._transform_tuple(flag, x, index′, trest)
118-
# @allowscalar yrest = outrest[1]
119-
# @allowscalar ℓrest = outrest[2]
120-
# @allowscalar index′′ = outrest[3]
121-
# return (yfirst, yrest...), ℓfirst + ℓrest, index′′
122-
# end
123-
124-
# TODO Upstream to TransformVariables.jl
125-
# function TV._transform_tuple(
126-
# flag::TV.LogJacFlag, x::Reactant.AnyTracedRVector, index, ::Tuple{}
127-
# )
128-
# return (), TV.logjac_zero(flag, eltype(x)), index
129-
# end
130-
131-
# # TODO Upstream to TransformVariables.jl
132-
# TV.logjac_zero(::TV.LogJac, ::Type{T}) where {T<:Reactant.RNumber} = log(one(T))
133-
134-
# TODO Upstream to TransformVariables.jl (essentially identical just need to loosen types)
135-
# function TV.transform_with(
136-
# flag::TV.LogJacFlag,
137-
# t::TV.ArrayTransformation{TV.Identity},
138-
# x::Reactant.AnyTracedRVector,
139-
# index,
140-
# )
141-
# (; dims) = t
142-
# index′ = index + dimension(t)
143-
# y = reshape(x[index:(index′ - 1)], t.dims)
144-
# return y, TV.logjac_zero(flag, eltype(x)), index′
145-
# end
146-
147-
#!!!! End of TransformVariables.jl upstream block
148-
149-
# TODO Upstream to VLBIImagePriors (allowscalar needed for Reactant)
150-
# function TV.transform_with(
151-
# flag::TV.LogJacFlag, ::AngleTransform, y::Reactant.AnyTracedRVector, index
152-
# )
153-
# T = eltype(y)
154-
# ℓi = TV.logjac_zero(flag, T)
155-
# x1 = @allowscalar y[index]
156-
# x2 = @allowscalar y[index + 1]
157-
# r = sqrt(x1^2 + x2^2)
158-
# # Use log-normal with μ = 0, σ = 1/4
159-
# σ = oftype(r, 1 / 4)
160-
# if !(flag isa TV.NoLogJac)
161-
# lr = log(r)
162-
# ℓi = -lr^2 * inv(2 * σ^2) - lr
163-
# end
164-
165-
# return atan(x1, x2), ℓi, index + 2
166-
# end
167-
168-
# TODO to upstream to VLBIImagePriors
169-
# function TV.transform_with(
170-
# flag::TV.LogJacFlag,
171-
# t::TV.ArrayTransformation{<:AngleTransform},
172-
# y::Reactant.AnyTracedRVector,
173-
# index,
174-
# )
175-
# (; inner_transformation, dims) = t
176-
# T = eltype(y)
177-
# ℓ = TV.logjac_zero(flag, T)
178-
# out = similar(y, dims)
179-
# @trace for i in eachindex(out)
180-
# θ, ℓi, index2 = TV.transform_with(flag, inner_transformation, y, index)
181-
# index = index2
182-
# ℓ += ℓi
183-
# @allowscalar out[i] = θ
184-
# end
185-
# return out, ℓ, index
186-
# end
187-
18859
# TODO Make Distributions package that is compatible with Reactant
18960
Distributions.logpdf(d::Uniform, x::Reactant.TracedRNumber) = oftype(x, -log(d.b - d.a))
19061
function Distributions.logpdf(d::Exponential, x::Reactant.TracedRNumber)

0 commit comments

Comments
 (0)