Skip to content

Commit 6ed3ca0

Browse files
Fix __getitem__ in LabelTensor (#546)
* Fix LabelTensor * Cleaning label_tensor.py --------- Co-authored-by: Dario Coscia <[email protected]>
1 parent 0a60ed4 commit 6ed3ca0

File tree

2 files changed

+85
-28
lines changed

2 files changed

+85
-28
lines changed

pina/label_tensor.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -505,50 +505,40 @@ def vstack(tensors):
505505
return LabelTensor.cat(tensors, dim=0)
506506

507507
# This method is used to update labels
508-
def _update_single_label(
509-
self, old_labels, to_update_labels, index, dim, to_update_dim
510-
):
508+
def _update_single_label(self, index, dim):
511509
"""
512510
Update the labels of the tensor based on the index (or list of indices).
513511
514-
:param dict old_labels: Labels from which retrieve data.
515-
:param dict to_update_labels: Labels to update.
516512
:param index: Index of dof to retain.
517513
:type index: int | slice | list[int] | tuple[int] | torch.Tensor
518-
:param int dim: The dimension to update.
519-
514+
:param int dim: Dimension of the indexes in the original tensor.
515+
:return: The updated labels for the specified dimension.
516+
:rtype: list[int]
520517
:raises: ValueError: If the index type is not supported.
521518
"""
522-
523-
old_dof = old_labels[to_update_dim]["dof"]
524-
label_name = old_labels[dim]["name"]
519+
old_dof = self._labels[dim]["dof"]
525520
# Handle slicing
526521
if isinstance(index, slice):
527-
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
522+
new_dof = old_dof[index]
528523
# Handle single integer index
529524
elif isinstance(index, int):
530-
to_update_labels[dim] = {
531-
"dof": [old_dof[index]],
532-
"name": label_name,
533-
}
525+
new_dof = [old_dof[index]]
534526
# Handle lists or tensors
535527
elif isinstance(index, (list, torch.Tensor)):
536528
# Handle list of bools
537529
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
538530
index = index.nonzero().squeeze()
539-
to_update_labels[dim] = {
540-
"dof": (
541-
[old_dof[i] for i in index]
542-
if isinstance(old_dof, list)
543-
else index
544-
),
545-
"name": label_name,
546-
}
531+
new_dof = (
532+
[old_dof[i] for i in index]
533+
if isinstance(old_dof, list)
534+
else index
535+
)
547536
else:
548537
raise NotImplementedError(
549538
f"Unsupported index type: {type(index)}. Expected slice, int, "
550539
f"list, or torch.Tensor."
551540
)
541+
return new_dof
552542

553543
def __getitem__(self, index):
554544
""" "
@@ -589,14 +579,20 @@ def __getitem__(self, index):
589579

590580
# Update labels based on the index
591581
offset = 0
582+
removed = 0
592583
for dim, idx in enumerate(index):
593-
if dim in self.stored_labels:
584+
if dim in original_labels:
594585
if isinstance(idx, int):
595-
selected_tensor = selected_tensor.unsqueeze(dim)
586+
# Compute the working dimension considering the removed
587+
# dimensions due to int index on a non labled dimension
588+
dim_ = dim - removed
589+
selected_tensor = selected_tensor.unsqueeze(dim_)
596590
if idx != slice(None):
597-
self._update_single_label(
598-
original_labels, updated_labels, idx, dim, offset
599-
)
591+
# Update the labels for the selected dimension
592+
updated_labels[offset] = {
593+
"dof": self._update_single_label(idx, dim),
594+
"name": original_labels[dim]["name"],
595+
}
600596
else:
601597
# Adjust label keys if dimension is reduced (case of integer
602598
# index on a non-labeled dimension)
@@ -605,6 +601,7 @@ def __getitem__(self, index):
605601
key - 1 if key > dim else key: value
606602
for key, value in updated_labels.items()
607603
}
604+
removed += 1
608605
continue
609606
offset += 1
610607

tests/test_label_tensor/test_label_tensor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,63 @@ def test_cat_bool(labels):
278278
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
279279
if isinstance(labels, dict):
280280
assert selected.stored_labels[0]["dof"] == ["a", "b"]
281+
282+
283+
def test_getitem_int():
284+
data = torch.rand(20, 3)
285+
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
286+
lt = LabelTensor(data, labels)
287+
new = lt[0, 0]
288+
assert new.ndim == 1
289+
assert new.shape[0] == 1
290+
assert torch.all(torch.isclose(data[0, 0], new))
291+
292+
data = torch.rand(20, 3, 2)
293+
labels = {
294+
1: {"name": 1, "dof": ["x", "y", "z"]},
295+
2: {"name": 2, "dof": ["a", "b"]},
296+
}
297+
lt = LabelTensor(data, labels)
298+
new = lt[0, 0, 0]
299+
assert new.ndim == 2
300+
assert new.shape[0] == 1
301+
assert new.shape[1] == 1
302+
assert torch.all(torch.isclose(data[0, 0, 0], new))
303+
assert new.stored_labels[0]["dof"] == ["x"]
304+
assert new.stored_labels[1]["dof"] == ["a"]
305+
306+
new = lt[0, 0, :]
307+
assert new.ndim == 2
308+
assert new.shape[0] == 1
309+
assert new.shape[1] == 2
310+
assert torch.all(torch.isclose(data[0, 0, :], new))
311+
assert new.stored_labels[0]["dof"] == ["x"]
312+
assert new.stored_labels[1]["dof"] == ["a", "b"]
313+
314+
new = lt[0, :, 1]
315+
assert new.ndim == 2
316+
assert new.shape[0] == 3
317+
assert new.shape[1] == 1
318+
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
319+
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
320+
assert new.stored_labels[1]["dof"] == ["b"]
321+
322+
labels.pop(2)
323+
lt = LabelTensor(data, labels)
324+
new = lt[0, 0, 0]
325+
assert new.ndim == 1
326+
assert new.shape[0] == 1
327+
assert new.stored_labels[0]["dof"] == ["x"]
328+
329+
new = lt[:, 0, 0]
330+
assert new.ndim == 2
331+
assert new.shape[0] == 20
332+
assert new.shape[1] == 1
333+
assert new.stored_labels[1]["dof"] == ["x"]
334+
335+
new = lt[:, 0, :]
336+
assert new.ndim == 3
337+
assert new.shape[0] == 20
338+
assert new.shape[1] == 1
339+
assert new.shape[2] == 2
340+
assert new.stored_labels[1]["dof"] == ["x"]

0 commit comments

Comments
 (0)