Skip to content

Commit 9c9c8a1

Browse files
✨ Add Glow-like multi-scale flow
1 parent 997b9e6 commit 9c9c8a1

File tree

6 files changed

+687
-18
lines changed

6 files changed

+687
-18
lines changed

tests/test_flows.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,51 @@ def test_autoregressive_transforms():
121121
assert (torch.triu(J, diagonal=1) == 0).all(), t
122122
assert (torch.tril(J[:4, :4], diagonal=-1) == 0).all(), t
123123
assert (torch.tril(J[4:, 4:], diagonal=-1) == 0).all(), t
124+
125+
126+
def test_Glow(tmp_path):
127+
flow = Glow((3, 32, 32), context=[5, 0, 5])
128+
129+
# Evaluation of log_prob
130+
x, y = randn(8, 3, 32, 32), [randn(5, 16, 16), None, randn(8, 5, 4, 4)]
131+
log_p = flow(y).log_prob(x)
132+
133+
assert log_p.shape == (8,)
134+
assert log_p.requires_grad
135+
136+
flow.zero_grad(set_to_none=True)
137+
loss = -log_p.mean()
138+
loss.backward()
139+
140+
for p in flow.parameters():
141+
assert p.grad is not None
142+
143+
# Sampling
144+
x = flow(y).sample()
145+
146+
assert x.shape == (8, 3, 32, 32)
147+
148+
# Reparameterization trick
149+
x = flow(y).rsample()
150+
151+
flow.zero_grad(set_to_none=True)
152+
loss = x.square().sum().sqrt()
153+
loss.backward()
154+
155+
for p in flow.parameters():
156+
assert p.grad is not None
157+
158+
# Saving
159+
torch.save(flow, tmp_path / 'flow.pth')
160+
161+
# Loading
162+
flow_bis = torch.load(tmp_path / 'flow.pth')
163+
164+
x, y = randn(3, 32, 32), [randn(5, 16, 16), None, randn(5, 4, 4)]
165+
166+
seed = torch.seed()
167+
log_p = flow(y).log_prob(x)
168+
torch.manual_seed(seed)
169+
log_p_bis = flow_bis(y).log_prob(x)
170+
171+
assert torch.allclose(log_p, log_p_bis)

tests/test_nn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,20 @@ def test_MonotonicMLP():
7070
J = torch.autograd.functional.jacobian(net, x)
7171

7272
assert (J >= 0).all()
73+
74+
75+
def test_FCN():
76+
net = FCN(3, 5)
77+
78+
# Non-batched
79+
x = randn(3, 64, 64)
80+
y = net(x)
81+
82+
assert y.shape == (5, 64, 64)
83+
assert y.requires_grad
84+
85+
# Batched
86+
x = randn(8, 3, 32, 32)
87+
y = net(x)
88+
89+
assert y.shape == (8, 5, 32, 32)

tests/test_transforms.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,49 @@ def test_univariate_transforms():
5858
assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t
5959

6060

61+
def test_multivariate_transforms():
62+
ts = [
63+
LULinearTransform(randn(3, 3), dim=-2),
64+
PermutationTransform(torch.randperm(3), dim=-2),
65+
PixelShuffleTransform(dim=-2),
66+
]
67+
68+
for t in ts:
69+
# Shapes
70+
x = randn(256, 3, 8)
71+
y = t(x)
72+
73+
assert t.forward_shape(x.shape) == y.shape, t
74+
assert t.inverse_shape(y.shape) == x.shape, t
75+
76+
# Inverse
77+
z = t.inv(y)
78+
79+
assert x.shape == z.shape, t
80+
assert torch.allclose(x, z, atol=1e-4), t
81+
82+
# Jacobian
83+
x = randn(3, 8)
84+
y = t(x)
85+
86+
jacobian = torch.autograd.functional.jacobian(t, x)
87+
jacobian = jacobian.reshape(3 * 8, 3 * 8)
88+
89+
_, ladj = torch.slogdet(jacobian)
90+
91+
assert torch.allclose(t.log_abs_det_jacobian(x, y), ladj, atol=1e-4), t
92+
93+
# Inverse Jacobian
94+
z = t.inv(y)
95+
96+
jacobian = torch.autograd.functional.jacobian(t.inv, y)
97+
jacobian = jacobian.reshape(3 * 8, 3 * 8)
98+
99+
_, ladj = torch.slogdet(jacobian)
100+
101+
assert torch.allclose(t.inv.log_abs_det_jacobian(y, z), ladj, atol=1e-4), t
102+
103+
61104
def test_FFJTransform():
62105
a = torch.randn(3)
63106
f = lambda x, t: a * x
@@ -80,20 +123,24 @@ def test_FFJTransform():
80123
assert ladj.shape == x.shape[:-1]
81124

82125

83-
def test_PermutationTransform():
84-
t = PermutationTransform(torch.randperm(8))
126+
def test_DropTransform():
127+
dist = Normal(randn(3), abs(randn(3)) + 1)
128+
t = DropTransform(dist)
85129

86-
x = torch.randn(256, 8)
130+
# Call
131+
x = randn(256, 5)
87132
y = t(x)
88133

89-
assert x.shape == y.shape
90-
91-
match = x[:, :, None] == y[:, None, :]
92-
93-
assert (match.sum(dim=-1) == 1).all()
94-
assert (match.sum(dim=-2) == 1).all()
134+
assert t.forward_shape(x.shape) == y.shape
135+
assert t.inverse_shape(y.shape) == x.shape
95136

137+
# Inverse
96138
z = t.inv(y)
97139

98140
assert x.shape == z.shape
99-
assert (x == z).all()
141+
assert not torch.allclose(x, z)
142+
143+
# Jacobian
144+
ladj = t.log_abs_det_jacobian(x, y)
145+
146+
assert ladj.shape == (256,)

zuko/flows.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
'NAF',
1414
'FreeFormJacobianTransform',
1515
'CNF',
16+
'ConvCouplingTransform',
17+
'Glow',
1618
]
1719

1820
import abc
@@ -753,3 +755,197 @@ def __init__(
753755
)
754756

755757
super().__init__(transforms, base)
758+
759+
760+
class ConvCouplingTransform(TransformModule):
761+
r"""Creates a convolution coupling transformation.
762+
763+
Arguments:
764+
channels: The number of channels.
765+
context: The number of context channels.
766+
spatial: The number of spatial dimensions.
767+
univariate: The univariate transformation constructor.
768+
shapes: The shapes of the univariate transformation parameters.
769+
kwargs: Keyword arguments passed to :class:`zuko.nn.FCN`.
770+
"""
771+
772+
def __init__(
773+
self,
774+
channels: int,
775+
context: int = 0,
776+
spatial: int = 2,
777+
univariate: Callable[..., Transform] = MonotonicAffineTransform,
778+
shapes: List[Size] = [(), ()],
779+
**kwargs,
780+
):
781+
super().__init__()
782+
783+
self.d = channels // 2
784+
self.dim = -(spatial + 1)
785+
786+
# Univariate transformation
787+
self.univariate = univariate
788+
self.shapes = list(map(Size, shapes))
789+
self.sizes = [s.numel() for s in self.shapes]
790+
791+
# Hyper network
792+
kwargs.setdefault('activation', nn.ELU)
793+
kwargs.setdefault('normalize', True)
794+
795+
self.hyper = FCN(
796+
in_channels=self.d + context,
797+
out_channels=(channels - self.d) * sum(self.sizes),
798+
spatial=spatial,
799+
**kwargs,
800+
)
801+
802+
def extra_repr(self) -> str:
803+
base = self.univariate(*map(torch.randn, self.shapes))
804+
805+
return f'(base): {base}'
806+
807+
def meta(self, y: Tensor, x: Tensor) -> Transform:
808+
if y is not None:
809+
x = torch.cat(broadcast(x, y, ignore=abs(self.dim)), dim=self.dim)
810+
811+
total = sum(self.sizes)
812+
813+
phi = self.hyper(x)
814+
phi = phi.unflatten(self.dim, (phi.shape[self.dim] // total, total))
815+
phi = phi.movedim(self.dim, -1)
816+
phi = phi.split(self.sizes, -1)
817+
phi = (p.unflatten(-1, s + (1,)) for p, s in zip(phi, self.shapes))
818+
phi = (p.squeeze(-1) for p in phi)
819+
820+
return self.univariate(*phi)
821+
822+
def forward(self, y: Tensor = None) -> Transform:
823+
return CouplingTransform(partial(self.meta, y), self.d, self.dim)
824+
825+
826+
class Glow(DistributionModule):
827+
r"""Creates a Glow-like multi-scale flow.
828+
829+
References:
830+
| Glow: Generative Flow with Invertible 1x1 Convolutions (Kingma et al., 2018)
831+
| https://arxiv.org/abs/1807.03039
832+
833+
Arguments:
834+
shape: The shape of a sample.
835+
context: The number of context channels at each scale.
836+
transforms: The number of coupling transformations at each scale.
837+
kwargs: Keyword arguments passed to :class:`ConvCouplingTransform`.
838+
"""
839+
840+
def __init__(
841+
self,
842+
shape: Size,
843+
context: Union[int, List[int]] = 0,
844+
transforms: List[int] = [8, 8, 8],
845+
**kwargs,
846+
):
847+
super().__init__()
848+
849+
channels, *space = shape
850+
spatial = len(space)
851+
dim = -len(shape)
852+
scales = len(transforms)
853+
854+
assert all(s % 2**scales == 0 for s in space), (
855+
f"'shape' cannot be downscaled {scales} times"
856+
)
857+
858+
if isinstance(context, int):
859+
context = [context] * len(transforms)
860+
861+
self.flows = nn.ModuleList()
862+
self.bases = nn.ModuleList()
863+
864+
for i, K in enumerate(transforms):
865+
flow = []
866+
flow.append(Unconditional(PixelShuffleTransform, dim=dim))
867+
868+
channels = channels * 2**spatial
869+
space = [s // 2 for s in space]
870+
871+
for _ in range(K):
872+
flow.extend([
873+
Unconditional(
874+
PermutationTransform,
875+
torch.randperm(channels),
876+
dim=dim,
877+
buffer=True,
878+
),
879+
Unconditional(
880+
LULinearTransform,
881+
torch.eye(channels),
882+
dim=dim,
883+
),
884+
ConvCouplingTransform(
885+
channels=channels,
886+
context=context[i],
887+
spatial=spatial,
888+
**kwargs,
889+
),
890+
])
891+
892+
self.flows.append(nn.ModuleList(flow))
893+
self.bases.append(
894+
Unconditional(
895+
DiagNormal,
896+
torch.zeros(channels // 2, *space),
897+
torch.ones(channels // 2, *space),
898+
ndims=spatial + 1,
899+
buffer=True,
900+
)
901+
)
902+
903+
channels = channels // 2
904+
905+
self.bases.pop()
906+
self.bases.append(
907+
Unconditional(
908+
DiagNormal,
909+
torch.zeros(channels * 2, *space),
910+
torch.ones(channels * 2, *space),
911+
ndims=spatial + 1,
912+
buffer=True,
913+
)
914+
)
915+
916+
def forward(self, y: Iterable[Tensor] = None) -> NormalizingFlow:
917+
r"""
918+
Arguments:
919+
y: A sequence of contexts :math:`y`. There should be one element :math:`y_i`
920+
per scale, but elements can be :py:`None`.
921+
922+
Returns:
923+
A multi-scale flow :math:`p(X | y)`.
924+
"""
925+
926+
if y is None:
927+
y = [None] * len(self.flows)
928+
929+
# Transforms
930+
transforms = []
931+
context_shapes = []
932+
933+
for flow, base, y_i in zip(self.flows, self.bases, y):
934+
for t in flow:
935+
transforms.append(t(y_i))
936+
937+
transforms.append(DropTransform(base(y_i)))
938+
939+
if y_i is not None:
940+
context_shapes.append(y_i.shape)
941+
942+
# Base
943+
base = transforms.pop().dist
944+
dim = -len(base.event_shape)
945+
946+
batch_shapes = (shape[:dim] for shape in context_shapes)
947+
batch_shape = torch.broadcast_shapes(*batch_shapes)
948+
949+
base = base.expand(batch_shape)
950+
951+
return NormalizingFlow(transforms, base)

0 commit comments

Comments
 (0)