Skip to content

Commit 8116e73

Browse files
add device arg to load_model
1 parent c82fd36 commit 8116e73

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tutorials/notebooks/task_notebooks/pytorch/example_deeplabv3p_pytorch_mixed_precision_ptq.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,10 @@
216216
"metadata": {},
217217
"outputs": [],
218218
"source": [
219+
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
219220
"model_builder = network.modeling.__dict__[\"deeplabv3plus_mobilenet\"]\n",
220221
"float_model = model_builder(output_stride=16)\n",
221-
"float_model.load_state_dict(torch.load( \"best_deeplabv3plus_mobilenet_voc_os16.pth\", weights_only=False)['model_state'])"
222+
"float_model.load_state_dict(torch.load( \"best_deeplabv3plus_mobilenet_voc_os16.pth\", weights_only=False, map_location=device)['model_state'])"
222223
]
223224
},
224225
{
@@ -532,7 +533,7 @@
532533
],
533534
"metadata": {
534535
"kernelspec": {
535-
"display_name": "torch11 (3.10.12)",
536+
"display_name": "py310-deeplabv3ptest (3.10.12)",
536537
"language": "python",
537538
"name": "python3"
538539
},

0 commit comments

Comments
 (0)