@@ -38,13 +38,18 @@ def send_me_to_gpu(module, _):
38
38
# see below for register_forward_pre_hook;
39
39
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
40
40
# useless here, and we just replace those methods
41
- def first_stage_model_encode_wrap (self , encoder , x ):
42
- send_me_to_gpu (self , None )
43
- return encoder (x )
44
41
45
- def first_stage_model_decode_wrap (self , decoder , z ):
46
- send_me_to_gpu (self , None )
47
- return decoder (z )
42
+ first_stage_model = sd_model .first_stage_model
43
+ first_stage_model_encode = sd_model .first_stage_model .encode
44
+ first_stage_model_decode = sd_model .first_stage_model .decode
45
+
46
+ def first_stage_model_encode_wrap (x ):
47
+ send_me_to_gpu (first_stage_model , None )
48
+ return first_stage_model_encode (x )
49
+
50
+ def first_stage_model_decode_wrap (z ):
51
+ send_me_to_gpu (first_stage_model , None )
52
+ return first_stage_model_decode (z )
48
53
49
54
# remove three big modules, cond, first_stage, and unet from the model and then
50
55
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def first_stage_model_decode_wrap(self, decoder, z):
56
61
# register hooks for those the first two models
57
62
sd_model .cond_stage_model .transformer .register_forward_pre_hook (send_me_to_gpu )
58
63
sd_model .first_stage_model .register_forward_pre_hook (send_me_to_gpu )
59
- sd_model .first_stage_model .encode = lambda x , en = sd_model . first_stage_model . encode : first_stage_model_encode_wrap ( sd_model . first_stage_model , en , x )
60
- sd_model .first_stage_model .decode = lambda z , de = sd_model . first_stage_model . decode : first_stage_model_decode_wrap ( sd_model . first_stage_model , de , z )
64
+ sd_model .first_stage_model .encode = first_stage_model_encode_wrap
65
+ sd_model .first_stage_model .decode = first_stage_model_decode_wrap
61
66
parents [sd_model .cond_stage_model .transformer ] = sd_model .cond_stage_model
62
67
63
68
if use_medvram :
0 commit comments