File tree Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change 4444
4545 # create MODNet and load the pre-trained ckpt
4646 modnet = MODNet (backbone_pretrained = False )
47- modnet = nn .DataParallel (modnet ).cuda ()
48- modnet .load_state_dict (torch .load (args .ckpt_path ))
47+ modnet = nn .DataParallel (modnet )
48+
49+ if torch .cuda .is_available ():
50+ modnet = modnet .cuda ()
51+ weights = torch .load (args .ckpt_path )
52+ else :
53+ weights = torch .load (args .ckpt_path , map_location = torch .device ('cpu' ))
54+ modnet .load_state_dict (weights )
4955 modnet .eval ()
5056
5157 # inference images
9096 im = F .interpolate (im , size = (im_rh , im_rw ), mode = 'area' )
9197
9298 # inference
93- _ , _ , matte = modnet (im .cuda (), True )
99+ _ , _ , matte = modnet (im .cuda () if torch . cuda . is_available () else im , True )
94100
95101 # resize and save matte
96102 matte = F .interpolate (matte , size = (im_h , im_w ), mode = 'area' )
You can’t perform that action at this time.
0 commit comments