diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3b2101543..dae989a26 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1360,10 +1360,7 @@ def run(self): # print("sage attention is not installed. Using SDP instead") if self.train_config.gradient_checkpointing: - if self.sd.is_flux: - unet.gradient_checkpointing = True - else: - unet.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() if isinstance(text_encoder, list): for te in text_encoder: if hasattr(te, 'enable_gradient_checkpointing'):