Skip to content

Commit 312644a

Browse files
authored
Merge pull request #160 from karllark/enhance_model_plot
Enhance model plot
2 parents 8e73103 + 9db7e32 commit 312644a

File tree

2 files changed

+56
-11
lines changed

2 files changed

+56
-11
lines changed

measure_extinction/model.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import numpy as np
33
import matplotlib.pyplot as plt
4+
from matplotlib.ticker import ScalarFormatter
45
import astropy.units as u
56
from astropy.table import QTable
67
import scipy.optimize as op
@@ -417,7 +418,7 @@ def stellar_sed(self, moddata):
417418
# check for any zero distance cases, requested parameters are directly on a model
418419
# in this case set the distance to 0.01 of the min distance so this model dominates
419420
tvals = dist2[gsindxs] == 0.0
420-
if (sum(tvals) > 0):
421+
if sum(tvals) > 0:
421422
dist2[gsindxs[tvals]] = 0.01 * np.min(dist2[gsindxs[~tvals]])
422423

423424
weights = 1.0 / np.sqrt(dist2[gsindxs])
@@ -964,7 +965,14 @@ def fit_sampler(
964965

965966
return (outmod, flat_samples, sampler)
966967

967-
def plot(self, obsdata, modinfo, resid_range=10.0, lyaplot=False):
968+
def plot(
969+
self,
970+
obsdata,
971+
modinfo,
972+
resid_range=10.0,
973+
lyaplot=False,
974+
xticks=[0.1, 0.2, 0.3, 0.5, 0.7, 1.0, 2.0],
975+
):
968976
"""
969977
Standard plot showing the data and best fit.
970978
@@ -978,6 +986,12 @@ def plot(self, obsdata, modinfo, resid_range=10.0, lyaplot=False):
978986
979987
resid_range : float
980988
percentage value for the +/- range for the residual plot
989+
990+
lyaplot : boolean
991+
set to add two panels giving the Ly-alpha fit and residuals
992+
993+
xticks : vector
994+
set to a vector of floats giving the values for the xticks
981995
"""
982996
# plotting setup for easier to read plots
983997
fontsize = 16
@@ -1035,7 +1049,7 @@ def plot(self, obsdata, modinfo, resid_range=10.0, lyaplot=False):
10351049
for cspec in obsdata.data.keys():
10361050
if cspec == "BAND":
10371051
ptype = "o"
1038-
rcolor = "g"
1052+
rcolor = "k"
10391053
else:
10401054
ptype = "-"
10411055
rcolor = "k"
@@ -1052,7 +1066,13 @@ def plot(self, obsdata, modinfo, resid_range=10.0, lyaplot=False):
10521066
multval = multlam * nvals
10531067

10541068
if first_pass:
1055-
plabs = ["Obs", "Star", "w/ Foreground", "w/ Dust Ext", "w/ Dust+Gas Ext"]
1069+
plabs = [
1070+
"Obs",
1071+
"Star",
1072+
"w/ Foreground",
1073+
"w/ Dust Ext",
1074+
"w/ Dust+Gas Ext",
1075+
]
10561076
else:
10571077
plabs = [None, None, None, None, None]
10581078

@@ -1061,15 +1081,28 @@ def plot(self, obsdata, modinfo, resid_range=10.0, lyaplot=False):
10611081
cax.plot(
10621082
cwaves, modsed_nofore[cspec] * multlam, "c" + ptype, alpha=0.2
10631083
)
1064-
cax.plot(cwaves, modsed_nofore[cspec] * multval, "c" + ptype, label=plabs[1])
1065-
cax.plot(cwaves, modsed[cspec] * multval, "b" + ptype, label=plabs[2])
1084+
cax.plot(
1085+
cwaves,
1086+
modsed_nofore[cspec] * multval,
1087+
"c" + ptype,
1088+
label=plabs[1],
1089+
)
1090+
cax.plot(
1091+
cwaves, modsed[cspec] * multval, "b" + ptype, label=plabs[2]
1092+
)
10661093
else:
1067-
cax.plot(cwaves, modsed[cspec] * multval, "b" + ptype, label=plabs[1])
1094+
cax.plot(
1095+
cwaves, modsed[cspec] * multval, "b" + ptype, label=plabs[1]
1096+
)
10681097
cax.plot(cwaves, modsed[cspec] * multlam, "b" + ptype, alpha=0.2)
10691098
cax.plot(cwaves, ext_modsed[cspec] * multlam, "g" + ptype, alpha=0.2)
1070-
cax.plot(cwaves, ext_modsed[cspec] * multval, "g" + ptype, label=plabs[3])
1099+
cax.plot(
1100+
cwaves, ext_modsed[cspec] * multval, "g" + ptype, label=plabs[3]
1101+
)
10711102
cax.plot(cwaves, hi_ext_modsed[cspec] * multlam, "r" + ptype, alpha=0.2)
1072-
cax.plot(cwaves, hi_ext_modsed[cspec] * multval, "r" + ptype, label=plabs[4])
1103+
cax.plot(
1104+
cwaves, hi_ext_modsed[cspec] * multval, "r" + ptype, label=plabs[4]
1105+
)
10731106

10741107
gvals = obsdata.data[cspec].fluxes > 0.0
10751108
cax.plot(
@@ -1151,6 +1184,13 @@ def plot(self, obsdata, modinfo, resid_range=10.0, lyaplot=False):
11511184
axes[1].set_xscale("log")
11521185
ax.set_yscale("log")
11531186

1187+
if xticks is not None:
1188+
for tax in [ax, axes[1]]:
1189+
tax.xaxis.set_major_formatter(ScalarFormatter())
1190+
tax.xaxis.set_minor_formatter(ScalarFormatter())
1191+
tax.set_xticks(xticks, minor=True)
1192+
tax.tick_params(axis="x", which="minor", labelsize=fontsize * 0.8)
1193+
11541194
ydelt = yrange[1] - yrange[0]
11551195
yrange[0] = 10 ** (yrange[0] - 0.1 * ydelt)
11561196
yrange[1] = 10 ** (yrange[1] + 0.1 * ydelt)

measure_extinction/plotting/plot_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def main():
1818
choices=["chains", "corner", "bestfit"],
1919
default="bestfit",
2020
)
21-
parser.add_argument("--residrange", help="residual range in percentage", default=50.0, type=float)
21+
parser.add_argument(
22+
"--residrange", help="residual range in percentage", default=50.0, type=float
23+
)
2224
parser.add_argument("--burnfrac", help="burn fraction", default=0.5, type=float)
2325
parser.add_argument(
2426
"--obspath",
@@ -117,7 +119,9 @@ def main():
117119
fig = memod.plot_sampler_corner(flat_samples)
118120
save_str = "_mefit_corner"
119121
else:
120-
fig = memod.plot(reddened_star, modinfo, lyaplot=True, resid_range=args.residrange)
122+
fig = memod.plot(
123+
reddened_star, modinfo, lyaplot=True, resid_range=args.residrange
124+
)
121125
save_str = "_mefit_mcmc"
122126

123127
# plot or save to a file
@@ -128,5 +132,6 @@ def main():
128132
else:
129133
plt.show()
130134

135+
131136
if __name__ == "__main__":
132137
main()

0 commit comments

Comments
 (0)