Skip to content

Commit 624e08f

Browse files
authored
Merge pull request #12 from SCOREC/cws/debugRotation
fixes in rotate and reflect functions
2 parents 9e3ce78 + bfb0c98 commit 624e08f

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

XPointMLTest.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,12 @@ def rotate(frameData,deg):
6363
if deg not in [90, 180, 270]:
6464
print(f"invalid rotation specified... exiting")
6565
sys.exit()
66+
6667
psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR)
6768
all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR)
68-
mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR)
69+
# For mask, use nearest neighbor interpolation to preserve binary values
70+
mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST)
71+
6972
return {
7073
"fnum": frameData["fnum"],
7174
"rotation": deg,
@@ -83,9 +86,11 @@ def reflect(frameData,axis):
8386
if axis not in [0,1]:
8487
print(f"invalid reflection axis specified... exiting")
8588
sys.exit()
86-
psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0)
87-
all = torch.flip(frameData["all"], dims=(axis,))
88-
mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0)
89+
90+
psi = torch.flip(frameData["psi"], dims=(axis+1,))
91+
all = torch.flip(frameData["all"], dims=(axis+1,))
92+
mask = torch.flip(frameData["mask"], dims=(axis+1,))
93+
8994
return {
9095
"fnum": frameData["fnum"],
9196
"rotation": 0,
@@ -844,9 +849,6 @@ def main():
844849
t2 = timer()
845850
print("time (s) to prepare model: " + str(t2-t1))
846851

847-
train_loss = []
848-
val_loss = []
849-
850852
num_epochs = args.epochs
851853
for epoch in range(start_epoch, num_epochs):
852854
train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device))
@@ -942,4 +944,4 @@ def main():
942944
print("total time (s): " + str(t5-t0))
943945

944946
if __name__ == "__main__":
945-
main()
947+
main()

0 commit comments

Comments
 (0)