Skip to content

Commit 0b50a32

Browse files
committed
Add eval code and example
1 parent d29cfe1 commit 0b50a32

File tree

3 files changed

+1483
-0
lines changed

3 files changed

+1483
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ python main.py --dataset [PATH_TO_NPZ_FILE] --embedding_name [EMBEDDING_NAME]
4343
```
4444

4545

46+
See `notebooks/EB-Eval.ipynb` for an example on how to use TrajectoryNet on a PCA embedding to get trajectories in the gene space.
47+
48+
4649
### References
4750
[1] Tong, A., Huang, J., Wolf, G., van Dijk, D., and Krishnaswamy, S. TrajectoryNet: A Dynamic Optimal Transport Network for Modeling Cellular Dynamics. In International Conference on Machine Learning, 2020. [[arxiv]](http://arxiv.org/abs/2002.04461)
4851

TrajectoryNet/eval.py

Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
import os
2+
import numpy as np
3+
import torch
4+
import matplotlib.pyplot as plt
5+
import matplotlib
6+
7+
from lib.growth_net import GrowthNet
8+
import lib.utils as utils
9+
from lib.viz_scrna import trajectory_to_video, save_vectors
10+
from lib.viz_scrna import (
11+
save_trajectory_density,
12+
save_2d_trajectory,
13+
save_2d_trajectory_v2,
14+
)
15+
16+
# from train_misc import standard_normal_logprob
17+
from train_misc import set_cnf_options, count_nfe, count_parameters
18+
from train_misc import count_total_time
19+
from train_misc import add_spectral_norm, spectral_norm_power_iteration
20+
from train_misc import create_regularization_fns, get_regularization
21+
from train_misc import append_regularization_to_log
22+
from train_misc import build_model_tabular
23+
24+
import eval_utils
25+
import dataset
26+
27+
28+
def makedirs(dirname):
29+
if not os.path.exists(dirname):
30+
os.makedirs(dirname)
31+
32+
33+
def save_trajectory(
34+
prior_logdensity,
35+
prior_sampler,
36+
model,
37+
data_samples,
38+
savedir,
39+
ntimes=101,
40+
end_times=None,
41+
memory=0.01,
42+
device="cpu",
43+
):
44+
model.eval()
45+
46+
# Sample from prior
47+
z_samples = prior_sampler(1000, 2).to(device)
48+
49+
# sample from a grid
50+
npts = 100
51+
side = np.linspace(-4, 4, npts)
52+
xx, yy = np.meshgrid(side, side)
53+
xx = torch.from_numpy(xx).type(torch.float32).to(device)
54+
yy = torch.from_numpy(yy).type(torch.float32).to(device)
55+
z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1)
56+
57+
with torch.no_grad():
58+
# We expect the model is a chain of CNF layers wrapped in a SequentialFlow container.
59+
logp_samples = prior_logdensity(z_samples)
60+
logp_grid = prior_logdensity(z_grid)
61+
t = 0
62+
for cnf in model.chain:
63+
64+
# Construct integration_list
65+
if end_times is None:
66+
end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)]
67+
integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)]
68+
for i, et in enumerate(end_times[1:]):
69+
integration_list.append(
70+
torch.linspace(end_times[i], et, ntimes).to(device)
71+
)
72+
full_times = torch.cat(integration_list, 0)
73+
print(full_times.shape)
74+
75+
# Integrate over evenly spaced samples
76+
z_traj, logpz = cnf(
77+
z_samples,
78+
logp_samples,
79+
integration_times=integration_list[0],
80+
reverse=True,
81+
)
82+
full_traj = [(z_traj, logpz)]
83+
for int_times in integration_list[1:]:
84+
prev_z, prev_logp = full_traj[-1]
85+
z_traj, logpz = cnf(
86+
prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True
87+
)
88+
full_traj.append((z_traj[1:], logpz[1:]))
89+
full_zip = list(zip(*full_traj))
90+
z_traj = torch.cat(full_zip[0], 0)
91+
# z_logp = torch.cat(full_zip[1], 0)
92+
z_traj = z_traj.cpu().numpy()
93+
94+
grid_z_traj, grid_logpz_traj = [], []
95+
inds = torch.arange(0, z_grid.shape[0]).to(torch.int64)
96+
for ii in torch.split(inds, int(z_grid.shape[0] * memory)):
97+
_grid_z_traj, _grid_logpz_traj = cnf(
98+
z_grid[ii],
99+
logp_grid[ii],
100+
integration_times=integration_list[0],
101+
reverse=True,
102+
)
103+
full_traj = [(_grid_z_traj, _grid_logpz_traj)]
104+
for int_times in integration_list[1:]:
105+
prev_z, prev_logp = full_traj[-1]
106+
_grid_z_traj, _grid_logpz_traj = cnf(
107+
prev_z[-1],
108+
prev_logp[-1],
109+
integration_times=int_times,
110+
reverse=True,
111+
)
112+
full_traj.append((_grid_z_traj, _grid_logpz_traj))
113+
full_zip = list(zip(*full_traj))
114+
_grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy()
115+
_grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy()
116+
print(_grid_z_traj.shape)
117+
grid_z_traj.append(_grid_z_traj)
118+
grid_logpz_traj.append(_grid_logpz_traj)
119+
120+
grid_z_traj = np.concatenate(grid_z_traj, axis=1)
121+
grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1)
122+
123+
plt.figure(figsize=(8, 8))
124+
for _ in range(z_traj.shape[0]):
125+
126+
plt.clf()
127+
128+
# plot target potential function
129+
ax = plt.subplot(1, 1, 1, aspect="equal")
130+
131+
"""
132+
ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200)
133+
ax.invert_yaxis()
134+
ax.get_xaxis().set_ticks([])
135+
ax.get_yaxis().set_ticks([])
136+
ax.set_title("Target", fontsize=32)
137+
138+
"""
139+
# plot the density
140+
# ax = plt.subplot(2, 2, 2, aspect="equal")
141+
142+
z, logqz = grid_z_traj[t], grid_logpz_traj[t]
143+
144+
xx = z[:, 0].reshape(npts, npts)
145+
yy = z[:, 1].reshape(npts, npts)
146+
qz = np.exp(logqz).reshape(npts, npts)
147+
rgb = plt.cm.Spectral(t / z_traj.shape[0])
148+
print(t, rgb)
149+
background_color = "white"
150+
cvals = [0, np.percentile(qz, 0.1)]
151+
colors = [
152+
background_color,
153+
rgb,
154+
]
155+
norm = plt.Normalize(min(cvals), max(cvals))
156+
tuples = list(zip(map(norm, cvals), colors))
157+
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
158+
from matplotlib.colors import LogNorm
159+
160+
plt.pcolormesh(
161+
xx,
162+
yy,
163+
qz,
164+
# norm=LogNorm(vmin=qz.min(), vmax=qz.max()),
165+
cmap=cmap,
166+
)
167+
ax.set_xlim(-4, 4)
168+
ax.set_ylim(-4, 4)
169+
cmap = matplotlib.cm.get_cmap(None)
170+
ax.set_facecolor(background_color)
171+
ax.invert_yaxis()
172+
ax.get_xaxis().set_ticks([])
173+
ax.get_yaxis().set_ticks([])
174+
ax.set_title("Density", fontsize=32)
175+
176+
"""
177+
# plot the samples
178+
ax = plt.subplot(2, 2, 3, aspect="equal")
179+
180+
zk = z_traj[t]
181+
ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200)
182+
ax.invert_yaxis()
183+
ax.get_xaxis().set_ticks([])
184+
ax.get_yaxis().set_ticks([])
185+
ax.set_title("Samples", fontsize=32)
186+
187+
# plot vector field
188+
ax = plt.subplot(2, 2, 4, aspect="equal")
189+
190+
K = 13j
191+
y, x = np.mgrid[-4:4:K, -4:4:K]
192+
K = int(K.imag)
193+
zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32)
194+
logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32)
195+
dydt = cnf.odefunc(full_times[t], (zs, logps))[0]
196+
dydt = -dydt.cpu().detach().numpy()
197+
dydt = dydt.reshape(K, K, 2)
198+
199+
logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1]))
200+
ax.quiver(
201+
x, y, dydt[:, :, 0], -dydt[:, :, 1],
202+
# x, y, dydt[:, :, 0], dydt[:, :, 1],
203+
np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid"
204+
)
205+
ax.set_xlim(-4, 4)
206+
ax.set_ylim(4, -4)
207+
#ax.set_ylim(-4, 4)
208+
ax.axis("off")
209+
ax.set_title("Vector Field", fontsize=32)
210+
"""
211+
212+
makedirs(savedir)
213+
plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg"))
214+
t += 1
215+
216+
217+
def get_trajectory_samples(device, model, data, n=2000):
218+
ntimes = 5
219+
model.eval()
220+
z_samples = data.base_sample()(n, 2).to(device)
221+
222+
integration_list = [torch.linspace(0, args.int_tps[0], ntimes).to(device)]
223+
for i, et in enumerate(args.int_tps[1:]):
224+
integration_list.append(torch.linspace(args.int_tps[i], et, ntimes).to(device))
225+
print(integration_list)
226+
227+
228+
def plot_output(device, args, model, data):
229+
# logger.info('Plotting trajectory to {}'.format(save_traj_dir))
230+
data_samples = data.get_data()[data.sample_index(2000, 0)]
231+
start_points = data.base_sample()(1000, 2)
232+
# start_points = data.get_data()[idx]
233+
# start_points = torch.from_numpy(start_points).type(torch.float32)
234+
"""
235+
save_vectors(
236+
data.base_density(),
237+
model,
238+
start_points,
239+
data.get_data()[data.get_times() == 1],
240+
data.get_times()[data.get_times() == 1],
241+
args.save,
242+
device=device,
243+
end_times=args.int_tps,
244+
ntimes=100,
245+
memory=1.0,
246+
lim=1.5,
247+
)
248+
save_traj_dir = os.path.join(args.save, "trajectory_2d")
249+
save_2d_trajectory_v2(
250+
data.base_density(),
251+
data.base_sample(),
252+
model,
253+
data_samples,
254+
save_traj_dir,
255+
device=device,
256+
end_times=args.int_tps,
257+
ntimes=3,
258+
memory=1.0,
259+
limit=2.5,
260+
)
261+
"""
262+
263+
density_dir = os.path.join(args.save, "density2")
264+
save_trajectory_density(
265+
data.base_density(),
266+
model,
267+
data_samples,
268+
density_dir,
269+
device=device,
270+
end_times=args.int_tps,
271+
ntimes=100,
272+
memory=1,
273+
)
274+
trajectory_to_video(density_dir)
275+
276+
277+
278+
def integrate_backwards(end_samples, model, savedir, ntimes=100, memory=0.1, device='cpu'):
279+
""" Integrate some samples backwards and save the results.
280+
"""
281+
with torch.no_grad():
282+
z = torch.from_numpy(end_samples).type(torch.float32).to(device)
283+
zero = torch.zeros(z.shape[0], 1).to(z)
284+
cnf = model.chain[0]
285+
286+
zs = [z]
287+
deltas = []
288+
int_tps = np.linspace(args.int_tps[0], args.int_tps[-1], ntimes)
289+
for i, itp in enumerate(int_tps[::-1][:-1]):
290+
# tp counts down from last
291+
timescale = int_tps[1] - int_tps[0]
292+
integration_times = torch.tensor([itp - timescale, itp])
293+
# integration_times = torch.tensor([np.linspace(itp - args.time_scale, itp, ntimes)])
294+
integration_times = integration_times.type(torch.float32).to(device)
295+
296+
# transform to previous timepoint
297+
z, delta_logp = cnf(zs[-1], zero, integration_times=integration_times)
298+
zs.append(z)
299+
deltas.append(delta_logp)
300+
zs = torch.stack(zs, 0)
301+
zs = zs.cpu().numpy()
302+
np.save(os.path.join(savedir, 'backward_trajectories.npy'), zs)
303+
304+
305+
def main(args):
306+
device = torch.device(
307+
"cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu"
308+
)
309+
if args.use_cpu:
310+
device = torch.device("cpu")
311+
312+
data = dataset.SCData.factory(args.dataset, args.max_dim)
313+
314+
args.timepoints = data.get_unique_times()
315+
316+
# Use maximum timepoint to establish integration_times
317+
# as some timepoints may be left out for validation etc.
318+
args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale
319+
320+
regularization_fns, regularization_coeffs = create_regularization_fns(args)
321+
model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to(
322+
device
323+
)
324+
growth_model_path = data.get_growth_net_path()
325+
#growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt"
326+
growth_model = torch.load(growth_model_path, map_location=device)
327+
if args.spectral_norm:
328+
add_spectral_norm(model)
329+
set_cnf_options(args, model)
330+
331+
state_dict = torch.load(args.save + "/checkpt.pth", map_location=device)
332+
model.load_state_dict(state_dict["state_dict"])
333+
334+
#plot_output(device, args, model, data)
335+
#exit()
336+
# get_trajectory_samples(device, model, data)
337+
338+
args.data = data
339+
args.timepoints = args.data.get_unique_times()
340+
args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale
341+
342+
print('integrating backwards')
343+
end_time_data = data.data_dict['mphate_expression']
344+
#end_time_data = data.get_data()[args.data.get_times()==np.max(args.data.get_times())]
345+
#np.random.permutation(end_time_data)
346+
#rand_idx = np.random.randint(end_time_data.shape[0], size=5000)
347+
#end_time_data = end_time_data[rand_idx,:]
348+
integrate_backwards(end_time_data, model, args.save, ntimes=100, device=device)
349+
exit()
350+
losses_list = []
351+
#for factor in np.linspace(0.05, 0.95, 19):
352+
#for factor in np.linspace(0.91, 0.99, 9):
353+
if args.dataset == 'CHAFFER': # Do timepoint adjustment
354+
print('adjusting_timepoints')
355+
lt = args.leaveout_timepoint
356+
if lt == 1:
357+
factor = 0.6799872494335812
358+
factor = 0.95
359+
elif lt == 2:
360+
factor = 0.2905983814032348
361+
factor = 0.01
362+
else:
363+
raise RuntimeError('Unknown timepoint %d' % args.leaveout_timepoint)
364+
args.int_tps[lt] = (1 - factor) * args.int_tps[lt-1] + factor * args.int_tps[lt+1]
365+
losses = eval_utils.evaluate_kantorovich_v2(device, args, model)
366+
losses_list.append(losses)
367+
print(np.array(losses_list))
368+
np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list))
369+
#zs = np.load(os.path.join(args.save, 'backward_trajectories'))
370+
#losses = eval_utils.evaluate_mse(device, args, model)
371+
#losses = eval_utils.evaluate_kantorovich(device, args, model)
372+
#print(losses)
373+
# eval_utils.generate_samples(device, args, model, growth_model, timepoint=args.timepoints[-1])
374+
# eval_utils.calculate_path_length(device, args, model, data, args.int_tps[-1])
375+
376+
377+
if __name__ == "__main__":
378+
from parse import parser
379+
380+
args = parser.parse_args()
381+
main(args)

0 commit comments

Comments
 (0)