@@ -27,7 +27,7 @@ def data_vstack(blocks: Sequence[modALinput]) -> modALinput:
2727 return np .concatenate (blocks )
2828 elif isinstance (blocks [0 ], list ):
2929 return np .concatenate (blocks ).tolist ()
30- elif torch .is_tensor (blocks [0 ]):
30+ elif torch .is_tensor (blocks [0 ]):
3131 return torch .cat (blocks )
3232
3333 raise TypeError ('%s datatype is not supported' % type (blocks [0 ]))
@@ -51,23 +51,22 @@ def data_hstack(blocks: Sequence[modALinput]) -> modALinput:
5151 return np .hstack (blocks )
5252 elif isinstance (blocks [0 ], list ):
5353 return np .hstack (blocks ).tolist ()
54- elif torch .is_tensor (blocks [0 ]):
54+ elif torch .is_tensor (blocks [0 ]):
5555 return torch .cat (blocks , dim = 1 )
5656
5757 TypeError ('%s datatype is not supported' % type (blocks [0 ]))
5858
5959
60- def add_row (X :modALinput , row : modALinput ):
60+ def add_row (X : modALinput , row : modALinput ):
6161 """
6262 Returns X' =
6363
6464 [X
6565
66- row]
67- """
66+ row] """
6867 if isinstance (X , np .ndarray ):
6968 return np .vstack ((X , row ))
70- elif torch .is_tensor (X ):
69+ elif torch .is_tensor (X ):
7170 return torch .cat ((X , row ))
7271 elif isinstance (X , list ):
7372 return np .vstack ((X , row )).tolist ()
@@ -102,7 +101,7 @@ def retrieve_rows(X: modALinput,
102101 return X .iloc [I ]
103102 elif isinstance (X , list ):
104103 return np .array (X )[I ].tolist ()
105- elif isinstance (X , dict ):
104+ elif isinstance (X , dict ):
106105 X_return = {}
107106 for key , value in X .items ():
108107 X_return [key ] = retrieve_rows (value , I )
@@ -118,7 +117,6 @@ def retrieve_rows(X: modALinput,
118117def drop_rows (X : modALinput ,
119118 I : Union [int , List [int ], np .ndarray ]) -> Union [sp .csc_matrix , np .ndarray , pd .DataFrame ]:
120119 """
121- TODO: Add pytorch support
122120 Returns X without the row(s) at index/indices I
123121 """
124122 if sp .issparse (X ):
@@ -131,6 +129,9 @@ def drop_rows(X: modALinput,
131129 return np .delete (X , I , axis = 0 )
132130 elif isinstance (X , list ):
133131 return np .delete (X , I , axis = 0 ).tolist ()
132+ elif torch .is_tensor (X ):
133+ return X [[True if row not in I else False
134+ for row in range (X .size (0 ))]]
134135
135136 raise TypeError ('%s datatype is not supported' % type (X ))
136137
@@ -165,7 +166,7 @@ def data_shape(X: modALinput):
165166 return X .shape
166167 elif isinstance (X , list ):
167168 return np .array (X ).shape
168- elif torch .is_tensor (X ):
169+ elif torch .is_tensor (X ):
169170 return tuple (X .size ())
170171
171172 raise TypeError ('%s datatype is not supported' % type (X ))
0 commit comments