Skip to content

Commit 8ff4511

Browse files
authored
fix bugs in inference.py
line 39, checkpoint = torch.load(model_dir, map_location='cuda:0')
1 parent 3852400 commit 8ff4511

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/DOSE/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def predict(condition=None, model_dir=None, params=None, device=device, fast_sam
3535
if os.path.exists(f'{model_dir}/weights.pt'):
3636
checkpoint = torch.load(f'{model_dir}/weights.pt')
3737
else:
38-
checkpoint = torch.load(model_dir)
38+
# checkpoint = torch.load(model_dir)
39+
checkpoint = torch.load(model_dir, map_location='cuda:0')
3940
model = DOSE(AttrDict(base_params)).to(device)
4041
model.load_state_dict(checkpoint['model'])
4142
model.eval()
@@ -160,4 +161,4 @@ def main(args):
160161
help='output file name')
161162
parser.add_argument('--fast', '-f', action='store_true',
162163
help='fast sampling procedure')
163-
main(parser.parse_args())
164+
main(parser.parse_args())

0 commit comments

Comments
 (0)