Skip to content

Commit 3ef7cb8

Browse files
committed
fix word2vec run_in_root_dir
1 parent 89f532f commit 3ef7cb8

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

models/recall/word2vec/word2vec_infer_reader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy as np
1717
import io
1818
import six
19-
19+
import os
2020
from paddle.io import IterableDataset
2121

2222

@@ -29,6 +29,9 @@ def __init__(self, file_list, config):
2929

3030
def init(self):
3131
dict_path = self.config.get("runner.word_id_dict_path")
32+
pwd = str(os.getcwd())
33+
if pwd[-8:] != 'word2vec':
34+
dict_path = os.path.join(pwd, 'models/recall/word2vec', dict_path)
3235
self.word_to_id = dict()
3336
self.id_to_word = dict()
3437
with io.open(dict_path, 'r', encoding='utf-8') as f:

models/recall/word2vec/word2vec_reader.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import six
1919
import time
2020
import random
21+
import os
2122
from paddle.io import IterableDataset
2223

2324

@@ -49,6 +50,9 @@ def __init__(self, file_list, config):
4950

5051
def init(self):
5152
dict_path = self.config.get("runner.word_count_dict_path")
53+
pwd = str(os.getcwd())
54+
if pwd[-8:] != 'word2vec':
55+
dict_path = os.path.join(pwd, 'models/recall/word2vec', dict_path)
5256
self.window_size = self.config.get("hyper_parameters.window_size")
5357
self.neg_num = self.config.get("hyper_parameters.neg_num")
5458
self.with_shuffle_batch = self.config.get(
@@ -127,6 +131,9 @@ def __init__(self, file_list, config):
127131

128132
def init(self):
129133
dict_path = self.config.get("runner.word_id_dict_path")
134+
pwd = str(os.getcwd())
135+
if pwd[-8:] != 'word2vec':
136+
dict_path = os.path.join(pwd, 'models/recall/word2vec', dict_path)
130137
self.word_to_id = dict()
131138
self.id_to_word = dict()
132139
with io.open(dict_path, 'r', encoding='utf-8') as f:

test_tipc/configs/tisas/train_infer_python.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ inference:-u test_tipc/configs/tisas/paddle_infer.py --model_name=tisas --reader
4242
--enable_mkldnn:True|False
4343
--cpu_threads:1|6
4444
--batchsize:1
45-
--enable_tensorRT:False
45+
--enable_tensorRT:True|False
4646
--precision:fp32
4747
--model_dir:
4848
--data_dir:test_tipc/data/infer

0 commit comments

Comments
 (0)