@@ -95,67 +95,67 @@ def get_hyperparameters(cls, **override_defaults) -> List:
9595 lower = 0 ,
9696 upper = 4096 ,
9797 default_value = numeric_defaults ["training_batch_size" ],
98- meta = dict ( desc = " Number of steps from each diffusion process to use for distillation.") ,
98+ meta = { " desc" : " Number of steps from each diffusion process to use for distillation."} ,
9999 ),
100100 UniformIntegerHyperparameter (
101101 "gradient_accumulation_steps" ,
102102 lower = 1 ,
103103 upper = 1024 ,
104104 default_value = numeric_defaults ["gradient_accumulation_steps" ],
105- meta = dict ( desc = " Number of captions processed to estimate each gradient step.") ,
105+ meta = { " desc" : " Number of captions processed to estimate each gradient step."} ,
106106 ),
107107 UniformIntegerHyperparameter (
108108 "num_epochs" ,
109109 lower = 0 ,
110110 upper = 4096 ,
111111 default_value = numeric_defaults ["num_epochs" ],
112- meta = dict ( desc = " Number of epochs for distillation.") ,
112+ meta = { " desc" : " Number of epochs for distillation."} ,
113113 ),
114114 UniformFloatHyperparameter (
115115 "validate_every_n_epoch" ,
116116 lower = 0.0 ,
117117 upper = 4096.0 ,
118118 default_value = numeric_defaults ["validate_every_n_epoch" ],
119- meta = dict (
120- desc = "Number of epochs between each round of validation and model checkpointing. "
119+ meta = {
120+ " desc" : "Number of epochs between each round of validation and model checkpointing. "
121121 "If the value is between 0 and 1, validation will be performed multiple times per epoch, "
122122 "e.g. 1/8 will result in 8 validations per epoch."
123- ) ,
123+ } ,
124124 ),
125125 UniformFloatHyperparameter (
126126 "learning_rate" ,
127127 lower = 0.0 ,
128128 upper = 1.0 ,
129129 default_value = numeric_defaults ["learning_rate" ],
130- meta = dict ( desc = " Learning rate for distillation.") ,
130+ meta = { " desc" : " Learning rate for distillation."} ,
131131 ),
132132 Constant ("weight_decay" , numeric_defaults ["weight_decay" ]),
133133 # report_to: for consistency with text-to-text-lora but wandb and tensorboard are not supported yet
134134 Constant ("report_to" , string_defaults ["report_to" ]),
135135 Boolean (
136136 "use_cpu_offloading" ,
137137 default = False ,
138- meta = dict ( desc = " Whether to use CPU offloading for distillation.") ,
138+ meta = { " desc" : " Whether to use CPU offloading for distillation."} ,
139139 ),
140140 CategoricalHyperparameter (
141141 "optimizer" ,
142142 choices = ["AdamW8bit" , "AdamW" , "Adam" ],
143143 default_value = string_defaults ["optimizer" ],
144- meta = dict ( desc = " Which optimizer to use for distillation.") ,
144+ meta = { " desc" : " Which optimizer to use for distillation."} ,
145145 ),
146146 UniformFloatHyperparameter (
147147 "lr_decay" ,
148148 lower = 0.0 ,
149149 upper = 1.0 ,
150150 default_value = numeric_defaults ["lr_decay" ],
151- meta = dict ( desc = " Learning rate decay, applied at each epoch.") ,
151+ meta = { " desc" : " Learning rate decay, applied at each epoch."} ,
152152 ),
153153 UniformIntegerHyperparameter (
154154 "warmup_steps" ,
155155 lower = 0 ,
156156 upper = 2 ** 14 ,
157157 default_value = numeric_defaults ["warmup_steps" ],
158- meta = dict ( desc = " Number of warmup steps for the learning rate scheduler.") ,
158+ meta = { " desc" : " Number of warmup steps for the learning rate scheduler."} ,
159159 ),
160160 ]
161161
@@ -405,7 +405,7 @@ def distillation_forward(*args, **kwargs):
405405 output ["sample" ] if ("return_dict" in kwargs and kwargs ["return_dict" ]) else output [0 ]
406406 )
407407 loss = self .loss (latent_output , latent_targets [self .num_previous_steps ])
408- if is_training :
408+ if is_training and active_steps is not None :
409409 accumulation_normalized_loss = loss / (len (active_steps ) * self .gradient_accumulation_steps )
410410 self .manual_backward (accumulation_normalized_loss )
411411 diffusion_step_losses .append (loss )
0 commit comments