@@ -63,9 +63,12 @@ def rotate(frameData,deg):
6363 if deg not in [90 , 180 , 270 ]:
6464 print (f"invalid rotation specified... exiting" )
6565 sys .exit ()
66+
6667 psi = v2 .functional .rotate (frameData ["psi" ], deg , v2 .InterpolationMode .BILINEAR )
6768 all = v2 .functional .rotate (frameData ["all" ], deg , v2 .InterpolationMode .BILINEAR )
68- mask = v2 .functional .rotate (frameData ["mask" ], deg , v2 .InterpolationMode .BILINEAR )
69+ # For mask, use nearest neighbor interpolation to preserve binary values
70+ mask = v2 .functional .rotate (frameData ["mask" ], deg , v2 .InterpolationMode .NEAREST )
71+
6972 return {
7073 "fnum" : frameData ["fnum" ],
7174 "rotation" : deg ,
@@ -83,9 +86,11 @@ def reflect(frameData,axis):
8386 if axis not in [0 ,1 ]:
8487 print (f"invalid reflection axis specified... exiting" )
8588 sys .exit ()
86- psi = torch .flip (frameData ["psi" ][0 ], dims = (axis ,)).unsqueeze (0 )
87- all = torch .flip (frameData ["all" ], dims = (axis ,))
88- mask = torch .flip (frameData ["mask" ][0 ], dims = (axis ,)).unsqueeze (0 )
89+
90+ psi = torch .flip (frameData ["psi" ], dims = (axis + 1 ,))
91+ all = torch .flip (frameData ["all" ], dims = (axis + 1 ,))
92+ mask = torch .flip (frameData ["mask" ], dims = (axis + 1 ,))
93+
8994 return {
9095 "fnum" : frameData ["fnum" ],
9196 "rotation" : 0 ,
@@ -844,9 +849,6 @@ def main():
844849 t2 = timer ()
845850 print ("time (s) to prepare model: " + str (t2 - t1 ))
846851
847- train_loss = []
848- val_loss = []
849-
850852 num_epochs = args .epochs
851853 for epoch in range (start_epoch , num_epochs ):
852854 train_loss .append (train_one_epoch (model , train_loader , criterion , optimizer , device ))
@@ -942,4 +944,4 @@ def main():
942944 print ("total time (s): " + str (t5 - t0 ))
943945
944946if __name__ == "__main__" :
945- main ()
947+ main ()
0 commit comments