diff --git a/experiment/figures/52/binary_contours.png b/experiment/figures/52/binary_contours.png new file mode 100644 index 000000000..50ae7729c Binary files /dev/null and b/experiment/figures/52/binary_contours.png differ diff --git a/experiment/figures/52/binary_curves.png b/experiment/figures/52/binary_curves.png index 313895156..12d7156ee 100644 Binary files a/experiment/figures/52/binary_curves.png and b/experiment/figures/52/binary_curves.png differ diff --git a/experiment/figures/52/binary_scatter.png b/experiment/figures/52/binary_scatter.png index 76cf80076..b56411fa6 100644 Binary files a/experiment/figures/52/binary_scatter.png and b/experiment/figures/52/binary_scatter.png differ diff --git a/experiment/figures/52/deblend_bins.png b/experiment/figures/52/deblend_bins.png index cdf039a8f..c5a0f7381 100644 Binary files a/experiment/figures/52/deblend_bins.png and b/experiment/figures/52/deblend_bins.png differ diff --git a/experiment/figures/52/deblend_bins_medians.png b/experiment/figures/52/deblend_bins_medians.png index 4db7757c7..862c8605c 100644 Binary files a/experiment/figures/52/deblend_bins_medians.png and b/experiment/figures/52/deblend_bins_medians.png differ diff --git a/experiment/figures/52/deblend_ellips_scatter.png b/experiment/figures/52/deblend_ellips_scatter.png index 1c5a51ce5..2caec55f9 100644 Binary files a/experiment/figures/52/deblend_ellips_scatter.png and b/experiment/figures/52/deblend_ellips_scatter.png differ diff --git a/experiment/figures/52/deblend_flux_scatter.png b/experiment/figures/52/deblend_flux_scatter.png index 4b927bb6d..43ddc00e4 100644 Binary files a/experiment/figures/52/deblend_flux_scatter.png and b/experiment/figures/52/deblend_flux_scatter.png differ diff --git a/experiment/figures/52/deblend_size_scatter.png b/experiment/figures/52/deblend_size_scatter.png index 71499772f..f651e204e 100644 Binary files a/experiment/figures/52/deblend_size_scatter.png and b/experiment/figures/52/deblend_size_scatter.png differ diff --git a/experiment/figures/52/samples_bld_res.png b/experiment/figures/52/samples_bld_res.png index 9225f189c..fef78a188 100644 Binary files a/experiment/figures/52/samples_bld_res.png and b/experiment/figures/52/samples_bld_res.png differ diff --git a/experiment/figures/52/samples_snr_res.png b/experiment/figures/52/samples_snr_res.png index 371ea42d2..fc8e0c2c0 100644 Binary files a/experiment/figures/52/samples_snr_res.png and b/experiment/figures/52/samples_snr_res.png differ diff --git a/experiment/figures/52/snr_detection.png b/experiment/figures/52/snr_detection.png index d7629dd11..b16e1120e 100644 Binary files a/experiment/figures/52/snr_detection.png and b/experiment/figures/52/snr_detection.png differ diff --git a/experiment/figures/52/toy_residuals.png b/experiment/figures/52/toy_residuals.png index 8fa1aba0b..8399d49c5 100644 Binary files a/experiment/figures/52/toy_residuals.png and b/experiment/figures/52/toy_residuals.png differ diff --git a/experiment/models/deblender_52.pt b/experiment/models/deblender_52.pt index ee7289b7a..f9869c13f 100644 Binary files a/experiment/models/deblender_52.pt and b/experiment/models/deblender_52.pt differ diff --git a/experiment/models/deblender_52_best.pt b/experiment/models/deblender_52_best.pt new file mode 100644 index 000000000..ee7289b7a Binary files /dev/null and b/experiment/models/deblender_52_best.pt differ diff --git a/experiment/models/deblender_52_version5.pt b/experiment/models/deblender_52_version5.pt deleted file mode 100644 index f9869c13f..000000000 Binary files a/experiment/models/deblender_52_version5.pt and /dev/null differ diff --git a/experiment/scripts/figures/binary_figures.py b/experiment/scripts/figures/binary_figures.py index 7319291d1..500e25d4f 100644 --- a/experiment/scripts/figures/binary_figures.py +++ b/experiment/scripts/figures/binary_figures.py @@ -7,6 +7,7 @@ import torch from einops import rearrange, reduce from matplotlib import pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable from tqdm import tqdm from bliss.catalog import FullCatalog, TileCatalog @@ -80,6 +81,7 @@ def all_rcs(self) -> dict: return { "binary_scatter": {"fontsize": 34}, "binary_curves": {"fontsize": 36, "major_tick_size": 12, "minor_tick_size": 7}, + "binary_contours": {"fontsize": 34}, } @property @@ -88,7 +90,7 @@ def cache_name(self) -> str: @property def fignames(self) -> tuple[str, ...]: - return ("binary_scatter", "binary_curves") + return ("binary_scatter", "binary_curves", "binary_contours") def compute_data(self, ds_path: str, binary: BinaryEncoder): # metadata @@ -238,7 +240,7 @@ def _get_binary_scatter_figure(self, data: dict): ax.legend(markerscale=6, fontsize=28) ax.set_xlabel(r"\rm SNR") ax.set_ylabel(r"\rm Galaxy Classification Probability") - ax.set_xlim(1e-2, 1e4) + ax.set_xlim(1, 1e3) return fig @@ -250,7 +252,7 @@ def _get_binary_curves(self, data: dict): egbools = data[0.5]["egbools"] esbools = data[0.5]["esbools"] - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 20)) c1 = CLR_CYCLE[0] c2 = CLR_CYCLE[1] @@ -287,9 +289,62 @@ def _get_binary_curves(self, data: dict): return fig + def _get_binary_contours(self, data: dict): + snr = data["snr"] + probs = data["probs"] + tgbools = data["tgbools"] + tsbools = data["tsbools"] + + galaxy_mask = tgbools.astype(bool) & (snr > 0) + star_mask = tsbools.astype(bool) & (snr > 0) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), sharey=True) + + ax1.hist2d( + np.log10(snr[galaxy_mask]), + probs[galaxy_mask], + bins=20, + range=[[0, 3], [0, 1]], + cmap="PuBu", + norm="log", + vmin=1, + vmax=2e4, + ) + _xticks = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] + _xticks_labels = [f"$10^{int(x)}$" if x in [0.0, 1.0, 2.0, 3.0] else "" for x in _xticks] + ax1.set_xticks(ticks=_xticks, labels=_xticks_labels) + ax1.set_xlabel(r"\rm SNR") + ax1.set_ylabel(r"\rm Galaxy Classification Probability") + ax1.set_title(r"\rm Galaxies") + + _, _, _, pcm = ax2.hist2d( + np.log10(snr[star_mask]), + probs[star_mask], + bins=20, + range=[[0, 3], [0, 1]], + cmap="PuBu", + norm="log", + vmin=1, + vmax=2e4, + ) + ax2.set_xticks(ticks=_xticks, labels=_xticks_labels) + ax2.set_xlabel(r"\rm SNR") + ax2.set_title(r"\rm Stars") + + divider = make_axes_locatable(ax2) + cax = divider.append_axes("right", size="5%", pad=0.05) + + fig.colorbar(pcm, cax=cax, orientation="vertical") + + plt.tight_layout() + + return fig + def create_figure(self, fname: str, data): if fname == "binary_scatter": return self._get_binary_scatter_figure(data) if fname == "binary_curves": return self._get_binary_curves(data) + if fname == "binary_contours": + return self._get_binary_contours(data) raise ValueError(f"Unknown figure name: {fname}") diff --git a/experiment/scripts/figures/detection_figures.py b/experiment/scripts/figures/detection_figures.py index 4ea12a35c..2f3e85e79 100644 --- a/experiment/scripts/figures/detection_figures.py +++ b/experiment/scripts/figures/detection_figures.py @@ -1,5 +1,6 @@ """Script to create detection encoder related figures.""" +import numpy as np import torch from einops import rearrange, reduce from matplotlib import pyplot as plt @@ -244,13 +245,14 @@ def compute_data(self, ds_path: str, detection: DetectionEncoder): return { "snr": snr_dict, "blendedness": bld_dict, + "truth": {"snr": truth["snr"], "bld": truth["bld"]}, } def _get_snr_detection_figure(self, data): # make a 3 column figure with precision, recall, f1 for all thresholds + sep # colors for thresholds hsould go from blue (low) to red (high) threshold - fig, axs = plt.subplots(1, 3, figsize=(30, 10)) + fig, axs = plt.subplots(2, 2, figsize=(20, 20)) axs = axs.flatten() ds = data["snr"] @@ -290,12 +292,24 @@ def _get_snr_detection_figure(self, data): ax.set_ylim(0, 1.02) ax.legend() + # snr distribution + ax = axs[3] + snr_bins = ds["snr_bins"] + true_snr = data["truth"]["snr"].flatten() + _snr = true_snr[true_snr > 0] + _bins = [x.item() for x in snr_bins[:, 0]] + [snr_bins[-1, 1].item()] + # _, bins = np.histogram(np.log10(_snr), bins=_bins, range=(1, 1000)) + # logbins = np.logspace(np.log10(bins[0]), np.log10(bins[-1]), len(bins)) + ax.hist(_snr, bins=_bins, histtype="step") + ax.set_xscale("log") + ax.set_xlabel(r"\rm SNR") + plt.tight_layout() return fig def _get_blendedness_detection_figure(self, data): - fig, ax = plt.subplots(figsize=(6, 6)) + fig, ax1 = plt.subplots(1, 1, figsize=(6, 6)) ds = data["blendedness"] bld_bins = ds["bld_bins"] bld_middle = bld_bins.mean(axis=-1) @@ -303,15 +317,24 @@ def _get_blendedness_detection_figure(self, data): # recall for tsh, out in ds["thresh_out"].items(): color = plt.cm.coolwarm(tsh) - ax.plot(bld_middle, out["recall"], c=color, label=f"${tsh:.2f}$") - ax.plot(bld_middle, ds["sep"]["recall"], "--k", lw=2, label=r"\rm SEP") - ax.set_xlabel(r"\rm Blendedness") - ax.set_ylabel(r"\rm Recall") - ax.set_ylim(0, 1.02) - ax.set_xlim(1e-2, 1) - ax.set_xticks([1e-2, 1e-1, 1]) - ax.set_xscale("log") - ax.legend() + ax1.plot(bld_middle, out["recall"], c=color, label=f"${tsh:.2f}$") + ax1.plot(bld_middle, ds["sep"]["recall"], "--k", lw=2, label=r"\rm SEP") + ax1.set_xlabel(r"\rm Blendedness") + ax1.set_ylabel(r"\rm Recall") + ax1.set_ylim(0, 1.02) + ax1.set_xlim(1e-2, 1) + ax1.set_xticks([1e-2, 1e-1, 1]) + ax1.set_xscale("log") + ax1.legend() + + # blendedness histogram + # true_bld = data["truth"]["bld"].flatten() + # _bld = true_bld[true_bld > 0] + # # bld_bins = data["blendedness"]["bld_bins"] + # # _bins = [x.item() for x in bld_bins[:, 0]] + [bld_bins[-1, 1].item()] + # ax2.hist(_bld, bins=21, histtype="step") + # # ax2.set_xscale("log") + # ax2.set_xlabel(r"\rm Blendedness") plt.tight_layout() return fig diff --git a/experiment/scripts/figures/sampling_figures.py b/experiment/scripts/figures/sampling_figures.py index f7833c410..b58d77f0a 100755 --- a/experiment/scripts/figures/sampling_figures.py +++ b/experiment/scripts/figures/sampling_figures.py @@ -438,6 +438,8 @@ def _get_diagnostic_figures(*, out_dir: Path, results: dict, tag_txt: str): def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None: + n_bins = 11 + # need to sort things first!!!! sorted_indices = [out["idx"] for out in rslts["outs"]] true_fluxes = rslts["true_flux"][sorted_indices][:, 0, 0] @@ -470,14 +472,13 @@ def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None: res2 = (map_fluxes - true_fluxes) / true_fluxes res3 = (sep_fluxes - true_fluxes) / true_fluxes - print("# of images used:", sum(mask)) - print("# of discarded images (non-detections):", len(mask) - sum(mask)) + print("# of images used:", int(sum(mask))) + print("# of discarded images (non-detections):", int(len(mask) - sum(mask))) # get snr figure set_rc_params() # now snr - n_bins = 20 out1 = equal_sized_bin_statistic( x=true_snr.log10(), y=res1, n_bins=n_bins, xlims=(0.5, 3), statistic="median" ) @@ -491,32 +492,32 @@ def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None: x = 10 ** out1["middles"] fig, ax = plt.subplots(1, 1, figsize=(8, 6)) - ax.plot(x, out3["stats"], label=r"\rm SEP", marker="", color=CLR_CYCLE[2]) + ax.plot(x, np.abs(out3["stats"]), label=r"\rm SEP", marker="", color=CLR_CYCLE[2]) ax.fill_between( x, - out3["stats"] - out3["errs"], - out3["stats"] + out3["errs"], + np.abs(out3["stats"]) - out3["errs"], + np.abs(out3["stats"]) + out3["errs"], alpha=0.2, color=CLR_CYCLE[2], ) - ax.plot(x, out2["stats"], label=r"\rm MAP", marker="", color=CLR_CYCLE[0]) + ax.plot(x, np.abs(out2["stats"]), label=r"\rm MAP", marker="", color=CLR_CYCLE[0]) ax.fill_between( x, - out2["stats"] - out2["errs"], - out2["stats"] + out2["errs"], + np.abs(out2["stats"]) - out2["errs"], + np.abs(out2["stats"]) + out2["errs"], alpha=0.2, color=CLR_CYCLE[0], ) - ax.plot(x, out1["stats"], label=r"\rm Samples", marker="", color=CLR_CYCLE[1]) + ax.plot(x, np.abs(out1["stats"]), label=r"\rm Samples", marker="", color=CLR_CYCLE[1]) ax.fill_between( x, - out1["stats"] - out1["errs"], - out1["stats"] + out1["errs"], + np.abs(out1["stats"]) - out1["errs"], + np.abs(out1["stats"]) + out1["errs"], alpha=0.2, color=CLR_CYCLE[1], ) ax.set_xlabel(r"\rm SNR", fontsize=28) - ax.set_ylabel(r"$\frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}}$", fontsize=32) + ax.set_ylabel(r"$\lvert \frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}} \rvert$", fontsize=32) ax.axhline(0, color="k", linestyle="--", label=r"\rm Zero Residual") ax.legend() ax.set_xlim(5, 1000) @@ -525,58 +526,48 @@ def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None: # as a function of blendedness # first define bins (as described in paper) + n_bins = 21 qs = torch.linspace(0.12, 0.99, 21) edges = bld.quantile(qs) bins = torch.tensor([0.0, *edges[1:-1], 1.0]) + print(f"Edge BLD 1: {edges[1]:.10f}") + print(f"Edge BLD -2: {edges[-2]:.10f}") - out1 = binned_statistic( - x=bld, - y=res1, - bins=bins, - statistic="median", - ) - out2 = binned_statistic( - x=bld, - y=res2, - bins=bins, - statistic="median", - ) - out3 = binned_statistic( - x=bld, - y=res3, - bins=bins, - statistic="median", - ) + out1 = binned_statistic(x=bld, y=res1, bins=bins, statistic="median") + out2 = binned_statistic(x=bld, y=res2, bins=bins, statistic="median") + out3 = binned_statistic(x=bld, y=res3, bins=bins, statistic="median") fig, ax = plt.subplots(1, 1, figsize=(8, 6)) - ax.plot(out3["middles"], out3["stats"], label=r"\rm SEP", marker="", color=CLR_CYCLE[2]) + ax.plot(out3["middles"], np.abs(out3["stats"]), label=r"\rm SEP", marker="", color=CLR_CYCLE[2]) ax.fill_between( out3["middles"], - out3["stats"] - out3["errs"], - out3["stats"] + out3["errs"], + np.abs(out3["stats"]) - out3["errs"], + np.abs(out3["stats"]) + out3["errs"], alpha=0.2, color=CLR_CYCLE[2], ) - ax.plot(out2["middles"], out2["stats"], label=r"\rm MAP", marker="", color=CLR_CYCLE[0]) + ax.plot(out2["middles"], np.abs(out2["stats"]), label=r"\rm MAP", marker="", color=CLR_CYCLE[0]) ax.fill_between( out2["middles"], - out2["stats"] - out2["errs"], - out2["stats"] + out2["errs"], + np.abs(out2["stats"]) - out2["errs"], + np.abs(out2["stats"]) + out2["errs"], alpha=0.2, color=CLR_CYCLE[0], ) - ax.plot(out1["middles"], out1["stats"], label=r"\rm Samples", marker="", color=CLR_CYCLE[1]) + ax.plot( + out1["middles"], np.abs(out1["stats"]), label=r"\rm Samples", marker="", color=CLR_CYCLE[1] + ) ax.fill_between( out1["middles"], - out1["stats"] - out1["errs"], - out1["stats"] + out1["errs"], + np.abs(out1["stats"]) - out1["errs"], + np.abs(out1["stats"]) + out1["errs"], alpha=0.2, color=CLR_CYCLE[1], ) ax.set_xlabel(r"\rm Blendedness", fontsize=28) - ax.set_ylabel(r"$\frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}}$", fontsize=32) + ax.set_ylabel(r"$\lvert \frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}} \rvert$", fontsize=32) ax.set_yscale("log") - ax.set_ylim(0.004, 10) + ax.set_ylim(0.004, 20) ax.legend(prop={"size": 22}) fig.savefig(out_dir / "samples_bld_res.png", dpi=500, bbox_inches="tight") @@ -639,7 +630,6 @@ def main( sources=ds["uncentered_sources"], no_bar=False, ) - true_snr = true_meas["snr"] # lets get models detection = DetectionEncoder().to(device).eval() @@ -671,7 +661,7 @@ def main( { "outs": outs, "bld": bld, - "true_snr": true_snr, + "true_snr": true_meas["snr"], "true_flux": true_meas["flux"], "true_plocs": truth.plocs, "true_n_sources": truth.n_sources, diff --git a/experiment/scripts/figures/toy_figures.py b/experiment/scripts/figures/toy_figures.py index f5dd4a7bc..0023af2e8 100644 --- a/experiment/scripts/figures/toy_figures.py +++ b/experiment/scripts/figures/toy_figures.py @@ -165,7 +165,7 @@ def compute_data(self, detection: DetectionEncoder, deblender: GalaxyEncoder): assert recon_ptiles.shape[-1] == recon_ptiles.shape[-2] == ptile_slen recon = reconstruct_image_from_ptiles(recon_ptiles, tile_slen) recon = recon.detach().cpu() - residuals = (recon - images) / (recon + bg).sqrt() + residuals = (recon - images) / torch.sqrt(bg) # now we need to obtain pred. plocs, prob. of detection in tile and std. of plocs # for each source @@ -302,6 +302,17 @@ def _get_residuals_figure(self, data) -> Figure: pad = 6.0 fig, axes = plt.subplots(nrows=n_examples, ncols=3, figsize=(11, 18)) + vmin, vmax = torch.inf, -torch.inf + vmin_res, vmax_res = torch.inf, -torch.inf + for ii in range(n_examples): + vmin = min(images[ii].min().item(), recons[ii].min().item(), vmin) + vmax = max(images[ii].max().item(), recons[ii].min().item(), vmax) + vmin_res = min(residuals[ii].min().item(), vmin_res) + vmax_res = max(residuals[ii].max().item(), vmax_res) + + vres = max(abs(vmin_res), abs(vmax_res)) + + # vres = max(abs(vmin_res), abs(vmax_res)) for i in range(n_examples): ax_true = axes[i, 0] ax_recon = axes[i, 1] @@ -312,7 +323,7 @@ def _get_residuals_figure(self, data) -> Figure: ax_true.set_title(r"\rm Images $x$", pad=pad) ax_recon.set_title(r"\rm Reconstruction $\tilde{x}$", pad=pad) ax_res.set_title( - r"\rm Residual $\left(\tilde{x} - x\right) / \sqrt{\tilde{x} + b}$", + r"\rm Residual $\left(\tilde{x} - x\right) / \sqrt{b}$", pad=pad, fontsize=18, ) @@ -353,12 +364,6 @@ def _get_residuals_figure(self, data) -> Figure: recon = recons[i] res = residuals[i] - vmin = min(image.min().item(), recon.min().item()) - vmax = max(image.max().item(), recon.max().item()) - vmin_res = res.min().item() - vmax_res = res.max().item() - vres = max(abs(vmin_res), abs(vmax_res)) - # plot images plot_image(fig, ax_true, image, vrange=(vmin, vmax)) plot_image(fig, ax_recon, recon, vrange=(vmin, vmax))