Skip to content

Commit 974cbe4

Browse files
committed
Fix rendering LT
1 parent 9d9a01b commit 974cbe4

File tree

1 file changed

+62
-41
lines changed

1 file changed

+62
-41
lines changed

pina/label_tensor.py

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,26 @@ def __new__(cls, x, labels, *args, **kwargs):
3333
@property
3434
def tensor(self):
3535
"""
36-
Give the tensor part of the :class:`~pina.label_tensor.LabelTensor`
36+
Returns the tensor part of the :class:`~pina.label_tensor.LabelTensor`
3737
object.
3838
39-
:return: tensor part of the :class:`~pina.label_tensor.LabelTensor`.
39+
:return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`.
4040
:rtype: torch.Tensor
4141
"""
4242

4343
return self.as_subclass(Tensor)
4444

4545
def __init__(self, x, labels):
4646
"""
47-
Construct a :class:`~pina.label_tensor.LabelTensor` by passing a dict of
48-
the labels and a :class:`torch.Tensor`. Internally, the initialization
49-
method will store check the compatibility of the labels with the tensor
50-
shape.
47+
Initialize the :class:`~pina.label_tensor.LabelTensor` instance, by
48+
checking the consistency of the labels and the tensor. Specifically, the
49+
labels must match the following conditions:
50+
51+
- At each dimension, the number of labels must match the size of the \
52+
dimension.
53+
- At each dimension, the labels must be unique.
54+
55+
The labels can be passed in the following formats:
5156
5257
:Example:
5358
>>> from pina import LabelTensor
@@ -57,11 +62,18 @@ def __init__(self, x, labels):
5762
>>> tensor = LabelTensor(
5863
>>> torch.rand((2000, 3)),
5964
... ["a", "b", "c"])
65+
66+
The keys of the dictionary are the dimension indices, and the values are
67+
dictionaries containing the labels and the name of the dimension. If
68+
the labels are passed as a list, these are assigned to the last
69+
dimension.
6070
71+
:param torch.Tensor x: The tensor to be casted as a
72+
:class:`~pina.label_tensor.LabelTensor`.
73+
:param labels: Labels to assign to the tensor.
74+
:type labels: str | list[str] | dict
75+
:raises ValueError: If the labels are not consistent with the tensor.
6176
"""
62-
# Avoid unused argument warning. x is not used in the constructor
63-
# of the parent class.
64-
# pylint: disable=unused-argument
6577
super().__init__()
6678
if labels is not None:
6779
self.labels = labels
@@ -71,7 +83,7 @@ def __init__(self, x, labels):
7183
@property
7284
def full_labels(self):
7385
"""
74-
Gives the full labels of the tensor, even for the dimensions that are
86+
Returns the full labels of the tensor, even for the dimensions that are
7587
not labeled.
7688
7789
:return: The full labels of the tensor
@@ -89,7 +101,7 @@ def full_labels(self):
89101
@property
90102
def stored_labels(self):
91103
"""
92-
Gives the labels stored inside the instance.
104+
Returns the labels stored inside the instance.
93105
94106
:return: The labels stored inside the instance.
95107
:rtype: dict
@@ -99,7 +111,7 @@ def stored_labels(self):
99111
@property
100112
def labels(self):
101113
"""
102-
Give the labels of the last dimension of the instance.
114+
Returns the labels of the last dimension of the instance.
103115
104116
:return: labels of last dimension
105117
:rtype: list
@@ -111,8 +123,9 @@ def labels(self):
111123
@labels.setter
112124
def labels(self, labels):
113125
"""
114-
Set the parameter ``_labels`` by checking the type of the input labels
115-
and handling it accordingly. The following types are accepted:
126+
Set labels stored insider the instance by checking the type of the
127+
input labels and handling it accordingly. The following types are
128+
accepted:
116129
117130
- **list**: The list of labels is assigned to the last dimension.
118131
- **dict**: The dictionary of labels is assigned to the tensor.
@@ -134,7 +147,7 @@ def labels(self, labels):
134147
else:
135148
raise ValueError("labels must be list, dict or string.")
136149

137-
def _init_labels_from_dict(self, labels: dict):
150+
def _init_labels_from_dict(self, labels):
138151
"""
139152
Store the internal label representation according to the values
140153
passed as input.
@@ -146,7 +159,7 @@ def _init_labels_from_dict(self, labels: dict):
146159

147160
tensor_shape = self.shape
148161

149-
def validate_dof(dof_list, dim_size: int):
162+
def validate_dof(dof_list, dim_size):
150163
"""Validate the 'dof' list for uniqueness and size."""
151164
if len(dof_list) != len(set(dof_list)):
152165
raise ValueError("dof must be unique")
@@ -187,7 +200,7 @@ def validate_dof(dof_list, dim_size: int):
187200

188201
def _init_labels_from_list(self, labels):
189202
"""
190-
Given a ``list`` of dof, this method update the internal label
203+
Given a list of dof, this method update the internal label
191204
representation by assigning the dof to the last dimension.
192205
193206
:param labels: The label(s) to update.
@@ -203,17 +216,25 @@ def _init_labels_from_list(self, labels):
203216
def extract(self, labels_to_extract):
204217
"""
205218
Extract the subset of the original tensor by returning all the positions
206-
corresponding to the passed ``label_to_extract``.
219+
corresponding to the passed ``label_to_extract``. If ``label_to_extract``
220+
is a dictionary, the keys are the dimension names and the values are the
221+
labels to extract. If a single label or a list of labels is passed, the
222+
last dimension is considered.
207223
208-
:param labels_to_extract: The label(s) to extract. If a single label or
209-
a list of labels is passed, the last dimension is considered.
210-
If a dictionary is passed, the keys are the dimension names and the
211-
values are the labels to extract.
224+
:Example:
225+
>>> from pina import LabelTensor
226+
>>> labels = {1: {'dof': ["a", "b", "c"], 'name': 'space'}}
227+
>>> tensor = LabelTensor(torch.rand((2000, 3)), labels)
228+
>>> tensor.extract("a")
229+
>>> tensor.extract(["a", "b"])
230+
>>> tensor.extract({"space": ["a", "b"]})
231+
232+
:param labels_to_extract: The label(s) to extract.
212233
:type labels_to_extract: str | list[str] | tuple[str] | dict
213234
:return: The extracted tensor with the updated labels.
214235
:rtype: LabelTensor
215236
216-
:raises TypeError: Labels are not ``str``, ``list of str`` or ``dict``
237+
:raises TypeError: Labels are not ``str``, ``list[str]`` or ``dict``
217238
properly setted.
218239
:raises ValueError: Label to extract is not in the labels ``list``.
219240
"""
@@ -298,13 +319,13 @@ def cat(tensors, dim=0):
298319
299320
:param list[LabelTensor] tensors:
300321
:class:`~pina.label_tensor.LabelTensor` instances to concatenate
301-
:param int dim: dimensions on which you want to perform the operation
322+
:param int dim: Dimensions on which you want to perform the operation
302323
(default is 0)
303-
:return: A new :class:`LabelTensor' instance obtained by concatenating
304-
the input instances, with the updated labels.
324+
:return: A new :class:`LabelTensor` instance obtained by concatenating
325+
the input instances.
305326
306327
:rtype: LabelTensor
307-
:raises ValueError: either number dof or dimensions names differ
328+
:raises ValueError: either number dof or dimensions names differ.
308329
"""
309330

310331
if not tensors:
@@ -353,7 +374,7 @@ def stack(tensors):
353374
:param list[LabelTensor] tensors: A list of tensors to stack.
354375
All tensors must have the same shape.
355376
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
356-
by stacking the input tensors, with the updated labels.
377+
by stacking the input tensors.
357378
:rtype: LabelTensor
358379
"""
359380

@@ -389,7 +410,7 @@ def dtype(self):
389410
Give the ``dtype`` of the tensor. For more details, see
390411
:meth:`torch.dtype`.
391412
392-
:return: dtype of the tensor
413+
:return: The data type of the tensor.
393414
:rtype: torch.dtype
394415
"""
395416

@@ -427,19 +448,19 @@ def clone(self, *args, **kwargs):
427448
def append(self, tensor, mode="std"):
428449
"""
429450
Appends a given tensor to the current tensor along the last dimension.
430-
431451
This method supports two types of appending operations:
432-
1. **Standard append** ("std"): Concatenates the input tensor with the
452+
453+
1. **Standard append** ("std"): Concatenates the input tensor with the \
433454
current tensor along the last dimension.
434-
2. **Cross append** ("cross"): Creates a cross-product of the current
435-
tensor and the input tensor by repeating them in a cross-product
436-
fashion, then concatenates the result along the last dimension.
455+
2. **Cross append** ("cross"): Creates a cross-product of the current \
456+
tensor and the input tensor.
437457
438458
:param tensor: The tensor to append to the current tensor.
439459
:type tensor: LabelTensor
440-
:param mode: The append mode to use. Defaults to "std".
460+
:param mode: The append mode to use. Defaults to ``st``.
441461
:type mode: str, optional
442-
:return: A new `LabelTensor` obtained by appending the input tensor.
462+
:return: A new :class:`LabelTensor` instance obtained by appending the
463+
input tensor.
443464
:rtype: LabelTensor
444465
445466
:raises ValueError: If the mode is not "std" or "cross".
@@ -468,7 +489,7 @@ def append(self, tensor, mode="std"):
468489
raise ValueError('mode must be either "std" or "cross"')
469490

470491
@staticmethod
471-
def vstack(label_tensors):
492+
def vstack(tensors):
472493
"""
473494
Stack tensors vertically. For more details, see :meth:`torch.vstack`.
474495
@@ -480,7 +501,7 @@ def vstack(label_tensors):
480501
:rtype: LabelTensor
481502
"""
482503

483-
return LabelTensor.cat(label_tensors, dim=0)
504+
return LabelTensor.cat(tensors, dim=0)
484505

485506
# This method is used to update labels
486507
def _update_single_label(
@@ -592,11 +613,11 @@ def __getitem__(self, index):
592613

593614
def sort_labels(self, dim=None):
594615
"""
595-
Sort the labels along the specified dimension and apply the same sorting
596-
to the :class:`torch.Tensor` part of the instance.
616+
Sort the labels along the specified dimension and apply. It applies the
617+
same sorting to the tensor part of the instance.
597618
598619
:param int dim: The dimension along which to sort the labels.
599-
If ``None``, the last dimension (``ndim - 1``) is used.
620+
If ``None``, the last dimension is used.
600621
:return: A new tensor with sorted labels along the specified dimension.
601622
:rtype: LabelTensor
602623
"""

0 commit comments

Comments
 (0)