@@ -505,50 +505,40 @@ def vstack(tensors):
505505 return LabelTensor .cat (tensors , dim = 0 )
506506
507507 # This method is used to update labels
508- def _update_single_label (
509- self , old_labels , to_update_labels , index , dim , to_update_dim
510- ):
508+ def _update_single_label (self , index , dim ):
511509 """
512510 Update the labels of the tensor based on the index (or list of indices).
513511
514- :param dict old_labels: Labels from which retrieve data.
515- :param dict to_update_labels: Labels to update.
516512 :param index: Index of dof to retain.
517513 :type index: int | slice | list[int] | tuple[int] | torch.Tensor
518- :param int dim: The dimension to update.
519-
514+ :param int dim: Dimension of the indexes in the original tensor.
515+ :return: The updated labels for the specified dimension.
516+ :rtype: list[int]
520517 :raises: ValueError: If the index type is not supported.
521518 """
522-
523- old_dof = old_labels [to_update_dim ]["dof" ]
524- label_name = old_labels [dim ]["name" ]
519+ old_dof = self ._labels [dim ]["dof" ]
525520 # Handle slicing
526521 if isinstance (index , slice ):
527- to_update_labels [ dim ] = { "dof" : old_dof [index ], "name" : label_name }
522+ new_dof = old_dof [index ]
528523 # Handle single integer index
529524 elif isinstance (index , int ):
530- to_update_labels [dim ] = {
531- "dof" : [old_dof [index ]],
532- "name" : label_name ,
533- }
525+ new_dof = [old_dof [index ]]
534526 # Handle lists or tensors
535527 elif isinstance (index , (list , torch .Tensor )):
536528 # Handle list of bools
537529 if isinstance (index , torch .Tensor ) and index .dtype == torch .bool :
538530 index = index .nonzero ().squeeze ()
539- to_update_labels [dim ] = {
540- "dof" : (
541- [old_dof [i ] for i in index ]
542- if isinstance (old_dof , list )
543- else index
544- ),
545- "name" : label_name ,
546- }
531+ new_dof = (
532+ [old_dof [i ] for i in index ]
533+ if isinstance (old_dof , list )
534+ else index
535+ )
547536 else :
548537 raise NotImplementedError (
549538 f"Unsupported index type: { type (index )} . Expected slice, int, "
550539 f"list, or torch.Tensor."
551540 )
541+ return new_dof
552542
553543 def __getitem__ (self , index ):
554544 """ "
@@ -589,14 +579,20 @@ def __getitem__(self, index):
589579
590580 # Update labels based on the index
591581 offset = 0
582+ removed = 0
592583 for dim , idx in enumerate (index ):
593- if dim in self . stored_labels :
584+ if dim in original_labels :
594585 if isinstance (idx , int ):
595- selected_tensor = selected_tensor .unsqueeze (dim )
586+ # Compute the working dimension considering the removed
587+ # dimensions due to int index on a non labled dimension
588+ dim_ = dim - removed
589+ selected_tensor = selected_tensor .unsqueeze (dim_ )
596590 if idx != slice (None ):
597- self ._update_single_label (
598- original_labels , updated_labels , idx , dim , offset
599- )
591+ # Update the labels for the selected dimension
592+ updated_labels [offset ] = {
593+ "dof" : self ._update_single_label (idx , dim ),
594+ "name" : original_labels [dim ]["name" ],
595+ }
600596 else :
601597 # Adjust label keys if dimension is reduced (case of integer
602598 # index on a non-labeled dimension)
@@ -605,6 +601,7 @@ def __getitem__(self, index):
605601 key - 1 if key > dim else key : value
606602 for key , value in updated_labels .items ()
607603 }
604+ removed += 1
608605 continue
609606 offset += 1
610607
0 commit comments