Skip to content

Commit 6fbc772

Browse files
committed
fixed bugs in rotate and reflect functions
1 parent 4bbf134 commit 6fbc772

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

XPointMLTest.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def rotate(frameData,deg):
7777
plotSimple(all[2], f"{frameData['fnum']}_rotation{deg}_all2.png")
7878
plotSimple(all[3], f"{frameData['fnum']}_rotation{deg}_all3.png")
7979

80-
mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR)
80+
# mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR)
81+
mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST)
8182
return {
8283
"fnum": frameData["fnum"],
8384
"rotation": deg,
@@ -95,13 +96,21 @@ def reflect(frameData,axis):
9596
if axis not in [0,1]:
9697
print(f"invalid reflection axis specified... exiting")
9798
sys.exit()
98-
psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0)
99-
all = torch.flip(frameData["all"], dims=(axis,))
99+
100+
# psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0)
101+
# all = torch.flip(frameData["all"], dims=(axis,))
102+
# mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0)
103+
104+
psi = torch.flip(frameData["psi"], dims=(axis+1,))
105+
all = torch.flip(frameData["all"], dims=(axis+1,))
106+
mask = torch.flip(frameData["mask"], dims=(axis+1,))
107+
100108
plotSimple(all[0], f"{frameData['fnum']}_reflectionAxis{axis}_all0.png")
101109
plotSimple(all[1], f"{frameData['fnum']}_reflectionAxis{axis}_all1.png")
102110
plotSimple(all[2], f"{frameData['fnum']}_reflectionAxis{axis}_all2.png")
103111
plotSimple(all[3], f"{frameData['fnum']}_reflectionAxis{axis}_all3.png")
104-
mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0)
112+
113+
105114
return {
106115
"fnum": frameData["fnum"],
107116
"rotation": 0,

0 commit comments

Comments
 (0)