Skip to content

Commit 4f9b0fe

Browse files
committed
modify infer_meta.py
1 parent 625c047 commit 4f9b0fe

File tree

1 file changed

+243
-0
lines changed

1 file changed

+243
-0
lines changed
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import os
17+
import time
18+
import logging
19+
import sys
20+
import numpy as np
21+
22+
__dir__ = os.path.dirname(os.path.abspath(__file__))
23+
print(os.path.abspath('/'.join(__dir__.split('/')[:-3])))
24+
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
25+
sys.path.append(os.path.abspath('/'.join(__dir__.split('/')[:-3])))
26+
27+
from tools.utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
28+
from tools.utils.save_load import save_model, load_model
29+
from paddle.io import DataLoader
30+
import argparse
31+
32+
logging.basicConfig(
33+
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
34+
logger = logging.getLogger(__name__)
35+
36+
37+
def parse_args():
38+
parser = argparse.ArgumentParser(description='paddle-rec run')
39+
parser.add_argument("-m", "--config_yaml", type=str)
40+
parser.add_argument("-o", "--opt", nargs='*', type=str)
41+
args = parser.parse_args()
42+
args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
43+
args.config_yaml = get_abs_model(args.config_yaml)
44+
return args
45+
46+
47+
def main(args):
48+
paddle.seed(2021)
49+
# load config
50+
config = load_yaml(args.config_yaml)
51+
dy_model_class = load_dy_model_class(args.abs_dir)
52+
config["config_abs_dir"] = args.abs_dir
53+
# modify config from command
54+
if args.opt:
55+
for parameter in args.opt:
56+
parameter = parameter.strip()
57+
key, value = parameter.split("=")
58+
if type(config.get(key)) is int:
59+
value = int(value)
60+
if type(config.get(key)) is float:
61+
value = float(value)
62+
if type(config.get(key)) is bool:
63+
value = (True if value.lower() == "true" else False)
64+
config[key] = value
65+
66+
# tools.vars
67+
use_gpu = config.get("runner.use_gpu", True)
68+
use_xpu = config.get("runner.use_xpu", False)
69+
use_npu = config.get("runner.use_npu", False)
70+
use_visual = config.get("runner.use_visual", False)
71+
test_data_dir = config.get("runner.test_data_dir", None)
72+
print_interval = config.get("runner.print_interval", None)
73+
infer_batch_size = config.get("runner.infer_batch_size", None)
74+
model_load_path = config.get("runner.infer_load_path", "model_output")
75+
start_epoch = config.get("runner.infer_start_epoch", 0)
76+
end_epoch = config.get("runner.infer_end_epoch", 10)
77+
infer_train_epoch = config.get("runner.infer_train_epoch", 2)
78+
batchsize = config.get("hyper_parameters.batch_size", 32)
79+
80+
logger.info("**************common.configs**********")
81+
logger.info(
82+
"use_gpu: {}, use_xpu: {}, use_npu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}".
83+
format(use_gpu, use_xpu, use_npu, use_visual, infer_batch_size,
84+
test_data_dir, start_epoch, end_epoch, print_interval,
85+
model_load_path))
86+
logger.info("**************common.configs**********")
87+
88+
if use_xpu:
89+
xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
90+
place = paddle.set_device(xpu_device)
91+
elif use_npu:
92+
npu_device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
93+
place = paddle.set_device(npu_device)
94+
else:
95+
place = paddle.set_device('gpu' if use_gpu else 'cpu')
96+
97+
dy_model = dy_model_class.create_model(config)
98+
99+
# Create a log_visual object and store the data in the path
100+
if use_visual:
101+
from visualdl import LogWriter
102+
log_visual = LogWriter(args.abs_dir + "/visualDL_log/infer")
103+
104+
# to do : add optimizer function
105+
#optimizer = dy_model_class.create_optimizer(dy_model, config)
106+
107+
logger.info("read data")
108+
infer_dataloader = create_data_loader(
109+
config=config, place=place, mode="test")
110+
111+
epoch_begin = time.time()
112+
interval_begin = time.time()
113+
114+
metric_list, metric_list_name = dy_model_class.create_metrics()
115+
step_num = 0
116+
print_interval = 1
117+
118+
for epoch_id in range(start_epoch, end_epoch):
119+
logger.info("load model epoch {}".format(epoch_id))
120+
model_path = os.path.join(model_load_path, str(epoch_id))
121+
122+
infer_reader_cost = 0.0
123+
infer_run_cost = 0.0
124+
reader_start = time.time()
125+
126+
assert any(infer_dataloader(
127+
)), "test_dataloader is null, please ensure batch size < dataset size!"
128+
129+
aid_flag = -1
130+
131+
for batch_id, batch in enumerate(infer_dataloader()):
132+
infer_reader_cost += time.time() - reader_start
133+
infer_start = time.time()
134+
135+
aid_flag = batch[0][0].item()
136+
x_spt, y_spt, x_qry, y_qry = batch[1], batch[2], batch[3], batch[4]
137+
138+
load_model(model_path, dy_model)
139+
# 对每个子任务进行训练
140+
optimizer = dy_model_class.create_optimizer(dy_model, config,
141+
"infer")
142+
dy_model.train()
143+
144+
for i in range(infer_train_epoch):
145+
n_samples = y_spt.shape[0]
146+
n_batch = int(np.ceil(n_samples / batchsize))
147+
optimizer.clear_grad()
148+
149+
for i_batch in range(n_batch):
150+
batch_input = list()
151+
batch_x = []
152+
batch_x.append(x_spt[0][i_batch * batchsize:(i_batch + 1) *
153+
batchsize])
154+
batch_x.append(x_spt[1][i_batch * batchsize:(i_batch + 1) *
155+
batchsize])
156+
batch_x.append(x_spt[2][i_batch * batchsize:(i_batch + 1) *
157+
batchsize])
158+
batch_x.append(x_spt[3][i_batch * batchsize:(i_batch + 1) *
159+
batchsize])
160+
161+
batch_y = y_spt[i_batch * batchsize:(i_batch + 1) *
162+
batchsize]
163+
164+
batch_input.append(batch_x)
165+
batch_input.append(batch_y)
166+
167+
loss = dy_model_class.infer_train_forward(
168+
dy_model, batch_input, config)
169+
170+
dy_model.clear_gradients()
171+
loss.backward()
172+
optimizer.step()
173+
# 对每个子任务进行测试
174+
dy_model.eval()
175+
metric_list_local, metric_list_local_name = dy_model_class.create_metrics(
176+
)
177+
with paddle.no_grad():
178+
n_samples = y_qry.shape[0]
179+
n_batch = int(np.ceil(n_samples / batchsize))
180+
181+
for i_batch in range(n_batch):
182+
batch_input = list()
183+
batch_x = []
184+
batch_x.append(x_qry[0][i_batch * batchsize:(i_batch + 1) *
185+
batchsize])
186+
batch_x.append(x_qry[1][i_batch * batchsize:(i_batch + 1) *
187+
batchsize])
188+
batch_x.append(x_qry[2][i_batch * batchsize:(i_batch + 1) *
189+
batchsize])
190+
batch_x.append(x_qry[3][i_batch * batchsize:(i_batch + 1) *
191+
batchsize])
192+
193+
batch_y = y_qry[i_batch * batchsize:(i_batch + 1) *
194+
batchsize]
195+
196+
batch_input.append(batch_x)
197+
batch_input.append(batch_y)
198+
199+
metric_list, metric_list_local = dy_model_class.infer_forward(
200+
dy_model, metric_list, metric_list_local, batch_input,
201+
config)
202+
203+
infer_run_cost += time.time() - infer_start
204+
205+
metric_str_local = ""
206+
for metric_id in range(len(metric_list_local_name)):
207+
metric_str_local += (
208+
metric_list_local_name[metric_id] + ": {:.6f},".format(
209+
metric_list_local[metric_id].accumulate()))
210+
if use_visual:
211+
log_visual.add_scalar(
212+
tag="infer/" + metric_list_local_name[metric_id],
213+
step=step_num,
214+
value=metric_list_local[metric_id].accumulate())
215+
logger.info(
216+
"epoch: {}, batch_id: {}, aid: {} ".format(
217+
epoch_id, batch_id, aid_flag) + metric_str_local +
218+
" avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.2f} ins/s".
219+
format(infer_reader_cost / print_interval, (
220+
infer_reader_cost + infer_run_cost) / print_interval,
221+
batchsize, print_interval * batchsize / (time.time(
222+
) - interval_begin)))
223+
224+
interval_begin = time.time()
225+
infer_reader_cost = 0.0
226+
infer_run_cost = 0.0
227+
step_num = step_num + 1
228+
reader_start = time.time()
229+
230+
metric_str = ""
231+
for metric_id in range(len(metric_list_name)):
232+
metric_str += (
233+
metric_list_name[metric_id] +
234+
": {:.6f},".format(metric_list[metric_id].accumulate()))
235+
236+
logger.info("epoch: {} done, ".format(epoch_id) + metric_str +
237+
" epoch time: {:.2f} s".format(time.time() - epoch_begin))
238+
epoch_begin = time.time()
239+
240+
241+
if __name__ == '__main__':
242+
args = parse_args()
243+
main(args)

0 commit comments

Comments
 (0)