File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed
tests/pipelines/animatediff Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -281,7 +281,6 @@ def test_inference_batch_single_identical(
281281 max_diff = np .abs (to_np (output_batch [0 ][0 ]) - to_np (output [0 ][0 ])).max ()
282282 assert max_diff < expected_max_diff
283283
284- @unittest .skipIf (torch_device != "cuda" , reason = "CUDA and CPU are required to switch devices" )
285284 def test_to_device (self ):
286285 components = self .get_dummy_components ()
287286 pipe = self .pipeline_class (** components )
@@ -297,14 +296,14 @@ def test_to_device(self):
297296 output_cpu = pipe (** self .get_dummy_inputs ("cpu" ))[0 ]
298297 self .assertTrue (np .isnan (output_cpu ).sum () == 0 )
299298
300- pipe .to ("cuda" )
299+ pipe .to (torch_device )
301300 model_devices = [
302301 component .device .type for component in pipe .components .values () if hasattr (component , "device" )
303302 ]
304- self .assertTrue (all (device == "cuda" for device in model_devices ))
303+ self .assertTrue (all (device == torch_device for device in model_devices ))
305304
306- output_cuda = pipe (** self .get_dummy_inputs ("cuda" ))[0 ]
307- self .assertTrue (np .isnan (to_np (output_cuda )).sum () == 0 )
305+ output_device = pipe (** self .get_dummy_inputs (torch_device ))[0 ]
306+ self .assertTrue (np .isnan (to_np (output_device )).sum () == 0 )
308307
309308 def test_to_dtype (self ):
310309 components = self .get_dummy_components ()
You can’t perform that action at this time.
0 commit comments