@@ -457,14 +457,20 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps
457457 opts .emphasis ,
458458 )
459459
460- def apply_generation_params_states (self , generation_params_states ):
460+ def apply_generation_params_list (self , generation_params_states ):
461461 """add and apply generation_params_states to self.extra_generation_params"""
462462 for key , value in generation_params_states .items ():
463463 if key in self .extra_generation_params and isinstance (current_value := self .extra_generation_params [key ], util .GenerationParametersList ):
464464 self .extra_generation_params [key ] = current_value + value
465465 else :
466466 self .extra_generation_params [key ] = value
467467
468+ def clear_marked_generation_params (self ):
469+ """clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
470+ for key , value in list (self .extra_generation_params .items ()):
471+ if getattr (value , 'to_be_clear_before_batch' , False ):
472+ self .extra_generation_params .pop (key )
473+
468474 def get_conds_with_caching (self , function , required_prompts , steps , caches , extra_network_data , hires_steps = None ):
469475 """
470476 Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -491,7 +497,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
491497 if len (cache ) == 3 :
492498 generation_params_states , cached_cached_params = cache [2 ]
493499 if cached_params == cached_cached_params :
494- self .apply_generation_params_states (generation_params_states )
500+ self .apply_generation_params_list (generation_params_states )
495501 return cache [1 ]
496502
497503 cache = caches [0 ]
@@ -500,7 +506,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
500506 cache [1 ] = function (shared .sd_model , required_prompts , steps , hires_steps , shared .opts .use_old_scheduling )
501507
502508 generation_params_states = model_hijack .extract_generation_params_states ()
503- self .apply_generation_params_states (generation_params_states )
509+ self .apply_generation_params_list (generation_params_states )
504510 if len (cache ) == 2 :
505511 cache .append ((generation_params_states , cached_params ))
506512 else :
@@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
959965 if state .interrupted or state .stopping_generation :
960966 break
961967
968+ p .clear_marked_generation_params () # clean up some generation params are tagged to be cleared before batch
962969 sd_models .reload_model_weights () # model can be changed for example by refiner
963970
964971 p .prompts = p .all_prompts [n * p .batch_size :(n + 1 ) * p .batch_size ]
0 commit comments