We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 5e4e3c3 commit 17a1b83Copy full SHA for 17a1b83
src/train/model.py
@@ -35,7 +35,8 @@ def __init__(
35
FluxOminiKontextPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device)
36
)
37
self.transformer = self.flux_pipe.transformer
38
- self.transformer.gradient_checkpointing = gradient_checkpointing
+ if gradient_checkpointing:
39
+ self.transformer.enable_gradient_checkpointing()
40
self.transformer.train()
41
42
# Freeze the Flux pipeline components
0 commit comments