Skip to content

Commit dce0732

Browse files
authored
Merge pull request #10380 from panyx0718/dist_timeline
timeline for distributed training
2 parents 0c51888 + d1ea74d commit dce0732

File tree

8 files changed

+146
-52
lines changed

8 files changed

+146
-52
lines changed

benchmark/cluster/vgg16/vgg16_fluid.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def str2bool(v):
8080
type=str,
8181
default="",
8282
help="Comma-separated list of hostname:port pairs")
83+
parser.add_argument(
84+
"--profile", action='store_true', help="If set, profile a few steps.")
8385

8486
# Flags for defining the tf.train.Server
8587
parser.add_argument(
@@ -183,8 +185,8 @@ def train_loop(exe, trainer_prog):
183185
start_time = time.time()
184186
num_samples = 0
185187
train_pass_acc.reset()
186-
for batch_id, data in enumerate(train_reader()):
187-
ts = time.time()
188+
189+
def run_step(batch_id, data):
188190
img_data = np.array(
189191
map(lambda x: x[0].reshape(data_shape), data)).astype(
190192
"float32")
@@ -196,14 +198,28 @@ def train_loop(exe, trainer_prog):
196198
feed={"pixel": img_data,
197199
"label": y_data},
198200
fetch_list=[avg_cost, batch_acc, batch_size])
201+
return loss, acc, b_size
202+
203+
if args.profile and args.task_index == 0:
204+
# warmup.
205+
for batch_id, data in enumerate(train_reader()):
206+
if batch_id > 5: break
207+
run_step(batch_id, data)
208+
with profiler.profiler('All', 'total', '/tmp/profile_vgg'):
209+
for batch_id, data in enumerate(train_reader()):
210+
if batch_id > 5: break
211+
run_step(batch_id, data)
212+
213+
for batch_id, data in enumerate(train_reader()):
214+
ts = time.time()
215+
loss, acc, b_size = run_step(batch_id, data)
199216
iters += 1
200217
num_samples += len(data)
201218
train_pass_acc.add(value=acc, weight=b_size)
202219
print(
203-
"Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
204-
"Speed = %.2f img/s " % (args.task_index, pass_id, iters,
205-
loss, acc,
206-
len(data) / (time.time() - ts))
220+
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
221+
"Speed = %.2f img/s" % (pass_id, iters, loss, acc,
222+
len(data) / (time.time() - ts))
207223
) # The accuracy is the accumulation of batches, but not the current batch.
208224

209225
pass_elapsed = time.time() - start_time

paddle/fluid/operators/detail/send_recv.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ message VariableMessage {
6969
bytes rows = 9;
7070
// Look up table block execution output variable name.
7171
string out_varname = 10;
72+
// If true, the ps server will start profiling, the ps
73+
// server stops profiling and generates a profile to /tmp/profile_ps_*
74+
// when profile switches from true to false.
75+
bool profile = 11;
7276
}
7377

7478
message VoidMessage {}

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
2424
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
2525
#include "paddle/fluid/operators/detail/variable_response.h"
26+
#include "paddle/fluid/platform/profiler.h"
2627

2728
namespace paddle {
2829
namespace operators {
@@ -45,6 +46,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
4546
void* payload = nullptr;
4647
size_t payload_size;
4748
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
49+
// Note: normally the profiler is enabled in 1 trainer, hence only
50+
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
51+
// servers the trainer's profiling state so that PS can follow the
52+
// trainer.
53+
if (platform::ShouldSendProfileState()) {
54+
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled());
55+
}
4856
e.WriteString(VarMsg::kVarnameFieldNumber, name);
4957
if (var->IsType<framework::LoDTensor>()) {
5058
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);

paddle/fluid/operators/detail/variable_response.cc

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include <utility>
1919
#include <vector>
20+
#include "paddle/fluid/platform/profiler.h"
2021

2122
#include "paddle/fluid/operators/detail/send_recv.pb.h"
2223
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
@@ -427,7 +428,26 @@ int VariableResponse::Parse(Source* source) {
427428
meta_.set_out_varname(temp);
428429
break;
429430
}
430-
431+
case sendrecv::VariableMessage::kProfileFieldNumber: {
432+
bool profiling;
433+
if (!input.ReadRaw(reinterpret_cast<void*>(&profiling), 1)) {
434+
return tag;
435+
}
436+
meta_.set_profile(profiling);
437+
int64_t listener_id = platform::ListenerId();
438+
if (listener_id <= 0) {
439+
break;
440+
}
441+
if (profiling && !platform::IsProfileEnabled()) {
442+
platform::EnableProfiler(platform::ProfilerState::kCPU);
443+
} else if (!profiling && platform::IsProfileEnabled()) {
444+
// TODO(panyx0718): Should we allow to customize file dir.
445+
platform::DisableProfiler(
446+
platform::EventSortingKey::kDefault,
447+
string::Sprintf("/tmp/profile_ps_%lld", listener_id));
448+
}
449+
break;
450+
}
431451
default: {
432452
// Unknown tag, return unknown error.
433453
return -1;

paddle/fluid/operators/listen_and_serv_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <vector>
1919

2020
#include "paddle/fluid/operators/listen_and_serv_op.h"
21+
#include "paddle/fluid/platform/profiler.h"
2122

2223
namespace paddle {
2324
namespace operators {
@@ -294,6 +295,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
294295

295296
void ListenAndServOp::RunImpl(const framework::Scope &scope,
296297
const platform::Place &dev_place) const {
298+
// Mark this as PS that it should decide profiling by listening from trainer.
299+
platform::SetProfileListener();
297300
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
298301
auto &dev_ctx = *pool.Get(dev_place);
299302
framework::Scope &recv_scope = scope.NewScope();

paddle/fluid/platform/profiler.cc

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/platform/profiler.h"
16+
1617
#include <sys/time.h>
1718
#include <time.h>
1819
#include <algorithm>
1920
#include <iomanip>
21+
#include <limits>
2022
#include <map>
2123
#include <mutex> // NOLINT
24+
#include <random>
2225
#include <string>
2326
#ifdef PADDLE_WITH_CUDA
2427
#include <cuda.h>
@@ -33,6 +36,9 @@ namespace platform {
3336

3437
struct EventList;
3538

39+
static int64_t profiler_lister_id = 0;
40+
static bool should_send_profile_state = false;
41+
3642
// The profiler state, the initial value is ProfilerState::kDisabled
3743
static ProfilerState g_state = ProfilerState::kDisabled;
3844
// The thread local event list only can be accessed by the specific thread
@@ -219,13 +225,12 @@ void EnableProfiler(ProfilerState state) {
219225
PADDLE_ENFORCE(state != ProfilerState::kDisabled,
220226
"Can't enbale profling, since the input state is ",
221227
"ProfilerState::kDisabled");
222-
PADDLE_ENFORCE(g_state == ProfilerState::kDisabled,
223-
"The profiling state should be disabled when calling ",
224-
"EnableProfiler.");
225-
g_state = state;
226-
if (g_state == ProfilerState::kAll) {
227-
GetDeviceTracer()->Enable();
228+
if (state == g_state) {
229+
return;
228230
}
231+
g_state = state;
232+
should_send_profile_state = true;
233+
GetDeviceTracer()->Enable();
229234
#ifdef PADDLE_WITH_CUDA
230235
if (g_state == ProfilerState::kCUDA) {
231236
// Generate some dummy events first to reduce the startup overhead.
@@ -435,21 +440,33 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
435440

436441
void DisableProfiler(EventSortingKey sorted_key,
437442
const std::string& profile_path) {
438-
PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
439-
"Can't disable profiling, since it's not starting.");
443+
if (g_state == ProfilerState::kDisabled) return;
440444
// Mark the profiling stop.
441445
Mark("_stop_profiler_", nullptr);
442446

443447
std::vector<std::vector<Event>> all_events = GetAllEvents();
444448
ParseEvents(all_events, sorted_key);
445449
ResetProfiler();
446450
DeviceTracer* tracer = GetDeviceTracer();
447-
if (g_state == ProfilerState::kAll && tracer && tracer->IsEnabled()) {
451+
if (tracer->IsEnabled()) {
448452
tracer->Disable();
449453
tracer->GenProfile(profile_path);
450454
}
451455
g_state = ProfilerState::kDisabled;
456+
should_send_profile_state = true;
457+
}
458+
459+
bool IsProfileEnabled() { return g_state != ProfilerState::kDisabled; }
460+
bool ShouldSendProfileState() { return should_send_profile_state; }
461+
462+
void SetProfileListener() {
463+
std::mt19937 rng;
464+
rng.seed(std::random_device()());
465+
std::uniform_int_distribution<std::mt19937::result_type> dist6(
466+
1, std::numeric_limits<int64_t>::max());
467+
profiler_lister_id = dist6(rng);
452468
}
469+
int64_t ListenerId() { return profiler_lister_id; }
453470

454471
} // namespace platform
455472
} // namespace paddle

paddle/fluid/platform/profiler.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,13 @@ void ResetProfiler();
114114
void DisableProfiler(EventSortingKey sorted_key,
115115
const std::string& profile_path);
116116

117+
// Test if the profiler is currently enabled.
118+
bool IsProfileEnabled();
119+
// Whether the trainer should send profiling state to PS.
120+
bool ShouldSendProfileState();
121+
// Mark current process as PS by assigning a lister id.
122+
void SetProfileListener();
123+
int64_t ListenerId();
124+
117125
} // namespace platform
118126
} // namespace paddle

tools/timeline.py

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
parser = argparse.ArgumentParser(description=__doc__)
2424
parser.add_argument(
25-
'--profile_path', type=str, default='', help='Input profile file name.')
25+
'--profile_path',
26+
type=str,
27+
default='',
28+
help='Input profile file name. If there are multiple file, the format '
29+
'should be trainer1=file1,trainer2=file2,ps=file3')
2630
parser.add_argument(
2731
'--timeline_path', type=str, default='', help='Output timeline file name.')
2832
args = parser.parse_args()
@@ -108,8 +112,8 @@ def format_to_string(self, pretty=False):
108112

109113

110114
class Timeline(object):
111-
def __init__(self, profile_pb):
112-
self._profile_pb = profile_pb
115+
def __init__(self, profile_dict):
116+
self._profile_dict = profile_dict
113117
self._pid = 0
114118
self._devices = dict()
115119
self._chrome_trace = _ChromeTraceFormatter()
@@ -120,35 +124,37 @@ def _allocate_pid(self):
120124
return cur_pid
121125

122126
def _allocate_pids(self):
123-
for event in self._profile_pb.events:
124-
if event.type == profiler_pb2.Event.CPU:
125-
if (event.device_id, "CPU") not in self._devices:
126-
pid = self._allocate_pid()
127-
self._devices[(event.device_id, "CPU")] = pid
128-
self._chrome_trace.emit_pid("cpu:block:%d" %
129-
(event.device_id), pid)
130-
elif event.type == profiler_pb2.Event.GPUKernel:
131-
if (event.device_id, "GPUKernel") not in self._devices:
132-
pid = self._allocate_pid()
133-
self._devices[(event.device_id, "GPUKernel")] = pid
134-
self._chrome_trace.emit_pid("gpu:%d" % (event.device_id),
135-
pid)
127+
for k, profile_pb in self._profile_dict.iteritems():
128+
for event in profile_pb.events:
129+
if event.type == profiler_pb2.Event.CPU:
130+
if (k, event.device_id, "CPU") not in self._devices:
131+
pid = self._allocate_pid()
132+
self._devices[(k, event.device_id, "CPU")] = pid
133+
self._chrome_trace.emit_pid("%s:cpu:block:%d" %
134+
(k, event.device_id), pid)
135+
elif event.type == profiler_pb2.Event.GPUKernel:
136+
if (k, event.device_id, "GPUKernel") not in self._devices:
137+
pid = self._allocate_pid()
138+
self._devices[(k, event.device_id, "GPUKernel")] = pid
139+
self._chrome_trace.emit_pid("%s:gpu:%d" %
140+
(k, event.device_id), pid)
136141

137142
def _allocate_events(self):
138-
for event in self._profile_pb.events:
139-
if event.type == profiler_pb2.Event.CPU:
140-
type = "CPU"
141-
elif event.type == profiler_pb2.Event.GPUKernel:
142-
type = "GPUKernel"
143-
pid = self._devices[(event.device_id, type)]
144-
args = {'name': event.name}
145-
if event.memcopy.bytes > 0:
146-
args = {'mem_bytes': event.memcopy.bytes}
147-
# TODO(panyx0718): Chrome tracing only handles ms. However, some
148-
# ops takes micro-seconds. Hence, we keep the ns here.
149-
self._chrome_trace.emit_region(
150-
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
151-
event.sub_device_id, 'Op', event.name, args)
143+
for k, profile_pb in self._profile_dict.iteritems():
144+
for event in profile_pb.events:
145+
if event.type == profiler_pb2.Event.CPU:
146+
type = "CPU"
147+
elif event.type == profiler_pb2.Event.GPUKernel:
148+
type = "GPUKernel"
149+
pid = self._devices[(k, event.device_id, type)]
150+
args = {'name': event.name}
151+
if event.memcopy.bytes > 0:
152+
args = {'mem_bytes': event.memcopy.bytes}
153+
# TODO(panyx0718): Chrome tracing only handles ms. However, some
154+
# ops takes micro-seconds. Hence, we keep the ns here.
155+
self._chrome_trace.emit_region(
156+
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
157+
event.sub_device_id, 'Op', event.name, args)
152158

153159
def generate_chrome_trace(self):
154160
self._allocate_pids()
@@ -163,11 +169,23 @@ def generate_chrome_trace(self):
163169
if args.timeline_path:
164170
timeline_path = args.timeline_path
165171

166-
with open(profile_path, 'r') as f:
167-
profile_s = f.read()
168-
profile_pb = profiler_pb2.Profile()
169-
profile_pb.ParseFromString(profile_s)
170-
171-
tl = Timeline(profile_pb)
172+
profile_paths = profile_path.split(',')
173+
profile_dict = dict()
174+
if len(profile_path) == 1:
175+
with open(profile_path, 'r') as f:
176+
profile_s = f.read()
177+
profile_pb = profiler_pb2.Profile()
178+
profile_pb.ParseFromString(profile_s)
179+
profile_dict['trainer'] = profile_pb
180+
else:
181+
for profile_path in profile_paths:
182+
k, v = profile_path.split('=')
183+
with open(v, 'r') as f:
184+
profile_s = f.read()
185+
profile_pb = profiler_pb2.Profile()
186+
profile_pb.ParseFromString(profile_s)
187+
profile_dict[k] = profile_pb
188+
189+
tl = Timeline(profile_dict)
172190
with open(timeline_path, 'w') as f:
173191
f.write(tl.generate_chrome_trace())

0 commit comments

Comments
 (0)