Skip to content

Commit 37c1bd7

Browse files
committed
fix pep8
1 parent 9d3c1bc commit 37c1bd7

File tree

4 files changed

+53
-46
lines changed

4 files changed

+53
-46
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ doctests = True
1010
verbose = 2
1111
# https://pep8.readthedocs.io/en/latest/intro.html#error-codes
1212
format = pylint
13-
ignore = E731,W504,F401,F841,E722,W503
13+
ignore = E731,W504,F401,F841,E722,W503,E203
1414

1515
[build_sphinx]
1616
source-dir = doc/source

stochman/curves.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ def _init_params(self, params, *args, **kwargs) -> None:
167167
.expand(self.begin.shape[0], -1, self.begin.shape[1]), # Bx(_num_nodes-2)xD
168168
)
169169
if params is None:
170-
params = self.t * self.end.unsqueeze(1) + \
171-
(1 - self.t) * self.begin.unsqueeze(1) # Bx(_num_nodes)xD
170+
params = self.t * self.end.unsqueeze(1) + (1 - self.t) * self.begin.unsqueeze(
171+
1
172+
) # Bx(_num_nodes)xD
172173
if self._requires_grad:
173174
self.register_parameter("params", nn.Parameter(params))
174175
else:
@@ -184,7 +185,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
184185
self.t,
185186
torch.ones(B, 1, D, dtype=self.t.dtype, device=self.device),
186187
),
187-
dim=1
188+
dim=1,
188189
) # Bx(num_nodes)xD
189190
a = (end_nodes - start_nodes) / (t0[:, 1:] - t0[:, :-1]) # Bx(num_edges)xD
190191
b = start_nodes - a * t0[:, :-1] # Bx(num_edges)xD
@@ -194,10 +195,12 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
194195
elif t.ndim == 2:
195196
tt = t # Bx|t|
196197
else:
197-
raise Exception('t must have at most 2 dimensions')
198+
raise Exception("t must have at most 2 dimensions")
198199
idx = (
199-
torch.floor(tt * num_edges).clamp(min=0, max=num_edges - 1).long() # Bx|t|
200-
).unsqueeze(2).repeat(1, 1, D) # Bx|t|xD, this assumes that nodes are equi-distant
200+
(torch.floor(tt * num_edges).clamp(min=0, max=num_edges - 1).long()) # Bx|t|
201+
.unsqueeze(2)
202+
.repeat(1, 1, D)
203+
) # Bx|t|xD, this assumes that nodes are equi-distant
201204
result = torch.gather(a, 1, idx) * tt.unsqueeze(2) + torch.gather(b, 1, idx) # Bx|t|xD
202205
if B == 1:
203206
result = result.squeeze(0) # |t|xD
@@ -316,29 +319,29 @@ def _compute_basis(self, num_edges) -> torch.Tensor:
316319
for i in range(num_edges - 1):
317320
si = 4 * i # start index
318321
fill = torch.tensor([1.0, t[i], t[i] ** 2, t[i] ** 3], dtype=self.begin.dtype)
319-
zeroth[i, si:(si + 4)] = fill
320-
zeroth[i, (si + 4):(si + 8)] = -fill
322+
zeroth[i, si : (si + 4)] = fill
323+
zeroth[i, (si + 4) : (si + 8)] = -fill
321324

322325
first = torch.zeros(num_edges - 1, 4 * num_edges, dtype=self.begin.dtype)
323326
for i in range(num_edges - 1):
324327
si = 4 * i # start index
325328
fill = torch.tensor([0.0, 1.0, 2.0 * t[i], 3.0 * t[i] ** 2], dtype=self.begin.dtype)
326-
first[i, si:(si + 4)] = fill
327-
first[i, (si + 4):(si + 8)] = -fill
329+
first[i, si : (si + 4)] = fill
330+
first[i, (si + 4) : (si + 8)] = -fill
328331

329332
second = torch.zeros(num_edges - 1, 4 * num_edges, dtype=self.begin.dtype)
330333
for i in range(num_edges - 1):
331334
si = 4 * i # start index
332335
fill = torch.tensor([0.0, 0.0, 6.0 * t[i], 2.0], dtype=self.begin.dtype)
333-
second[i, si:(si + 4)] = fill
334-
second[i, (si + 4):(si + 8)] = -fill
336+
second[i, si : (si + 4)] = fill
337+
second[i, (si + 4) : (si + 8)] = -fill
335338

336339
constraints = torch.cat((end_points, zeroth, first, second))
337340
self.constraints = constraints
338341

339342
# Compute null space, which forms our basis
340343
_, S, V = torch.svd(constraints, some=False)
341-
basis = V[:, S.numel():] # (num_coeffs)x(intr_dim)
344+
basis = V[:, S.numel() :] # (num_coeffs)x(intr_dim)
342345

343346
return basis
344347

stochman/nnj.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from enum import Enum
44
from typing import Optional, Tuple, Union
55

6+
import numpy as np
67
import torch
78
import torch.nn.functional as F
89
from torch import nn
9-
import numpy as np
1010

1111

1212
class JacType(Enum):
@@ -17,40 +17,43 @@ class JacType(Enum):
1717
FULL: The Jacobian is a matrix of whatever size.
1818
"""
1919

20-
DIAG = 'diag'
21-
FULL = 'full'
22-
CONV = 'conv'
23-
20+
DIAG = "diag"
21+
FULL = "full"
22+
CONV = "conv"
23+
2424
def __eq__(self, other: Union[str, Enum]) -> bool:
2525
other = other.value if isinstance(other, Enum) else str(other)
2626
return self.value.lower() == other.lower()
27-
27+
2828

2929
class Jacobian(torch.Tensor):
30-
""" Class representing a jacobian tensor, subclasses from torch.Tensor
31-
Requires the additional `jactype` parameter to initialize, which
32-
is a string indicating the jacobian type
30+
"""Class representing a jacobian tensor, subclasses from torch.Tensor
31+
Requires the additional `jactype` parameter to initialize, which
32+
is a string indicating the jacobian type
3333
"""
34+
3435
def __init__(self, tensor, jactype):
3536
available_jactype = [item.value for item in JacType]
3637
if jactype not in available_jactype:
37-
raise ValueError(f'Tried to initialize jacobian tensor with unknown jacobian type {jactype}.'
38-
f' Please choose between {available_jactype}')
38+
raise ValueError(
39+
f"Tried to initialize jacobian tensor with unknown jacobian type {jactype}."
40+
f" Please choose between {available_jactype}"
41+
)
3942
self.jactype = jactype
40-
43+
4144
@staticmethod
4245
def __new__(cls, x, jactype, *args, **kwargs):
4346
cls.jactype = jactype
4447
return super().__new__(cls, x, *args, **kwargs)
45-
48+
4649
def __repr__(self):
4750
tensor_repr = super().__repr__()
48-
tensor_repr = tensor_repr.replace('tensor', 'jacobian')
49-
tensor_repr += f'\n jactype={self.jactype.value if isinstance(self.jactype, Enum) else self.jactype}'
51+
tensor_repr = tensor_repr.replace("tensor", "jacobian")
52+
tensor_repr += f"\n jactype={self.jactype.value if isinstance(self.jactype, Enum) else self.jactype}"
5053
return tensor_repr
51-
54+
5255
def __add__(self, other):
53-
if isinstance(other, Jacobian):
56+
if isinstance(other, Jacobian):
5457
if self.jactype == other.jactype:
5558
res = torch.add(self, other)
5659
return jacobian(res, self.jactype)
@@ -59,14 +62,14 @@ def __add__(self, other):
5962
return jacobian(res, JacType.FULL)
6063
if self.jactype == JacType.DIAG and other.jactype == JacType.FULL:
6164
res = torch.add(torch.diag_embed(self), other)
62-
return jacobian(res, JacType.FULL)
65+
return jacobian(res, JacType.FULL)
6366
if self.jactype == JacType.CONV and other.jactype == JacType.CONV:
6467
res = torch.add(self, other)
6568
return jacobian(res, JacType.CONV)
66-
raise ValueError('Unknown addition of jacobian matrices')
67-
69+
raise ValueError("Unknown addition of jacobian matrices")
70+
6871
return super().__add__(other)
69-
72+
7073
def __matmul__(self, other):
7174
if isinstance(other, Jacobian):
7275
# diag * diag
@@ -90,9 +93,9 @@ def __matmul__(self, other):
9093
if other == JacType.CONV:
9194
res = self * other
9295
return jacobian(res, JacType.CONV)
93-
94-
raise ValueError('Unknown matrix multiplication of jacobian matrices')
95-
96+
97+
raise ValueError("Unknown matrix multiplication of jacobian matrices")
98+
9699

97100
def jacobian(tensor, jactype):
98101
""" Initialize a jacobian tensor by a specified jacobian type """
@@ -126,8 +129,7 @@ def _jacobian(self, x: torch.Tensor, val: torch.Tensor) -> Jacobian:
126129
attains value val."""
127130
pass
128131

129-
def _jac_mul(
130-
self, x: torch.Tensor, val: torch.Tensor, jac_in: torch.Tensor) -> Jacobian:
132+
def _jac_mul(self, x: torch.Tensor, val: torch.Tensor, jac_in: torch.Tensor) -> Jacobian:
131133
"""Multiply the Jacobian at x with M.
132134
This can potentially be done more efficiently than
133135
first computing the Jacobian, and then performing the
@@ -557,9 +559,9 @@ def forward(self, x: torch.Tensor, jacobian: bool = False):
557559
def _jacobian(self, x: torch.Tensor, val: torch.Tensor) -> Jacobian:
558560
w = self._conv_to_toeplitz(x.shape[1:])
559561
w = w.unsqueeze(0).repeat(x.shape[0], 1, 1)
560-
return jacobian(w, JacType.CONV)
562+
return jacobian(w, JacType.CONV)
563+
561564

562-
563565
class Conv1d(_BaseJacConv, nn.Conv1d):
564566
def _conv_to_toeplitz(self, input_shape):
565567
identity = torch.eye(np.prod(input_shape).item()).reshape([-1] + list(input_shape))

tests/test_nnj.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,12 @@ def test_jac_return(model, return_jac):
124124
else:
125125
assert isinstance(output, torch.Tensor)
126126

127+
127128
_testcases = [
128-
(nnj.jacobian(torch.ones(5, 10, 10), 'full'), nnj.jacobian(torch.ones(5, 10, 10), 'full')),
129-
(nnj.jacobian(torch.ones(5, 10, 10), 'full'), nnj.jacobian(torch.ones(5, 10), 'diag')),
130-
(nnj.jacobian(torch.ones(5, 10), 'diag'), nnj.jacobian(torch.ones(5, 10, 10), 'full')),
131-
(nnj.jacobian(torch.ones(5, 10), 'diag'), nnj.jacobian(torch.ones(5, 10), 'diag')),
129+
(nnj.jacobian(torch.ones(5, 10, 10), "full"), nnj.jacobian(torch.ones(5, 10, 10), "full")),
130+
(nnj.jacobian(torch.ones(5, 10, 10), "full"), nnj.jacobian(torch.ones(5, 10), "diag")),
131+
(nnj.jacobian(torch.ones(5, 10), "diag"), nnj.jacobian(torch.ones(5, 10, 10), "full")),
132+
(nnj.jacobian(torch.ones(5, 10), "diag"), nnj.jacobian(torch.ones(5, 10), "diag")),
132133
]
133134

134135

@@ -147,6 +148,7 @@ def test_add(cases):
147148
j_out_diag = torch.stack([jo.diag() for jo in j_out]).flatten()
148149
assert all(j_out_diag == 2 * torch.ones_like(j_out_diag))
149150

151+
150152
@pytest.mark.parametrize("cases", _testcases)
151153
def test_matmul(cases):
152154
j_out = cases[0] @ cases[1]

0 commit comments

Comments
 (0)