1- """Module for Base Continuous Convolution class."""
1+ """Module for the Base Continuous Convolution class."""
22
33from abc import ABCMeta , abstractmethod
44import torch
77
88
99class BaseContinuousConv (torch .nn .Module , metaclass = ABCMeta ):
10- """
11- Abstract class
10+ r"""
11+ Base Class for Continuous Convolution.
12+
13+ The class expects the input to be in the form:
14+ :math:`[B \times N_{in} \times N \times D]`, where :math:`B` is the
15+ batch_size, :math:`N_{in}` is the number of input fields, :math:`N`
16+ the number of points in the mesh, :math:`D` the dimension of the problem.
17+ In particular:
18+ * :math:`D` is the number of spatial variables + 1. The last column must
19+ contain the field value.
20+ * :math:`N_{in}` represents the number of function components.
21+ For instance, a vectorial function :math:`f = [f_1, f_2]` has
22+ :math:`N_{in}=2`.
23+
24+ :Note
25+ A 2-dimensional vector-valued function defined on a 3-dimensional input
26+ evaluated on a 100 points input mesh and batch size of 8 is represented
27+ as a tensor of shape ``[8, 2, 100, 4]``, where the columns
28+ ``[:, 0, :, -1]`` and ``[:, 1, :, -1]`` represent the first and second,
29+ components of the function, respectively.
30+
31+ The algorithm returns a tensor of shape:
32+ :math:`[B \times N_{out} \times N \times D]`, where :math:`B` is the
33+ batch_size, :math:`N_{out}` is the number of output fields, :math:`N`
34+ the number of points in the mesh, :math:`D` the dimension of the problem.
1235 """
1336
1437 def __init__ (
@@ -22,56 +45,30 @@ def __init__(
2245 no_overlap = False ,
2346 ):
2447 """
25- Base Class for Continuous Convolution.
26-
27- The algorithm expects input to be in the form:
28- $$[B \t imes N_{in} \t imes N \t imes D]$$
29- where $B$ is the batch_size, $N_{in}$ is the number of input
30- fields, $N$ the number of points in the mesh, $D$ the dimension
31- of the problem. In particular:
32- * $D$ is the number of spatial variables + 1. The last column must
33- contain the field value. For example for 2D problems $D=3$ and
34- the tensor will be something like `[first coordinate, second
35- coordinate, field value]`.
36- * $N_{in}$ represents the number of vectorial function presented.
37- For example a vectorial function $f = [f_1, f_2]$ will have
38- $N_{in}=2$.
39-
40- :Note
41- A 2-dimensional vectorial function $N_{in}=2$ of 3-dimensional
42- input $D=3+1=4$ with 100 points input mesh and batch size of 8
43- is represented as a tensor `[8, 2, 100, 4]`, where the columns
44- `[:, 0, :, -1]` and `[:, 1, :, -1]` represent the first and
45- second filed value respectively
46-
47- The algorithm returns a tensor of shape:
48- $$[B \t imes N_{out} \t imes N' \t imes D]$$
49- where $B$ is the batch_size, $N_{out}$ is the number of output
50- fields, $N'$ the number of points in the mesh, $D$ the dimension
51- of the problem.
52-
53- :param input_numb_field: number of fields in the input
54- :type input_numb_field: int
55- :param output_numb_field: number of fields in the output
56- :type output_numb_field: int
57- :param filter_dim: dimension of the filter
58- :type filter_dim: tuple/ list
59- :param stride: stride for the filter
60- :type stride: dict
61- :param model: neural network for inner parametrization,
62- defaults to None.
63- :type model: torch.nn.Module, optional
64- :param optimize: flag for performing optimization on the continuous
65- filter, defaults to False. The flag `optimize=True` should be
66- used only when the scatter datapoints are fixed through the
67- training. If torch model is in `.eval()` mode, the flag is
68- automatically set to False always.
69- :type optimize: bool, optional
70- :param no_overlap: flag for performing optimization on the transpose
71- continuous filter, defaults to False. The flag set to `True` should
72- be used only when the filter positions do not overlap for different
73- strides. RuntimeError will raise in case of non-compatible strides.
74- :type no_overlap: bool, optional
48+ Initialization of the :class:`BaseContinuousConv` class.
49+
50+ :param int input_numb_field: The number of input fields.
51+ :param int output_numb_field: The number of input fields.
52+ :param filter_dim: The shape of the filter.
53+ :type filter_dim: list[int] | tuple[int]
54+ :param dict stride: The stride of the filter.
55+ :param torch.nn.Module model: The neural network for inner
56+ parametrization. Default is ``None``.
57+ :param bool optimize: If ``True``, optimization is performed on the
58+ continuous filter. It should be used only when the training points
59+ are fixed. If ``model`` is in ``eval`` mode, it is reset to
60+ ``False``. Default is ``False``.
61+ :param bool no_overlap: If ``True``, optimization is performed on the
62+ transposed continuous filter. It should be used only when the filter
63+ positions do not overlap for different strides.
64+ Default is ``False``.
65+ :raises ValueError: If ``input_numb_field`` is not an integer.
66+ :raises ValueError: If ``output_numb_field`` is not an integer.
67+ :raises ValueError: If ``filter_dim`` is not a list or tuple.
68+ :raises ValueError: If ``stride`` is not a dictionary.
69+ :raises ValueError: If ``optimize`` is not a boolean.
70+ :raises ValueError: If ``no_overlap`` is not a boolean.
71+ :raises NotImplementedError: If ``no_overlap`` is ``True``.
7572 """
7673 super ().__init__ ()
7774
@@ -119,12 +116,17 @@ def __init__(
119116
120117 class DefaultKernel (torch .nn .Module ):
121118 """
122- TODO
119+ The default kernel.
123120 """
124121
125122 def __init__ (self , input_dim , output_dim ):
126123 """
127- TODO
124+ Initialization of the :class:`DefaultKernel` class.
125+
126+ :param int input_dim: The input dimension.
127+ :param int output_dim: The output dimension.
128+ :raises ValueError: If ``input_dim`` is not an integer.
129+ :raises ValueError: If ``output_dim`` is not an integer.
128130 """
129131 super ().__init__ ()
130132 assert isinstance (input_dim , int )
@@ -139,65 +141,93 @@ def __init__(self, input_dim, output_dim):
139141
140142 def forward (self , x ):
141143 """
142- TODO
144+ Forward pass.
145+
146+ :param torch.Tensor x: The input data.
147+ :return: The output data.
148+ :rtype: torch.Tensor
143149 """
144150 return self ._model (x )
145151
146152 @property
147153 def net (self ):
148154 """
149- TODO
155+ The neural network for inner parametrization.
156+
157+ :return: The neural network.
158+ :rtype: torch.nn.Module
150159 """
151160 return self ._net
152161
153162 @property
154163 def stride (self ):
155164 """
156- TODO
165+ The stride of the filter.
166+
167+ :return: The stride of the filter.
168+ :rtype: dict
157169 """
158170 return self ._stride
159171
160172 @property
161173 def filter_dim (self ):
162174 """
163- TODO
175+ The shape of the filter.
176+
177+ :return: The shape of the filter.
178+ :rtype: torch.Tensor
164179 """
165180 return self ._dim
166181
167182 @property
168183 def input_numb_field (self ):
169184 """
170- TODO
185+ The number of input fields.
186+
187+ :return: The number of input fields.
188+ :rtype: int
171189 """
172190 return self ._input_numb_field
173191
174192 @property
175193 def output_numb_field (self ):
176194 """
177- TODO
195+ The number of output fields.
196+
197+ :return: The number of output fields.
198+ :rtype: int
178199 """
179200 return self ._output_numb_field
180201
181202 @abstractmethod
182203 def forward (self , X ):
183204 """
184- TODO
205+ Forward pass.
206+
207+ :param torch.Tensor X: The input data.
185208 """
186209
187210 @abstractmethod
188211 def transpose_overlap (self , X ):
189212 """
190- TODO
213+ Transpose the convolution with overlap.
214+
215+ :param torch.Tensor X: The input data.
191216 """
192217
193218 @abstractmethod
194219 def transpose_no_overlap (self , X ):
195220 """
196- TODO
221+ Transpose the convolution without overlap.
222+
223+ :param torch.Tensor X: The input data.
197224 """
198225
199226 @abstractmethod
200227 def _initialize_convolution (self , X , type_ ):
201228 """
202- TODO
229+ Initialize the convolution.
230+
231+ :param torch.Tensor X: The input data.
232+ :param str type_: The type of initialization.
203233 """
0 commit comments