1515import torch .nn .functional as F
1616from torch .utils .data import DataLoader
1717torch .backends .cudnn .benchmark = True
18+ torch .backends .cuda .matmul .allow_tf32 = True # PyTorch 1.7+
1819
1920import cryodrgn
2021from cryodrgn import mrc
@@ -95,7 +96,10 @@ def add_args(parser):
9596 group .add_argument ('--pose-enc' , action = 'store_true' , help = 'predict pose parameter using encoder' )
9697 group .add_argument ('--pose-only' , action = 'store_true' , help = 'train pose encoder only' )
9798 group .add_argument ('--plot' , action = 'store_true' , help = 'plot intermediate result' )
98- group .add_argument ('--estpose' , default = True , action = 'store_true' , help = 'estimate pose (default: %(default)s)' )
99+ group .add_argument ('--estpose' , default = False , action = 'store_true' , help = 'estimate pose' )
100+ group .add_argument ('--warp' , default = False , action = 'store_true' , help = 'using subtomograms from warp' )
101+ group .add_argument ('--tilt-step' , type = int , default = 2 , help = 'the interval between successive tilts (default: %(default)s)' )
102+ group .add_argument ('--tilt-range' , type = int , default = 50 , help = 'the range of tilt angles (default: %(default)s)' )
99103
100104 group = parser .add_argument_group ('Encoder Network' )
101105 group .add_argument ('--enc-layers' , dest = 'qlayers' , type = int , default = 3 , help = 'Number of hidden layers (default: %(default)s)' )
@@ -279,10 +283,10 @@ def run_batch(model, lattice, y, yt, rot, tilt=None, ind=None, ctf_params=None,
279283 z_mu , z_logvar , z = 0. , 0. , 0.
280284
281285 # add bfactors to ctf_params, the second from last column stores bfactor, the last column stores scale
282- #random_b = np.random.rand()*1.5
283- random_b = np .random .gamma (1. , 0.6 )
284- # random_b = torch.randn_like(c[..., 0, -2])/3.
285- c [...,- 2 ] = c [...,- 2 ] + (args .bfactor + random_b )* (4 * np .pi ** 2 )
286+ #random_b = ( np.random.normal())/3.
287+ # random_b = np.random.gamma(1., 0.6)
288+ random_b = torch .randn_like (c [..., 0 , - 2 ])/ 3.
289+ c [...,- 2 ] = c [...,- 2 ] + (args .bfactor + random_b . unsqueeze ( - 1 ) )* (4 * np .pi ** 2 )
286290
287291 plot = args .plot and it % (args .log_interval ) == B
288292 if plot :
@@ -333,6 +337,7 @@ def run_batch(model, lattice, y, yt, rot, tilt=None, ind=None, ctf_params=None,
333337 decout = model .vanilla_decode (rot , trans , z = z , save_mrc = save_image , eulers = euler ,
334338 ref_fft = y , ctf_param = c , encout = encout , mask = mask_real , body_poses = body_poses ,
335339 ctf_grid = ctf_grid , estpose = args .estpose , ctf_filename = ctf_filename , write_ctf = args .write_ctf )
340+
336341 if decout ["affine" ] is not None :
337342 posetracker .set_pose (decout ["affine" ][0 ].detach (), decout ["affine" ][1 ].detach (), ind )
338343
@@ -712,11 +717,17 @@ def flog(msg): # HACK: switch to logging module
712717 args .use_real = args .encode_mode == 'conv'
713718 args .real_data = args .pe_type == 'vanilla'
714719
715- if args .lazy_single :
720+ if args .lazy_single and not args . warp :
716721 data = dataset .LazyTomoMRCData (args .particles , norm = args .norm ,
717722 real_data = args .real_data , invert_data = args .invert_data ,
718723 ind = ind , keepreal = args .use_real , window = False ,
719724 datadir = args .datadir , relion31 = args .relion31 , window_r = args .window_r , downfrac = args .downfrac )
725+ elif args .lazy_single and args .warp :
726+ data = dataset .LazyTomoWARPMRCData (args .particles , norm = args .norm ,
727+ real_data = args .real_data , invert_data = args .invert_data ,
728+ ind = ind , keepreal = args .use_real , window = False ,
729+ datadir = args .datadir , relion31 = args .relion31 , window_r = args .window_r , downfrac = args .downfrac ,
730+ tilt_step = args .tilt_step , tilt_range = args .tilt_range )
720731 else :
721732 raise NotImplementedError ("Use --lazy-single for on-the-fly image loading" )
722733
@@ -751,8 +762,6 @@ def flog(msg): # HACK: switch to logging module
751762
752763 # load ctf
753764 if args .ctf is not None :
754- #if args.use_real:
755- # raise NotImplementedError("Not implemented with real-space encoder. Use phase-flipped images instead")
756765 flog ('Loading ctf params from {}' .format (args .ctf ))
757766 ctf_params = ctf .load_ctf_for_training (D - 1 , args .ctf )
758767 log ('first ctf params is: {}' .format (ctf_params [0 ,:]))
@@ -824,7 +833,6 @@ def flog(msg): # HACK: switch to logging module
824833 model_parameters = list (model .encoder .parameters ()) + list (model .decoder .parameters ()) #+ list(group_stat.parameters())
825834 pose_encoder = None
826835 optim = torch .optim .AdamW (model_parameters , lr = args .lr , weight_decay = args .wd )
827- assert args .accum_step >= 1
828836
829837 #if args.encode_mode == "grad":
830838 # discriminator_parameters = list(model.shape_encoder.parameters())
@@ -946,7 +954,8 @@ def flog(msg): # HACK: switch to logging module
946954 bfactor = args .bfactor
947955 lamb = args .lamb
948956 if args .log_interval % args .batch_size != 0 :
949- args .log_interval = args .batch_size * 8
957+ args .log_interval = args .batch_size * 16
958+ assert args .accum_step >= 1
950959
951960 for epoch in range (start_epoch , num_epochs ):
952961 t2 = dt .now ()
@@ -979,6 +988,7 @@ def flog(msg): # HACK: switch to logging module
979988 ind = minibatch [- 1 ]#.to(device)
980989 y = minibatch [0 ][0 ].to (device , non_blocking = True )
981990 ctf_param = minibatch [0 ][1 ].float ().to (device , non_blocking = True )
991+ ctf_filename = minibatch [0 ][2 ]
982992 #apixs = torch.ones(ctf_param.shape[:-1]).to(device)*args.angpix
983993 #ctf_param = torch.cat([apixs.unsqueeze(-1), ctf_param], dim=-1)
984994 # compute ctf!
@@ -1009,6 +1019,20 @@ def flog(msg): # HACK: switch to logging module
10091019 if body_euler is not None :
10101020 body_euler = body_euler .to (device )
10111021 body_trans = body_trans .to (device )
1022+
1023+ o_rot = lie_tools .hopf_to_SO3 (euler [:, :3 ])
1024+ ## perturb rotation by symm ops
1025+ #samples = torch.multinomial(symm_ops_weights, o_rot.shape[0], replacement=True)
1026+
1027+ ###rand_z = o_rot @ symm_ops[samples].to(o_rot.get_device())
1028+ ###print(rand_z)
1029+ ####pixrad = hp.max_pixrad(64)
1030+ #rand_z = lie_tools.random_biased_SO3(o_rot.shape[0], bias=256*np.sqrt(3)).to(o_rot.get_device())
1031+ #rand_z = o_rot @ rand_z
1032+ #rand_e = lie_tools.so3_to_hopf(rand_z)
1033+ ##print(rand_e - euler[:, :3])
1034+ #euler = rand_e
1035+
10121036 #print("euler, trans: ", euler.shape, tran.shape, y.shape)
10131037 #ctf_param = ctf_params[ind] if ctf_params is not None else None
10141038 z_mu , loss , gen_loss , snr , l1_loss , tv_loss , mu2 , std2 , mmd , c_mmd , mse , body_poses_pred = \
@@ -1020,7 +1044,7 @@ def flog(msg): # HACK: switch to logging module
10201044 it = batch_it , enc = None ,
10211045 args = args , euler = euler ,
10221046 posetracker = posetracker , data = data , update_params = (update_it % args .accum_step == args .accum_step - 1 ),
1023- snr2 = snr_ema , body_poses = (body_euler , body_trans ))
1047+ snr2 = snr_ema , body_poses = (body_euler , body_trans ), ctf_filename = ctf_filename )
10241048 update_it += 1
10251049 if do_pose_sgd and epoch >= args .pretrain :
10261050 pose_optimizer .step ()
@@ -1114,7 +1138,7 @@ def flog(msg): # HACK: switch to logging module
11141138 it = batch_it , enc = None ,
11151139 args = args , euler = euler ,
11161140 posetracker = posetracker , data = data , backward = False , update_params = False ,
1117- snr2 = snr_ema , body_poses = (body_euler , body_trans ))
1141+ snr2 = snr_ema , body_poses = (body_euler , body_trans ), ctf_filename = ctf_filename )
11181142 if do_pose_sgd and epoch >= args .pretrain :
11191143 pose_optimizer .step ()
11201144 # logging
@@ -1124,8 +1148,8 @@ def flog(msg): # HACK: switch to logging module
11241148
11251149 flog ('# =====> Epoch: {} Average validation gen_loss = {:.6}, SNR2 = {:.6f}, ' \
11261150 'total loss = {:.6f}; Finished in {}' .format (epoch + 1 ,
1127- gen_loss_accum / Nimg_test ,
1128- snr_accum / Nimg_test , loss_accum / Nimg_test , dt .now ()- t2 ))
1151+ gen_loss_accum / ( Nimg_test + 1 ) ,
1152+ snr_accum / ( Nimg_test + 1 ) , loss_accum / ( Nimg_test + 1 ) , dt .now ()- t2 ))
11291153
11301154
11311155 if args .checkpoint and epoch % args .checkpoint == 0 :
0 commit comments