Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ jobs:
tests:
uses: janosh/workflows/.github/workflows/pytest.yml@main
with:
python-version: "3.10"
python-version: "3.11"
secrets: inherit
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.9
rev: v0.12.2
hooks:
- id: ruff
args: [--fix, --ignore, D]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
rev: v1.16.1
hooks:
- id: mypy

Expand Down
12 changes: 6 additions & 6 deletions examples/half_moons.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"from collections.abc import Callable\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand Down Expand Up @@ -245,7 +245,7 @@
}
],
"source": [
"steps, loss_vals = zip(*losses.items())\n",
"steps, loss_vals = zip(*losses.items(), strict=True)\n",
"plt.plot(steps, loss_vals)"
]
},
Expand Down Expand Up @@ -329,11 +329,11 @@
" # y coords\n",
" p1 = np.reshape(grid[1:, :, :], (n_lines**2 - n_lines, 2))\n",
" p2 = np.reshape(grid[:-1, :, :], (n_lines**2 - n_lines, 2))\n",
" lcy = LineCollection(tuple(zip(p1, p2)), alpha=0.3)\n",
" lcy = LineCollection(tuple(zip(p1, p2, strict=True)), alpha=0.3)\n",
" # x coords\n",
" p1 = np.reshape(grid[:, 1:, :], (n_lines**2 - n_lines, 2))\n",
" p2 = np.reshape(grid[:, :-1, :], (n_lines**2 - n_lines, 2))\n",
" lcx = LineCollection(tuple(zip(p1, p2)), alpha=0.3)\n",
" lcx = LineCollection(tuple(zip(p1, p2, strict=True)), alpha=0.3)\n",
" # draw the lines\n",
" ax.add_collection(lcx)\n",
" ax.add_collection(lcy)\n",
Expand Down Expand Up @@ -477,7 +477,7 @@
"xs, *_ = model.forward(latent_grid)\n",
"xs = [z.detach().numpy() for z in xs]\n",
"\n",
"for idx, [z0, z1] in enumerate(zip(xs, xs[1:])):\n",
"for idx, [z0, z1] in enumerate(zip(xs, xs[1:], strict=True)):\n",
" _, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))\n",
"\n",
" plot_point_flow(ax1, z0, z1)\n",
Expand Down Expand Up @@ -506,7 +506,7 @@
" fig, axes = plt.subplots(plot_grid_height, 2 * plot_grid_height, figsize=(20, 10))\n",
" fig.subplots_adjust(wspace=0.05, hspace=0.05)\n",
"\n",
" for z0, z1, ax in zip(xs, xs[1:], axes[:, :plot_grid_height].flat):\n",
" for z0, z1, ax in zip(xs, xs[1:], axes[:, :plot_grid_height].flat, strict=True):\n",
" plot_point_flow(ax, z0, z1)\n",
" ax.set(xlim=[-4, 4], ylim=[-4, 4], xticks=[], yticks=[])\n",
"\n",
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
]
urls = { Homepage = "https://github.com/janosh/torch-mnf" }
requires-python = ">=3.8"
requires-python = ">=3.11"
dependencies = [
"matplotlib",
"numpy",
Expand Down Expand Up @@ -68,7 +68,7 @@ no_implicit_optional = false
ignore-words-list = "hist"

[tool.ruff]
target-version = "py38"
target-version = "py311"
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]

[tool.ruff.lint]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_flows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from collections.abc import Sequence

import torch
from torch import Tensor, nn
Expand Down
2 changes: 1 addition & 1 deletion torch_mnf/flows/affine_half_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
(Laurent's extension of NICE)
"""

from typing import Sequence
from collections.abc import Sequence

import torch
from torch import Tensor, nn
Expand Down
2 changes: 1 addition & 1 deletion torch_mnf/flows/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from collections.abc import Sequence

import torch
from torch import Tensor, nn
Expand Down
2 changes: 1 addition & 1 deletion torch_mnf/flows/maf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from __future__ import annotations

from typing import Sequence
from collections.abc import Sequence

import torch
from torch import nn
Expand Down
4 changes: 2 additions & 2 deletions torch_mnf/layers/made.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, n_in, hidden_sizes, n_out, num_masks=1, natural_ordering=Fals
# define a simple MLP neural net
layers = []
hs = [n_in, *list(hidden_sizes), n_out]
for h0, h1 in zip(hs, hs[1:]):
for h0, h1 in zip(hs, hs[1:], strict=False):
layers.extend([MaskedLinear(h0, h1), nn.ReLU()])
super().__init__(*layers[:-1]) # drop last ReLU)

Expand Down Expand Up @@ -89,5 +89,5 @@ def update_masks(self):

# set the masks in all MaskedLinear layers
masked_layers = [lyr for lyr in self if isinstance(lyr, MaskedLinear)]
for lyr, m in zip(masked_layers, masks):
for lyr, m in zip(masked_layers, masks, strict=False):
lyr.set_mask(m)
2 changes: 1 addition & 1 deletion torch_mnf/layers/mnf_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from collections.abc import Sequence

import torch
import torch.nn.functional as F
Expand Down
2 changes: 1 addition & 1 deletion torch_mnf/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class MLP(nn.Sequential):

def __init__(self, *layer_sizes, leaky_a=0.2):
layers = []
for s1, s2 in zip(layer_sizes, layer_sizes[1:]):
for s1, s2 in zip(layer_sizes, layer_sizes[1:], strict=False):
layers.append(nn.Linear(s1, s2))
layers.append(nn.LeakyReLU(leaky_a))
super().__init__(*layers[:-1]) # drop last ReLU
5 changes: 3 additions & 2 deletions torch_mnf/models/mnf_feed_forward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any

from torch import nn
from torch.nn import BatchNorm1d, ReLU, Sequential
Expand All @@ -23,7 +24,7 @@ def __init__(
) -> None:
"""Initialize the model."""
layers = []
for s1, s2 in zip(layer_sizes, layer_sizes[1:]):
for s1, s2 in zip(layer_sizes, layer_sizes[1:], strict=False):
layers.extend(
[MNFLinear()(s1, s2, **kwargs), activation(), BatchNorm1d(s2)]
)
Expand Down
2 changes: 1 addition & 1 deletion torch_mnf/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import os
from collections.abc import Callable
from functools import wraps
from typing import Callable

import matplotlib.pyplot as plt
import pandas as pd
Expand Down