1919# See the License for the specific language governing permissions and
2020# limitations under the License.
2121#
22- """Multiobjective contrastive learning."""
22+ """Multiobjective contrastive learning.
23+
24+ Starting in CEBRA 0.6.0, we have added support for subspace contrastive learning.
25+ This is a method for training models that are able to learn multiple subspaces of the
26+ feature space simultaneously.
27+
28+ Subspace contrastive learning requires to use specialized models and criterions.
29+ This module specifies a test of classes required for training CEBRA models with multiple objectives.
30+ The objectives are defined by the wrapper class :py:class:`cebra.models.multicriterions.MultiCriterions`.
31+
32+ Two solvers are currently implemented:
33+
34+ - :py:class:`cebra.solver.multiobjective.ContrastiveMultiobjectiveSolverxCEBRA`
35+ - :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA`
36+
37+ See Also:
38+ :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA`
39+ :py:class:`cebra.solver.multiobjective.MultiObjectiveConfig`
40+ :py:class:`cebra.models.multicriterions.MultiCriterions`
41+ """
2342
2443import logging
2544import time
4362class MultiObjectiveConfig :
4463 """Configuration class for setting up multi-objective learning with Cebra.
4564
65+
66+
4667 Args:
4768 loader: Data loader used for configurations.
4869 """
@@ -458,7 +479,11 @@ def transform(self, inputs: torch.Tensor) -> torch.Tensor:
458479@register ("supervised-solver-xcebra" )
459480@dataclasses .dataclass
460481class SupervisedMultiobjectiveSolverxCEBRA (MultiobjectiveSolverBase ):
461- """Supervised neural network training with MSE loss"""
482+ """Supervised neural network training using the MSE loss.
483+
484+ This solver can be used as a baseline variant instead of the contrastive solver,
485+ :py:class:`cebra.solver.multiobjective.ContrastiveMultiobjectiveSolverxCEBRA`.
486+ """
462487
463488 _variant_name = "supervised-solver-xcebra"
464489
@@ -477,6 +502,15 @@ def _inference(self, batch):
477502@register ("multiobjective-solver" )
478503@dataclasses .dataclass
479504class ContrastiveMultiobjectiveSolverxCEBRA (MultiobjectiveSolverBase ):
505+ """Multi-objective solver for CEBRA.
506+
507+ This solver is used for training CEBRA models with multiple objectives.
508+
509+ See Also:
510+ :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA`
511+ :py:class:`cebra.solver.multiobjective.MultiObjectiveConfig`
512+ :py:class:`cebra.models.multicriterions.MultiCriterions`
513+ """
480514
481515 _variant_name = "contrastive-solver-xcebra"
482516
0 commit comments