2929
3030def setup_loader (args ):
3131 ds = Hdf5PoseDataset (
32- args .filename ,
33- transform = Compose ([
34- dtr .batch .offset_points_by_half_pixel , # For when pixels are considered grid cell centers
35- ]),
36- monochrome = True )
32+ args .filename ,
33+ transform = Compose (
34+ [
35+ dtr .batch .offset_points_by_half_pixel , # For when pixels are considered grid cell centers
36+ ]
37+ ),
38+ monochrome = True ,
39+ )
3740 if args .dryrun :
3841 ds = Subset (ds , np .arange (10 ))
3942 N = len (ds )
40- loader = dtr .PostprocessingLoader (ds , args .batchsize ,
43+ loader = dtr .PostprocessingLoader (
44+ ds ,
45+ args .batchsize ,
4146 shuffle = False ,
4247 num_workers = utils .num_workers (),
4348 postprocess = None ,
44- collate_fn = lambda samples : samples ,
49+ collate_fn = lambda samples : samples ,
4550 )
4651 return loader , ds
4752
4853
49- def fit_batch (net : InferenceNetwork , batch : List [Batch ]):
50- images = [ s ['image' ] for s in batch ]
51- rois = torch .stack ([ s ['roi' ] for s in batch ])
52- indices = torch .stack ([ s ['index' ] for s in batch ])
53- out = Predictor (net , focus_roi_expansion_factor = 1.2 ).predict_batch (images , rois , )
54- out = {
55- k :out [k ] for k in 'unnormalized_quat coord pt3d_68 shapeparam' .split ()
56- }
57- out .update (index = indices )
54+ def fit_batch (net : InferenceNetwork , batch : List [Batch ]):
55+ images = [s ["image" ] for s in batch ]
56+ rois = torch .stack ([s ["roi" ] for s in batch ])
57+ indices = torch .stack ([s ["index" ] for s in batch ])
58+ out = Predictor (net , focus_roi_expansion_factor = 1.2 ).predict_batch (
59+ images ,
60+ rois ,
61+ )
62+ out = {k : out [k ] for k in "unnormalized_quat coord pt3d_68 shapeparam" .split ()}
63+ out .update (index = indices )
5864 return out
5965
6066
6167def test_quats_average ():
6268 def positivereal (q ):
63- s = np .sign (q [...,3 ])
64- return q * s [...,None ]
69+ s = np .sign (q [..., 3 ])
70+ return q * s [..., None ]
71+
6572 from scipy .spatial .transform import Rotation
73+
6674 expected_quats = Rotation .random (10 ).as_quat ()
6775 quats = Rotation .from_quat (np .repeat (expected_quats , 10 , axis = 0 ))
68- offsets = Rotation .random (10 * 10 ).as_rotvec ()* 0.01
76+ offsets = Rotation .random (10 * 10 ).as_rotvec () * 0.01
6977 quats = quats * Rotation .from_rotvec (offsets )
70- quats = quats .as_quat ().reshape ((10 ,10 ,4 )).transpose (1 ,0 , 2 )
78+ quats = quats .as_quat ().reshape ((10 , 10 , 4 )).transpose (1 , 0 , 2 )
7179 out = quat_average (quats )
72- #print (positivereal(out) - positivereal(expected_quats))
73- assert np .allclose (positivereal (out ) , positivereal (expected_quats ), atol = 0.02 )
80+ # print (positivereal(out) - positivereal(expected_quats))
81+ assert np .allclose (positivereal (out ), positivereal (expected_quats ), atol = 0.02 )
7482
7583
7684@torch .no_grad ()
7785def fitall (args ):
7886 assert all (isfile (f ) for f in args .checkpoints )
79- print ("Inferring from networks:" , args .checkpoints )
87+ print ("Inferring from networks:" , args .checkpoints )
8088
81- with h5py .File (args .filename , 'r+' ) as f :
89+ with h5py .File (args .filename , "r+" ) as f :
8290 g = f .require_group (args .hdfgroupname ) if args .hdfgroupname else f
83- for key in ' coords quats pt3d_68 shapeparams' :
91+ for key in " coords quats pt3d_68 shapeparams" :
8492 try :
8593 del g [key ]
8694 except KeyError :
@@ -91,22 +99,20 @@ def fitall(args):
9199
92100 outputs_per_net = defaultdict (list )
93101 for modelfile in tqdm .tqdm (args .checkpoints , desc = "Network" ):
94- net = load_pose_network (modelfile , ' cuda' )
95- outputs = [ fit_batch (net , batch ) for batch in tqdm .tqdm (loader , "Batch" ) ]
102+ net = load_pose_network (modelfile , " cuda" )
103+ outputs = [fit_batch (net , batch ) for batch in tqdm .tqdm (loader , "Batch" )]
96104 outputs = utils .list_of_dicts_to_dict_of_lists (outputs )
97- outputs = {k :np .concatenate (v ,axis = 0 ) for k ,v in outputs .items () }
98- ordering = np .argsort (outputs .pop (' index' ))
99- outputs = { k : v [ordering ] for k ,v in outputs .items () }
100- for k ,v in outputs .items ():
105+ outputs = {k : np .concatenate (v , axis = 0 ) for k , v in outputs .items ()}
106+ ordering = np .argsort (outputs .pop (" index" ))
107+ outputs = {k : v [ordering ] for k , v in outputs .items ()}
108+ for k , v in outputs .items ():
101109 outputs_per_net [k ].append (v )
102110 del outputs
103- outputs_per_net = {
104- k :np .stack (v ) for k ,v in outputs_per_net .items ()
105- }
111+ outputs_per_net = {k : np .stack (v ) for k , v in outputs_per_net .items ()}
106112
107113 del loader
108114 del ds
109- gc .collect () # Ensure the hdf5 file in the data was really closed.
115+ gc .collect () # Ensure the hdf5 file in the data was really closed.
110116 # There is no way to enforce it. We can only hope the garbage
111117 # collector will destroy the objects. If there is still a reference
112118 # left it will be read-only and lead to failure when trying to write
@@ -115,30 +121,56 @@ def fitall(args):
115121 # FIXME: final quats output is busted. Values are more or less garbage.
116122 # unnormalized_quat looks fine!
117123
118- quats = quat_average (outputs_per_net .pop (' unnormalized_quat' ))
119- coords = np .average (outputs_per_net .pop (' coord' ), axis = 0 )
120- pt3d_68 = np .average (outputs_per_net .pop (' pt3d_68' ), axis = 0 )
121- shapeparams = np .average (outputs_per_net .pop (' shapeparam' ), axis = 0 )
124+ quats = quat_average (outputs_per_net .pop (" unnormalized_quat" ))
125+ coords = np .average (outputs_per_net .pop (" coord" ), axis = 0 )
126+ pt3d_68 = np .average (outputs_per_net .pop (" pt3d_68" ), axis = 0 )
127+ shapeparams = np .average (outputs_per_net .pop (" shapeparam" ), axis = 0 )
122128
123129 assert len (quats ) == num_samples
124130
125- with h5py .File (args .filename , 'r+' ) as f :
131+ with h5py .File (args .filename , "r+" ) as f :
126132 g = f .require_group (args .hdfgroupname ) if args .hdfgroupname else f
127- ds_quats = create_pose_dataset (g , C .quat , count = num_samples , data = quats , exists_ok = args .overwrite )
128- ds_coords = create_pose_dataset (g , C .xys , count = num_samples , data = coords , exists_ok = args .overwrite )
129- ds_pt3d_68 = create_pose_dataset (g , C .points , name = 'pt3d_68' , count = num_samples , shape_wo_batch_dim = (68 ,3 ), data = pt3d_68 , exists_ok = args .overwrite )
130- ds_shapeparams = create_pose_dataset (g ,C .general , name = 'shapeparams' , count = num_samples , shape_wo_batch_dim = (50 ,), data = shapeparams , exists_ok = args .overwrite )
131-
132-
133- if __name__ == '__main__' :
133+ ds_quats = create_pose_dataset (
134+ g , C .quat , count = num_samples , data = quats , exists_ok = args .overwrite
135+ )
136+ ds_coords = create_pose_dataset (
137+ g , C .xys , count = num_samples , data = coords , exists_ok = args .overwrite
138+ )
139+ ds_pt3d_68 = create_pose_dataset (
140+ g ,
141+ C .points ,
142+ name = "pt3d_68" ,
143+ count = num_samples ,
144+ shape_wo_batch_dim = (68 , 3 ),
145+ data = pt3d_68 ,
146+ exists_ok = args .overwrite ,
147+ )
148+ ds_shapeparams = create_pose_dataset (
149+ g ,
150+ C .general ,
151+ name = "shapeparams" ,
152+ count = num_samples ,
153+ shape_wo_batch_dim = (50 ,),
154+ data = shapeparams ,
155+ exists_ok = args .overwrite ,
156+ )
157+
158+
159+ if __name__ == "__main__" :
134160 test_quats_average ()
135161 parser = argparse .ArgumentParser ()
136- parser .add_argument ('filename' , type = str , help = 'the dataset to label' )
137- parser .add_argument ('-c' ,'--checkpoints' , help = 'model checkpoint' , nargs = '*' , type = str )
138- parser .add_argument ('-b' ,'--batchsize' , help = "The batch size" , type = int , default = 512 )
139- parser .add_argument ('--hdf-group-name' , help = "Group to store the annotations in" , type = str , default = '' , dest = 'hdfgroupname' )
140- parser .add_argument ('--dryrun' , default = False , action = 'store_true' )
141- parser .add_argument ('--overwrite' , '-f' , default = False , action = 'store_true' )
162+ parser .add_argument ("filename" , type = str , help = "the dataset to label" )
163+ parser .add_argument ("-c" , "--checkpoints" , help = "model checkpoint" , nargs = "*" , type = str )
164+ parser .add_argument ("-b" , "--batchsize" , help = "The batch size" , type = int , default = 512 )
165+ parser .add_argument (
166+ "--hdf-group-name" ,
167+ help = "Group to store the annotations in" ,
168+ type = str ,
169+ default = "" ,
170+ dest = "hdfgroupname" ,
171+ )
172+ parser .add_argument ("--dryrun" , default = False , action = "store_true" )
173+ parser .add_argument ("--overwrite" , "-f" , default = False , action = "store_true" )
142174 args = parser .parse_args ()
143- args .device = ' cuda'
144- fitall (args )
175+ args .device = " cuda"
176+ fitall (args )
0 commit comments