1414 ResidualBlock ,
1515 TFSamepaddingLayer ,
1616)
17+ from tiatoolbox .utils .misc import select_device
1718from tiatoolbox .wsicore .wsireader import WSIReader
1819
1920
@@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None:
3435 weights_path = fetch_pretrained_weights ("hovernet_fast-pannuke" )
3536 pretrained = torch .load (weights_path )
3637 model .load_state_dict (pretrained )
37- output = model .infer_batch (model , batch , on_gpu = False )
38+ output = model .infer_batch (model , batch , device = select_device ( on_gpu = False ) )
3839 output = [v [0 ] for v in output ]
3940 output = model .postproc (output )
4041 assert len (output [1 ]) > 0 , "Must have some nuclei."
@@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None:
5152 weights_path = fetch_pretrained_weights ("hovernet_fast-monusac" )
5253 pretrained = torch .load (weights_path )
5354 model .load_state_dict (pretrained )
54- output = model .infer_batch (model , batch , on_gpu = False )
55+ output = model .infer_batch (model , batch , device = select_device ( on_gpu = False ) )
5556 output = [v [0 ] for v in output ]
5657 output = model .postproc (output )
5758 assert len (output [1 ]) > 0 , "Must have some nuclei."
@@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None:
6869 weights_path = fetch_pretrained_weights ("hovernet_original-consep" )
6970 pretrained = torch .load (weights_path )
7071 model .load_state_dict (pretrained )
71- output = model .infer_batch (model , batch , on_gpu = False )
72+ output = model .infer_batch (model , batch , device = select_device ( on_gpu = False ) )
7273 output = [v [0 ] for v in output ]
7374 output = model .postproc (output )
7475 assert len (output [1 ]) > 0 , "Must have some nuclei."
@@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None:
8586 weights_path = fetch_pretrained_weights ("hovernet_original-kumar" )
8687 pretrained = torch .load (weights_path )
8788 model .load_state_dict (pretrained )
88- output = model .infer_batch (model , batch , on_gpu = False )
89+ output = model .infer_batch (model , batch , device = select_device ( on_gpu = False ) )
8990 output = [v [0 ] for v in output ]
9091 output = model .postproc (output )
9192 assert len (output [1 ]) > 0 , "Must have some nuclei."
0 commit comments