Skip to content

Commit b628860

Browse files
committed
store psi and mask once per frame
1 parent 8ec9b84 commit b628860

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

XPointMLTest.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,6 @@ def load(self, fnum):
223223
"fnum": fnum,
224224
"psi": psi_torch, # shape [1, Nx, Ny]
225225
"mask": mask_torch, # shape [1, Nx, Ny] // Used in: psi, mask = batch["psi"].to(device), batch["mask"].to(device)
226-
"psi_np": psi, # 2D np array [Nx, Ny]
227-
"mask_np": binaryMap, # 2D np array [Nx, Ny]
228226
"x": x,
229227
"y": y,
230228
"filenameBase": tmp.filenameBase,
@@ -696,8 +694,8 @@ def main():
696694
for item in set:
697695
# item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params
698696
fnum = item["fnum"]
699-
psi_np = item["psi_np"]
700-
mask_gt = item["mask_np"]
697+
psi_np = np.array(item["psi"])[0]
698+
mask_gt = np.array(item["mask"])[0]
701699
x = item["x"]
702700
y = item["y"]
703701
filenameBase = item["filenameBase"]

0 commit comments

Comments
 (0)