@@ -734,13 +734,10 @@ def from_table(cls, domain, source, row_indices=...):
734734 table = assure_domain_conversion_sparsity (table , source )
735735 return table
736736
737- if row_indices is ...:
738- n_rows = len (source )
739- elif isinstance (row_indices , slice ):
740- row_indices_range = range (* row_indices .indices (source .X .shape [0 ]))
741- n_rows = len (row_indices_range )
742- else :
743- n_rows = len (row_indices )
737+ # avoid boolean indices; also convert to slices if possible
738+ row_indices = _optimize_indices (row_indices , len (source ))
739+
740+ n_rows = _selection_length (row_indices , len (source ))
744741
745742 self = cls ()
746743 self .domain = domain
@@ -783,13 +780,8 @@ def from_table(cls, domain, source, row_indices=...):
783780
784781 while i_done < n_rows :
785782 target_indices = slice (i_done , min (n_rows , i_done + PART ))
786- if row_indices is ...:
787- source_indices = target_indices
788- elif isinstance (row_indices , slice ):
789- r = row_indices_range [target_indices ]
790- source_indices = slice (r .start , r .stop , r .step )
791- else :
792- source_indices = row_indices [target_indices ]
783+ source_indices = _select_from_selection (row_indices , target_indices ,
784+ len (source ))
793785 part_rows = min (n_rows , i_done + PART ) - i_done
794786
795787 for array_conv in table_conversion .columnwise :
@@ -810,15 +802,9 @@ def from_table(cls, domain, source, row_indices=...):
810802 out = cparts if not array_conv .is_sparse else sp .vstack (cparts )
811803 setattr (self , array_conv .target , out )
812804
813- if source .has_weights ():
814- self .W = source .W [row_indices ]
815- else :
816- self .W = np .empty ((n_rows , 0 ))
805+ self .W = source .W [row_indices ]
817806 self .name = getattr (source , 'name' , '' )
818- if hasattr (source , 'ids' ):
819- self .ids = source .ids [row_indices ]
820- else :
821- cls ._init_ids (self )
807+ self .ids = source .ids [row_indices ]
822808 self .attributes = deepcopy (getattr (source , 'attributes' , {}))
823809 _idcache_save (_thread_local .conversion_cache , (domain , source ), self )
824810 return self
@@ -876,7 +862,7 @@ def from_table_rows(cls, source, row_indices):
876862 self .metas = self .metas .reshape (- 1 , len (self .domain .metas ))
877863 self .W = source .W [row_indices ]
878864 self .name = getattr (source , 'name' , '' )
879- self .ids = np . array ( source .ids [row_indices ])
865+ self .ids = source .ids [row_indices ]
880866 self .attributes = deepcopy (getattr (source , 'attributes' , {}))
881867 return self
882868
@@ -2421,19 +2407,24 @@ def _subarray(arr, rows, cols):
24212407 # so they need to be reshaped to produce an open mesh
24222408 return arr [np .ix_ (rows , cols )]
24232409
2424- def _optimize_indices (indices , maxlen ):
2410+
2411+ def _optimize_indices (indices , size ):
24252412 """
2426- Convert integer indices to slice if possible. It only converts increasing
2427- integer ranges with positive steps and valid starts and ends.
2428- Only convert valid ends so that invalid ranges will still raise
2429- an exception.
2413+ Convert boolean indices to integer indices and convert these to a slice
2414+ if possible.
2415+
2416+ A slice is created from only from indices with positive steps and
2417+ valid starts and ends (so that invalid ranges will still raise an
2418+ exception. An IndexError is raised if boolean indices do not conform
2419+ to input size.
24302420
24312421 Allows numpy to reuse the data array, because it defaults to copying
24322422 if given indices.
24332423
24342424 Parameters
24352425 ----------
24362426 indices : 1D sequence, slice or Ellipsis
2427+ size : int
24372428 """
24382429 if isinstance (indices , slice ):
24392430 return indices
@@ -2450,19 +2441,58 @@ def _optimize_indices(indices, maxlen):
24502441
24512442 if len (indices ) >= 1 :
24522443 indices = np .asarray (indices )
2453- if indices .dtype != bool :
2454- begin = indices [0 ]
2455- end = indices [- 1 ]
2456- steps = np .diff (indices ) if len (indices ) > 1 else np .array ([1 ])
2457- step = steps [0 ]
2444+ if indices .dtype == bool :
2445+ if len (indices ) == size :
2446+ indices = np .nonzero (indices )[0 ]
2447+ else :
2448+ # raise an exception that numpy would if boolean indices were used
2449+ raise IndexError ("boolean indices did not match dimension" )
2450+
2451+ if len (indices ) >= 1 : # conversion from boolean indices could result in an empty array
2452+ begin = indices [0 ]
2453+ end = indices [- 1 ]
2454+ steps = np .diff (indices ) if len (indices ) > 1 else np .array ([1 ])
2455+ step = steps [0 ]
24582456
2459- # continuous ranges with constant step and valid start and stop index can be slices
2460- if np .all (steps == step ) and step > 0 and begin >= 0 and end < maxlen :
2461- return slice (begin , end + step , step )
2457+ # continuous ranges with constant step and valid start and stop index can be slices
2458+ if np .all (steps == step ) and step > 0 and begin >= 0 and end < size :
2459+ return slice (begin , end + step , step )
24622460
24632461 return indices
24642462
24652463
2464+ def _selection_length (indices , maxlen ):
2465+ """ Return the selection length.
2466+ Args:
2467+ indices: 1D sequence, slice or Ellipsis
2468+ maxlen: maximum length of the sequence
2469+ """
2470+ if indices is ...:
2471+ return maxlen
2472+ elif isinstance (indices , slice ):
2473+ return len (range (* indices .indices (maxlen )))
2474+ else :
2475+ return len (indices )
2476+
2477+
2478+ def _select_from_selection (source_indices , selection_indices , maxlen ):
2479+ """
2480+ Create efficient selection indices from a previous selection.
2481+ Try to keep slices as slices.
2482+ Args:
2483+ source_indices: 1D sequence, slice or Ellipsis
2484+ selection_indices: 1D sequence or slice
2485+ maxlen: maximum length of the sequence
2486+ """
2487+ if source_indices is ...:
2488+ return selection_indices
2489+ elif isinstance (source_indices , slice ):
2490+ r = range (* source_indices .indices (maxlen ))[selection_indices ]
2491+ return slice (r .start , r .stop , r .step )
2492+ else :
2493+ return source_indices [selection_indices ]
2494+
2495+
24662496def assure_domain_conversion_sparsity (target , source ):
24672497 """
24682498 Assure that the table obeys the domain conversion's suggestions about sparsity.
0 commit comments