@@ -59,25 +59,16 @@ def expand_xpoints_mask(binary_mask, kernel_size=9):
5959
6060 return expanded_mask
6161
62- def plotSimple (arr , outfile ):
63- plt .imshow (arr , interpolation = "nearest" , origin = "upper" )
64- plt .colorbar ()
65- plt .savefig (outfile )
66- plt .clf ()
67-
6862def rotate (frameData ,deg ):
6963 if deg not in [90 , 180 , 270 ]:
7064 print (f"invalid rotation specified... exiting" )
7165 sys .exit ()
66+
7267 psi = v2 .functional .rotate (frameData ["psi" ], deg , v2 .InterpolationMode .BILINEAR )
7368 all = v2 .functional .rotate (frameData ["all" ], deg , v2 .InterpolationMode .BILINEAR )
74-
75- plotSimple (all [0 ], f"{ frameData ['fnum' ]} _rotation{ deg } _all0.png" )
76- plotSimple (all [1 ], f"{ frameData ['fnum' ]} _rotation{ deg } _all1.png" )
77- plotSimple (all [2 ], f"{ frameData ['fnum' ]} _rotation{ deg } _all2.png" )
78- plotSimple (all [3 ], f"{ frameData ['fnum' ]} _rotation{ deg } _all3.png" )
79-
69+ # For mask, use nearest neighbor interpolation to preserve binary values
8070 mask = v2 .functional .rotate (frameData ["mask" ], deg , v2 .InterpolationMode .NEAREST )
71+
8172 return {
8273 "fnum" : frameData ["fnum" ],
8374 "rotation" : deg ,
@@ -100,12 +91,6 @@ def reflect(frameData,axis):
10091 all = torch .flip (frameData ["all" ], dims = (axis + 1 ,))
10192 mask = torch .flip (frameData ["mask" ], dims = (axis + 1 ,))
10293
103- plotSimple (all [0 ], f"{ frameData ['fnum' ]} _reflectionAxis{ axis } _all0.png" )
104- plotSimple (all [1 ], f"{ frameData ['fnum' ]} _reflectionAxis{ axis } _all1.png" )
105- plotSimple (all [2 ], f"{ frameData ['fnum' ]} _reflectionAxis{ axis } _all2.png" )
106- plotSimple (all [3 ], f"{ frameData ['fnum' ]} _reflectionAxis{ axis } _all3.png" )
107-
108-
10994 return {
11095 "fnum" : frameData ["fnum" ],
11196 "rotation" : 0 ,
@@ -278,10 +263,6 @@ def load(self, fnum):
278263 by_torch = torch .from_numpy (fields ["By" ]).float ().unsqueeze (0 )
279264 jz_torch = torch .from_numpy (fields ["Jz" ]).float ().unsqueeze (0 )
280265 all_torch = torch .cat ((psi_torch ,bx_torch ,by_torch ,jz_torch )) # [4, Nx, Ny]
281- plotSimple (all_torch [0 ], f"{ fnum } _all0.png" )
282- plotSimple (all_torch [1 ], f"{ fnum } _all1.png" )
283- plotSimple (all_torch [2 ], f"{ fnum } _all2.png" )
284- plotSimple (all_torch [3 ], f"{ fnum } _all3.png" )
285266 mask_torch = torch .from_numpy (binaryMap ).float ().unsqueeze (0 ) # [1, Nx, Ny]
286267
287268 if self .verbosity > 0 :
0 commit comments