Skip to content

Commit 5ccd9f5

Browse files
committed
changed model plotting routine
1 parent 40c4432 commit 5ccd9f5

File tree

1 file changed

+83
-13
lines changed

1 file changed

+83
-13
lines changed

scripts/plotting/plot_model.py

Lines changed: 83 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,107 @@
33

44
import argparse, yaml, pprint, os, shutil, datetime, sys, pickle
55
import numpy as np
6+
import matplotlib.pyplot as plt
7+
from matplotlib.backends.backend_pdf import PdfPages
8+
from matplotlib.patches import FancyArrowPatch
69
from cyclejet.cyclegan import CycleGAN
7-
from cyclejet.tools import loss_calc, plot_model
10+
from cyclejet.tools import loss_calc, plot_model, xval, yval
811
from cyclejet.scripts.run import load_yaml
12+
from random import randrange
13+
14+
def plot_event(fn, refA, refB, predictA, predictB,
15+
predictA2, predictB2, averager, titleA=None, titleB=None):
16+
with PdfPages(fn) as pdf:
17+
fig, axs = plt.subplots(2,3, figsize=(7.5,5.5))
18+
plt.subplots_adjust(wspace=0.5,hspace=0.52)
19+
i = randrange(len(refA))
20+
# figtr = fig.transFigure.inverted()
21+
# ptB = figtr.transform(ax0tr.transform((225., -10.)))
22+
# ptE = figtr.transform(ax1tr.transform((225., 1.)))
23+
# arrow=FancyArrowPatch(
24+
# ptB, ptE, transform=fig.transFigure, # Place arrow in figure coord system
25+
# fc = "g", connectionstyle="arc3,rad=0.2", arrowstyle='simple', alpha = 0.3,
26+
# mutation_scale = 40.)
27+
# fig.patches.append(arrow)
28+
axs[0,0].imshow(refA[i].transpose(),vmin=0.0,vmax=0.2,origin='lower',
29+
aspect='auto', extent=[xval[0], xval[1], yval[0], yval[1]])
30+
axs[0,0].set_title('A' if not titleA else titleA)
31+
axs[0,0].set_xticks([])
32+
axs[0,0].set_yticks([])
33+
axs[0,1].imshow(predictB[i].transpose(),vmin=0.0,vmax=0.2,origin='lower',
34+
aspect='auto', extent=[xval[0], xval[1], yval[0], yval[1]])
35+
# axs[0,1].set_xlabel('$\ln(1 / \Delta_{ab})$')
36+
# axs[0,1].set_ylabel('$\ln(k_{t} / \mathrm{GeV})$',labelpad=-2)
37+
axs[0,1].set_title('B' if not titleB else titleB)
38+
axs[0,1].set_xticks([])
39+
axs[0,1].set_yticks([])
40+
axs[0,2].imshow(predictB2[i].transpose(),vmin=0.0,vmax=0.2,origin='lower',
41+
aspect='auto', extent=[xval[0], xval[1], yval[0], yval[1]])
42+
axs[0,2].set_title('A' if not titleA else titleA)
43+
44+
axs[0,2].set_xticks([])
45+
axs[0,2].set_yticks([])
46+
axs[1,0].imshow(averager.inverse(refA)[i].transpose(),
47+
vmin=0.0,vmax=0.2,origin='lower',aspect='auto',
48+
extent=[xval[0], xval[1], yval[0], yval[1]])
49+
#axs[1,0].set_title('A' if not titleA else titleA)
50+
axs[1,0].set_xticks([])
51+
axs[1,0].set_yticks([])
52+
axs[1,1].imshow(averager.inverse(predictB)[i].transpose(),
53+
vmin=0.0,vmax=0.2,origin='lower',aspect='auto',
54+
extent=[xval[0], xval[1], yval[0], yval[1]])
55+
# axs[0,1].set_xlabel('$\ln(1 / \Delta_{ab})$')
56+
# axs[0,1].set_ylabel('$\ln(k_{t} / \mathrm{GeV})$',labelpad=-2)
57+
#axs[1,1].set_title('B' if not titleB else titleB)
58+
axs[1,1].set_xticks([])
59+
axs[1,1].set_yticks([])
60+
axs[1,2].imshow(averager.inverse(predictB2)[i].transpose(),
61+
vmin=0.0,vmax=0.2,origin='lower',aspect='auto',
62+
extent=[xval[0], xval[1], yval[0], yval[1]])
63+
#axs[1,2].set_title('A' if not titleA else titleA)
64+
axs[1,2].set_xticks([])
65+
axs[1,2].set_yticks([])
66+
# pdf.savefig()
67+
# plt.close()
68+
69+
# fig, axs = plt.subplots(3, 2, figsize=(6,8))
70+
plt.close()
71+
pdf.savefig(fig)
72+
973

1074
def main(args):
11-
if os.path.isfile(args.model.strip('/')+'/best-model.yaml'):
12-
fn=args.model.strip('/')+'/best-model.yaml'
75+
model=args.model.strip('/')
76+
if os.path.isfile(model+'/best-model.yaml'):
77+
fn=model+'/best-model.yaml'
1378
else:
14-
fn=args.model.strip('/')+'/input-runcard.json'
79+
fn=model+'/input-runcard.json'
1580
hps=load_yaml(fn)
1681
cgan = CycleGAN(hps)
17-
cgan.load(args.model.strip('/'))
82+
cgan.load(model)
1883
refA=np.array(cgan.imagesA)
1984
refB=np.array(cgan.imagesB)
2085
# generating predicted sample
2186
predictA=cgan.g_BA.predict(refA)
2287
predictB=cgan.g_AB.predict(refB)
88+
predictA2=cgan.g_AB.predict(predictA)
89+
predictB2=cgan.g_BA.predict(predictB)
2390
refA = cgan.preproc.inverse(refA)
2491
refB = cgan.preproc.inverse(refB)
2592
predictA = cgan.preproc.inverse(predictA)
2693
predictB = cgan.preproc.inverse(predictB)
2794
if args.savefull:
28-
np.save('%s/referenceA'%args.model.strip('/'), refA)
29-
np.save('%s/referenceB'%args.model.strip('/'), refB)
30-
np.save('%s/predictedA'%args.model.strip('/'), predictA)
31-
np.save('%s/predictedB'%args.model.strip('/'), predictB)
95+
np.save('%s/referenceA'%model, refA)
96+
np.save('%s/referenceB'%model, refB)
97+
np.save('%s/predictedA'%model, predictA)
98+
np.save('%s/predictedB'%model, predictB)
3299

33-
# now create diagnostic plots
34-
figfn='%s/result.pdf' % args.model.strip('/')
35-
plot_model(figfn, refA, refB, predictA, predictB,
100+
# now create plots
101+
figfn1='%s/result.pdf' % model
102+
plot_model(figfn1, refA, refB, predictA, predictB,
103+
titleA=args.titleA, titleB=args.titleB)
104+
figfn2='%s/result_event.pdf' % model
105+
plot_event(figfn2, refA, refB, predictA, predictB,
106+
predictA, predictB, cgan.avg,
36107
titleA=args.titleA, titleB=args.titleB)
37108

38109
#----------------------------------------------------------------------
@@ -46,7 +117,6 @@ def main(args):
46117
help='Title of sample A.')
47118
parser.add_argument('--titleB', type=str, default=None,
48119
help='Title of sample A.')
49-
50120
parser.add_argument('--savefull', action='store_true')
51121
args = parser.parse_args()
52122
main(args)

0 commit comments

Comments
 (0)