Skip to content

Commit 51b9c6f

Browse files
committed
add runtime check
1 parent c9f29b5 commit 51b9c6f

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

notebooks/02_inference_review.ipynb

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8387,7 +8387,9 @@
83878387
"source": [
83888388
"subset = tokenized_datasets[\"test\"].num_rows\n",
83898389
"\n",
8390-
"if torch.cuda.device_count() <=1:\n",
8390+
"if torch.cuda.device_count() ==0:\n",
8391+
" subset = 20\n",
8392+
"elif torch.cuda.device_count() ==1:\n",
83918393
" subset = 1_000\n",
83928394
"\n",
83938395
"test_dataset = tokenized_datasets[\"test\"].shuffle(42).select(range(subset)) \n",
@@ -8678,7 +8680,7 @@
86788680
}
86798681
],
86808682
"source": [
8681-
"prediction_batch(model,test_dataset.select(range(1_000)),device=device )"
8683+
"prediction_batch(model,test_dataset, device=device )"
86828684
]
86838685
},
86848686
{
@@ -8745,11 +8747,13 @@
87458747
],
87468748
"source": [
87478749
"%%timeit -r 3 -n 5\n",
8748-
"device_current = 'cuda'\n",
8749-
"model.to(device_current)\n",
87508750
"\n",
8751-
"model( input_ids =res['input_ids'].to(device_current)\n",
8752-
" , attention_mask =res['attention_mask'].to(device_current) )"
8751+
"if torch.cuda.is_available():\n",
8752+
" device_current = 'cuda'\n",
8753+
" model.to(device_current)\n",
8754+
"\n",
8755+
" model( input_ids =res['input_ids'].to(device_current)\n",
8756+
" , attention_mask =res['attention_mask'].to(device_current) )"
87538757
]
87548758
},
87558759
{

0 commit comments

Comments
 (0)