@@ -1144,20 +1144,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
1144
1144
self .assertLess (
1145
1145
max_diff , expected_max_diff , "running CPU offloading 2nd time should not affect the inference results"
1146
1146
)
1147
- offloaded_modules = [
1148
- v
1147
+ offloaded_modules = {
1148
+ k : v
1149
1149
for k , v in pipe .components .items ()
1150
1150
if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1151
- ]
1152
- (
1153
- self . assertTrue ( all (v .device .type == "cpu" for v in offloaded_modules )),
1154
- f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'cpu' ]} " ,
1151
+ }
1152
+ self . assertTrue (
1153
+ all (v .device .type == "cpu" for v in offloaded_modules . values ( )),
1154
+ f"Not offloaded: { [k for k , v in offloaded_modules . items () if v .device .type != 'cpu' ]} " ,
1155
1155
)
1156
1156
1157
- offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr (v , "_hf_hook" )]
1158
- (
1159
- self .assertTrue (all (isinstance (v , accelerate .hooks .CpuOffload ) for v in offloaded_modules_with_hooks )),
1160
- f"Not installed correct hook: { [v for v in offloaded_modules_with_hooks if not isinstance (v , accelerate .hooks .CpuOffload )]} " ,
1157
+ offloaded_modules_with_incorrect_hooks = {}
1158
+ for k , v in offloaded_modules .items ():
1159
+ if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .CpuOffload ):
1160
+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1161
+
1162
+ self .assertTrue (
1163
+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1164
+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
1161
1165
)
1162
1166
1163
1167
@unittest .skipIf (
@@ -1189,22 +1193,23 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
1189
1193
self .assertLess (
1190
1194
max_diff , expected_max_diff , "running sequential offloading second time should have the inference results"
1191
1195
)
1192
- offloaded_modules = [
1193
- v
1196
+ offloaded_modules = {
1197
+ k : v
1194
1198
for k , v in pipe .components .items ()
1195
1199
if isinstance (v , torch .nn .Module ) and k not in pipe ._exclude_from_cpu_offload
1196
- ]
1197
- (
1198
- self . assertTrue ( all (v .device .type == "meta" for v in offloaded_modules )),
1199
- f"Not offloaded: { [v for v in offloaded_modules if v .device .type != 'meta' ]} " ,
1200
+ }
1201
+ self . assertTrue (
1202
+ all (v .device .type == "meta" for v in offloaded_modules . values ( )),
1203
+ f"Not offloaded: { [k for k , v in offloaded_modules . items () if v .device .type != 'meta' ]} " ,
1200
1204
)
1205
+ offloaded_modules_with_incorrect_hooks = {}
1206
+ for k , v in offloaded_modules .items ():
1207
+ if hasattr (v , "_hf_hook" ) and not isinstance (v ._hf_hook , accelerate .hooks .AlignDevicesHook ):
1208
+ offloaded_modules_with_incorrect_hooks [k ] = type (v ._hf_hook )
1201
1209
1202
- offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr (v , "_hf_hook" )]
1203
- (
1204
- self .assertTrue (
1205
- all (isinstance (v , accelerate .hooks .AlignDevicesHook ) for v in offloaded_modules_with_hooks )
1206
- ),
1207
- f"Not installed correct hook: { [v for v in offloaded_modules_with_hooks if not isinstance (v , accelerate .hooks .AlignDevicesHook )]} " ,
1210
+ self .assertTrue (
1211
+ len (offloaded_modules_with_incorrect_hooks ) == 0 ,
1212
+ f"Not installed correct hook: { offloaded_modules_with_incorrect_hooks } " ,
1208
1213
)
1209
1214
1210
1215
@unittest .skipIf (
0 commit comments