Skip to content

Commit 5512ac4

Browse files
authored
Merge pull request #747 from wangzhen38/tipc_fix
fix tipc bugs from qa
2 parents b28c1a2 + 68ba6a5 commit 5512ac4

File tree

7 files changed

+38
-11
lines changed

7 files changed

+38
-11
lines changed

models/rank/dlrm/config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ runner:
3737
split_file_list: False
3838
thread_num: 1
3939

40+
# use inference save model
41+
inference: False # 静态图训练时保存为inference model
42+
save_inference_feed_varnames: ["C1","C2","C3","C4","C5","C6","C7","C8","C9","C10","C11","C12","C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24","C25","C26","dense_input"] # inference model 的feed参数的名字
43+
save_inference_fetch_varnames: ["sigmoid_0.tmp_0"] # inference model 的fetch参数的名字
4044

4145
# hyper parameters of user-defined network
4246
hyper_parameters:

models/rank/dlrm/criteo_reader.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,25 @@ class RecDataset(IterableDataset):
2222
def __init__(self, file_list, config):
2323
super(RecDataset, self).__init__()
2424
self.file_list = file_list
25+
if config:
26+
use_fleet = config.get("runner.use_fleet", False)
27+
self.inference = config.get("runner.inference", False)
28+
else:
29+
use_fleet = False
30+
if use_fleet:
31+
worker_id = paddle.distributed.get_rank()
32+
worker_num = paddle.distributed.get_world_size()
33+
file_num = len(file_list)
34+
if file_num < worker_num:
35+
raise ValueError(
36+
"The number of data files is less than the number of workers"
37+
)
38+
blocksize = int(file_num / worker_num)
39+
self.file_list = file_list[worker_id * blocksize:(worker_id + 1) *
40+
blocksize]
41+
remainder = file_num - (blocksize * worker_num)
42+
if worker_id < remainder:
43+
self.file_list.append(file_list[-(worker_id + 1)])
2544
self.init()
2645

2746
def init(self):
@@ -78,4 +97,8 @@ def __iter__(self):
7897
output_list.append(
7998
np.array(output[-1][1]).astype("float32"))
8099
# list
81-
yield output_list
100+
#yield output_list
101+
if self.inference:
102+
yield output_list[1:]
103+
else:
104+
yield output_list

models/recall/ensfm/config_bigdata.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ runner:
1616
train_data_dir: "../../../datasets/ml-1m_ensfm/data/ml-1m-ensfm"
1717
train_reader_path: "movielens_reader" # importlib format
1818
train_batch_size: 512
19-
model_save_path: "output_model_ensfm"
19+
model_save_path: "output_model_ensfm_all"
2020
mode: "train"
2121
use_gpu: True
2222
epochs: 501
@@ -25,7 +25,7 @@ runner:
2525
test_data_dir: "../../../datasets/ml-1m_ensfm/data/ml-1m-ensfm"
2626
infer_reader_path: "movielens_reader" # importlib format
2727
infer_batch_size: 512
28-
infer_load_path: "output_model_ensfm"
28+
infer_load_path: "output_model_ensfm_all"
2929
infer_start_epoch: 100
3030
infer_end_epoch: 501
3131

test_tipc/configs/tisas/paddle_infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from importlib import import_module
2424

2525
__dir__ = os.path.dirname(os.path.abspath(__file__))
26-
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
26+
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../tools')))
2727
from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model
2828
from utils.save_load import save_model, load_model
2929
from paddle.io import DistributedBatchSampler, DataLoader

test_tipc/configs/tisas/to_static.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
__dir__ = os.path.dirname(os.path.abspath(__file__))
2424
# sys.path.append(__dir__)
25-
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
25+
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../../tools')))
2626

2727
from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
2828
from utils.save_load import load_model, save_model, save_jit_model

test_tipc/configs/tisas/train_infer_python.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ runner.model_save_path
99
runner.train_batch_size:lite_train_lite_infer=2|whole_train_whole_infer=128|whole_infer=1|lite_train_whole_infer=2
1010
runner.infer_load_path:null
1111
train_model_name:lite_train_lite_infer=0|whole_train_whole_infer=101|whole_infer=101|lite_train_whole_infer=0
12-
runner.test_data_dir:test_tipc/data
13-
runner.train_data_dir:../../../test_tipc/data
12+
runner.test_data_dir:test_tipc/data/infer
13+
runner.train_data_dir:../../../test_tipc/data/train
1414
##
1515
trainer:norm_train
1616
norm_train:-u tools/trainer.py -m ./models/recall/tisas/config.yaml -o runner.print_interval=2
@@ -27,7 +27,7 @@ null:null
2727
===========================infer_params===========================
2828
runner.model_save_path:
2929
runner.model_init_path:
30-
norm_export:-u ./to_static.py -m ./models/recall/tisas/config.yaml -o runner.CE=true
30+
norm_export:-u test_tipc/configs/tisas/to_static.py -m ./models/recall/tisas/config.yaml -o runner.CE=true
3131
quant_export:null
3232
fpgm_export:null
3333
distill_export:null
@@ -37,15 +37,15 @@ null:null
3737
infer_model:test_tipc/save_tisas_model
3838
infer_export:null
3939
infer_quant:False
40-
inference:-u ./paddle_infer.py --model_name=tisas --reader_file=models/recall/tisas/movielens_reader.py
40+
inference:-u test_tipc/configs/tisas/paddle_infer.py --model_name=tisas --reader_file=models/recall/tisas/movielens_reader.py
4141
--use_gpu:True|False
4242
--enable_mkldnn:True|False
4343
--cpu_threads:1|6
4444
--batchsize:1
4545
--enable_tensorRT:False
4646
--precision:fp32
4747
--model_dir:
48-
--data_dir:test_tipc/data
48+
--data_dir:test_tipc/data/infer
4949
--save_log_path:./test_tipc/output/
5050
--benchmark:True
5151
null:null

test_tipc/test_train_inference_python.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ else
359359
#run inference
360360
eval $env
361361
save_infer_path="${save_log}"
362-
if [ ${inference_dir} != "null" ] && [ ${inference_dir} != '##' ]; then
362+
if [ "${inference_dir}" != "null" ] && [ "${inference_dir}" != '##' ]; then
363363
infer_model_dir="${save_infer_path}/${inference_dir}"
364364
else
365365
infer_model_dir=${save_infer_path}

0 commit comments

Comments
 (0)