Skip to content

Commit d9e76fe

Browse files
authored
support recording batch time during training (#697)
1 parent 7b27d8e commit d9e76fe

File tree

3 files changed

+228
-1
lines changed

3 files changed

+228
-1
lines changed

pcdet/utils/common_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,20 @@ def sa_create(name, var):
245245
x.flags.writeable = False
246246
return x
247247

248+
249+
class AverageMeter(object):
250+
"""Computes and stores the average and current value"""
251+
def __init__(self):
252+
self.reset()
253+
254+
def reset(self):
255+
self.val = 0
256+
self.avg = 0
257+
self.sum = 0
258+
self.count = 0
259+
260+
def update(self, val, n=1):
261+
self.val = val
262+
self.sum += val * n
263+
self.count += n
264+
self.avg = self.sum / self.count

pcdet/utils/commu_utils.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""
2+
This file contains primitives for multi-gpu communication.
3+
This is useful when doing distributed training.
4+
5+
deeply borrow from maskrcnn-benchmark and ST3D
6+
"""
7+
8+
import pickle
9+
import time
10+
11+
import torch
12+
import torch.distributed as dist
13+
14+
15+
def get_world_size():
16+
if not dist.is_available():
17+
return 1
18+
if not dist.is_initialized():
19+
return 1
20+
return dist.get_world_size()
21+
22+
23+
def get_rank():
24+
if not dist.is_available():
25+
return 0
26+
if not dist.is_initialized():
27+
return 0
28+
return dist.get_rank()
29+
30+
31+
def is_main_process():
32+
return get_rank() == 0
33+
34+
35+
def synchronize():
36+
"""
37+
Helper function to synchronize (barrier) among all processes when
38+
using distributed training
39+
"""
40+
if not dist.is_available():
41+
return
42+
if not dist.is_initialized():
43+
return
44+
world_size = dist.get_world_size()
45+
if world_size == 1:
46+
return
47+
dist.barrier()
48+
49+
50+
def all_gather(data):
51+
"""
52+
Run all_gather on arbitrary picklable data (not necessarily tensors)
53+
Args:
54+
data: any picklable object
55+
Returns:
56+
list[data]: list of data gathered from each rank
57+
"""
58+
world_size = get_world_size()
59+
if world_size == 1:
60+
return [data]
61+
62+
# serialized to a Tensor
63+
origin_size = None
64+
if not isinstance(data, torch.Tensor):
65+
buffer = pickle.dumps(data)
66+
storage = torch.ByteStorage.from_buffer(buffer)
67+
tensor = torch.ByteTensor(storage).to("cuda")
68+
else:
69+
origin_size = data.size()
70+
tensor = data.reshape(-1)
71+
72+
tensor_type = tensor.dtype
73+
74+
# obtain Tensor size of each rank
75+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
76+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
77+
dist.all_gather(size_list, local_size)
78+
size_list = [int(size.item()) for size in size_list]
79+
max_size = max(size_list)
80+
81+
# receiving Tensor from all ranks
82+
# we pad the tensor because torch all_gather does not support
83+
# gathering tensors of different shapes
84+
tensor_list = []
85+
for _ in size_list:
86+
tensor_list.append(torch.FloatTensor(size=(max_size,)).cuda().to(tensor_type))
87+
if local_size != max_size:
88+
padding = torch.FloatTensor(size=(max_size - local_size,)).cuda().to(tensor_type)
89+
tensor = torch.cat((tensor, padding), dim=0)
90+
dist.all_gather(tensor_list, tensor)
91+
92+
data_list = []
93+
for size, tensor in zip(size_list, tensor_list):
94+
if origin_size is None:
95+
buffer = tensor.cpu().numpy().tobytes()[:size]
96+
data_list.append(pickle.loads(buffer))
97+
else:
98+
buffer = tensor[:size]
99+
data_list.append(buffer)
100+
101+
if origin_size is not None:
102+
new_shape = [-1] + list(origin_size[1:])
103+
resized_list = []
104+
for data in data_list:
105+
# suppose the difference of tensor size exist in first dimension
106+
data = data.reshape(new_shape)
107+
resized_list.append(data)
108+
109+
return resized_list
110+
else:
111+
return data_list
112+
113+
114+
def reduce_dict(input_dict, average=True):
115+
"""
116+
Args:
117+
input_dict (dict): all the values will be reduced
118+
average (bool): whether to do average or sum
119+
Reduce the values in the dictionary from all processes so that process with rank
120+
0 has the averaged results. Returns a dict with the same fields as
121+
input_dict, after reduction.
122+
"""
123+
world_size = get_world_size()
124+
if world_size < 2:
125+
return input_dict
126+
with torch.no_grad():
127+
names = []
128+
values = []
129+
# sort the keys so that they are consistent across processes
130+
for k in sorted(input_dict.keys()):
131+
names.append(k)
132+
values.append(input_dict[k])
133+
values = torch.stack(values, dim=0)
134+
dist.reduce(values, dst=0)
135+
if dist.get_rank() == 0 and average:
136+
# only main process gets accumulated, so only divide by
137+
# world_size in this case
138+
values /= world_size
139+
reduced_dict = {k: v for k, v in zip(names, values)}
140+
return reduced_dict
141+
142+
143+
def average_reduce_value(data):
144+
data_list = all_gather(data)
145+
return sum(data_list) / len(data_list)
146+
147+
148+
def all_reduce(data, op="sum", average=False):
149+
150+
def op_map(op):
151+
op_dict = {
152+
"SUM": dist.ReduceOp.SUM,
153+
"MAX": dist.ReduceOp.MAX,
154+
"MIN": dist.ReduceOp.MIN,
155+
"PRODUCT": dist.ReduceOp.PRODUCT,
156+
}
157+
return op_dict[op]
158+
159+
world_size = get_world_size()
160+
if world_size > 1:
161+
reduced_data = data.clone()
162+
dist.all_reduce(reduced_data, op=op_map(op.upper()))
163+
if average:
164+
assert op.upper() == 'SUM'
165+
return reduced_data / world_size
166+
else:
167+
return reduced_data
168+
return data
169+
170+
171+
@torch.no_grad()
172+
def concat_all_gather(tensor):
173+
"""
174+
Performs all_gather operation on the provided tensors.
175+
*** Warning ***: torch.distributed.all_gather has no gradient.
176+
"""
177+
tensors_gather = [torch.ones_like(tensor)
178+
for _ in range(torch.distributed.get_world_size())]
179+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
180+
181+
output = torch.cat(tensors_gather, dim=0)
182+
return output

tools/train_utils/train_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import torch
55
import tqdm
6+
import time
67
from torch.nn.utils import clip_grad_norm_
8+
from pcdet.utils import common_utils, commu_utils
79

810

911
def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, accumulated_iter, optim_cfg,
@@ -13,14 +15,21 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
1315

1416
if rank == 0:
1517
pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)
18+
data_time = common_utils.AverageMeter()
19+
batch_time = common_utils.AverageMeter()
20+
forward_time = common_utils.AverageMeter()
1621

1722
for cur_it in range(total_it_each_epoch):
23+
end = time.time()
1824
try:
1925
batch = next(dataloader_iter)
2026
except StopIteration:
2127
dataloader_iter = iter(train_loader)
2228
batch = next(dataloader_iter)
2329
print('new iters')
30+
31+
data_timer = time.time()
32+
cur_data_time = data_timer - end
2433

2534
lr_scheduler.step(accumulated_iter)
2635

@@ -37,12 +46,31 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
3746

3847
loss, tb_dict, disp_dict = model_func(model, batch)
3948

49+
forward_timer = time.time()
50+
cur_forward_time = forward_timer - data_timer
51+
4052
loss.backward()
4153
clip_grad_norm_(model.parameters(), optim_cfg.GRAD_NORM_CLIP)
4254
optimizer.step()
4355

4456
accumulated_iter += 1
45-
disp_dict.update({'loss': loss.item(), 'lr': cur_lr})
57+
58+
cur_batch_time = time.time() - end
59+
# average reduce
60+
avg_data_time = commu_utils.average_reduce_value(cur_data_time)
61+
avg_forward_time = commu_utils.average_reduce_value(cur_forward_time)
62+
avg_batch_time = commu_utils.average_reduce_value(cur_batch_time)
63+
64+
if rank == 0:
65+
data_time.update(avg_data_time)
66+
forward_time.update(avg_forward_time)
67+
batch_time.update(avg_batch_time)
68+
69+
70+
disp_dict.update({
71+
'loss': loss.item(), 'lr': cur_lr, 'd_time': f'{data_time.val:.2f}({data_time.avg:.2f})',
72+
'f_time': f'{forward_time.val:.2f}({forward_time.avg:.2f})', 'b_time': f'{batch_time.val:.2f}({batch_time.avg:.2f})'
73+
})
4674

4775
# log to console and tensorboard
4876
if rank == 0:

0 commit comments

Comments
 (0)