|
| 1 | +import os |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import matplotlib |
| 6 | + |
| 7 | +from lib.growth_net import GrowthNet |
| 8 | +import lib.utils as utils |
| 9 | +from lib.viz_scrna import trajectory_to_video, save_vectors |
| 10 | +from lib.viz_scrna import ( |
| 11 | + save_trajectory_density, |
| 12 | + save_2d_trajectory, |
| 13 | + save_2d_trajectory_v2, |
| 14 | +) |
| 15 | + |
| 16 | +# from train_misc import standard_normal_logprob |
| 17 | +from train_misc import set_cnf_options, count_nfe, count_parameters |
| 18 | +from train_misc import count_total_time |
| 19 | +from train_misc import add_spectral_norm, spectral_norm_power_iteration |
| 20 | +from train_misc import create_regularization_fns, get_regularization |
| 21 | +from train_misc import append_regularization_to_log |
| 22 | +from train_misc import build_model_tabular |
| 23 | + |
| 24 | +import eval_utils |
| 25 | +import dataset |
| 26 | + |
| 27 | + |
| 28 | +def makedirs(dirname): |
| 29 | + if not os.path.exists(dirname): |
| 30 | + os.makedirs(dirname) |
| 31 | + |
| 32 | + |
| 33 | +def save_trajectory( |
| 34 | + prior_logdensity, |
| 35 | + prior_sampler, |
| 36 | + model, |
| 37 | + data_samples, |
| 38 | + savedir, |
| 39 | + ntimes=101, |
| 40 | + end_times=None, |
| 41 | + memory=0.01, |
| 42 | + device="cpu", |
| 43 | +): |
| 44 | + model.eval() |
| 45 | + |
| 46 | + # Sample from prior |
| 47 | + z_samples = prior_sampler(1000, 2).to(device) |
| 48 | + |
| 49 | + # sample from a grid |
| 50 | + npts = 100 |
| 51 | + side = np.linspace(-4, 4, npts) |
| 52 | + xx, yy = np.meshgrid(side, side) |
| 53 | + xx = torch.from_numpy(xx).type(torch.float32).to(device) |
| 54 | + yy = torch.from_numpy(yy).type(torch.float32).to(device) |
| 55 | + z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) |
| 56 | + |
| 57 | + with torch.no_grad(): |
| 58 | + # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. |
| 59 | + logp_samples = prior_logdensity(z_samples) |
| 60 | + logp_grid = prior_logdensity(z_grid) |
| 61 | + t = 0 |
| 62 | + for cnf in model.chain: |
| 63 | + |
| 64 | + # Construct integration_list |
| 65 | + if end_times is None: |
| 66 | + end_times = [(cnf.sqrt_end_time * cnf.sqrt_end_time)] |
| 67 | + integration_list = [torch.linspace(0, end_times[0], ntimes).to(device)] |
| 68 | + for i, et in enumerate(end_times[1:]): |
| 69 | + integration_list.append( |
| 70 | + torch.linspace(end_times[i], et, ntimes).to(device) |
| 71 | + ) |
| 72 | + full_times = torch.cat(integration_list, 0) |
| 73 | + print(full_times.shape) |
| 74 | + |
| 75 | + # Integrate over evenly spaced samples |
| 76 | + z_traj, logpz = cnf( |
| 77 | + z_samples, |
| 78 | + logp_samples, |
| 79 | + integration_times=integration_list[0], |
| 80 | + reverse=True, |
| 81 | + ) |
| 82 | + full_traj = [(z_traj, logpz)] |
| 83 | + for int_times in integration_list[1:]: |
| 84 | + prev_z, prev_logp = full_traj[-1] |
| 85 | + z_traj, logpz = cnf( |
| 86 | + prev_z[-1], prev_logp[-1], integration_times=int_times, reverse=True |
| 87 | + ) |
| 88 | + full_traj.append((z_traj[1:], logpz[1:])) |
| 89 | + full_zip = list(zip(*full_traj)) |
| 90 | + z_traj = torch.cat(full_zip[0], 0) |
| 91 | + # z_logp = torch.cat(full_zip[1], 0) |
| 92 | + z_traj = z_traj.cpu().numpy() |
| 93 | + |
| 94 | + grid_z_traj, grid_logpz_traj = [], [] |
| 95 | + inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) |
| 96 | + for ii in torch.split(inds, int(z_grid.shape[0] * memory)): |
| 97 | + _grid_z_traj, _grid_logpz_traj = cnf( |
| 98 | + z_grid[ii], |
| 99 | + logp_grid[ii], |
| 100 | + integration_times=integration_list[0], |
| 101 | + reverse=True, |
| 102 | + ) |
| 103 | + full_traj = [(_grid_z_traj, _grid_logpz_traj)] |
| 104 | + for int_times in integration_list[1:]: |
| 105 | + prev_z, prev_logp = full_traj[-1] |
| 106 | + _grid_z_traj, _grid_logpz_traj = cnf( |
| 107 | + prev_z[-1], |
| 108 | + prev_logp[-1], |
| 109 | + integration_times=int_times, |
| 110 | + reverse=True, |
| 111 | + ) |
| 112 | + full_traj.append((_grid_z_traj, _grid_logpz_traj)) |
| 113 | + full_zip = list(zip(*full_traj)) |
| 114 | + _grid_z_traj = torch.cat(full_zip[0], 0).cpu().numpy() |
| 115 | + _grid_logpz_traj = torch.cat(full_zip[1], 0).cpu().numpy() |
| 116 | + print(_grid_z_traj.shape) |
| 117 | + grid_z_traj.append(_grid_z_traj) |
| 118 | + grid_logpz_traj.append(_grid_logpz_traj) |
| 119 | + |
| 120 | + grid_z_traj = np.concatenate(grid_z_traj, axis=1) |
| 121 | + grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1) |
| 122 | + |
| 123 | + plt.figure(figsize=(8, 8)) |
| 124 | + for _ in range(z_traj.shape[0]): |
| 125 | + |
| 126 | + plt.clf() |
| 127 | + |
| 128 | + # plot target potential function |
| 129 | + ax = plt.subplot(1, 1, 1, aspect="equal") |
| 130 | + |
| 131 | + """ |
| 132 | + ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200) |
| 133 | + ax.invert_yaxis() |
| 134 | + ax.get_xaxis().set_ticks([]) |
| 135 | + ax.get_yaxis().set_ticks([]) |
| 136 | + ax.set_title("Target", fontsize=32) |
| 137 | +
|
| 138 | + """ |
| 139 | + # plot the density |
| 140 | + # ax = plt.subplot(2, 2, 2, aspect="equal") |
| 141 | + |
| 142 | + z, logqz = grid_z_traj[t], grid_logpz_traj[t] |
| 143 | + |
| 144 | + xx = z[:, 0].reshape(npts, npts) |
| 145 | + yy = z[:, 1].reshape(npts, npts) |
| 146 | + qz = np.exp(logqz).reshape(npts, npts) |
| 147 | + rgb = plt.cm.Spectral(t / z_traj.shape[0]) |
| 148 | + print(t, rgb) |
| 149 | + background_color = "white" |
| 150 | + cvals = [0, np.percentile(qz, 0.1)] |
| 151 | + colors = [ |
| 152 | + background_color, |
| 153 | + rgb, |
| 154 | + ] |
| 155 | + norm = plt.Normalize(min(cvals), max(cvals)) |
| 156 | + tuples = list(zip(map(norm, cvals), colors)) |
| 157 | + cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples) |
| 158 | + from matplotlib.colors import LogNorm |
| 159 | + |
| 160 | + plt.pcolormesh( |
| 161 | + xx, |
| 162 | + yy, |
| 163 | + qz, |
| 164 | + # norm=LogNorm(vmin=qz.min(), vmax=qz.max()), |
| 165 | + cmap=cmap, |
| 166 | + ) |
| 167 | + ax.set_xlim(-4, 4) |
| 168 | + ax.set_ylim(-4, 4) |
| 169 | + cmap = matplotlib.cm.get_cmap(None) |
| 170 | + ax.set_facecolor(background_color) |
| 171 | + ax.invert_yaxis() |
| 172 | + ax.get_xaxis().set_ticks([]) |
| 173 | + ax.get_yaxis().set_ticks([]) |
| 174 | + ax.set_title("Density", fontsize=32) |
| 175 | + |
| 176 | + """ |
| 177 | + # plot the samples |
| 178 | + ax = plt.subplot(2, 2, 3, aspect="equal") |
| 179 | +
|
| 180 | + zk = z_traj[t] |
| 181 | + ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200) |
| 182 | + ax.invert_yaxis() |
| 183 | + ax.get_xaxis().set_ticks([]) |
| 184 | + ax.get_yaxis().set_ticks([]) |
| 185 | + ax.set_title("Samples", fontsize=32) |
| 186 | +
|
| 187 | + # plot vector field |
| 188 | + ax = plt.subplot(2, 2, 4, aspect="equal") |
| 189 | +
|
| 190 | + K = 13j |
| 191 | + y, x = np.mgrid[-4:4:K, -4:4:K] |
| 192 | + K = int(K.imag) |
| 193 | + zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) |
| 194 | + logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32) |
| 195 | + dydt = cnf.odefunc(full_times[t], (zs, logps))[0] |
| 196 | + dydt = -dydt.cpu().detach().numpy() |
| 197 | + dydt = dydt.reshape(K, K, 2) |
| 198 | +
|
| 199 | + logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1])) |
| 200 | + ax.quiver( |
| 201 | + x, y, dydt[:, :, 0], -dydt[:, :, 1], |
| 202 | + # x, y, dydt[:, :, 0], dydt[:, :, 1], |
| 203 | + np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid" |
| 204 | + ) |
| 205 | + ax.set_xlim(-4, 4) |
| 206 | + ax.set_ylim(4, -4) |
| 207 | + #ax.set_ylim(-4, 4) |
| 208 | + ax.axis("off") |
| 209 | + ax.set_title("Vector Field", fontsize=32) |
| 210 | + """ |
| 211 | + |
| 212 | + makedirs(savedir) |
| 213 | + plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) |
| 214 | + t += 1 |
| 215 | + |
| 216 | + |
| 217 | +def get_trajectory_samples(device, model, data, n=2000): |
| 218 | + ntimes = 5 |
| 219 | + model.eval() |
| 220 | + z_samples = data.base_sample()(n, 2).to(device) |
| 221 | + |
| 222 | + integration_list = [torch.linspace(0, args.int_tps[0], ntimes).to(device)] |
| 223 | + for i, et in enumerate(args.int_tps[1:]): |
| 224 | + integration_list.append(torch.linspace(args.int_tps[i], et, ntimes).to(device)) |
| 225 | + print(integration_list) |
| 226 | + |
| 227 | + |
| 228 | +def plot_output(device, args, model, data): |
| 229 | + # logger.info('Plotting trajectory to {}'.format(save_traj_dir)) |
| 230 | + data_samples = data.get_data()[data.sample_index(2000, 0)] |
| 231 | + start_points = data.base_sample()(1000, 2) |
| 232 | + # start_points = data.get_data()[idx] |
| 233 | + # start_points = torch.from_numpy(start_points).type(torch.float32) |
| 234 | + """ |
| 235 | + save_vectors( |
| 236 | + data.base_density(), |
| 237 | + model, |
| 238 | + start_points, |
| 239 | + data.get_data()[data.get_times() == 1], |
| 240 | + data.get_times()[data.get_times() == 1], |
| 241 | + args.save, |
| 242 | + device=device, |
| 243 | + end_times=args.int_tps, |
| 244 | + ntimes=100, |
| 245 | + memory=1.0, |
| 246 | + lim=1.5, |
| 247 | + ) |
| 248 | + save_traj_dir = os.path.join(args.save, "trajectory_2d") |
| 249 | + save_2d_trajectory_v2( |
| 250 | + data.base_density(), |
| 251 | + data.base_sample(), |
| 252 | + model, |
| 253 | + data_samples, |
| 254 | + save_traj_dir, |
| 255 | + device=device, |
| 256 | + end_times=args.int_tps, |
| 257 | + ntimes=3, |
| 258 | + memory=1.0, |
| 259 | + limit=2.5, |
| 260 | + ) |
| 261 | + """ |
| 262 | + |
| 263 | + density_dir = os.path.join(args.save, "density2") |
| 264 | + save_trajectory_density( |
| 265 | + data.base_density(), |
| 266 | + model, |
| 267 | + data_samples, |
| 268 | + density_dir, |
| 269 | + device=device, |
| 270 | + end_times=args.int_tps, |
| 271 | + ntimes=100, |
| 272 | + memory=1, |
| 273 | + ) |
| 274 | + trajectory_to_video(density_dir) |
| 275 | + |
| 276 | + |
| 277 | + |
| 278 | +def integrate_backwards(end_samples, model, savedir, ntimes=100, memory=0.1, device='cpu'): |
| 279 | + """ Integrate some samples backwards and save the results. |
| 280 | + """ |
| 281 | + with torch.no_grad(): |
| 282 | + z = torch.from_numpy(end_samples).type(torch.float32).to(device) |
| 283 | + zero = torch.zeros(z.shape[0], 1).to(z) |
| 284 | + cnf = model.chain[0] |
| 285 | + |
| 286 | + zs = [z] |
| 287 | + deltas = [] |
| 288 | + int_tps = np.linspace(args.int_tps[0], args.int_tps[-1], ntimes) |
| 289 | + for i, itp in enumerate(int_tps[::-1][:-1]): |
| 290 | + # tp counts down from last |
| 291 | + timescale = int_tps[1] - int_tps[0] |
| 292 | + integration_times = torch.tensor([itp - timescale, itp]) |
| 293 | + # integration_times = torch.tensor([np.linspace(itp - args.time_scale, itp, ntimes)]) |
| 294 | + integration_times = integration_times.type(torch.float32).to(device) |
| 295 | + |
| 296 | + # transform to previous timepoint |
| 297 | + z, delta_logp = cnf(zs[-1], zero, integration_times=integration_times) |
| 298 | + zs.append(z) |
| 299 | + deltas.append(delta_logp) |
| 300 | + zs = torch.stack(zs, 0) |
| 301 | + zs = zs.cpu().numpy() |
| 302 | + np.save(os.path.join(savedir, 'backward_trajectories.npy'), zs) |
| 303 | + |
| 304 | + |
| 305 | +def main(args): |
| 306 | + device = torch.device( |
| 307 | + "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu" |
| 308 | + ) |
| 309 | + if args.use_cpu: |
| 310 | + device = torch.device("cpu") |
| 311 | + |
| 312 | + data = dataset.SCData.factory(args.dataset, args.max_dim) |
| 313 | + |
| 314 | + args.timepoints = data.get_unique_times() |
| 315 | + |
| 316 | + # Use maximum timepoint to establish integration_times |
| 317 | + # as some timepoints may be left out for validation etc. |
| 318 | + args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale |
| 319 | + |
| 320 | + regularization_fns, regularization_coeffs = create_regularization_fns(args) |
| 321 | + model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to( |
| 322 | + device |
| 323 | + ) |
| 324 | + growth_model_path = data.get_growth_net_path() |
| 325 | + #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt" |
| 326 | + growth_model = torch.load(growth_model_path, map_location=device) |
| 327 | + if args.spectral_norm: |
| 328 | + add_spectral_norm(model) |
| 329 | + set_cnf_options(args, model) |
| 330 | + |
| 331 | + state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) |
| 332 | + model.load_state_dict(state_dict["state_dict"]) |
| 333 | + |
| 334 | + #plot_output(device, args, model, data) |
| 335 | + #exit() |
| 336 | + # get_trajectory_samples(device, model, data) |
| 337 | + |
| 338 | + args.data = data |
| 339 | + args.timepoints = args.data.get_unique_times() |
| 340 | + args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale |
| 341 | + |
| 342 | + print('integrating backwards') |
| 343 | + end_time_data = data.data_dict['mphate_expression'] |
| 344 | + #end_time_data = data.get_data()[args.data.get_times()==np.max(args.data.get_times())] |
| 345 | + #np.random.permutation(end_time_data) |
| 346 | + #rand_idx = np.random.randint(end_time_data.shape[0], size=5000) |
| 347 | + #end_time_data = end_time_data[rand_idx,:] |
| 348 | + integrate_backwards(end_time_data, model, args.save, ntimes=100, device=device) |
| 349 | + exit() |
| 350 | + losses_list = [] |
| 351 | + #for factor in np.linspace(0.05, 0.95, 19): |
| 352 | + #for factor in np.linspace(0.91, 0.99, 9): |
| 353 | + if args.dataset == 'CHAFFER': # Do timepoint adjustment |
| 354 | + print('adjusting_timepoints') |
| 355 | + lt = args.leaveout_timepoint |
| 356 | + if lt == 1: |
| 357 | + factor = 0.6799872494335812 |
| 358 | + factor = 0.95 |
| 359 | + elif lt == 2: |
| 360 | + factor = 0.2905983814032348 |
| 361 | + factor = 0.01 |
| 362 | + else: |
| 363 | + raise RuntimeError('Unknown timepoint %d' % args.leaveout_timepoint) |
| 364 | + args.int_tps[lt] = (1 - factor) * args.int_tps[lt-1] + factor * args.int_tps[lt+1] |
| 365 | + losses = eval_utils.evaluate_kantorovich_v2(device, args, model) |
| 366 | + losses_list.append(losses) |
| 367 | + print(np.array(losses_list)) |
| 368 | + np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list)) |
| 369 | + #zs = np.load(os.path.join(args.save, 'backward_trajectories')) |
| 370 | + #losses = eval_utils.evaluate_mse(device, args, model) |
| 371 | + #losses = eval_utils.evaluate_kantorovich(device, args, model) |
| 372 | + #print(losses) |
| 373 | + # eval_utils.generate_samples(device, args, model, growth_model, timepoint=args.timepoints[-1]) |
| 374 | + # eval_utils.calculate_path_length(device, args, model, data, args.int_tps[-1]) |
| 375 | + |
| 376 | + |
| 377 | +if __name__ == "__main__": |
| 378 | + from parse import parser |
| 379 | + |
| 380 | + args = parser.parse_args() |
| 381 | + main(args) |
0 commit comments