Skip to content

Commit 8e168bf

Browse files
committed
support WARP's subtomogram and ctf parameters in csv
1 parent 527353e commit 8e168bf

File tree

8 files changed

+285
-61
lines changed

8 files changed

+285
-61
lines changed

analysis_scripts/prepare_multi.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1-
python -m cryodrgn.commands.parse_multi_pose_star $1/$2.star -D $3 --Apix $4 -o $1/$2_pose_euler.pkl --masks $5.star $6 --bodies $7 --volumes $8
2-
python -m cryodrgn.commands.parse_ctf_star $1/$2.star -D $3 --Apix $4 -o $1/$2_ctf.pkl $6
1+
starname=$(basename $1)
2+
dirn=$(dirname $1)
3+
filename=$(basename $starname .star)
4+
echo $dirn $filename
5+
python -m cryodrgn.commands.parse_multi_pose_star $1 -D $2 --Apix $3 -o $dirn/$filename\_pose_euler.pkl --masks $4 --bodies $5 $6 $7 $8 $9
6+
#python -m cryodrgn.commands.parse_ctf_star $1/$2.star -D $3 --Apix $4 -o $1/$2_ctf.pkl $6

cryodrgn/commands/analyze.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def add_args(parser):
4141
group.add_argument('-d','--downsample', type=int, help='Downsample volumes to this box size (pixels)')
4242
group.add_argument('--pc', type=int, default=2, help='Number of principal component traversals to generate (default: %(default)s)')
4343
group.add_argument('--ksample', type=int, default=20, help='Number of kmeans samples to generate (default: %(default)s)')
44+
group.add_argument('--kpc', type=str, default=None, help='Perform PCA within the kpc cluster (default: %(default)s)')
4445
return parser
4546

4647
def analyze_z1(z, outdir, vg):
@@ -74,7 +75,7 @@ def analyze_zN(z, outdir, vg, groups, skip_umap=False, num_pcs=2, num_ksamples=2
7475
print(pc[:4, :])
7576
log('Generating volumes...')
7677
for i in range(num_pcs):
77-
start, end = np.percentile(pc[:,i],(5,95))
78+
start, end = np.percentile(pc[:,i],(1,99))
7879
log(f'traversing pc {i} from {start} to {end}')
7980
z_pc = analysis.get_pc_traj(pca, z.shape[1], 10, i+1, start, end)
8081
if not os.path.exists(f'{outdir}/pc{i+1}'):
@@ -155,6 +156,8 @@ def analyze_zN(z, outdir, vg, groups, skip_umap=False, num_pcs=2, num_ksamples=2
155156
ymin = np.min(umap_emb[:, 1])
156157
pmax = max(xmax, ymax)
157158
pmin = min(xmin, ymin)
159+
ymax = max(xmax - xmin, ymax - ymin) + ymin
160+
xmax = max(xmax - xmin, ymax - ymin) + xmin
158161
plt.figure(3)
159162
g = sns.jointplot(x=umap_emb[:,0], y=umap_emb[:,1], hue=groups, palette="inferno", s=3., alpha=.3, xlim=(xmin, xmax), ylim=(ymin, ymax))
160163
g.ax_joint.set_aspect('equal')
@@ -197,7 +200,7 @@ def analyze_zN(z, outdir, vg, groups, skip_umap=False, num_pcs=2, num_ksamples=2
197200

198201
analysis.scatter_annotate(umap_emb[:,0], umap_emb[:,1], centers_ind=centers_ind, annotate=True,
199202
xlim=(xmin, xmax), ylim=(ymin, ymax),
200-
alpha=.15, s=1.)
203+
alpha=.15, s=0.5)
201204
plt.xlabel('UMAP1', fontsize=14, weight='bold')
202205
plt.ylabel('UMAP2', fontsize=14, weight='bold')
203206
plt.savefig(f'{outdir}/kmeans{K}/umap.png')
@@ -215,6 +218,7 @@ def analyze_zN(z, outdir, vg, groups, skip_umap=False, num_pcs=2, num_ksamples=2
215218
plt.tight_layout()
216219

217220
plt.savefig(f'{outdir}/pc{i+1}/umap.png')
221+
return kmeans_labels, umap_emb
218222

219223
class VolumeGenerator:
220224
'''Helper class to call analysis.gen_volumes'''
@@ -256,18 +260,19 @@ def main(args):
256260

257261

258262
if args.vanilla:
259-
losses = analysis.parse_loss_vanilla(f"{workdir}/run.log", "validation")
260-
#plt.ylabel('validation loss')
261-
#plt.xlabel('step')
262-
plt.plot(np.arange(1,len(losses)+1), losses, label="validation")
263-
#plt.savefig(f"{workdir}/val_losses.png")
264-
losses = analysis.parse_loss_vanilla(f"{workdir}/run.log", "training")
265-
plt.ylabel('loss')
266-
plt.xlabel('epoch')
267-
plt.plot(np.arange(1,len(losses)+1), losses, label="training")
268-
plt.xticks(range(1, len(losses)+1))
269-
plt.legend(loc="upper right")
270-
plt.savefig(f"{workdir}/train_losses.png")
263+
if os.path.isfile(f"{workdir}/run.log"):
264+
losses = analysis.parse_loss_vanilla(f"{workdir}/run.log", "validation")
265+
#plt.ylabel('validation loss')
266+
#plt.xlabel('step')
267+
plt.plot(np.arange(1,len(losses)+1), losses, label="validation")
268+
#plt.savefig(f"{workdir}/val_losses.png")
269+
losses = analysis.parse_loss_vanilla(f"{workdir}/run.log", "training")
270+
plt.ylabel('loss')
271+
plt.xlabel('epoch')
272+
plt.plot(np.arange(1,len(losses)+1), losses, label="training")
273+
plt.xticks(range(1, len(losses)+1))
274+
plt.legend(loc="upper right")
275+
plt.savefig(f"{workdir}/train_losses.png")
271276

272277
z = torch.load(zfile)["mu"].cpu().numpy()
273278
log("loading {}, z shape {}".format(zfile, z.shape))

cryodrgn/commands/parse_multi_pose_star.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@ def center_of_mass(volume):
3131
#center = torch.where(center > 0, (center + 0.5).int(), (center - 0.5).int()).float()
3232
centered = (grid - center)
3333
radius = (centered).pow(2)*vol
34-
r = torch.sqrt(radius.sum(dim=(0,1,2))/mass)
34+
r0 = torch.sqrt(radius.sum(dim=(0,1,2))/mass)
3535
#principal axes
3636
matrix = -centered.unsqueeze(-1) * centered.unsqueeze(-2)
3737
radius_sum = torch.eye(3) * (radius.sum(dim=-1, keepdim=True).unsqueeze(-1))
38-
matrix = ((matrix+radius_sum)*vol.unsqueeze(-1)).sum(dim=(0, 1, 2))
38+
matrix = ((-matrix)*vol.unsqueeze(-1)).sum(dim=(0, 1, 2))
3939
eigvals, eigvecs = np.linalg.eig(matrix.numpy())
4040
indices = np.argsort(eigvals)
4141
#print(matrix, eigvals[indices])
4242
eigvecs = torch.from_numpy(eigvecs[:, indices].T) # eigvecs[0] is the first eigen vector with largest eigenvalues
43+
r = np.sqrt(eigvals[indices]/mass)
44+
print("r0 vs r: ", r0, r)
4345

4446
return center, r, eigvecs
4547

@@ -53,6 +55,7 @@ def add_args(parser):
5355
parser.add_argument('--masks', metavar='PKL', type=os.path.abspath, required=False, help='masks for multi-body')
5456
parser.add_argument('--volumes', metavar='PKL', type=os.path.abspath, required=False, help='Output label.pkl')
5557
parser.add_argument('--bodies', type=int, required=True, help='Number of bodies')
58+
parser.add_argument('--outmasks', default="mask_params", help="the name of pkl file storing masks related parameters")
5659
parser.add_argument('--outdir', type=os.path.abspath)
5760
return parser
5861

@@ -77,14 +80,16 @@ def main(args):
7780
log(rot[0])
7881

7982
# parse translations
80-
trans = np.empty((N,2))
81-
if '_rlnOriginX' in s.headers and '_rlnOriginY' in s.headers:
83+
trans = np.zeros((N,3))
84+
if '_rlnOriginX' in s.headers and '_rlnOriginY' in s.headers and '_rlnOriginZ' in s.headers:
8285
trans[:,0] = s.df['_rlnOriginX']
8386
trans[:,1] = s.df['_rlnOriginY']
84-
elif '_rlnOriginXAngst' in s.headers and '_rlnOriginYAngst' in s.headers:
87+
trans[:,2] = s.df['_rlnOriginZ']
88+
elif '_rlnOriginXAngst' in s.headers and '_rlnOriginYAngst' in s.headers and '_rlnOriginZAngst' in s.headers:
8589
assert args.Apix is not None, "Must provide --Apix argument to convert _rlnOriginXAngst and _rlnOriginYAngst translation units"
8690
trans[:,0] = s.df['_rlnOriginXAngst']
8791
trans[:,1] = s.df['_rlnOriginYAngst']
92+
trans[:,2] = s.df['_rlnOriginZAngst']
8893
trans /= args.Apix
8994

9095
log('Translations (pixels):')
@@ -95,7 +100,7 @@ def main(args):
95100

96101
#process multibody
97102
log(f"there are {args.bodies} bodies")
98-
if s.multibodies is not None:
103+
if s.multibodies is not None and len(s.multibodies) != 0:
99104
assert len(s.multibodies) == args.bodies
100105
body_eulers = []
101106
body_trans = []
@@ -109,7 +114,7 @@ def main(args):
109114
log('Euler angles (Rot, Tilt, Psi):')
110115
log(euler_body[0])
111116
body_eulers.append(euler_body)
112-
trans_body = np.empty((N,1,2))
117+
trans_body = np.empty((N,1,3))
113118
body_header = s.multibody_headers[b_i]
114119
if '_rlnOriginX' in body_header and '_rlnOriginY' in body_header:
115120
trans_body[:,0,0] = body['_rlnOriginX']
@@ -129,7 +134,7 @@ def main(args):
129134
for b_i in range(args.bodies):
130135
euler_body = np.zeros((N,1,3))
131136
euler_body[:,0,1] = 90.
132-
trans_body = np.zeros((N,1,2))
137+
trans_body = np.zeros((N,1,3))
133138
body_eulers.append(euler_body)
134139
body_trans.append(trans_body)
135140

@@ -249,16 +254,21 @@ def main(args):
249254
relats = []
250255
print("in_relatives: ", in_relatives)
251256
#print("com_bodies: ", com_bodies - vol_coms, "radii_bodies: ", radii_bodies)
252-
origin_rel = np.bincount(in_relatives).argmax()
257+
origin_rel = 1 #np.bincount(in_relatives).argmax()
258+
print("origin_rel:", origin_rel)
253259
for b_i in range(len(s_mask.df)):
254260
rotate_directions.append(com_bodies[in_relatives[b_i]] - com_bodies[b_i])
255261
rotate_directions_ori.append(com_bodies[b_i] - com_bodies[in_relatives[b_i]])
256262
rotate_directions[-1] = F.normalize(rotate_directions[-1], dim=0)
257-
orient_bodies.append(utils.align_with_z(-rotate_directions[-1]))
263+
if b_i != origin_rel:
264+
orient_bodies.append(utils.align_with_z(-rotate_directions[-1]))
265+
else:
266+
orient_bodies.append(utils.align_with_z(rotate_directions[-1]))
267+
print(rotate_directions[-1].shape, orient_bodies[-1] @ rotate_directions[-1])
258268
relats.append(com_bodies[in_relatives[b_i]])
259269
#reset rotation axis for center
260-
if b_i == origin_rel:
261-
rotate_directions_ori[b_i] = com_bodies[b_i] - com_bodies[b_i]
270+
#if b_i == origin_rel:
271+
# rotate_directions_ori[b_i] = com_bodies[b_i] - com_bodies[b_i]
262272
#normalize direction
263273
A_rot90 = lie_tools.yrot(torch.tensor(-90))
264274
rotate_directions = torch.stack(rotate_directions, dim=0)
@@ -268,19 +278,20 @@ def main(args):
268278
#print((orientations@rotate_directions_ori.unsqueeze(-1)).squeeze(), rot_axes, orientations)
269279
#print((orientations@rot_radii.unsqueeze(-1)).squeeze())
270280
#print(orientations@torch.transpose(principal_axes, -1, -2))
271-
print("rot_radii: ", rot_radii)
281+
print("rotate_directions from volumes: ", rot_radii)
272282
orient_bodies = torch.stack(orient_bodies, dim=0)
273283
relats = torch.stack(relats, dim=0)
274284
axes = torch.stack(axes, dim=0)
275285
#print("A_rot90: ", A_rot90)
276286
#print("relats: ", relats)
277287
print("rotate_directions: ", rotate_directions_ori)
278288
print("orient_bodies: ", orient_bodies)
279-
output_name = prefix + "/masks.pkl"
289+
output_name = prefix + f"/{args.outmasks}.pkl"
280290
log(f'Writing {output_name}')
281291
if not args.volumes:
282-
torch.save({"in_relatives": in_relatives, "com_bodies": com_bodies,
283-
"orient_bodies": orient_bodies, "rotate_directions": rotate_directions_ori, "radii_bodies": radii_bodies}, \
292+
print("principal_axes: ", axes)
293+
torch.save({"in_relatives": relats, "com_bodies": com_bodies,
294+
"orient_bodies": orient_bodies, "rotate_directions": rotate_directions_ori, "radii_bodies": radii_bodies, "principal_axes": axes}, \
284295
# #"weights": weights, "consensus_mask": consensus_mask},
285296
output_name)
286297
else:

cryodrgn/commands/train_tomo.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch.nn.functional as F
1616
from torch.utils.data import DataLoader
1717
torch.backends.cudnn.benchmark = True
18+
torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.7+
1819

1920
import cryodrgn
2021
from cryodrgn import mrc
@@ -95,7 +96,10 @@ def add_args(parser):
9596
group.add_argument('--pose-enc', action='store_true', help='predict pose parameter using encoder')
9697
group.add_argument('--pose-only', action='store_true', help='train pose encoder only')
9798
group.add_argument('--plot', action='store_true', help='plot intermediate result')
98-
group.add_argument('--estpose', default=True, action='store_true', help='estimate pose (default: %(default)s)')
99+
group.add_argument('--estpose', default=False, action='store_true', help='estimate pose')
100+
group.add_argument('--warp', default=False, action='store_true', help='using subtomograms from warp')
101+
group.add_argument('--tilt-step', type=int, default=2, help='the interval between successive tilts (default: %(default)s)')
102+
group.add_argument('--tilt-range', type=int, default=50, help='the range of tilt angles (default: %(default)s)')
99103

100104
group = parser.add_argument_group('Encoder Network')
101105
group.add_argument('--enc-layers', dest='qlayers', type=int, default=3, help='Number of hidden layers (default: %(default)s)')
@@ -279,10 +283,10 @@ def run_batch(model, lattice, y, yt, rot, tilt=None, ind=None, ctf_params=None,
279283
z_mu, z_logvar, z = 0., 0., 0.
280284

281285
# add bfactors to ctf_params, the second from last column stores bfactor, the last column stores scale
282-
#random_b = np.random.rand()*1.5
283-
random_b = np.random.gamma(1., 0.6)
284-
#random_b = torch.randn_like(c[..., 0, -2])/3.
285-
c[...,-2] = c[...,-2] + (args.bfactor+random_b)*(4*np.pi**2)
286+
#random_b = (np.random.normal())/3.
287+
#random_b = np.random.gamma(1., 0.6)
288+
random_b = torch.randn_like(c[..., 0, -2])/3.
289+
c[...,-2] = c[...,-2] + (args.bfactor+random_b.unsqueeze(-1))*(4*np.pi**2)
286290

287291
plot = args.plot and it % (args.log_interval) == B
288292
if plot:
@@ -333,6 +337,7 @@ def run_batch(model, lattice, y, yt, rot, tilt=None, ind=None, ctf_params=None,
333337
decout = model.vanilla_decode(rot, trans, z=z, save_mrc=save_image, eulers=euler,
334338
ref_fft=y, ctf_param=c, encout=encout, mask=mask_real, body_poses=body_poses,
335339
ctf_grid=ctf_grid, estpose=args.estpose, ctf_filename=ctf_filename, write_ctf=args.write_ctf)
340+
336341
if decout["affine"] is not None:
337342
posetracker.set_pose(decout["affine"][0].detach(), decout["affine"][1].detach(), ind)
338343

@@ -712,11 +717,17 @@ def flog(msg): # HACK: switch to logging module
712717
args.use_real = args.encode_mode == 'conv'
713718
args.real_data = args.pe_type == 'vanilla'
714719

715-
if args.lazy_single:
720+
if args.lazy_single and not args.warp:
716721
data = dataset.LazyTomoMRCData(args.particles, norm=args.norm,
717722
real_data=args.real_data, invert_data=args.invert_data,
718723
ind=ind, keepreal=args.use_real, window=False,
719724
datadir=args.datadir, relion31=args.relion31, window_r=args.window_r, downfrac=args.downfrac)
725+
elif args.lazy_single and args.warp:
726+
data = dataset.LazyTomoWARPMRCData(args.particles, norm=args.norm,
727+
real_data=args.real_data, invert_data=args.invert_data,
728+
ind=ind, keepreal=args.use_real, window=False,
729+
datadir=args.datadir, relion31=args.relion31, window_r=args.window_r, downfrac=args.downfrac,
730+
tilt_step=args.tilt_step, tilt_range=args.tilt_range)
720731
else:
721732
raise NotImplementedError("Use --lazy-single for on-the-fly image loading")
722733

@@ -751,8 +762,6 @@ def flog(msg): # HACK: switch to logging module
751762

752763
# load ctf
753764
if args.ctf is not None:
754-
#if args.use_real:
755-
# raise NotImplementedError("Not implemented with real-space encoder. Use phase-flipped images instead")
756765
flog('Loading ctf params from {}'.format(args.ctf))
757766
ctf_params = ctf.load_ctf_for_training(D-1, args.ctf)
758767
log('first ctf params is: {}'.format(ctf_params[0,:]))
@@ -824,7 +833,6 @@ def flog(msg): # HACK: switch to logging module
824833
model_parameters = list(model.encoder.parameters()) + list(model.decoder.parameters()) #+ list(group_stat.parameters())
825834
pose_encoder = None
826835
optim = torch.optim.AdamW(model_parameters, lr=args.lr, weight_decay=args.wd)
827-
assert args.accum_step >= 1
828836

829837
#if args.encode_mode == "grad":
830838
# discriminator_parameters = list(model.shape_encoder.parameters())
@@ -946,7 +954,8 @@ def flog(msg): # HACK: switch to logging module
946954
bfactor = args.bfactor
947955
lamb = args.lamb
948956
if args.log_interval % args.batch_size != 0:
949-
args.log_interval = args.batch_size*8
957+
args.log_interval = args.batch_size*16
958+
assert args.accum_step >= 1
950959

951960
for epoch in range(start_epoch, num_epochs):
952961
t2 = dt.now()
@@ -979,6 +988,7 @@ def flog(msg): # HACK: switch to logging module
979988
ind = minibatch[-1]#.to(device)
980989
y = minibatch[0][0].to(device, non_blocking=True)
981990
ctf_param = minibatch[0][1].float().to(device, non_blocking=True)
991+
ctf_filename = minibatch[0][2]
982992
#apixs = torch.ones(ctf_param.shape[:-1]).to(device)*args.angpix
983993
#ctf_param = torch.cat([apixs.unsqueeze(-1), ctf_param], dim=-1)
984994
# compute ctf!
@@ -1009,6 +1019,20 @@ def flog(msg): # HACK: switch to logging module
10091019
if body_euler is not None:
10101020
body_euler = body_euler.to(device)
10111021
body_trans = body_trans.to(device)
1022+
1023+
o_rot = lie_tools.hopf_to_SO3(euler[:, :3])
1024+
## perturb rotation by symm ops
1025+
#samples = torch.multinomial(symm_ops_weights, o_rot.shape[0], replacement=True)
1026+
1027+
###rand_z = o_rot @ symm_ops[samples].to(o_rot.get_device())
1028+
###print(rand_z)
1029+
####pixrad = hp.max_pixrad(64)
1030+
#rand_z = lie_tools.random_biased_SO3(o_rot.shape[0], bias=256*np.sqrt(3)).to(o_rot.get_device())
1031+
#rand_z = o_rot @ rand_z
1032+
#rand_e = lie_tools.so3_to_hopf(rand_z)
1033+
##print(rand_e - euler[:, :3])
1034+
#euler = rand_e
1035+
10121036
#print("euler, trans: ", euler.shape, tran.shape, y.shape)
10131037
#ctf_param = ctf_params[ind] if ctf_params is not None else None
10141038
z_mu, loss, gen_loss, snr, l1_loss, tv_loss, mu2, std2, mmd, c_mmd, mse, body_poses_pred = \
@@ -1020,7 +1044,7 @@ def flog(msg): # HACK: switch to logging module
10201044
it=batch_it, enc=None,
10211045
args=args, euler=euler,
10221046
posetracker=posetracker, data=data, update_params=(update_it%args.accum_step == args.accum_step - 1),
1023-
snr2=snr_ema, body_poses=(body_euler, body_trans))
1047+
snr2=snr_ema, body_poses=(body_euler, body_trans), ctf_filename=ctf_filename)
10241048
update_it += 1
10251049
if do_pose_sgd and epoch >= args.pretrain:
10261050
pose_optimizer.step()
@@ -1114,7 +1138,7 @@ def flog(msg): # HACK: switch to logging module
11141138
it=batch_it, enc=None,
11151139
args=args, euler=euler,
11161140
posetracker=posetracker, data=data, backward=False, update_params=False,
1117-
snr2=snr_ema, body_poses = (body_euler, body_trans))
1141+
snr2=snr_ema, body_poses = (body_euler, body_trans), ctf_filename=ctf_filename)
11181142
if do_pose_sgd and epoch >= args.pretrain:
11191143
pose_optimizer.step()
11201144
# logging
@@ -1124,8 +1148,8 @@ def flog(msg): # HACK: switch to logging module
11241148

11251149
flog('# =====> Epoch: {} Average validation gen_loss = {:.6}, SNR2 = {:.6f}, '\
11261150
'total loss = {:.6f}; Finished in {}'.format(epoch+1,
1127-
gen_loss_accum/Nimg_test,
1128-
snr_accum/Nimg_test, loss_accum/Nimg_test, dt.now()-t2))
1151+
gen_loss_accum/(Nimg_test+1),
1152+
snr_accum/(Nimg_test+1), loss_accum/(Nimg_test+1), dt.now()-t2))
11291153

11301154

11311155
if args.checkpoint and epoch % args.checkpoint == 0:

0 commit comments

Comments
 (0)