diff --git a/denseformer/denseformer.py b/denseformer/denseformer.py index 83b3867..1144abd 100644 --- a/denseformer/denseformer.py +++ b/denseformer/denseformer.py @@ -7,7 +7,7 @@ class InPlaceSetSlice(torch.autograd.Function): def forward(ctx, full_tensor, last_slice, x_idx, x_val): full_tensor[x_idx] = x_val ctx.x_idx = x_idx - ret = torch.Tensor().to(full_tensor.device) + ret = torch.tensor([], dtype=full_tensor.dtype, device=full_tensor.device) ret.set_(full_tensor[:x_idx + 1]) return ret