@@ -276,7 +276,7 @@ struct EventItem {
276
276
// Print results
277
277
void PrintProfiler (const std::vector<std::vector<EventItem>>& events_table,
278
278
const std::string& sorted_domain, const size_t name_width,
279
- const size_t data_width, double total ) {
279
+ const size_t data_width, bool merge_thread ) {
280
280
// Output header information
281
281
std::cout << " \n ------------------------->"
282
282
<< " Profiling Report "
@@ -292,6 +292,10 @@ void PrintProfiler(const std::vector<std::vector<EventItem>>& events_table,
292
292
PADDLE_THROW (" Invalid profiler state" , g_state);
293
293
}
294
294
295
+ if (merge_thread) {
296
+ std::cout << " Note! This Report merge all thread info into one."
297
+ << std::endl;
298
+ }
295
299
std::cout << " Place: " << place << std::endl;
296
300
std::cout << " Time unit: ms" << std::endl;
297
301
std::cout << " Sorted by " << sorted_domain
@@ -312,17 +316,18 @@ void PrintProfiler(const std::vector<std::vector<EventItem>>& events_table,
312
316
<< std::setw (data_width) << event_item.min_time
313
317
<< std::setw (data_width) << event_item.max_time
314
318
<< std::setw (data_width) << event_item.ave_time
315
- << std::setw (data_width) << event_item.total_time / total
316
- << std::endl;
319
+ << std::setw (data_width) << event_item.ratio << std::endl;
317
320
}
318
321
}
319
322
std::cout << std::endl;
320
323
}
321
324
322
325
// Parse the event list and output the profiling report
323
326
void ParseEvents (const std::vector<std::vector<Event>>& events,
327
+ bool merge_thread,
324
328
EventSortingKey sorted_by = EventSortingKey::kDefault ) {
325
329
if (g_state == ProfilerState::kDisabled ) return ;
330
+ if (merge_thread && events.size () < 2 ) return ;
326
331
327
332
std::string sorted_domain;
328
333
std::function<bool (const EventItem&, const EventItem&)> sorted_func;
@@ -361,34 +366,55 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
361
366
sorted_domain = " event first end time" ;
362
367
}
363
368
369
+ const std::vector<std::vector<Event>>* analyze_events;
370
+ std::vector<std::vector<Event>> merged_events_list;
371
+ if (merge_thread) {
372
+ std::vector<Event> merged_events;
373
+ for (int i = 0 ; i < events.size (); ++i) {
374
+ for (int j = 0 ; j < events[i].size (); ++j) {
375
+ merged_events.push_back (events[i][j]);
376
+ }
377
+ }
378
+ merged_events_list.push_back (merged_events);
379
+ analyze_events = &merged_events_list;
380
+ } else {
381
+ analyze_events = &events;
382
+ }
383
+
364
384
std::vector<std::vector<EventItem>> events_table;
365
385
size_t max_name_width = 0 ;
366
- double total = 0 .; // the total time
367
- for ( size_t i = 0 ; i < events. size (); i++) {
386
+ for ( size_t i = 0 ; i < (*analyze_events). size (); i++) {
387
+ double total = 0 .; // the total time in one thread
368
388
std::list<Event> pushed_events;
369
389
std::vector<EventItem> event_items;
370
390
std::unordered_map<std::string, int > event_idx;
371
391
372
- for (size_t j = 0 ; j < events [i].size (); j++) {
373
- if (events [i][j].type () == EventType::kPushRange ) {
374
- pushed_events.push_back (events [i][j]);
375
- } else if (events [i][j].type () == EventType::kPopRange ) {
392
+ for (size_t j = 0 ; j < (*analyze_events) [i].size (); j++) {
393
+ if ((*analyze_events) [i][j].type () == EventType::kPushRange ) {
394
+ pushed_events.push_back ((*analyze_events) [i][j]);
395
+ } else if ((*analyze_events) [i][j].type () == EventType::kPopRange ) {
376
396
std::list<Event>::reverse_iterator rit = pushed_events.rbegin ();
377
397
while (rit != pushed_events.rend () &&
378
- rit->name () != events [i][j].name ()) {
398
+ rit->name () != (*analyze_events) [i][j].name ()) {
379
399
++rit;
380
400
}
381
401
382
402
if (rit != pushed_events.rend ()) {
383
403
double event_time = (g_state == ProfilerState::kCUDA ||
384
404
g_state == ProfilerState::kAll )
385
- ? rit->CudaElapsedMs (events [i][j])
386
- : rit->CpuElapsedMs (events [i][j]);
405
+ ? rit->CudaElapsedMs ((*analyze_events) [i][j])
406
+ : rit->CpuElapsedMs ((*analyze_events) [i][j]);
387
407
total += event_time;
388
408
389
- std::string event_name =
390
- " thread" + std::to_string (rit->thread_id ()) + " ::" + rit->name ();
391
- max_name_width = std::max (max_name_width, event_name.size ());
409
+ std::string event_name;
410
+ if (merge_thread) {
411
+ event_name = rit->name ();
412
+ max_name_width = std::max (max_name_width, event_name.size ());
413
+ } else {
414
+ event_name = " thread" + std::to_string (rit->thread_id ()) + " ::" +
415
+ rit->name ();
416
+ max_name_width = std::max (max_name_width, event_name.size ());
417
+ }
392
418
393
419
if (event_idx.find (event_name) == event_idx.end ()) {
394
420
event_idx[event_name] = event_items.size ();
@@ -413,14 +439,15 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
413
439
pushed_events.erase ((++rit).base ());
414
440
} else {
415
441
LOG (WARNING) << " Cannot find the push marker of event \' "
416
- << events [i][j].name ()
442
+ << (*analyze_events) [i][j].name ()
417
443
<< " \' , which will be ignored in profiling report." ;
418
444
}
419
445
}
420
446
}
421
447
// average time
422
448
for (auto & item : event_items) {
423
449
item.ave_time = item.total_time / item.calls ;
450
+ item.ratio = item.total_time / total;
424
451
}
425
452
// sort
426
453
if (sorted_by != EventSortingKey::kDefault ) {
@@ -438,7 +465,8 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,
438
465
}
439
466
440
467
// Print report
441
- PrintProfiler (events_table, sorted_domain, max_name_width + 4 , 12 , total);
468
+ PrintProfiler (events_table, sorted_domain, max_name_width + 4 , 12 ,
469
+ merge_thread);
442
470
}
443
471
444
472
void DisableProfiler (EventSortingKey sorted_key,
@@ -449,7 +477,8 @@ void DisableProfiler(EventSortingKey sorted_key,
449
477
Mark (" _stop_profiler_" , nullptr );
450
478
451
479
std::vector<std::vector<Event>> all_events = GetAllEvents ();
452
- ParseEvents (all_events, sorted_key);
480
+ ParseEvents (all_events, true , sorted_key);
481
+ ParseEvents (all_events, false , sorted_key);
453
482
ResetProfiler ();
454
483
DeviceTracer* tracer = GetDeviceTracer ();
455
484
if (tracer->IsEnabled ()) {
0 commit comments