11"""Module for Base Continuous Convolution class."""
22
3- import torch
43import warnings
4+ import torch
55
66
77class PODBlock (torch .nn .Module ):
@@ -29,9 +29,10 @@ def __init__(self, rank, scale_coefficients=True):
2929 """
3030 super ().__init__ ()
3131 self .__scale_coefficients = scale_coefficients
32- self ._basis = None
32+ self .register_buffer ( " _basis" , None )
3333 self ._singular_values = None
34- self ._scaler = None
34+ self .register_buffer ("_std" , None )
35+ self .register_buffer ("_mean" , None )
3536 self ._rank = rank
3637
3738 @property
@@ -94,12 +95,12 @@ def scaler(self):
9495 :return: The scaler dictionary.
9596 :rtype: dict
9697 """
97- if self ._scaler is None :
98+ if self ._std is None :
9899 return None
99100
100101 return {
101- "mean" : self ._scaler [ "mean" ] [: self .rank ],
102- "std" : self ._scaler [ "std" ] [: self .rank ],
102+ "mean" : self ._mean [: self .rank ],
103+ "std" : self ._std [: self .rank ],
103104 }
104105
105106 @property
@@ -119,6 +120,10 @@ def fit(self, X, randomized=True):
119120 are scaled after the projection to have zero mean and unit variance.
120121
121122 :param torch.Tensor X: The input tensor to be reduced.
123+ :param bool randomized: If ``True``, a randomized algorithm is used to
124+ compute the POD basis. In general, this leads to faster
125+ computations, but the results may be less accurate. Default is
126+ ``True``.
122127 """
123128 self ._fit_pod (X , randomized )
124129
@@ -132,10 +137,8 @@ def _fit_scaler(self, coeffs):
132137
133138 :param torch.Tensor coeffs: The coefficients to be scaled.
134139 """
135- self ._scaler = {
136- "std" : torch .std (coeffs , dim = 1 ),
137- "mean" : torch .mean (coeffs , dim = 1 ),
138- }
140+ self ._std = torch .std (coeffs , dim = 1 ) # pylint: disable=W0201
141+ self ._mean = torch .mean (coeffs , dim = 1 ) # pylint: disable=W0201
139142
140143 def _fit_pod (self , X , randomized ):
141144 """
@@ -154,13 +157,14 @@ def _fit_pod(self, X, randomized):
154157 else :
155158 if randomized :
156159 warnings .warn (
157- "Considering a randomized algorithm to compute the POD basis"
160+ "Considering a randomized algorithm to compute the POD "
161+ "basis"
158162 )
159163 u , s , _ = torch .svd_lowrank (X .T , q = X .shape [0 ])
160164
161165 else :
162166 u , s , _ = torch .svd (X .T )
163- self ._basis = u .T
167+ self ._basis = u .T # pylint: disable=W0201
164168 self ._singular_values = s
165169
166170 def forward (self , X ):
0 commit comments