Skip to content

Commit 87a8caa

Browse files
seiriosPlusguru4elephant
authored andcommitted
Refactor fetch handler (#21264) (#21537)
* fix fetch handler problem and refactor when a user define FetchHandler class, he or she should initialize a handler with variable dict. the key of a variable dict is a user defined name, the value of a variable dict is a Varaible generated from python API. For each fetching, a user should implement handler function in which fetched_result_dict will be available and the user can access the fetched value with user defined keys.
1 parent 20a0937 commit 87a8caa

File tree

7 files changed

+112
-75
lines changed

7 files changed

+112
-75
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
186186
// training and finalize training
187187
VLOG(3) << "Trainer starts to run";
188188
trainer->Run();
189+
}
190+
191+
void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
189192
VLOG(3) << "Trainer going to finalize";
190193
trainer->Finalize();
191194
}

paddle/fluid/framework/executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ class Executor {
126126
Scope* scope, Dataset* dataset);
127127
void RunFromDataset(std::shared_ptr<TrainerBase> trainer);
128128

129+
void ReleaseTrainer(std::shared_ptr<TrainerBase> trainer);
130+
129131
const platform::Place GetPlace() const { return place_; }
130132

131133
private:

paddle/fluid/framework/multi_trainer.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,12 @@ void MultiTrainer::Run() {
7777
workers_[thidx].get()));
7878
}
7979
}
80-
}
81-
82-
void MultiTrainer::Finalize() {
8380
for (auto& th : threads_) {
8481
th.join();
8582
}
86-
root_scope_->DropKids();
8783
}
8884

85+
void MultiTrainer::Finalize() { root_scope_->DropKids(); }
86+
8987
} // end namespace framework
9088
} // end namespace paddle

paddle/fluid/pybind/pybind.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,10 +1364,13 @@ All parameter, weight, gradient are variables in Paddle.
13641364
.def("close", &Executor::Close)
13651365
.def("run_from_dataset", &Executor::RunFromDataset,
13661366
py::call_guard<py::gil_scoped_release>())
1367+
.def("release_trainer", &Executor::ReleaseTrainer,
1368+
py::call_guard<py::gil_scoped_release>())
13671369
.def("init_for_dataset",
13681370
[](Executor &self, const ProgramDesc &prog,
13691371
const std::string &trainer_desc, Scope *scope,
13701372
Dataset *dataset) -> std::shared_ptr<TrainerBase> {
1373+
pybind11::gil_scoped_release release;
13711374
return self.InitForDataset(prog, trainer_desc, scope, dataset);
13721375
})
13731376
.def("run_from_dataset",

python/paddle/fluid/executor.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -395,23 +395,28 @@ def _as_lodtensor(data, place):
395395

396396

397397
class FetchHandler(object):
398-
def __init__(self, fetch_target_names, period_secs=60, return_np=True):
399-
self.fetch_target_names = fetch_target_names
398+
def __init__(self, var_dict=None, period_secs=60):
399+
assert var_dict != None
400+
self.var_dict = var_dict
400401
self.period_secs = period_secs
401-
self.return_np = return_np
402402

403-
def handler(self, fetch_target_vars):
404-
return
403+
def handler(self, res_dict):
404+
for key in res_dict:
405+
if type(res_dict[key]) is np.ndarray:
406+
sys.stdout.write("{}[0]: {} ".format(key, res_dict[key][0]))
407+
sys.stdout.write("\n")
405408

406409
@staticmethod
407410
def help():
408411
print("""
409-
class FetchHandlerExamlpe(FetchHandler):
410-
def handler(self, fetch_target_vars):
411-
b_auc = fetch_target_vars[0]
412-
g_auc = fetch_target_vars[1]
413-
414-
print("b_auc: {}, g_auc: {} at time: {}".format(b_auc, g_auc, time.ctime()))
412+
class FetchHandlerExample(FetchHandler):
413+
def handler(self, res_dict):
414+
print(res_dict["auc"])
415+
print("auc: {}, {}".format(res_dict["auc"], time.ctime()))
416+
417+
auc = Variable()
418+
var_dict = {"auc": auc}
419+
handler = FetchHandlerExample(var_dict=var_dict)
415420
""")
416421

417422

@@ -1019,13 +1024,13 @@ def _run_from_dataset(self,
10191024
scope0 = trainer_instance.get_worker_scope(0)
10201025
fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler)
10211026
fetch_monitor.start()
1022-
10231027
self._default_executor.run_from_dataset(trainer_instance)
1024-
10251028
fetch_monitor.stop()
1029+
self._default_executor.release_trainer(trainer_instance)
10261030
else:
10271031

10281032
self._default_executor.run_from_dataset(trainer_instance)
1033+
self._default_executor.release_trainer(trainer_instance)
10291034

10301035
dataset._dynamic_adjust_after_train()
10311036
dataset._finish_to_run()

python/paddle/fluid/tests/unittests/test_fetch_handler.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import time
1818
import unittest
1919
import numpy as np
20+
from paddle.fluid.framework import Program
2021

2122
import paddle.fluid.core as core
2223
import paddle.fluid as fluid
@@ -29,20 +30,35 @@ def test_fetch_handler(self):
2930

3031
table = np.random.random((3, 10)).astype("float32")
3132

33+
prog = Program()
34+
block = prog.current_block()
35+
var_emb = block.create_var(name='emb', type=core.VarDesc.VarType.FP32)
36+
var_emb3 = block.create_var(name='emb3', type=core.VarDesc.VarType.FP32)
37+
3238
class FH(fluid.executor.FetchHandler):
33-
def handler(self, fetch_target_vars):
34-
assert len(fetch_target_vars) == 1
39+
def handler(self, fetch_dict):
40+
assert len(fetch_dict) == 1
3541

3642
table_var = scope.var('emb').get_tensor()
3743
table_var.set(table, place)
38-
39-
fh = FH(['emb'], period_secs=2, return_np=True)
44+
fh = FH(var_dict={'emb': var_emb}, period_secs=2)
4045
fm = fluid.trainer_factory.FetchHandlerMonitor(scope, fh)
4146

4247
fm.start()
43-
time.sleep(10)
48+
time.sleep(3)
4449
fm.stop()
4550

51+
default_fh = fluid.executor.FetchHandler(
52+
var_dict={'emb': var_emb,
53+
'emb2': None,
54+
'emb3': var_emb3},
55+
period_secs=1)
56+
default_fm = fluid.trainer_factory.FetchHandlerMonitor(scope,
57+
default_fh)
58+
default_fm.start()
59+
time.sleep(5)
60+
default_fm.stop()
61+
4662

4763
if __name__ == "__main__":
4864
unittest.main()

python/paddle/fluid/trainer_factory.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515

1616
import threading
1717
import time
18-
18+
import logging
1919
import numpy as np
2020

21+
logging.basicConfig()
22+
2123
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
2224
from .device_worker import Hogwild, DownpourSGD, Section
25+
from .framework import Variable
26+
from multiprocessing import Process, Manager
2327

2428
__all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"]
2529

@@ -93,68 +97,74 @@ class FetchHandlerMonitor(object):
9397
def __init__(self, scope, handler):
9498
self.fetch_instance = handler
9599
self.fetch_thread = threading.Thread(
96-
target=self.handler_decorator,
97-
args=(scope, self.fetch_instance.handler))
100+
target=self.handler_launch_func, args=(scope, self.fetch_instance))
101+
self.running_lock = threading.Lock()
98102
self.running = False
99103

104+
def handler_launch_func(self, scope, handler):
105+
fetch_instance = handler
106+
period_secs = fetch_instance.period_secs
107+
var_name_to_key = {}
108+
for key in fetch_instance.var_dict:
109+
if isinstance(fetch_instance.var_dict[key], Variable):
110+
var_name_to_key[fetch_instance.var_dict[key].name] = key
111+
else:
112+
logging.warning("the value of {} is not a Variable".format(key))
113+
var_name_to_key["None.var"] = key
114+
elapsed_secs = 0
115+
while True:
116+
self.running_lock.acquire()
117+
if self.running == False:
118+
break
119+
if elapsed_secs < period_secs:
120+
# TODO(guru4elephant): needs customized condition
121+
time.sleep(1)
122+
elapsed_secs += 1
123+
else:
124+
elapsed_secs = 0
125+
fetch_dict = {}
126+
for key in var_name_to_key:
127+
var = scope.find_var(key)
128+
fetch_dict[key] = var
129+
if var == None:
130+
logging.warning("{} value currently not available".
131+
format(var_name_to_key[key]))
132+
res_dict = {}
133+
for key in fetch_dict:
134+
user_name = var_name_to_key[key]
135+
if fetch_dict[key] == None:
136+
res_dict[user_name] = None
137+
continue
138+
else:
139+
res_dict[user_name] = fetch_dict[key].get_tensor()
140+
141+
lod = res_dict[user_name].lod()
142+
if len(lod) > 0:
143+
raise RuntimeError("Some of your fetched tensors \
144+
hold LoD information. \
145+
They can not be completely cast \
146+
to Python ndarray. We can \
147+
not return LoDTensor itself directly, \
148+
please choose another targets")
149+
if res_dict[user_name]._is_initialized():
150+
res_dict[user_name] = np.array(res_dict[user_name])
151+
else:
152+
res_dict[user_name] = None
153+
fetch_instance.handler(res_dict)
154+
self.running_lock.release()
155+
100156
def start(self):
101157
"""
102158
start monitor,
103159
it will start a monitor thread.
104160
"""
161+
self.running_lock.acquire()
105162
self.running = True
163+
self.running_lock.release()
106164
self.fetch_thread.setDaemon(True)
107165
self.fetch_thread.start()
108166

109-
def handler_decorator(self, fetch_scope, fetch_handler):
110-
"""
111-
decorator of handler,
112-
Args:
113-
fetch_scope(Scope): fetch scope
114-
fetch_handler(Handler): fetch handler
115-
"""
116-
fetch_target_names = self.fetch_instance.fetch_target_names
117-
period_secs = self.fetch_instance.period_secs
118-
119-
elapsed_secs = 0
120-
while True:
121-
while self.running and elapsed_secs >= period_secs:
122-
elapsed_secs = 0
123-
124-
fetch_vars = [
125-
fetch_scope.find_var(varname)
126-
for varname in fetch_target_names
127-
]
128-
129-
if None in fetch_vars:
130-
continue
131-
132-
fetch_tensors = [var.get_tensor() for var in fetch_vars]
133-
134-
if self.fetch_instance.return_np:
135-
fetch_nps = []
136-
137-
for tensor in fetch_tensors:
138-
lod = tensor.lod()
139-
140-
if len(lod) > 0:
141-
raise RuntimeError(
142-
"Some of your fetched tensors hold LoD information. \
143-
They can not be completely cast to Python ndarray. We can not \
144-
return LoDTensor itself directly, please choose another targets"
145-
)
146-
147-
if tensor._is_initialized():
148-
fetch_nps.append(np.array(tensor))
149-
else:
150-
fetch_nps.append(None)
151-
152-
fetch_handler(fetch_nps)
153-
else:
154-
fetch_handler(fetch_tensors)
155-
else:
156-
time.sleep(1)
157-
elapsed_secs += 1
158-
159167
def stop(self):
168+
self.running_lock.acquire()
160169
self.running = False
170+
self.running_lock.release()

0 commit comments

Comments
 (0)