Skip to content

Commit fe45b14

Browse files
committed
bump min supported python version from 3.8 to 3.11
1 parent fd2cedd commit fe45b14

File tree

15 files changed

+63
-53
lines changed

15 files changed

+63
-53
lines changed

.github/workflows/lint.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name: Lint
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
prek:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- name: Check out repo
14+
uses: actions/checkout@v5
15+
16+
- name: Run prek
17+
uses: j178/prek-action@v1

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ jobs:
1111
tests:
1212
uses: janosh/workflows/.github/workflows/pytest.yml@main
1313
with:
14-
python-version: "3.10"
14+
python-version: "3.11"
1515
secrets: inherit

.pre-commit-config.yaml

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,21 @@
1-
ci:
2-
autoupdate_schedule: quarterly
3-
skip: [ty]
4-
51
default_stages: [pre-commit]
6-
72
default_install_hook_types: [pre-commit, commit-msg]
83

94
repos:
105
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.11.9
6+
rev: v0.14.3
127
hooks:
13-
- id: ruff
8+
- id: ruff-check
149
args: [--fix, --ignore, D]
1510
- id: ruff-format
1611

17-
- repo: https://github.com/pre-commit/mirrors-mypy
18-
rev: v1.15.0
19-
hooks:
20-
- id: mypy
21-
2212
- repo: https://github.com/janosh/format-ipy-cells
2313
rev: v0.1.11
2414
hooks:
2515
- id: format-ipy-cells
2616

2717
- repo: https://github.com/pre-commit/pre-commit-hooks
28-
rev: v5.0.0
18+
rev: v6.0.0
2919
hooks:
3020
- id: check-case-conflict
3121
- id: check-symlinks
@@ -40,6 +30,7 @@ repos:
4030
rev: v2.4.1
4131
hooks:
4232
- id: codespell
33+
stages: [pre-commit, commit-msg]
4334
exclude_types: [jupyter]
4435
args: [--check-filenames]
4536

@@ -53,5 +44,6 @@ repos:
5344
hooks:
5445
- id: ty
5546
name: ty check
56-
entry: ty check .
47+
entry: ty check
5748
language: python
49+
additional_dependencies: [ty]

examples/half_moons.ipynb

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18-
"from typing import Callable\n",
18+
"from collections.abc import Callable\n",
1919
"\n",
2020
"import matplotlib.pyplot as plt\n",
2121
"import numpy as np\n",
@@ -245,7 +245,7 @@
245245
}
246246
],
247247
"source": [
248-
"steps, loss_vals = zip(*losses.items())\n",
248+
"steps, loss_vals = zip(*losses.items(), strict=True)\n",
249249
"plt.plot(steps, loss_vals)"
250250
]
251251
},
@@ -321,19 +321,23 @@
321321
"outputs": [],
322322
"source": [
323323
"def plot_grid_warp(\n",
324-
" ax: plt.Axes, z: np.ndarray, target_samples: np.ndarray, n_lines: int, idx: int\n",
324+
" ax: plt.Axes,\n",
325+
" z: np.ndarray,\n",
326+
" target_samples: np.ndarray | torch.Tensor,\n",
327+
" n_lines: int,\n",
328+
" idx: int,\n",
325329
"):\n",
326330
" \"\"\"plots how the flow warps space\"\"\"\n",
327331
"\n",
328332
" grid = z.reshape((n_lines, n_lines, 2))\n",
329333
" # y coords\n",
330334
" p1 = np.reshape(grid[1:, :, :], (n_lines**2 - n_lines, 2))\n",
331335
" p2 = np.reshape(grid[:-1, :, :], (n_lines**2 - n_lines, 2))\n",
332-
" lcy = LineCollection(tuple(zip(p1, p2)), alpha=0.3)\n",
336+
" lcy = LineCollection(tuple(zip(p1, p2, strict=True)), alpha=0.3)\n",
333337
" # x coords\n",
334338
" p1 = np.reshape(grid[:, 1:, :], (n_lines**2 - n_lines, 2))\n",
335339
" p2 = np.reshape(grid[:, :-1, :], (n_lines**2 - n_lines, 2))\n",
336-
" lcx = LineCollection(tuple(zip(p1, p2)), alpha=0.3)\n",
340+
" lcx = LineCollection(tuple(zip(p1, p2, strict=True)), alpha=0.3)\n",
337341
" # draw the lines\n",
338342
" ax.add_collection(lcx)\n",
339343
" ax.add_collection(lcy)\n",
@@ -477,7 +481,7 @@
477481
"xs, *_ = model.forward(latent_grid)\n",
478482
"xs = [z.detach().numpy() for z in xs]\n",
479483
"\n",
480-
"for idx, [z0, z1] in enumerate(zip(xs, xs[1:])):\n",
484+
"for idx, [z0, z1] in enumerate(zip(xs, xs[1:], strict=True)):\n",
481485
" _, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))\n",
482486
"\n",
483487
" plot_point_flow(ax1, z0, z1)\n",
@@ -506,7 +510,7 @@
506510
" fig, axes = plt.subplots(plot_grid_height, 2 * plot_grid_height, figsize=(20, 10))\n",
507511
" fig.subplots_adjust(wspace=0.05, hspace=0.05)\n",
508512
"\n",
509-
" for z0, z1, ax in zip(xs, xs[1:], axes[:, :plot_grid_height].flat):\n",
513+
" for z0, z1, ax in zip(xs, xs[1:], axes[:, :plot_grid_height].flat, strict=True):\n",
510514
" plot_point_flow(ax, z0, z1)\n",
511515
" ax.set(xlim=[-4, 4], ylim=[-4, 4], xticks=[], yticks=[])\n",
512516
"\n",

pyproject.toml

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
2-
requires = ["setuptools>=61.2"]
3-
build-backend = "setuptools.build_meta"
2+
requires = ["uv_build>=0.9.5"]
3+
build-backend = "uv_build"
44

55
[project]
66
name = "torch-mnf"
@@ -21,13 +21,13 @@ classifiers = [
2121
"License :: OSI Approved :: MIT License",
2222
"Programming Language :: Python :: 3 :: Only",
2323
"Programming Language :: Python :: 3",
24-
"Programming Language :: Python :: 3.10",
2524
"Programming Language :: Python :: 3.11",
26-
"Programming Language :: Python :: 3.8",
27-
"Programming Language :: Python :: 3.9",
25+
"Programming Language :: Python :: 3.12",
26+
"Programming Language :: Python :: 3.13",
27+
"Programming Language :: Python :: 3.14",
2828
]
2929
urls = { Homepage = "https://github.com/janosh/torch-mnf" }
30-
requires-python = ">=3.8"
30+
requires-python = ">=3.11"
3131
dependencies = [
3232
"matplotlib",
3333
"numpy",
@@ -42,14 +42,10 @@ dependencies = [
4242
[project.optional-dependencies]
4343
test = ["pytest", "pytest-cov"]
4444

45-
[tool.setuptools.packages]
46-
find = {}
47-
48-
[tool.setuptools.package-data]
49-
torch_mnf = ["*.csv"]
50-
51-
[tool.distutils.bdist_wheel]
52-
universal = true
45+
[tool.uv.build-backend]
46+
module-name = "torch_mnf"
47+
module-root = ""
48+
source-include = ["torch_mnf/**/*.csv"]
5349

5450
[tool.pytest.ini_options]
5551
testpaths = ["tests"]
@@ -68,7 +64,7 @@ no_implicit_optional = false
6864
ignore-words-list = "hist"
6965

7066
[tool.ruff]
71-
target-version = "py38"
67+
target-version = "py311"
7268
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]
7369

7470
[tool.ruff.lint]

readme.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Torch MNF
22

33
[![Tests](https://github.com/janosh/torch-mnf/actions/workflows/test.yml/badge.svg)](https://github.com/janosh/torch-mnf/actions/workflows/test.yml)
4-
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/janosh/torch-mnf/master.svg)](https://results.pre-commit.ci/latest/github/janosh/torch-mnf/master)
54
![GitHub Repo Size](https://img.shields.io/github/repo-size/janosh/torch-mnf?label=Repo+Size)
65

76
PyTorch implementation of Multiplicative Normalizing Flows [[1]](#mnf-bnn).

tests/test_flows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Sequence
1+
from collections.abc import Sequence
22

33
import torch
44
from torch import Tensor, nn

torch_mnf/flows/affine_half_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
(Laurent's extension of NICE)
1010
"""
1111

12-
from typing import Sequence
12+
from collections.abc import Sequence
1313

1414
import torch
1515
from torch import Tensor, nn

torch_mnf/flows/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Sequence
3+
from collections.abc import Sequence
44

55
import torch
66
from torch import Tensor, nn
@@ -52,4 +52,4 @@ def sample(self, *num_samples: int) -> Tensor:
5252
"""Sample from the flow's base distribution."""
5353
z = self.base.sample(*num_samples)
5454
xs, _ = self.forward(z)
55-
return xs
55+
return xs[-1]

torch_mnf/flows/maf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from typing import Sequence
13+
from collections.abc import Sequence
1414

1515
import torch
1616
from torch import nn

0 commit comments

Comments
 (0)