1616import numpy as np
1717import io
1818import six
19-
19+ import time
20+ import random
2021from paddle .io import IterableDataset
2122
2223
@@ -35,7 +36,7 @@ def __call__(self):
3536 self .idx = 0
3637
3738 result = self .buffer [self .idx ]
38- self .idx += 1
39+ self .idx = self . idx + 1
3940 return result
4041
4142
@@ -52,7 +53,9 @@ def init(self):
5253 self .neg_num = self .config .get ("hyper_parameters.neg_num" )
5354 self .with_shuffle_batch = self .config .get (
5455 "hyper_parameters.with_shuffle_batch" )
55- self .random_generator = NumpyRandomInt (1 , self .window_size + 1 )
56+ #self.random_generator = NumpyRandomInt(1, self.window_size + 1)
57+ np .random .seed (12345 )
58+ self .random_generator = np .random .randint (1 , self .window_size + 1 )
5659 self .batch_size = self .config .get ("runner.batch_size" )
5760
5861 self .cs = None
@@ -78,7 +81,7 @@ def get_context_words(self, words, idx):
7881 idx: input word index
7982 window_size: window size
8083 """
81- target_window = self .random_generator ()
84+ target_window = self .random_generator
8285 # if (idx - target_window) > 0 else 0
8386 start_point = idx - target_window
8487 if start_point < 0 :
@@ -102,11 +105,15 @@ def __iter__(self):
102105 np .array ([int (target_id )]).astype ('int64' ))
103106 output .append (
104107 np .array ([int (context_id )]).astype ('int64' ))
105- np .random .seed (12345 )
106- neg_array = self .cs .searchsorted (
107- np .random .sample (self .neg_num ))
108+
109+ tmp = []
110+ random .seed (12345 )
111+ for i in range (self .neg_num ):
112+ tmp .append (random .random ())
113+ neg_array = self .cs .searchsorted (tmp )
114+
108115 output .append (
109- np .array ([int (str ( i ) )
116+ np .array ([int (i )
110117 for i in neg_array ]).astype ('int64' ))
111118 yield output
112119
0 commit comments