diff --git a/data.py b/data.py index 60c180f..5222765 100644 --- a/data.py +++ b/data.py @@ -44,7 +44,7 @@ def __init__(self, root, split = 'train'): idxs = np.arange(len(self.strokes)) np.random.shuffle(idxs) self.strokes = self.strokes[idxs] - self.sentences = np.asarray(self.sentences)[idxs].tolist() + self.sentences = [self.sentences[i] for i in idxs] c = Counter() for line in self.sentences: