Skip to content

Commit 0778dd9

Browse files
committed
fix Text Encoder only LoRA training
1 parent 59b3b94 commit 0778dd9

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

flux_train_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def get_noise_pred_and_target(
378378
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
379379
# if not args.split_mode:
380380
# normal forward
381-
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
381+
with torch.set_grad_enabled(is_train), accelerator.autocast():
382382
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
383383
model_pred = unet(
384384
img=img,

sd3_train_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def get_noise_pred_and_target(
345345
t5_attn_mask = None
346346

347347
# call model
348-
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
348+
with torch.set_grad_enabled(is_train), accelerator.autocast():
349349
# TODO support attention mask
350350
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
351351

train_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def get_noise_pred_and_target(
233233
t.requires_grad_(True)
234234

235235
# Predict the noise residual
236-
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
236+
with torch.set_grad_enabled(is_train), accelerator.autocast():
237237
noise_pred = self.call_unet(
238238
args,
239239
accelerator,

0 commit comments

Comments
 (0)