@@ -94,13 +94,13 @@ def init_tensors(self, ctx: DenoiseContext):
94
94
generator = torch .Generator (device = "cpu" ).manual_seed (ctx .seed ),
95
95
).to (device = ctx .latents .device , dtype = ctx .latents .dtype )
96
96
97
- # TODO: order value
97
+ # Use negative order to make extensions with default order work with patched latents
98
98
@callback (ExtensionCallbackType .PRE_STEP , order = - 100 )
99
99
def apply_mask_to_initial_latents (self , ctx : DenoiseContext ):
100
100
ctx .latents = self ._apply_mask (ctx , ctx .latents , ctx .timestep )
101
101
102
- # TODO: order value
103
102
# TODO: redo this with preview events rewrite
103
+ # Use negative order to make extensions with default order work with patched latents
104
104
@callback (ExtensionCallbackType .POST_STEP , order = - 100 )
105
105
def apply_mask_to_step_output (self , ctx : DenoiseContext ):
106
106
timestep = ctx .scheduler .timesteps [- 1 ]
@@ -111,8 +111,7 @@ def apply_mask_to_step_output(self, ctx: DenoiseContext):
111
111
else :
112
112
ctx .step_output .pred_original_sample = self ._apply_mask (ctx , ctx .step_output .prev_sample , timestep )
113
113
114
- # TODO: should here be used order?
115
- # restore unmasked part after the last step is completed
114
+ # Restore unmasked part after the last step is completed
116
115
@callback (ExtensionCallbackType .POST_DENOISE_LOOP )
117
116
def restore_unmasked (self , ctx : DenoiseContext ):
118
117
if self ._is_gradient_mask :
0 commit comments