Skip to content

Commit 2966f1d

Browse files
committed
fix word2vec reader
1 parent 3ef7cb8 commit 2966f1d

File tree

3 files changed

+5
-7
lines changed

3 files changed

+5
-7
lines changed

models/rank/dlrm/criteo_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import print_function
1616
import numpy as np
17-
17+
import paddle
1818
from paddle.io import IterableDataset
1919

2020

models/recall/word2vec/word2vec_infer_reader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@ def __init__(self, file_list, config):
2525
super(RecDataset, self).__init__()
2626
self.file_list = file_list
2727
self.config = config
28+
self.config_abs_dir = config.get("config_abs_dir", None)
2829
self.init()
2930

3031
def init(self):
3132
dict_path = self.config.get("runner.word_id_dict_path")
32-
pwd = str(os.getcwd())
33-
if pwd[-8:] != 'word2vec':
34-
dict_path = os.path.join(pwd, 'models/recall/word2vec', dict_path)
33+
dict_path = os.path.join(self.config_abs_dir, dict_path)
3534
self.word_to_id = dict()
3635
self.id_to_word = dict()
3736
with io.open(dict_path, 'r', encoding='utf-8') as f:

models/recall/word2vec/word2vec_reader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,12 @@ def __init__(self, file_list, config):
4646
super(RecDataset, self).__init__()
4747
self.file_list = file_list
4848
self.config = config
49+
self.config_abs_dir = config.get("config_abs_dir", None)
4950
self.init()
5051

5152
def init(self):
5253
dict_path = self.config.get("runner.word_count_dict_path")
53-
pwd = str(os.getcwd())
54-
if pwd[-8:] != 'word2vec':
55-
dict_path = os.path.join(pwd, 'models/recall/word2vec', dict_path)
54+
dict_path = os.path.join(self.config_abs_dir, dict_path)
5655
self.window_size = self.config.get("hyper_parameters.window_size")
5756
self.neg_num = self.config.get("hyper_parameters.neg_num")
5857
self.with_shuffle_batch = self.config.get(

0 commit comments

Comments
 (0)