File tree Expand file tree Collapse file tree 2 files changed +12
-8
lines changed
examples/consistency_distillation Expand file tree Collapse file tree 2 files changed +12
-8
lines changed Original file line number Diff line number Diff line change @@ -921,10 +921,12 @@ def main(args):
921
921
922
922
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
923
923
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
924
- if teacher_unet .config .time_cond_proj_dim is None :
925
- teacher_unet .config ["time_cond_proj_dim" ] = args .unet_time_cond_proj_dim
926
- time_cond_proj_dim = teacher_unet .config .time_cond_proj_dim
927
- unet = UNet2DConditionModel (** teacher_unet .config )
924
+ time_cond_proj_dim = (
925
+ teacher_unet .config .time_cond_proj_dim
926
+ if teacher_unet .config .time_cond_proj_dim is not None
927
+ else args .unet_time_cond_proj_dim
928
+ )
929
+ unet = UNet2DConditionModel .from_config (teacher_unet .config , time_cond_proj_dim = time_cond_proj_dim )
928
930
# load teacher_unet weights into unet
929
931
unet .load_state_dict (teacher_unet .state_dict (), strict = False )
930
932
unet .train ()
Original file line number Diff line number Diff line change @@ -980,10 +980,12 @@ def main(args):
980
980
981
981
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
982
982
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
983
- if teacher_unet .config .time_cond_proj_dim is None :
984
- teacher_unet .config ["time_cond_proj_dim" ] = args .unet_time_cond_proj_dim
985
- time_cond_proj_dim = teacher_unet .config .time_cond_proj_dim
986
- unet = UNet2DConditionModel (** teacher_unet .config )
983
+ time_cond_proj_dim = (
984
+ teacher_unet .config .time_cond_proj_dim
985
+ if teacher_unet .config .time_cond_proj_dim is not None
986
+ else args .unet_time_cond_proj_dim
987
+ )
988
+ unet = UNet2DConditionModel .from_config (teacher_unet .config , time_cond_proj_dim = time_cond_proj_dim )
987
989
# load teacher_unet weights into unet
988
990
unet .load_state_dict (teacher_unet .state_dict (), strict = False )
989
991
unet .train ()
You can’t perform that action at this time.
0 commit comments