Skip to content

Commit 2a6897e

Browse files
authored
Merge pull request #4 from MedericFourmy/add_cpu_option
Add cpu option
2 parents 8e284ac + fd2e555 commit 2a6897e

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

cosypose/integrated/detector.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,17 @@ def __init__(self, model):
1111
self.model = model
1212
self.config = model.config
1313
self.category_id_to_label = {v: k for k, v in self.config.label_to_category_id.items()}
14-
15-
def cast(self, obj):
16-
return obj.cuda()
14+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1715

1816
@torch.no_grad()
1917
def get_detections(self, images, detection_th=None,
2018
output_masks=False, mask_th=0.8,
2119
one_instance_per_class=False):
22-
images = self.cast(images).float()
20+
images = images.to(self.device).float()
2321
if images.shape[-1] == 3:
2422
images = images.permute(0, 3, 1, 2)
2523
if images.max() > 1:
2624
images = images / 255.
27-
images = images.float().cuda()
2825
outputs_ = self.model([image_n for image_n in images])
2926

3027
infos = []
@@ -46,12 +43,12 @@ def get_detections(self, images, detection_th=None,
4643
infos.append(info)
4744

4845
if len(bboxes) > 0:
49-
bboxes = torch.stack(bboxes).cuda().float()
50-
masks = torch.stack(masks).cuda()
46+
bboxes = torch.stack(bboxes).to(self.device).float()
47+
masks = torch.stack(masks).to(self.device)
5148
else:
5249
infos = dict(score=[], label=[], batch_im_id=[])
53-
bboxes = torch.empty(0, 4).cuda().float()
54-
masks = torch.empty(0, images.shape[1], images.shape[2], dtype=torch.bool).cuda()
50+
bboxes = torch.empty(0, 4).to(self.device).float()
51+
masks = torch.empty(0, images.shape[1], images.shape[2], dtype=torch.bool).to(self.device)
5552

5653
outputs = tc.PandasTensorCollection(
5754
infos=pd.DataFrame(infos),

cosypose/rendering/bullet_batch_renderer.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
from .bullet_scene_renderer import BulletSceneRenderer
77

88

9-
def init_renderer(urdf_ds, preload=True):
9+
def init_renderer(urdf_ds, preload=True, gpu_renderer=True):
1010
renderer = BulletSceneRenderer(urdf_ds=urdf_ds,
1111
preload_cache=preload,
12-
background_color=(0, 0, 0))
12+
background_color=(0, 0, 0),
13+
gpu_renderer=gpu_renderer)
1314
return renderer
1415

1516

16-
def worker_loop(worker_id, in_queue, out_queue, object_set, preload=True):
17-
renderer = init_renderer(object_set, preload=preload)
17+
def worker_loop(worker_id, in_queue, out_queue, object_set, preload=True, gpu_renderer=True):
18+
renderer = init_renderer(object_set, preload=preload, gpu_renderer=gpu_renderer)
1819
while True:
1920
kwargs = in_queue.get()
2021
if kwargs is None:
@@ -38,10 +39,11 @@ def worker_loop(worker_id, in_queue, out_queue, object_set, preload=True):
3839

3940

4041
class BulletBatchRenderer:
41-
def __init__(self, object_set, n_workers=8, preload_cache=True):
42+
def __init__(self, object_set, n_workers=8, preload_cache=True, gpu_renderer=True):
4243
self.object_set = object_set
4344
self.n_workers = n_workers
44-
self.init_plotters(preload_cache)
45+
self.init_plotters(preload_cache, gpu_renderer)
46+
self.gpu_renderer = gpu_renderer
4547

4648
def render(self, obj_infos, TCO, K, resolution=(240, 320), render_depth=False):
4749
TCO = torch.as_tensor(TCO).detach()
@@ -79,17 +81,23 @@ def render(self, obj_infos, TCO, K, resolution=(240, 320), render_depth=False):
7981
images[data_id] = im[0]
8082
if render_depth:
8183
depths[data_id] = depth[0]
82-
images = torch.as_tensor(np.stack(images, axis=0)).pin_memory().cuda(non_blocking=True)
84+
if self.gpu_renderer:
85+
images = torch.as_tensor(np.stack(images, axis=0)).pin_memory().cuda(non_blocking=True)
86+
else:
87+
images = torch.as_tensor(np.stack(images, axis=0))
8388
images = images.float().permute(0, 3, 1, 2) / 255
8489

8590
if render_depth:
86-
depths = torch.as_tensor(np.stack(depths, axis=0)).pin_memory().cuda(non_blocking=True)
91+
if self.gpu_renderer:
92+
depths = torch.as_tensor(np.stack(depths, axis=0)).pin_memory().cuda(non_blocking=True)
93+
else:
94+
depths = torch.as_tensor(np.stack(depths, axis=0))
8795
depths = depths.float()
8896
return images, depths
8997
else:
9098
return images
9199

92-
def init_plotters(self, preload_cache):
100+
def init_plotters(self, preload_cache, gpu_renderer):
93101
self.plotters = []
94102
self.in_queue = multiprocessing.Queue()
95103
self.out_queue = multiprocessing.Queue()
@@ -100,12 +108,13 @@ def init_plotters(self, preload_cache):
100108
kwargs=dict(worker_id=n,
101109
in_queue=self.in_queue,
102110
out_queue=self.out_queue,
111+
object_set=self.object_set,
103112
preload=preload_cache,
104-
object_set=self.object_set))
113+
gpu_renderer=gpu_renderer))
105114
plotter.start()
106115
self.plotters.append(plotter)
107116
else:
108-
self.plotters = [init_renderer(self.object_set, preload_cache)]
117+
self.plotters = [init_renderer(self.object_set, preload_cache, gpu_renderer)]
109118

110119
def stop(self):
111120
if self.n_workers > 0:

0 commit comments

Comments
 (0)