66from .bullet_scene_renderer import BulletSceneRenderer
77
88
9- def init_renderer (urdf_ds , preload = True ):
9+ def init_renderer (urdf_ds , preload = True , gpu_renderer = True ):
1010 renderer = BulletSceneRenderer (urdf_ds = urdf_ds ,
1111 preload_cache = preload ,
12- background_color = (0 , 0 , 0 ))
12+ background_color = (0 , 0 , 0 ),
13+ gpu_renderer = gpu_renderer )
1314 return renderer
1415
1516
16- def worker_loop (worker_id , in_queue , out_queue , object_set , preload = True ):
17- renderer = init_renderer (object_set , preload = preload )
17+ def worker_loop (worker_id , in_queue , out_queue , object_set , preload = True , gpu_renderer = True ):
18+ renderer = init_renderer (object_set , preload = preload , gpu_renderer = gpu_renderer )
1819 while True :
1920 kwargs = in_queue .get ()
2021 if kwargs is None :
@@ -38,10 +39,11 @@ def worker_loop(worker_id, in_queue, out_queue, object_set, preload=True):
3839
3940
4041class BulletBatchRenderer :
41- def __init__ (self , object_set , n_workers = 8 , preload_cache = True ):
42+ def __init__ (self , object_set , n_workers = 8 , preload_cache = True , gpu_renderer = True ):
4243 self .object_set = object_set
4344 self .n_workers = n_workers
44- self .init_plotters (preload_cache )
45+ self .init_plotters (preload_cache , gpu_renderer )
46+ self .gpu_renderer = gpu_renderer
4547
4648 def render (self , obj_infos , TCO , K , resolution = (240 , 320 ), render_depth = False ):
4749 TCO = torch .as_tensor (TCO ).detach ()
@@ -79,17 +81,23 @@ def render(self, obj_infos, TCO, K, resolution=(240, 320), render_depth=False):
7981 images [data_id ] = im [0 ]
8082 if render_depth :
8183 depths [data_id ] = depth [0 ]
82- images = torch .as_tensor (np .stack (images , axis = 0 )).pin_memory ().cuda (non_blocking = True )
84+ if self .gpu_renderer :
85+ images = torch .as_tensor (np .stack (images , axis = 0 )).pin_memory ().cuda (non_blocking = True )
86+ else :
87+ images = torch .as_tensor (np .stack (images , axis = 0 ))
8388 images = images .float ().permute (0 , 3 , 1 , 2 ) / 255
8489
8590 if render_depth :
86- depths = torch .as_tensor (np .stack (depths , axis = 0 )).pin_memory ().cuda (non_blocking = True )
91+ if self .gpu_renderer :
92+ depths = torch .as_tensor (np .stack (depths , axis = 0 )).pin_memory ().cuda (non_blocking = True )
93+ else :
94+ depths = torch .as_tensor (np .stack (depths , axis = 0 ))
8795 depths = depths .float ()
8896 return images , depths
8997 else :
9098 return images
9199
92- def init_plotters (self , preload_cache ):
100+ def init_plotters (self , preload_cache , gpu_renderer ):
93101 self .plotters = []
94102 self .in_queue = multiprocessing .Queue ()
95103 self .out_queue = multiprocessing .Queue ()
@@ -100,12 +108,13 @@ def init_plotters(self, preload_cache):
100108 kwargs = dict (worker_id = n ,
101109 in_queue = self .in_queue ,
102110 out_queue = self .out_queue ,
111+ object_set = self .object_set ,
103112 preload = preload_cache ,
104- object_set = self . object_set ))
113+ gpu_renderer = gpu_renderer ))
105114 plotter .start ()
106115 self .plotters .append (plotter )
107116 else :
108- self .plotters = [init_renderer (self .object_set , preload_cache )]
117+ self .plotters = [init_renderer (self .object_set , preload_cache , gpu_renderer )]
109118
110119 def stop (self ):
111120 if self .n_workers > 0 :
0 commit comments