Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added experiment/figures/52/binary_contours.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/binary_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/binary_scatter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/deblend_bins.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/deblend_bins_medians.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/deblend_ellips_scatter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/deblend_flux_scatter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/deblend_size_scatter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/samples_bld_res.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/samples_snr_res.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/snr_detection.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/figures/52/toy_residuals.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified experiment/models/deblender_52.pt
Binary file not shown.
Binary file added experiment/models/deblender_52_best.pt
Binary file not shown.
Binary file removed experiment/models/deblender_52_version5.pt
Binary file not shown.
61 changes: 58 additions & 3 deletions experiment/scripts/figures/binary_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from einops import rearrange, reduce
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from tqdm import tqdm

from bliss.catalog import FullCatalog, TileCatalog
Expand Down Expand Up @@ -80,6 +81,7 @@ def all_rcs(self) -> dict:
return {
"binary_scatter": {"fontsize": 34},
"binary_curves": {"fontsize": 36, "major_tick_size": 12, "minor_tick_size": 7},
"binary_contours": {"fontsize": 34},
}

@property
Expand All @@ -88,7 +90,7 @@ def cache_name(self) -> str:

@property
def fignames(self) -> tuple[str, ...]:
return ("binary_scatter", "binary_curves")
return ("binary_scatter", "binary_curves", "binary_contours")

def compute_data(self, ds_path: str, binary: BinaryEncoder):
# metadata
Expand Down Expand Up @@ -238,7 +240,7 @@ def _get_binary_scatter_figure(self, data: dict):
ax.legend(markerscale=6, fontsize=28)
ax.set_xlabel(r"\rm SNR")
ax.set_ylabel(r"\rm Galaxy Classification Probability")
ax.set_xlim(1e-2, 1e4)
ax.set_xlim(1, 1e3)

return fig

Expand All @@ -250,7 +252,7 @@ def _get_binary_curves(self, data: dict):
egbools = data[0.5]["egbools"]
esbools = data[0.5]["esbools"]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 20))

c1 = CLR_CYCLE[0]
c2 = CLR_CYCLE[1]
Expand Down Expand Up @@ -287,9 +289,62 @@ def _get_binary_curves(self, data: dict):

return fig

def _get_binary_contours(self, data: dict):
snr = data["snr"]
probs = data["probs"]
tgbools = data["tgbools"]
tsbools = data["tsbools"]

galaxy_mask = tgbools.astype(bool) & (snr > 0)
star_mask = tsbools.astype(bool) & (snr > 0)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), sharey=True)

ax1.hist2d(
np.log10(snr[galaxy_mask]),
probs[galaxy_mask],
bins=20,
range=[[0, 3], [0, 1]],
cmap="PuBu",
norm="log",
vmin=1,
vmax=2e4,
)
_xticks = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
_xticks_labels = [f"$10^{int(x)}$" if x in [0.0, 1.0, 2.0, 3.0] else "" for x in _xticks]
ax1.set_xticks(ticks=_xticks, labels=_xticks_labels)
ax1.set_xlabel(r"\rm SNR")
ax1.set_ylabel(r"\rm Galaxy Classification Probability")
ax1.set_title(r"\rm Galaxies")

_, _, _, pcm = ax2.hist2d(
np.log10(snr[star_mask]),
probs[star_mask],
bins=20,
range=[[0, 3], [0, 1]],
cmap="PuBu",
norm="log",
vmin=1,
vmax=2e4,
)
ax2.set_xticks(ticks=_xticks, labels=_xticks_labels)
ax2.set_xlabel(r"\rm SNR")
ax2.set_title(r"\rm Stars")

divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="5%", pad=0.05)

fig.colorbar(pcm, cax=cax, orientation="vertical")

plt.tight_layout()

return fig

def create_figure(self, fname: str, data):
if fname == "binary_scatter":
return self._get_binary_scatter_figure(data)
if fname == "binary_curves":
return self._get_binary_curves(data)
if fname == "binary_contours":
return self._get_binary_contours(data)
raise ValueError(f"Unknown figure name: {fname}")
45 changes: 34 additions & 11 deletions experiment/scripts/figures/detection_figures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Script to create detection encoder related figures."""

import numpy as np
import torch
from einops import rearrange, reduce
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -244,13 +245,14 @@ def compute_data(self, ds_path: str, detection: DetectionEncoder):
return {
"snr": snr_dict,
"blendedness": bld_dict,
"truth": {"snr": truth["snr"], "bld": truth["bld"]},
}

def _get_snr_detection_figure(self, data):
# make a 3 column figure with precision, recall, f1 for all thresholds + sep
# colors for thresholds hsould go from blue (low) to red (high) threshold

fig, axs = plt.subplots(1, 3, figsize=(30, 10))
fig, axs = plt.subplots(2, 2, figsize=(20, 20))
axs = axs.flatten()
ds = data["snr"]

Expand Down Expand Up @@ -290,28 +292,49 @@ def _get_snr_detection_figure(self, data):
ax.set_ylim(0, 1.02)
ax.legend()

# snr distribution
ax = axs[3]
snr_bins = ds["snr_bins"]
true_snr = data["truth"]["snr"].flatten()
_snr = true_snr[true_snr > 0]
_bins = [x.item() for x in snr_bins[:, 0]] + [snr_bins[-1, 1].item()]
# _, bins = np.histogram(np.log10(_snr), bins=_bins, range=(1, 1000))
# logbins = np.logspace(np.log10(bins[0]), np.log10(bins[-1]), len(bins))
ax.hist(_snr, bins=_bins, histtype="step")
ax.set_xscale("log")
ax.set_xlabel(r"\rm SNR")

plt.tight_layout()

return fig

def _get_blendedness_detection_figure(self, data):
fig, ax = plt.subplots(figsize=(6, 6))
fig, ax1 = plt.subplots(1, 1, figsize=(6, 6))
ds = data["blendedness"]
bld_bins = ds["bld_bins"]
bld_middle = bld_bins.mean(axis=-1)

# recall
for tsh, out in ds["thresh_out"].items():
color = plt.cm.coolwarm(tsh)
ax.plot(bld_middle, out["recall"], c=color, label=f"${tsh:.2f}$")
ax.plot(bld_middle, ds["sep"]["recall"], "--k", lw=2, label=r"\rm SEP")
ax.set_xlabel(r"\rm Blendedness")
ax.set_ylabel(r"\rm Recall")
ax.set_ylim(0, 1.02)
ax.set_xlim(1e-2, 1)
ax.set_xticks([1e-2, 1e-1, 1])
ax.set_xscale("log")
ax.legend()
ax1.plot(bld_middle, out["recall"], c=color, label=f"${tsh:.2f}$")
ax1.plot(bld_middle, ds["sep"]["recall"], "--k", lw=2, label=r"\rm SEP")
ax1.set_xlabel(r"\rm Blendedness")
ax1.set_ylabel(r"\rm Recall")
ax1.set_ylim(0, 1.02)
ax1.set_xlim(1e-2, 1)
ax1.set_xticks([1e-2, 1e-1, 1])
ax1.set_xscale("log")
ax1.legend()

# blendedness histogram
# true_bld = data["truth"]["bld"].flatten()
# _bld = true_bld[true_bld > 0]
# # bld_bins = data["blendedness"]["bld_bins"]
# # _bins = [x.item() for x in bld_bins[:, 0]] + [bld_bins[-1, 1].item()]
# ax2.hist(_bld, bins=21, histtype="step")
# # ax2.set_xscale("log")
# ax2.set_xlabel(r"\rm Blendedness")

plt.tight_layout()
return fig
Expand Down
78 changes: 34 additions & 44 deletions experiment/scripts/figures/sampling_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ def _get_diagnostic_figures(*, out_dir: Path, results: dict, tag_txt: str):


def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None:
n_bins = 11

# need to sort things first!!!!
sorted_indices = [out["idx"] for out in rslts["outs"]]
true_fluxes = rslts["true_flux"][sorted_indices][:, 0, 0]
Expand Down Expand Up @@ -470,14 +472,13 @@ def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None:
res2 = (map_fluxes - true_fluxes) / true_fluxes
res3 = (sep_fluxes - true_fluxes) / true_fluxes

print("# of images used:", sum(mask))
print("# of discarded images (non-detections):", len(mask) - sum(mask))
print("# of images used:", int(sum(mask)))
print("# of discarded images (non-detections):", int(len(mask) - sum(mask)))

# get snr figure
set_rc_params()

# now snr
n_bins = 20
out1 = equal_sized_bin_statistic(
x=true_snr.log10(), y=res1, n_bins=n_bins, xlims=(0.5, 3), statistic="median"
)
Expand All @@ -491,32 +492,32 @@ def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None:
x = 10 ** out1["middles"]

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(x, out3["stats"], label=r"\rm SEP", marker="", color=CLR_CYCLE[2])
ax.plot(x, np.abs(out3["stats"]), label=r"\rm SEP", marker="", color=CLR_CYCLE[2])
ax.fill_between(
x,
out3["stats"] - out3["errs"],
out3["stats"] + out3["errs"],
np.abs(out3["stats"]) - out3["errs"],
np.abs(out3["stats"]) + out3["errs"],
alpha=0.2,
color=CLR_CYCLE[2],
)
ax.plot(x, out2["stats"], label=r"\rm MAP", marker="", color=CLR_CYCLE[0])
ax.plot(x, np.abs(out2["stats"]), label=r"\rm MAP", marker="", color=CLR_CYCLE[0])
ax.fill_between(
x,
out2["stats"] - out2["errs"],
out2["stats"] + out2["errs"],
np.abs(out2["stats"]) - out2["errs"],
np.abs(out2["stats"]) + out2["errs"],
alpha=0.2,
color=CLR_CYCLE[0],
)
ax.plot(x, out1["stats"], label=r"\rm Samples", marker="", color=CLR_CYCLE[1])
ax.plot(x, np.abs(out1["stats"]), label=r"\rm Samples", marker="", color=CLR_CYCLE[1])
ax.fill_between(
x,
out1["stats"] - out1["errs"],
out1["stats"] + out1["errs"],
np.abs(out1["stats"]) - out1["errs"],
np.abs(out1["stats"]) + out1["errs"],
alpha=0.2,
color=CLR_CYCLE[1],
)
ax.set_xlabel(r"\rm SNR", fontsize=28)
ax.set_ylabel(r"$\frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}}$", fontsize=32)
ax.set_ylabel(r"$\lvert \frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}} \rvert$", fontsize=32)
ax.axhline(0, color="k", linestyle="--", label=r"\rm Zero Residual")
ax.legend()
ax.set_xlim(5, 1000)
Expand All @@ -525,58 +526,48 @@ def _make_final_results_figures(*, out_dir: Path, rslts: dict) -> None:

# as a function of blendedness
# first define bins (as described in paper)
n_bins = 21
qs = torch.linspace(0.12, 0.99, 21)
edges = bld.quantile(qs)
bins = torch.tensor([0.0, *edges[1:-1], 1.0])
print(f"Edge BLD 1: {edges[1]:.10f}")
print(f"Edge BLD -2: {edges[-2]:.10f}")

out1 = binned_statistic(
x=bld,
y=res1,
bins=bins,
statistic="median",
)
out2 = binned_statistic(
x=bld,
y=res2,
bins=bins,
statistic="median",
)
out3 = binned_statistic(
x=bld,
y=res3,
bins=bins,
statistic="median",
)
out1 = binned_statistic(x=bld, y=res1, bins=bins, statistic="median")
out2 = binned_statistic(x=bld, y=res2, bins=bins, statistic="median")
out3 = binned_statistic(x=bld, y=res3, bins=bins, statistic="median")

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(out3["middles"], out3["stats"], label=r"\rm SEP", marker="", color=CLR_CYCLE[2])
ax.plot(out3["middles"], np.abs(out3["stats"]), label=r"\rm SEP", marker="", color=CLR_CYCLE[2])
ax.fill_between(
out3["middles"],
out3["stats"] - out3["errs"],
out3["stats"] + out3["errs"],
np.abs(out3["stats"]) - out3["errs"],
np.abs(out3["stats"]) + out3["errs"],
alpha=0.2,
color=CLR_CYCLE[2],
)
ax.plot(out2["middles"], out2["stats"], label=r"\rm MAP", marker="", color=CLR_CYCLE[0])
ax.plot(out2["middles"], np.abs(out2["stats"]), label=r"\rm MAP", marker="", color=CLR_CYCLE[0])
ax.fill_between(
out2["middles"],
out2["stats"] - out2["errs"],
out2["stats"] + out2["errs"],
np.abs(out2["stats"]) - out2["errs"],
np.abs(out2["stats"]) + out2["errs"],
alpha=0.2,
color=CLR_CYCLE[0],
)
ax.plot(out1["middles"], out1["stats"], label=r"\rm Samples", marker="", color=CLR_CYCLE[1])
ax.plot(
out1["middles"], np.abs(out1["stats"]), label=r"\rm Samples", marker="", color=CLR_CYCLE[1]
)
ax.fill_between(
out1["middles"],
out1["stats"] - out1["errs"],
out1["stats"] + out1["errs"],
np.abs(out1["stats"]) - out1["errs"],
np.abs(out1["stats"]) + out1["errs"],
alpha=0.2,
color=CLR_CYCLE[1],
)
ax.set_xlabel(r"\rm Blendedness", fontsize=28)
ax.set_ylabel(r"$\frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}}$", fontsize=32)
ax.set_ylabel(r"$\lvert \frac{f_{\rm pred} - f_{\rm true}}{f_{\rm true}} \rvert$", fontsize=32)
ax.set_yscale("log")
ax.set_ylim(0.004, 10)
ax.set_ylim(0.004, 20)
ax.legend(prop={"size": 22})
fig.savefig(out_dir / "samples_bld_res.png", dpi=500, bbox_inches="tight")

Expand Down Expand Up @@ -639,7 +630,6 @@ def main(
sources=ds["uncentered_sources"],
no_bar=False,
)
true_snr = true_meas["snr"]

# lets get models
detection = DetectionEncoder().to(device).eval()
Expand Down Expand Up @@ -671,7 +661,7 @@ def main(
{
"outs": outs,
"bld": bld,
"true_snr": true_snr,
"true_snr": true_meas["snr"],
"true_flux": true_meas["flux"],
"true_plocs": truth.plocs,
"true_n_sources": truth.n_sources,
Expand Down
Loading