File tree Expand file tree Collapse file tree 3 files changed +6
-4
lines changed
Expand file tree Collapse file tree 3 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -81,12 +81,13 @@ def feature_dim(self):
8181 return self .features .shape [1 ]
8282
8383 def _swap_feature (self , id_map ):
84- for old_idx , raw_id in enumerate (self ._ids ):
84+ for old_idx , raw_id in enumerate (self ._ids . copy () ):
8585 new_idx = id_map .get (raw_id , None )
8686 if new_idx is None :
8787 continue
8888 assert new_idx < self .features .shape [0 ]
8989 self .features [[new_idx , old_idx ]] = self .features [[old_idx , new_idx ]]
90+ self ._ids [old_idx ], self ._ids [new_idx ] = self ._ids [new_idx ], self ._ids [old_idx ]
9091
9192 def build (self , id_map = None ):
9293 """Build the feature matrix.
Original file line number Diff line number Diff line change @@ -491,12 +491,13 @@ def __init__(self,
491491 self .count_matrix = None
492492
493493 def _swap_text (self , id_map : Dict ):
494- for old_idx , raw_id in enumerate (self ._ids ):
494+ for old_idx , raw_id in enumerate (self ._ids . copy () ):
495495 new_idx = id_map .get (raw_id , None )
496496 if new_idx is None :
497497 continue
498498 assert new_idx < len (self .corpus )
499499 self .corpus [old_idx ], self .corpus [new_idx ] = self .corpus [new_idx ], self .corpus [old_idx ]
500+ self ._ids [old_idx ], self ._ids [new_idx ] = self ._ids [new_idx ], self ._ids [old_idx ]
500501
501502 def _build_text (self , id_map : Dict ):
502503 """Build the text based on provided global id map
Original file line number Diff line number Diff line change @@ -33,7 +33,7 @@ def test_init(self):
3333 self .assertIsNone (md .features )
3434
3535 md = FeatureModule (features = np .asarray (list (self .id_feature .values ())),
36- ids = self .id_feature .keys (),
36+ ids = list ( self .id_feature .keys () ),
3737 normalized = True )
3838
3939 global_iid_map = OrderedDict ()
@@ -46,7 +46,7 @@ def test_init(self):
4646
4747 def test_batch_feature (self ):
4848 md = FeatureModule (features = np .asarray (list (self .id_feature .values ())),
49- ids = self .id_feature .keys (),
49+ ids = list ( self .id_feature .keys () ),
5050 normalized = True )
5151
5252 global_iid_map = OrderedDict ({'a' : 0 , 'b' : 1 })
You can’t perform that action at this time.
0 commit comments