Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ base_distribution = distributions.StandardNormal(shape=[2])
flow = flows.Flow(transform=transform, distribution=base_distribution)
```

or use already implemented flow architectures:

```python
from nflows.flows import MaskedAutoregressiveFlow, SimpleRealNVP

features=2
hidden_features=4

maf = MaskedAutoregressiveFlow(features=features, hidden_features=hidden_features)
rnvp = SimpleRealNVP(features=features, hidden_features=hidden_features)
nice = SimpleRealNVP(feautres=features, hidden_features=hidden_features, use_volumne_perserving=True)
```

To evaluate log probabilities of inputs:
```python
log_prob = flow.log_prob(inputs)
Expand Down
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,3 @@ dependencies:
- pyyaml
- tensorboard
- tqdm

237 changes: 237 additions & 0 deletions examples/Example-Conditional-MAF.ipynb

Large diffs are not rendered by default.

237 changes: 237 additions & 0 deletions examples/Example-Conditional-Real-NVP.ipynb

Large diffs are not rendered by default.

234 changes: 234 additions & 0 deletions examples/Example-MAF.ipynb

Large diffs are not rendered by default.

234 changes: 234 additions & 0 deletions examples/Example-Real-NVP.ipynb

Large diffs are not rendered by default.

252 changes: 0 additions & 252 deletions examples/conditional_moons.ipynb

This file was deleted.

245 changes: 0 additions & 245 deletions examples/moons.ipynb

This file was deleted.

28 changes: 15 additions & 13 deletions nflows/flows/autoregressive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Implementations of autoregressive flows."""

from torch.nn import functional as F
import torch.nn.functional as F

from nflows.distributions.normal import StandardNormal
from nflows.flows.base import Flow
Expand All @@ -19,18 +19,19 @@ class MaskedAutoregressiveFlow(Flow):
"""

def __init__(
self,
features,
hidden_features,
num_layers,
num_blocks_per_layer,
use_residual_blocks=True,
use_random_masks=False,
use_random_permutations=False,
activation=F.relu,
dropout_probability=0.0,
batch_norm_within_layers=False,
batch_norm_between_layers=False,
self,
features,
hidden_features,
context_features=None,
num_layers=5,
num_blocks_per_layer=2,
use_residual_blocks=True,
use_random_masks=False,
use_random_permutations=False,
activation=F.relu,
dropout_probability=0.0,
batch_norm_within_layers=False,
batch_norm_between_layers=False,
):

if use_random_permutations:
Expand All @@ -45,6 +46,7 @@ def __init__(
MaskedAffineAutoregressiveTransform(
features=features,
hidden_features=hidden_features,
context_features=context_features,
num_blocks=num_blocks_per_layer,
use_residual_blocks=use_residual_blocks,
random_mask=use_random_masks,
Expand Down
22 changes: 12 additions & 10 deletions nflows/flows/realnvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ class SimpleRealNVP(Flow):
"""

def __init__(
self,
features,
hidden_features,
num_layers,
num_blocks_per_layer,
use_volume_preserving=False,
activation=F.relu,
dropout_probability=0.0,
batch_norm_within_layers=False,
batch_norm_between_layers=False,
self,
features,
hidden_features,
context_features=None,
num_layers=5,
num_blocks_per_layer=2,
use_volume_preserving=False,
activation=F.relu,
dropout_probability=0.0,
batch_norm_within_layers=False,
batch_norm_between_layers=False,
):

if use_volume_preserving:
Expand All @@ -49,6 +50,7 @@ def create_resnet(in_features, out_features):
in_features,
out_features,
hidden_features=hidden_features,
context_features=context_features,
num_blocks=num_blocks_per_layer,
activation=activation,
dropout_probability=dropout_probability,
Expand Down
27 changes: 27 additions & 0 deletions tests/flows/autoregressive_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ def test_sample(self):
self.assertIsInstance(samples, torch.Tensor)
self.assertEqual(samples.shape, torch.Size([num_samples, features]))

def test_conditional_log_prob(self):
batch_size = 10
features = 20
context_features = 5
flow = ar.MaskedAutoregressiveFlow(
features=features, hidden_features=30, context_features=context_features
)
inputs = torch.randn(batch_size, features)
context = torch.randn(batch_size, context_features)
log_prob = flow.log_prob(inputs, context)
self.assertIsInstance(log_prob, torch.Tensor)
self.assertEqual(log_prob.shape, torch.Size([batch_size]))

def test_conditional_sample(self):
num_samples = 10

batch_size = 10
features = 20
context_features = 5
flow = ar.MaskedAutoregressiveFlow(
features=features, hidden_features=30, context_features=context_features
)
context = torch.randn(batch_size, context_features)
samples = flow.sample(num_samples, context)
self.assertIsInstance(samples, torch.Tensor)
self.assertEqual(samples.shape, torch.Size([batch_size, num_samples, features]))


if __name__ == "__main__":
unittest.main()
27 changes: 27 additions & 0 deletions tests/flows/realnvp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ def test_sample(self):
self.assertIsInstance(samples, torch.Tensor)
self.assertEqual(samples.shape, torch.Size([num_samples, features]))

def test_conditional_log_prob(self):
batch_size = 10
features = 20
context_features = 5
flow = realnvp.SimpleRealNVP(
features=features, hidden_features=30, context_features=context_features
)
inputs = torch.randn(batch_size, features)
context = torch.randn(batch_size, context_features)
log_prob = flow.log_prob(inputs, context)
self.assertIsInstance(log_prob, torch.Tensor)
self.assertEqual(log_prob.shape, torch.Size([batch_size]))

def test_conditional_sample(self):
num_samples = 10

batch_size = 10
features = 20
context_features = 5
flow = realnvp.SimpleRealNVP(
features=features, hidden_features=30, context_features=context_features
)
context = torch.randn(batch_size, context_features)
samples = flow.sample(num_samples, context)
self.assertIsInstance(samples, torch.Tensor)
self.assertEqual(samples.shape, torch.Size([batch_size, num_samples, features]))


if __name__ == "__main__":
unittest.main()