Skip to content

Commit 3405adb

Browse files
committed
Adding state batching
1 parent e88ecbb commit 3405adb

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

examples/torus.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ sampleX1(n_samples) = Flowfusion.random_literal_cat(n_samples, sigma = T(0.05))[
3939
n_samples = 500
4040

4141
M = Torus(2)
42-
P = ManifoldProcess(0.2f0)
43-
#P = Deterministic()
42+
P = FProcess(ManifoldProcess(0.2f0), t -> 1-(1-t)^2)
43+
#P = FProcess(Deterministic(), t -> 1-(1-t)^2)
4444

4545
eta = 0.01
4646
opt_state = Flux.setup(AdamW(eta = eta, lambda = 0.00001), model)
@@ -57,8 +57,7 @@ for i in 1:iters
5757
ξ = Guide(Xt, X1)
5858
#Gradient
5959
l,g = Flux.withgradient(model) do m
60-
#tcloss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t)) #GOING TO HAVE TO ADD GUIDE HERE, AND CHANGE IT TO FLOSS
61-
floss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t)) #GOING TO HAVE TO ADD GUIDE HERE, AND CHANGE IT TO FLOSS
60+
floss(P, m(t,tensor(Xt)), ξ, scalefloss(P, t))
6261
end
6362
#Update
6463
Flux.update!(opt_state, model, g[1])

src/Flowfusion.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ export
4747
FProcess,
4848
floss,
4949
tcloss,
50-
dense
50+
dense,
51+
batch
5152

5253

5354
#Useful for demos etc:

src/bridge.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,20 @@ function apply_tangent_coordinates(Xt::ManifoldState, ξ; retraction_method=defa
179179
end
180180
return X̂₁
181181
end
182+
183+
#As some point we should move these to ForwardBackward.jl. Low prio because they're mostly not needed for other applications.
184+
tensor_cat(Xs::Vector{T}; dims_from_end = 1) where T = cat(tensor.(Xs)..., dims = ndims(tensor(Xs[1])) - dims_from_end + 1)
185+
tensor_cat(Xs::Vector{Nothing}) = nothing
186+
187+
"""
188+
batch(Xs::Vector{T}; dims_from_end = 1)
189+
190+
Doesn't handle padding. Add option to pad if batching along dims that don't have the same length.
191+
"""
192+
batch(Xs::Vector{T}; dims_from_end = 1) where T<:ContinuousState = T(tensor_cat(Xs; dims_from_end))
193+
batch(Xs::Vector{T}; dims_from_end = 1) where T<:DiscreteState = T(Xs[1].K, tensor_cat(Xs; dims_from_end))
194+
batch(Xs::Vector{<:ManifoldState{<:M,<:A}}; dims_from_end = 1) where {M, A} = ManifoldState(Xs[1].M, eachslice(tensor_cat(Xs; dims_from_end), dims = Tuple((ndims(Xs[1].state[1])+1:ndims(tensor(Xs[1])))))) #Only tested for rotations.
195+
batch(Xs::Vector{<:Tuple{Vararg{UState}}}, dims_from_end = 1) = Tuple([batch([x[i] for x in Xs], dims_from_end = dims_from_end) for i in 1:length(Xs[1])])
196+
197+
#Should never move to ForwardBackward.jl
198+
batch(Xs::Vector{<:MaskedState}; dims_from_end = 1) = MaskedState(batch(unmask.(Xs); dims_from_end), tensor_cat([X.cmask for X in Xs]; dims_from_end), tensor_cat([X.lmask for X in Xs]; dims_from_end))

0 commit comments

Comments
 (0)