Skip to content

Commit bfb0c98

Browse files
committed
removed plotSimple
1 parent fc8b2bb commit bfb0c98

File tree

1 file changed

+3
-22
lines changed

1 file changed

+3
-22
lines changed

XPointMLTest.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,16 @@ def expand_xpoints_mask(binary_mask, kernel_size=9):
5959

6060
return expanded_mask
6161

62-
def plotSimple(arr, outfile):
63-
plt.imshow(arr, interpolation="nearest", origin="upper")
64-
plt.colorbar()
65-
plt.savefig(outfile)
66-
plt.clf()
67-
6862
def rotate(frameData,deg):
6963
if deg not in [90, 180, 270]:
7064
print(f"invalid rotation specified... exiting")
7165
sys.exit()
66+
7267
psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR)
7368
all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR)
74-
75-
plotSimple(all[0], f"{frameData['fnum']}_rotation{deg}_all0.png")
76-
plotSimple(all[1], f"{frameData['fnum']}_rotation{deg}_all1.png")
77-
plotSimple(all[2], f"{frameData['fnum']}_rotation{deg}_all2.png")
78-
plotSimple(all[3], f"{frameData['fnum']}_rotation{deg}_all3.png")
79-
69+
# For mask, use nearest neighbor interpolation to preserve binary values
8070
mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST)
71+
8172
return {
8273
"fnum": frameData["fnum"],
8374
"rotation": deg,
@@ -100,12 +91,6 @@ def reflect(frameData,axis):
10091
all = torch.flip(frameData["all"], dims=(axis+1,))
10192
mask = torch.flip(frameData["mask"], dims=(axis+1,))
10293

103-
plotSimple(all[0], f"{frameData['fnum']}_reflectionAxis{axis}_all0.png")
104-
plotSimple(all[1], f"{frameData['fnum']}_reflectionAxis{axis}_all1.png")
105-
plotSimple(all[2], f"{frameData['fnum']}_reflectionAxis{axis}_all2.png")
106-
plotSimple(all[3], f"{frameData['fnum']}_reflectionAxis{axis}_all3.png")
107-
108-
10994
return {
11095
"fnum": frameData["fnum"],
11196
"rotation": 0,
@@ -278,10 +263,6 @@ def load(self, fnum):
278263
by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0)
279264
jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0)
280265
all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny]
281-
plotSimple(all_torch[0], f"{fnum}_all0.png")
282-
plotSimple(all_torch[1], f"{fnum}_all1.png")
283-
plotSimple(all_torch[2], f"{fnum}_all2.png")
284-
plotSimple(all_torch[3], f"{fnum}_all3.png")
285266
mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny]
286267

287268
if self.verbosity > 0:

0 commit comments

Comments
 (0)