Skip to content

Commit 4a6e73f

Browse files
authored
Fix basis device transfer in PODBlock (#650)
* fix gpu data moving
1 parent 87c5c6a commit 4a6e73f

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

pina/model/block/pod_block.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module for Base Continuous Convolution class."""
22

3-
import torch
43
import warnings
4+
import torch
55

66

77
class PODBlock(torch.nn.Module):
@@ -29,9 +29,10 @@ def __init__(self, rank, scale_coefficients=True):
2929
"""
3030
super().__init__()
3131
self.__scale_coefficients = scale_coefficients
32-
self._basis = None
32+
self.register_buffer("_basis", None)
3333
self._singular_values = None
34-
self._scaler = None
34+
self.register_buffer("_std", None)
35+
self.register_buffer("_mean", None)
3536
self._rank = rank
3637

3738
@property
@@ -94,12 +95,12 @@ def scaler(self):
9495
:return: The scaler dictionary.
9596
:rtype: dict
9697
"""
97-
if self._scaler is None:
98+
if self._std is None:
9899
return None
99100

100101
return {
101-
"mean": self._scaler["mean"][: self.rank],
102-
"std": self._scaler["std"][: self.rank],
102+
"mean": self._mean[: self.rank],
103+
"std": self._std[: self.rank],
103104
}
104105

105106
@property
@@ -119,6 +120,10 @@ def fit(self, X, randomized=True):
119120
are scaled after the projection to have zero mean and unit variance.
120121
121122
:param torch.Tensor X: The input tensor to be reduced.
123+
:param bool randomized: If ``True``, a randomized algorithm is used to
124+
compute the POD basis. In general, this leads to faster
125+
computations, but the results may be less accurate. Default is
126+
``True``.
122127
"""
123128
self._fit_pod(X, randomized)
124129

@@ -132,10 +137,8 @@ def _fit_scaler(self, coeffs):
132137
133138
:param torch.Tensor coeffs: The coefficients to be scaled.
134139
"""
135-
self._scaler = {
136-
"std": torch.std(coeffs, dim=1),
137-
"mean": torch.mean(coeffs, dim=1),
138-
}
140+
self._std = torch.std(coeffs, dim=1) # pylint: disable=W0201
141+
self._mean = torch.mean(coeffs, dim=1) # pylint: disable=W0201
139142

140143
def _fit_pod(self, X, randomized):
141144
"""
@@ -154,13 +157,14 @@ def _fit_pod(self, X, randomized):
154157
else:
155158
if randomized:
156159
warnings.warn(
157-
"Considering a randomized algorithm to compute the POD basis"
160+
"Considering a randomized algorithm to compute the POD "
161+
"basis"
158162
)
159163
u, s, _ = torch.svd_lowrank(X.T, q=X.shape[0])
160164

161165
else:
162166
u, s, _ = torch.svd(X.T)
163-
self._basis = u.T
167+
self._basis = u.T # pylint: disable=W0201
164168
self._singular_values = s
165169

166170
def forward(self, X):

tests/test_block/test_pod.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ def test_fit(rank, scale, randomized):
4242
assert pod.singular_values.shape == (rank,)
4343
assert pod._singular_values.shape == (n_snap,)
4444
if scale is True:
45-
assert pod._scaler["mean"].shape == (n_snap,)
46-
assert pod._scaler["std"].shape == (n_snap,)
45+
assert pod._mean.shape == (n_snap,)
46+
assert pod._std.shape == (n_snap,)
4747
assert pod.scaler["mean"].shape == (rank,)
4848
assert pod.scaler["std"].shape == (rank,)
4949
assert pod.scaler["mean"].shape[0] == pod.basis.shape[0]
5050
else:
51-
assert pod._scaler == None
51+
assert pod._std == None
52+
assert pod._mean == None
5253
assert pod.scaler == None
5354

5455

0 commit comments

Comments
 (0)