@@ -1360,6 +1360,8 @@ def _test_attention_slicing_forward_pass(
13601360 reason = "CPU offload is only available with CUDA and `accelerate v0.14.0` or higher" ,
13611361 )
13621362 def test_sequential_cpu_offload_forward_pass (self , expected_max_diff = 1e-4 ):
1363+ import accelerate
1364+
13631365 components = self .get_dummy_components ()
13641366 pipe = self .pipeline_class (** components )
13651367 for component in pipe .components .values ():
@@ -1373,18 +1375,56 @@ def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
13731375 output_without_offload = pipe (** inputs )[0 ]
13741376
13751377 pipe .enable_sequential_cpu_offload ()
1378+ assert pipe ._execution_device .type == pipe ._offload_device .type
13761379
13771380 inputs = self .get_dummy_inputs (generator_device )
13781381 output_with_offload = pipe (** inputs )[0 ]
13791382
13801383 max_diff = np .abs (to_np (output_with_offload ) - to_np (output_without_offload )).max ()
13811384 self .assertLess (max_diff , expected_max_diff , "CPU offloading should not affect the inference results" )
13821385
1386+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1387+ offloaded_modules = {
1388+ k : v
1389+ for k , v in pipe .components .items ()
1390+ if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1391+ }
1392+ # 1. all offloaded modules should be saved to cpu and moved to meta device
1393+ self .assertTrue (
1394+ all (v .device .type == "meta" for v in offloaded_modules .values ()),
1395+ f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'meta' ]} " ,
1396+ )
1397+ # 2. all offloaded modules should have hook installed
1398+ self .assertTrue (
1399+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1400+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1401+ )
1402+ # 3. all offloaded modules should have correct hooks installed, should be either one of these two
1403+ # - `AlignDevicesHook`
1404+ # - a SequentialHook` that contains `AlignDevicesHook`
1405+ offloaded_modules_with_incorrect_hooks = {}
1406+ for k , v in offloaded_modules .items ():
1407+ if hasattr (v , "_hf_hook" ):
1408+ if isinstance (v ._hf_hook , accelerate .hooks .SequentialHook ):
1409+ # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
1410+ for hook in v ._hf_hook .hooks :
1411+ if not isinstance (hook , accelerate .hooks .AlignDevicesHook ):
1412+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook .hooks [0 ])
1413+ elif not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1414+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1415+
1416+ self .assertTrue (
1417+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1418+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
1419+ )
1420+
13831421 @unittest .skipIf (
13841422 torch_device != "cuda" or not is_accelerate_available () or is_accelerate_version ("<" , "0.17.0" ),
13851423 reason = "CPU offload is only available with CUDA and `accelerate v0.17.0` or higher" ,
13861424 )
13871425 def test_model_cpu_offload_forward_pass (self , expected_max_diff = 2e-4 ):
1426+ import accelerate
1427+
13881428 generator_device = "cpu"
13891429 components = self .get_dummy_components ()
13901430 pipe = self .pipeline_class (** components )
@@ -1400,19 +1440,39 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
14001440 output_without_offload = pipe (** inputs )[0 ]
14011441
14021442 pipe .enable_model_cpu_offload ()
1443+ assert pipe ._execution_device .type == pipe ._offload_device .type
1444+
14031445 inputs = self .get_dummy_inputs (generator_device )
14041446 output_with_offload = pipe (** inputs )[0 ]
14051447
14061448 max_diff = np .abs (to_np (output_with_offload ) - to_np (output_without_offload )).max ()
14071449 self .assertLess (max_diff , expected_max_diff , "CPU offloading should not affect the inference results" )
1408- offloaded_modules = [
1409- v
1450+
1451+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
1452+ offloaded_modules = {
1453+ k : v
14101454 for k , v in pipe .components .items ()
14111455 if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1412- ]
1413- (
1414- self .assertTrue (all (v .device .type == "cpu" for v in offloaded_modules )),
1415- f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'cpu' ]} " ,
1456+ }
1457+ # 1. check if all offloaded modules are saved to cpu
1458+ self .assertTrue (
1459+ all (v .device .type == "cpu" for v in offloaded_modules .values ()),
1460+ f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'cpu' ]} " ,
1461+ )
1462+ # 2. check if all offloaded modules have hooks installed
1463+ self .assertTrue (
1464+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1465+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1466+ )
1467+ # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
1468+ offloaded_modules_with_incorrect_hooks = {}
1469+ for k , v in offloaded_modules .items ():
1470+ if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .CpuOffload ):
1471+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1472+
1473+ self .assertTrue (
1474+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1475+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
14161476 )
14171477
14181478 @unittest .skipIf (
@@ -1444,16 +1504,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
14441504 self .assertLess (
14451505 max_diff , expected_max_diff , "running CPU offloading 2nd time should not affect the inference results"
14461506 )
1507+
1508+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
14471509 offloaded_modules = {
14481510 k : v
14491511 for k , v in pipe .components .items ()
14501512 if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
14511513 }
1514+ # 1. check if all offloaded modules are saved to cpu
14521515 self .assertTrue (
14531516 all (v .device .type == "cpu" for v in offloaded_modules .values ()),
14541517 f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'cpu' ]} " ,
14551518 )
1456-
1519+ # 2. check if all offloaded modules have hooks installed
1520+ self .assertTrue (
1521+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1522+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1523+ )
1524+ # 3. check if all offloaded modules have correct type of hooks installed, should be `CpuOffload`
14571525 offloaded_modules_with_incorrect_hooks = {}
14581526 for k , v in offloaded_modules .items ():
14591527 if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .CpuOffload ):
@@ -1493,19 +1561,36 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
14931561 self .assertLess (
14941562 max_diff , expected_max_diff , "running sequential offloading second time should have the inference results"
14951563 )
1564+
1565+ # make sure all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are offloaded correctly
14961566 offloaded_modules = {
14971567 k : v
14981568 for k , v in pipe .components .items ()
14991569 if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
15001570 }
1571+ # 1. check if all offloaded modules are moved to meta device
15011572 self .assertTrue (
15021573 all (v .device .type == "meta" for v in offloaded_modules .values ()),
15031574 f"Not offloaded: { [k for k , v in offloaded_modules .items () if v .device .type != 'meta' ]} " ,
15041575 )
1576+ # 2. check if all offloaded modules have hook installed
1577+ self .assertTrue (
1578+ all (hasattr (v , "_hf_hook" ) for k , v in offloaded_modules .items ()),
1579+ f"No hook attached: { [k for k , v in offloaded_modules .items () if not hasattr (v , '_hf_hook' )]} " ,
1580+ )
1581+ # 3. check if all offloaded modules have correct hooks installed, should be either one of these two
1582+ # - `AlignDevicesHook`
1583+ # - a SequentialHook` that contains `AlignDevicesHook`
15051584 offloaded_modules_with_incorrect_hooks = {}
15061585 for k , v in offloaded_modules .items ():
1507- if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1508- offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1586+ if hasattr (v , "_hf_hook" ):
1587+ if isinstance (v ._hf_hook , accelerate .hooks .SequentialHook ):
1588+ # if it is a `SequentialHook`, we loop through its `hooks` attribute to check if it only contains `AlignDevicesHook`
1589+ for hook in v ._hf_hook .hooks :
1590+ if not isinstance (hook , accelerate .hooks .AlignDevicesHook ):
1591+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook .hooks [0 ])
1592+ elif not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1593+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
15091594
15101595 self .assertTrue (
15111596 len (offloaded_modules_with_incorrect_hooks ) == 0 ,
0 commit comments