Skip to content

Commit 2ec982a

Browse files
authored
Merge pull request #741 from yinhaofeng/fix_tipc
fix_tipc
2 parents e74853e + 64bcf86 commit 2ec982a

File tree

5 files changed

+18
-11
lines changed

5 files changed

+18
-11
lines changed

models/rank/flen/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ os : windows/linux/macos
7171
python -u ../../../tools/trainer.py -m config.yaml # 全量数据运行config_bigdata.yaml
7272
# 动态图预测
7373
python -u ../../../tools/infer.py -m config.yaml # 全量数据运行config_bigdata.yaml
74-
74+
```
7575

7676
## 模型组网
7777

models/recall/word2vec/dygraph_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def create_metrics(self):
8585
# construct train forward phase
8686
def train_forward(self, dy_model, metrics_list, batch_data, config):
8787
input_word, true_word, neg_word = self.create_feeds(batch_data, config)
88-
8988
true_logits, neg_logits = dy_model.forward(
9089
[input_word, true_word, neg_word])
9190
loss = self.create_loss(true_logits, neg_logits, config)

models/recall/word2vec/static_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def net(self, inputs, is_infer=False):
9595

9696
def create_optimizer(self, strategy=None):
9797
optimizer = paddle.optimizer.SGD(learning_rate=self.learning_rate)
98+
optimizer.minimize(self._cost)
9899
# learning_rate=paddle.fluid.layers.exponential_decay(
99100
# learning_rate=self.learning_rate,
100101
# decay_steps=self.decay_steps,

models/recall/word2vec/word2vec_reader.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import numpy as np
1717
import io
1818
import six
19-
19+
import time
20+
import random
2021
from paddle.io import IterableDataset
2122

2223

@@ -35,7 +36,7 @@ def __call__(self):
3536
self.idx = 0
3637

3738
result = self.buffer[self.idx]
38-
self.idx += 1
39+
self.idx = self.idx + 1
3940
return result
4041

4142

@@ -52,7 +53,9 @@ def init(self):
5253
self.neg_num = self.config.get("hyper_parameters.neg_num")
5354
self.with_shuffle_batch = self.config.get(
5455
"hyper_parameters.with_shuffle_batch")
55-
self.random_generator = NumpyRandomInt(1, self.window_size + 1)
56+
#self.random_generator = NumpyRandomInt(1, self.window_size + 1)
57+
np.random.seed(12345)
58+
self.random_generator = np.random.randint(1, self.window_size + 1)
5659
self.batch_size = self.config.get("runner.batch_size")
5760

5861
self.cs = None
@@ -78,7 +81,7 @@ def get_context_words(self, words, idx):
7881
idx: input word index
7982
window_size: window size
8083
"""
81-
target_window = self.random_generator()
84+
target_window = self.random_generator
8285
# if (idx - target_window) > 0 else 0
8386
start_point = idx - target_window
8487
if start_point < 0:
@@ -102,11 +105,15 @@ def __iter__(self):
102105
np.array([int(target_id)]).astype('int64'))
103106
output.append(
104107
np.array([int(context_id)]).astype('int64'))
105-
np.random.seed(12345)
106-
neg_array = self.cs.searchsorted(
107-
np.random.sample(self.neg_num))
108+
109+
tmp = []
110+
random.seed(12345)
111+
for i in range(self.neg_num):
112+
tmp.append(random.random())
113+
neg_array = self.cs.searchsorted(tmp)
114+
108115
output.append(
109-
np.array([int(str(i))
116+
np.array([int(i)
110117
for i in neg_array]).astype('int64'))
111118
yield output
112119

test_tipc/prepare.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ elif [ ${model_name} == "tisas" ]; then
272272
cp -r ./models/recall/tisas/data/sample_data/* ./test_tipc/data/train
273273
cp -r ./models/recall/tisas/data/sample_data/* ./test_tipc/data/infer
274274
echo "demo data ready"
275+
fi
275276

276277
elif [ ${model_name} == "dselect_k" ]; then
277278
mkdir -p ./test_tipc/data/train
@@ -300,5 +301,4 @@ elif [ ${model_name} == "dselect_k" ]; then
300301
cp -r ./models/multitask/dselect_k/data/* ./test_tipc/data/train
301302
cp -r ./datasets/Multi_MNIST_DselectK/test/* ./test_tipc/data/infer
302303
fi
303-
304304
fi

0 commit comments

Comments
 (0)