33
44import argparse , yaml , pprint , os , shutil , datetime , sys , pickle
55import numpy as np
6+ import matplotlib .pyplot as plt
7+ from matplotlib .backends .backend_pdf import PdfPages
8+ from matplotlib .patches import FancyArrowPatch
69from cyclejet .cyclegan import CycleGAN
7- from cyclejet .tools import loss_calc , plot_model
10+ from cyclejet .tools import loss_calc , plot_model , xval , yval
811from cyclejet .scripts .run import load_yaml
12+ from random import randrange
13+
14+ def plot_event (fn , refA , refB , predictA , predictB ,
15+ predictA2 , predictB2 , averager , titleA = None , titleB = None ):
16+ with PdfPages (fn ) as pdf :
17+ fig , axs = plt .subplots (2 ,3 , figsize = (7.5 ,5.5 ))
18+ plt .subplots_adjust (wspace = 0.5 ,hspace = 0.52 )
19+ i = randrange (len (refA ))
20+ # figtr = fig.transFigure.inverted()
21+ # ptB = figtr.transform(ax0tr.transform((225., -10.)))
22+ # ptE = figtr.transform(ax1tr.transform((225., 1.)))
23+ # arrow=FancyArrowPatch(
24+ # ptB, ptE, transform=fig.transFigure, # Place arrow in figure coord system
25+ # fc = "g", connectionstyle="arc3,rad=0.2", arrowstyle='simple', alpha = 0.3,
26+ # mutation_scale = 40.)
27+ # fig.patches.append(arrow)
28+ axs [0 ,0 ].imshow (refA [i ].transpose (),vmin = 0.0 ,vmax = 0.2 ,origin = 'lower' ,
29+ aspect = 'auto' , extent = [xval [0 ], xval [1 ], yval [0 ], yval [1 ]])
30+ axs [0 ,0 ].set_title ('A' if not titleA else titleA )
31+ axs [0 ,0 ].set_xticks ([])
32+ axs [0 ,0 ].set_yticks ([])
33+ axs [0 ,1 ].imshow (predictB [i ].transpose (),vmin = 0.0 ,vmax = 0.2 ,origin = 'lower' ,
34+ aspect = 'auto' , extent = [xval [0 ], xval [1 ], yval [0 ], yval [1 ]])
35+ # axs[0,1].set_xlabel('$\ln(1 / \Delta_{ab})$')
36+ # axs[0,1].set_ylabel('$\ln(k_{t} / \mathrm{GeV})$',labelpad=-2)
37+ axs [0 ,1 ].set_title ('B' if not titleB else titleB )
38+ axs [0 ,1 ].set_xticks ([])
39+ axs [0 ,1 ].set_yticks ([])
40+ axs [0 ,2 ].imshow (predictB2 [i ].transpose (),vmin = 0.0 ,vmax = 0.2 ,origin = 'lower' ,
41+ aspect = 'auto' , extent = [xval [0 ], xval [1 ], yval [0 ], yval [1 ]])
42+ axs [0 ,2 ].set_title ('A' if not titleA else titleA )
43+
44+ axs [0 ,2 ].set_xticks ([])
45+ axs [0 ,2 ].set_yticks ([])
46+ axs [1 ,0 ].imshow (averager .inverse (refA )[i ].transpose (),
47+ vmin = 0.0 ,vmax = 0.2 ,origin = 'lower' ,aspect = 'auto' ,
48+ extent = [xval [0 ], xval [1 ], yval [0 ], yval [1 ]])
49+ #axs[1,0].set_title('A' if not titleA else titleA)
50+ axs [1 ,0 ].set_xticks ([])
51+ axs [1 ,0 ].set_yticks ([])
52+ axs [1 ,1 ].imshow (averager .inverse (predictB )[i ].transpose (),
53+ vmin = 0.0 ,vmax = 0.2 ,origin = 'lower' ,aspect = 'auto' ,
54+ extent = [xval [0 ], xval [1 ], yval [0 ], yval [1 ]])
55+ # axs[0,1].set_xlabel('$\ln(1 / \Delta_{ab})$')
56+ # axs[0,1].set_ylabel('$\ln(k_{t} / \mathrm{GeV})$',labelpad=-2)
57+ #axs[1,1].set_title('B' if not titleB else titleB)
58+ axs [1 ,1 ].set_xticks ([])
59+ axs [1 ,1 ].set_yticks ([])
60+ axs [1 ,2 ].imshow (averager .inverse (predictB2 )[i ].transpose (),
61+ vmin = 0.0 ,vmax = 0.2 ,origin = 'lower' ,aspect = 'auto' ,
62+ extent = [xval [0 ], xval [1 ], yval [0 ], yval [1 ]])
63+ #axs[1,2].set_title('A' if not titleA else titleA)
64+ axs [1 ,2 ].set_xticks ([])
65+ axs [1 ,2 ].set_yticks ([])
66+ # pdf.savefig()
67+ # plt.close()
68+
69+ # fig, axs = plt.subplots(3, 2, figsize=(6,8))
70+ plt .close ()
71+ pdf .savefig (fig )
72+
973
1074def main (args ):
11- if os .path .isfile (args .model .strip ('/' )+ '/best-model.yaml' ):
12- fn = args .model .strip ('/' )+ '/best-model.yaml'
75+ model = args .model .strip ('/' )
76+ if os .path .isfile (model + '/best-model.yaml' ):
77+ fn = model + '/best-model.yaml'
1378 else :
14- fn = args . model . strip ( '/' ) + '/input-runcard.json'
79+ fn = model + '/input-runcard.json'
1580 hps = load_yaml (fn )
1681 cgan = CycleGAN (hps )
17- cgan .load (args . model . strip ( '/' ) )
82+ cgan .load (model )
1883 refA = np .array (cgan .imagesA )
1984 refB = np .array (cgan .imagesB )
2085 # generating predicted sample
2186 predictA = cgan .g_BA .predict (refA )
2287 predictB = cgan .g_AB .predict (refB )
88+ predictA2 = cgan .g_AB .predict (predictA )
89+ predictB2 = cgan .g_BA .predict (predictB )
2390 refA = cgan .preproc .inverse (refA )
2491 refB = cgan .preproc .inverse (refB )
2592 predictA = cgan .preproc .inverse (predictA )
2693 predictB = cgan .preproc .inverse (predictB )
2794 if args .savefull :
28- np .save ('%s/referenceA' % args . model . strip ( '/' ) , refA )
29- np .save ('%s/referenceB' % args . model . strip ( '/' ) , refB )
30- np .save ('%s/predictedA' % args . model . strip ( '/' ) , predictA )
31- np .save ('%s/predictedB' % args . model . strip ( '/' ) , predictB )
95+ np .save ('%s/referenceA' % model , refA )
96+ np .save ('%s/referenceB' % model , refB )
97+ np .save ('%s/predictedA' % model , predictA )
98+ np .save ('%s/predictedB' % model , predictB )
3299
33- # now create diagnostic plots
34- figfn = '%s/result.pdf' % args .model .strip ('/' )
35- plot_model (figfn , refA , refB , predictA , predictB ,
100+ # now create plots
101+ figfn1 = '%s/result.pdf' % model
102+ plot_model (figfn1 , refA , refB , predictA , predictB ,
103+ titleA = args .titleA , titleB = args .titleB )
104+ figfn2 = '%s/result_event.pdf' % model
105+ plot_event (figfn2 , refA , refB , predictA , predictB ,
106+ predictA , predictB , cgan .avg ,
36107 titleA = args .titleA , titleB = args .titleB )
37108
38109#----------------------------------------------------------------------
@@ -46,7 +117,6 @@ def main(args):
46117 help = 'Title of sample A.' )
47118 parser .add_argument ('--titleB' , type = str , default = None ,
48119 help = 'Title of sample A.' )
49-
50120 parser .add_argument ('--savefull' , action = 'store_true' )
51121 args = parser .parse_args ()
52122 main (args )
0 commit comments