@@ -31,57 +31,57 @@ def __init__(
31
31
resize_mode : str ,
32
32
):
33
33
super ().__init__ ()
34
- self .model = model
35
- self .image = image
36
- self .weight = weight
37
- self .begin_step_percent = begin_step_percent
38
- self .end_step_percent = end_step_percent
39
- self .control_mode = control_mode
40
- self .resize_mode = resize_mode
34
+ self ._model = model
35
+ self ._image = image
36
+ self ._weight = weight
37
+ self ._begin_step_percent = begin_step_percent
38
+ self ._end_step_percent = end_step_percent
39
+ self ._control_mode = control_mode
40
+ self ._resize_mode = resize_mode
41
41
42
- self .image_tensor : Optional [torch .Tensor ] = None
42
+ self ._image_tensor : Optional [torch .Tensor ] = None
43
43
44
44
@contextmanager
45
45
def patch_extension (self , ctx : DenoiseContext ):
46
46
try :
47
- original_processors = self .model .attn_processors
48
- self .model .set_attn_processor (ctx .inputs .attention_processor_cls ())
47
+ original_processors = self ._model .attn_processors
48
+ self ._model .set_attn_processor (ctx .inputs .attention_processor_cls ())
49
49
50
50
yield None
51
51
finally :
52
- self .model .set_attn_processor (original_processors )
52
+ self ._model .set_attn_processor (original_processors )
53
53
54
54
@callback (ExtensionCallbackType .PRE_DENOISE_LOOP )
55
55
def resize_image (self , ctx : DenoiseContext ):
56
56
_ , _ , latent_height , latent_width = ctx .latents .shape
57
57
image_height = latent_height * LATENT_SCALE_FACTOR
58
58
image_width = latent_width * LATENT_SCALE_FACTOR
59
59
60
- self .image_tensor = prepare_control_image (
61
- image = self .image ,
60
+ self ._image_tensor = prepare_control_image (
61
+ image = self ._image ,
62
62
do_classifier_free_guidance = False ,
63
63
width = image_width ,
64
64
height = image_height ,
65
65
# batch_size=batch_size * num_images_per_prompt,
66
66
# num_images_per_prompt=num_images_per_prompt,
67
67
device = ctx .latents .device ,
68
68
dtype = ctx .latents .dtype ,
69
- control_mode = self .control_mode ,
70
- resize_mode = self .resize_mode ,
69
+ control_mode = self ._control_mode ,
70
+ resize_mode = self ._resize_mode ,
71
71
)
72
72
73
73
@callback (ExtensionCallbackType .PRE_UNET )
74
74
def pre_unet_step (self , ctx : DenoiseContext ):
75
75
# skip if model not active in current step
76
76
total_steps = len (ctx .inputs .timesteps )
77
- first_step = math .floor (self .begin_step_percent * total_steps )
78
- last_step = math .ceil (self .end_step_percent * total_steps )
77
+ first_step = math .floor (self ._begin_step_percent * total_steps )
78
+ last_step = math .ceil (self ._end_step_percent * total_steps )
79
79
if ctx .step_index < first_step or ctx .step_index > last_step :
80
80
return
81
81
82
82
# convert mode to internal flags
83
- soft_injection = self .control_mode in ["more_prompt" , "more_control" ]
84
- cfg_injection = self .control_mode in ["more_control" , "unbalanced" ]
83
+ soft_injection = self ._control_mode in ["more_prompt" , "more_control" ]
84
+ cfg_injection = self ._control_mode in ["more_control" , "unbalanced" ]
85
85
86
86
# no negative conditioning in cfg_injection mode
87
87
if cfg_injection :
@@ -117,7 +117,7 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con
117
117
total_steps = len (ctx .inputs .timesteps )
118
118
119
119
model_input = ctx .latent_model_input
120
- image_tensor = self .image_tensor
120
+ image_tensor = self ._image_tensor
121
121
if conditioning_mode == ConditioningMode .Both :
122
122
model_input = torch .cat ([model_input ] * 2 )
123
123
image_tensor = torch .cat ([image_tensor ] * 2 )
@@ -134,7 +134,7 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con
134
134
ctx .inputs .conditioning_data .to_unet_kwargs (cn_unet_kwargs , conditioning_mode = conditioning_mode )
135
135
136
136
# get static weight, or weight corresponding to current step
137
- weight = self .weight
137
+ weight = self ._weight
138
138
if isinstance (weight , list ):
139
139
weight = weight [ctx .step_index ]
140
140
@@ -144,7 +144,7 @@ def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: Con
144
144
tmp_kwargs .pop ("down_intrablock_additional_residuals" , None )
145
145
146
146
# controlnet(s) inference
147
- down_samples , mid_sample = self .model (
147
+ down_samples , mid_sample = self ._model (
148
148
controlnet_cond = image_tensor ,
149
149
conditioning_scale = weight , # controlnet specific, NOT the guidance scale
150
150
guess_mode = soft_injection , # this is still called guess_mode in diffusers ControlNetModel
0 commit comments