Skip to content

Commit 17a1b83

Browse files
authored
Update model.py
1 parent 5e4e3c3 commit 17a1b83

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/train/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def __init__(
3535
FluxOminiKontextPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
3636
)
3737
self.transformer = self.flux_pipe.transformer
38-
self.transformer.gradient_checkpointing = gradient_checkpointing
38+
if gradient_checkpointing:
39+
self.transformer.enable_gradient_checkpointing()
3940
self.transformer.train()
4041

4142
# Freeze the Flux pipeline components

0 commit comments

Comments
 (0)