Skip to content

Commit 8ee2125

Browse files
authored
Data module id swapping (#170)
1 parent 83b3df7 commit 8ee2125

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

cornac/data/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,13 @@ def feature_dim(self):
8181
return self.features.shape[1]
8282

8383
def _swap_feature(self, id_map):
84-
for old_idx, raw_id in enumerate(self._ids):
84+
for old_idx, raw_id in enumerate(self._ids.copy()):
8585
new_idx = id_map.get(raw_id, None)
8686
if new_idx is None:
8787
continue
8888
assert new_idx < self.features.shape[0]
8989
self.features[[new_idx, old_idx]] = self.features[[old_idx, new_idx]]
90+
self._ids[old_idx], self._ids[new_idx] = self._ids[new_idx], self._ids[old_idx]
9091

9192
def build(self, id_map=None):
9293
"""Build the feature matrix.

cornac/data/text.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,13 @@ def __init__(self,
491491
self.count_matrix = None
492492

493493
def _swap_text(self, id_map: Dict):
494-
for old_idx, raw_id in enumerate(self._ids):
494+
for old_idx, raw_id in enumerate(self._ids.copy()):
495495
new_idx = id_map.get(raw_id, None)
496496
if new_idx is None:
497497
continue
498498
assert new_idx < len(self.corpus)
499499
self.corpus[old_idx], self.corpus[new_idx] = self.corpus[new_idx], self.corpus[old_idx]
500+
self._ids[old_idx], self._ids[new_idx] = self._ids[new_idx], self._ids[old_idx]
500501

501502
def _build_text(self, id_map: Dict):
502503
"""Build the text based on provided global id map

tests/cornac/data/test_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_init(self):
3333
self.assertIsNone(md.features)
3434

3535
md = FeatureModule(features=np.asarray(list(self.id_feature.values())),
36-
ids=self.id_feature.keys(),
36+
ids=list(self.id_feature.keys()),
3737
normalized=True)
3838

3939
global_iid_map = OrderedDict()
@@ -46,7 +46,7 @@ def test_init(self):
4646

4747
def test_batch_feature(self):
4848
md = FeatureModule(features=np.asarray(list(self.id_feature.values())),
49-
ids=self.id_feature.keys(),
49+
ids=list(self.id_feature.keys()),
5050
normalized=True)
5151

5252
global_iid_map = OrderedDict({'a': 0, 'b': 1})

0 commit comments

Comments
 (0)