Skip to content

Commit c78cf64

Browse files
turn python values into scalar tensors
1 parent a09fffc commit c78cf64

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

core/raft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def upsample_flow(self, flow, mask):
8383
return up_flow.reshape(N, 2, 8*H, 8*W)
8484

8585

86-
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
86+
def forward(self, image1, image2, iters=torch.tensor(12), flow_init=torch.tensor([]), upsample=torch.tensor(True), test_mode=torch.tensor(False)):
8787
""" Estimate optical flow between pair of frames """
8888

8989
image1 = 2 * (image1 / 255.0) - 1.0
@@ -115,7 +115,7 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
115115

116116
coords0, coords1 = self.initialize_flow(image1)
117117

118-
if flow_init is not None:
118+
if flow_init is not None and flow_init.numel()>0:
119119
coords1 = coords1 + flow_init
120120

121121
flow_predictions = []

demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def demo(args):
6161
padder = InputPadder(image1.shape)
6262
image1, image2 = padder.pad(image1, image2)
6363

64-
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
64+
flow_low, flow_up = model(image1, image2, iters=torch.tensor(20), test_mode=torch.tensor(True))
6565
viz(image1, flow_up)
6666

6767

0 commit comments

Comments
 (0)