@@ -33,21 +33,26 @@ def __new__(cls, x, labels, *args, **kwargs):
3333 @property
3434 def tensor (self ):
3535 """
36- Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
36+ Returns the tensor part of the :class:`~pina.label_tensor.LabelTensor`
3737 object.
3838
39- :return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
39+ :return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`.
4040 :rtype: torch.Tensor
4141 """
4242
4343 return self .as_subclass (Tensor )
4444
4545 def __init__ (self , x , labels ):
4646 """
47- Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
48- the labels and a :class:`torch.Tensor`. Internally, the initialization
49- method will store check the compatibility of the labels with the tensor
50- shape.
47+ Initialize the :class:`~pina.label_tensor.LabelTensor` instance, by
48+ checking the consistency of the labels and the tensor. Specifically, the
49+ labels must match the following conditions:
50+
51+ - At each dimension, the number of labels must match the size of the \
52+ dimension.
53+ - At each dimension, the labels must be unique.
54+
55+ The labels can be passed in the following formats:
5156
5257 :Example:
5358 >>> from pina import LabelTensor
@@ -57,11 +62,18 @@ def __init__(self, x, labels):
5762 >>> tensor = LabelTensor(
5863 >>> torch.rand((2000, 3)),
5964 ... ["a", "b", "c"])
65+
66+ The keys of the dictionary are the dimension indices, and the values are
67+ dictionaries containing the labels and the name of the dimension. If
68+ the labels are passed as a list, these are assigned to the last
69+ dimension.
6070
71+ :param torch.Tensor x: The tensor to be casted as a
72+ :class:`~pina.label_tensor.LabelTensor`.
73+ :param labels: Labels to assign to the tensor.
74+ :type labels: str | list[str] | dict
75+ :raises ValueError: If the labels are not consistent with the tensor.
6176 """
62- # Avoid unused argument warning. x is not used in the constructor
63- # of the parent class.
64- # pylint: disable=unused-argument
6577 super ().__init__ ()
6678 if labels is not None :
6779 self .labels = labels
@@ -71,7 +83,7 @@ def __init__(self, x, labels):
7183 @property
7284 def full_labels (self ):
7385 """
74- Gives the full labels of the tensor, even for the dimensions that are
86+ Returns the full labels of the tensor, even for the dimensions that are
7587 not labeled.
7688
7789 :return: The full labels of the tensor
@@ -89,7 +101,7 @@ def full_labels(self):
89101 @property
90102 def stored_labels (self ):
91103 """
92- Gives the labels stored inside the instance.
104+ Returns the labels stored inside the instance.
93105
94106 :return: The labels stored inside the instance.
95107 :rtype: dict
@@ -99,7 +111,7 @@ def stored_labels(self):
99111 @property
100112 def labels (self ):
101113 """
102- Give the labels of the last dimension of the instance.
114+ Returns the labels of the last dimension of the instance.
103115
104116 :return: labels of last dimension
105117 :rtype: list
@@ -111,8 +123,9 @@ def labels(self):
111123 @labels .setter
112124 def labels (self , labels ):
113125 """
114- Set the parameter ``_labels`` by checking the type of the input labels
115- and handling it accordingly. The following types are accepted:
126+ Set labels stored insider the instance by checking the type of the
127+ input labels and handling it accordingly. The following types are
128+ accepted:
116129
117130 - **list**: The list of labels is assigned to the last dimension.
118131 - **dict**: The dictionary of labels is assigned to the tensor.
@@ -134,7 +147,7 @@ def labels(self, labels):
134147 else :
135148 raise ValueError ("labels must be list, dict or string." )
136149
137- def _init_labels_from_dict (self , labels : dict ):
150+ def _init_labels_from_dict (self , labels ):
138151 """
139152 Store the internal label representation according to the values
140153 passed as input.
@@ -146,7 +159,7 @@ def _init_labels_from_dict(self, labels: dict):
146159
147160 tensor_shape = self .shape
148161
149- def validate_dof (dof_list , dim_size : int ):
162+ def validate_dof (dof_list , dim_size ):
150163 """Validate the 'dof' list for uniqueness and size."""
151164 if len (dof_list ) != len (set (dof_list )):
152165 raise ValueError ("dof must be unique" )
@@ -187,7 +200,7 @@ def validate_dof(dof_list, dim_size: int):
187200
188201 def _init_labels_from_list (self , labels ):
189202 """
190- Given a `` list`` of dof, this method update the internal label
203+ Given a list of dof, this method update the internal label
191204 representation by assigning the dof to the last dimension.
192205
193206 :param labels: The label(s) to update.
@@ -203,17 +216,25 @@ def _init_labels_from_list(self, labels):
203216 def extract (self , labels_to_extract ):
204217 """
205218 Extract the subset of the original tensor by returning all the positions
206- corresponding to the passed ``label_to_extract``.
219+ corresponding to the passed ``label_to_extract``. If ``label_to_extract``
220+ is a dictionary, the keys are the dimension names and the values are the
221+ labels to extract. If a single label or a list of labels is passed, the
222+ last dimension is considered.
207223
208- :param labels_to_extract: The label(s) to extract. If a single label or
209- a list of labels is passed, the last dimension is considered.
210- If a dictionary is passed, the keys are the dimension names and the
211- values are the labels to extract.
224+ :Example:
225+ >>> from pina import LabelTensor
226+ >>> labels = {1: {'dof': ["a", "b", "c"], 'name': 'space'}}
227+ >>> tensor = LabelTensor(torch.rand((2000, 3)), labels)
228+ >>> tensor.extract("a")
229+ >>> tensor.extract(["a", "b"])
230+ >>> tensor.extract({"space": ["a", "b"]})
231+
232+ :param labels_to_extract: The label(s) to extract.
212233 :type labels_to_extract: str | list[str] | tuple[str] | dict
213234 :return: The extracted tensor with the updated labels.
214235 :rtype: LabelTensor
215236
216- :raises TypeError: Labels are not ``str``, ``list of str`` or ``dict``
237+ :raises TypeError: Labels are not ``str``, ``list[ str] `` or ``dict``
217238 properly setted.
218239 :raises ValueError: Label to extract is not in the labels ``list``.
219240 """
@@ -298,13 +319,13 @@ def cat(tensors, dim=0):
298319
299320 :param list[LabelTensor] tensors:
300321 :class:`~pina.label_tensor.LabelTensor` instances to concatenate
301- :param int dim: dimensions on which you want to perform the operation
322+ :param int dim: Dimensions on which you want to perform the operation
302323 (default is 0)
303- :return: A new :class:`LabelTensor' instance obtained by concatenating
304- the input instances, with the updated labels .
324+ :return: A new :class:`LabelTensor` instance obtained by concatenating
325+ the input instances.
305326
306327 :rtype: LabelTensor
307- :raises ValueError: either number dof or dimensions names differ
328+ :raises ValueError: either number dof or dimensions names differ.
308329 """
309330
310331 if not tensors :
@@ -353,7 +374,7 @@ def stack(tensors):
353374 :param list[LabelTensor] tensors: A list of tensors to stack.
354375 All tensors must have the same shape.
355376 :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
356- by stacking the input tensors, with the updated labels .
377+ by stacking the input tensors.
357378 :rtype: LabelTensor
358379 """
359380
@@ -389,7 +410,7 @@ def dtype(self):
389410 Give the ``dtype`` of the tensor. For more details, see
390411 :meth:`torch.dtype`.
391412
392- :return: dtype of the tensor
413+ :return: The data type of the tensor.
393414 :rtype: torch.dtype
394415 """
395416
@@ -427,19 +448,19 @@ def clone(self, *args, **kwargs):
427448 def append (self , tensor , mode = "std" ):
428449 """
429450 Appends a given tensor to the current tensor along the last dimension.
430-
431451 This method supports two types of appending operations:
432- 1. **Standard append** ("std"): Concatenates the input tensor with the
452+
453+ 1. **Standard append** ("std"): Concatenates the input tensor with the \
433454 current tensor along the last dimension.
434- 2. **Cross append** ("cross"): Creates a cross-product of the current
435- tensor and the input tensor by repeating them in a cross-product
436- fashion, then concatenates the result along the last dimension.
455+ 2. **Cross append** ("cross"): Creates a cross-product of the current \
456+ tensor and the input tensor.
437457
438458 :param tensor: The tensor to append to the current tensor.
439459 :type tensor: LabelTensor
440- :param mode: The append mode to use. Defaults to "std" .
460+ :param mode: The append mode to use. Defaults to ``st`` .
441461 :type mode: str, optional
442- :return: A new `LabelTensor` obtained by appending the input tensor.
462+ :return: A new :class:`LabelTensor` instance obtained by appending the
463+ input tensor.
443464 :rtype: LabelTensor
444465
445466 :raises ValueError: If the mode is not "std" or "cross".
@@ -468,7 +489,7 @@ def append(self, tensor, mode="std"):
468489 raise ValueError ('mode must be either "std" or "cross"' )
469490
470491 @staticmethod
471- def vstack (label_tensors ):
492+ def vstack (tensors ):
472493 """
473494 Stack tensors vertically. For more details, see :meth:`torch.vstack`.
474495
@@ -480,7 +501,7 @@ def vstack(label_tensors):
480501 :rtype: LabelTensor
481502 """
482503
483- return LabelTensor .cat (label_tensors , dim = 0 )
504+ return LabelTensor .cat (tensors , dim = 0 )
484505
485506 # This method is used to update labels
486507 def _update_single_label (
@@ -592,11 +613,11 @@ def __getitem__(self, index):
592613
593614 def sort_labels (self , dim = None ):
594615 """
595- Sort the labels along the specified dimension and apply the same sorting
596- to the :class:`torch.Tensor` part of the instance.
616+ Sort the labels along the specified dimension and apply. It applies the
617+ same sorting to the tensor part of the instance.
597618
598619 :param int dim: The dimension along which to sort the labels.
599- If ``None``, the last dimension (``ndim - 1``) is used.
620+ If ``None``, the last dimension is used.
600621 :return: A new tensor with sorted labels along the specified dimension.
601622 :rtype: LabelTensor
602623 """
0 commit comments