You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/Flowfusion.jl
+28-5Lines changed: 28 additions & 5 deletions
Original file line number
Diff line number
Diff line change
@@ -1,7 +1,30 @@
1
+
#=
2
+
Need to test/do:
3
+
Urgent:
4
+
- Test tuples!
5
+
- Masking (cmask) on all state types for bridge and gen
6
+
- Masking (lmask) on all state types for both losses
7
+
- tensor on masked states
8
+
- FProcess and whether it matches the target where allowed. Need to come up with a policy on using FProcess with InterpolatingDiscreteProcesses
9
+
- X1 pred for rotations (add angle/axis loss back in just because yolo)
10
+
- self-conditioning
11
+
- GPU use of all state types
12
+
Later:
13
+
- Make a table of Manifolds where you test whether the key functions are defined, with checkboxes and timing for diffusion and flow.
14
+
- Make a table of commands for key types of diffusion/flow. Columns for Process, X0/X1 setup, Xt bridge, loss, gen where things like softmax, Guide, etc are clear.
15
+
- Compute probability velocities for UniformDiscrete and PiQ so these can flow.
16
+
=#
17
+
18
+
19
+
20
+
21
+
1
22
module Flowfusion
2
23
3
24
using ForwardBackward, OneHotArrays, Adapt, Manifolds, NNlib
Returns a view of `a` where slices specified by `m` are selected. `m` can be multidimensional, but the dimensions of m must match the last dimensions of `a`.
61
-
For example, if `m` is a boolean array, then `size(a)[ndims(a)-ndims(m):end] == size(m)`.
22
+
Rerturns a state where `X.state` is a onehot array.
Applies, in place, a conditioning mask, forcing elements (or slices) of `Xt` to be equal to `X1`, where `cmask` is 1.
31
+
Converts `X` to an appropriate dense representation. If `X` is a `DiscreteState`, then `X` is converted to a `CategoricalLikelihood` with default eltype Float32.
32
+
If `X` is a "onehot" CategoricalLikelihood then `X` is converted to a fully dense one.
#copytensor! and predictresolve are used handle the state translation that happens in gen(...).
85
-
#We want the user's X̂₁predictor, which is a DL model, to return a plain tensor (since that will be on the GPU, in the loss, etc).
86
-
#This means we need to automagically create a State (typical for the continuous case) or Likelihood (typical for the discrete case) from the tensor.
87
-
#But the user may return a State in the Discrete case (for massive state spaces with sub-linear sampling), and a Likelihood in the Continuous case (for variance matching models)
88
-
#This also needs to handle MaskedStates (needs testing).
89
-
#We need: X̂₁ = fix(X̂₁predictor(t, Xₜ))
90
-
#Plan: When X̂₁predictor(t, Xₜ) is a State or Likelihood, just pass through.
91
-
#When X̂₁predictor(t, Xₜ) is a plain tensor, we apply default conversion rules.
@@ -124,24 +64,20 @@ If `X1` is a `MaskedState`, then `Xt` will equal `X1` where the conditioning mas
124
64
The same `t` and (optionally) `t0` will be used for all elements. If you need a different `t` for each Proces/State, broadcast with `bridge.(P, X0, X1, t0, t)`.
0 commit comments