diff --git a/examples/controlnet/README_sd3.md b/examples/controlnet/README_sd3.md index 7a7b4841125f..c95f34e32f38 100644 --- a/examples/controlnet/README_sd3.md +++ b/examples/controlnet/README_sd3.md @@ -1,6 +1,6 @@ -# ControlNet training example for Stable Diffusion 3 (SD3) +# ControlNet training example for Stable Diffusion 3/3.5 (SD3/3.5) -The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206). +The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206) and [Stable Diffusion 3.5](https://stability.ai/news/introducing-stable-diffusion-3-5). ## Running locally with PyTorch @@ -51,9 +51,9 @@ Please download the dataset and unzip it in the directory `fill50k` in the `exam ## Training -First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training. +First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or the SD3.5 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). We will use it as a base model for the ControlNet training. > [!NOTE] -> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: +> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or [Stable Diffusion 3.5 Large Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: ```bash huggingface-cli login @@ -90,6 +90,8 @@ accelerate launch train_controlnet_sd3.py \ --gradient_accumulation_steps=4 ``` +To train a ControlNet model for Stable Diffusion 3.5, replace the `MODEL_DIR` with `stabilityai/stable-diffusion-3.5-medium`. + To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. Our experiments were conducted on a single 40GB A100 GPU. @@ -124,6 +126,8 @@ image = pipe( image.save("./output.png") ``` +Similarly, for SD3.5, replace the `base_model_path` with `stabilityai/stable-diffusion-3.5-medium` and controlnet_path `DavyMorgan/sd35-controlnet-out'. + ## Notes ### GPU usage @@ -135,6 +139,8 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin ## Example results +### SD3 + #### After 500 steps with batch size 8 | | | @@ -150,3 +156,20 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin || pale golden rod circle with old lace background | ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-6500.png) | +### SD3.5 + +#### After 500 steps with batch size 8 + +| | | +|-------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------:| +|| pale golden rod circle with old lace background | + ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-500-3.5.png) | + + +#### After 3000 steps with batch size 8: + +| | | +|-------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------:| +|| pale golden rod circle with old lace background | + ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-3000-3.5.png) | + diff --git a/examples/controlnet/test_controlnet.py b/examples/controlnet/test_controlnet.py index 3c508f80f1a4..d595a1a312b0 100644 --- a/examples/controlnet/test_controlnet.py +++ b/examples/controlnet/test_controlnet.py @@ -138,6 +138,27 @@ def test_controlnet_sd3(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) +class ControlNetSD35(ExamplesTestsAccelerate): + def test_controlnet_sd3(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet_sd3.py + --pretrained_model_name_or_path=hf-internal-testing/tiny-sd35-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd35 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) + + class ControlNetflux(ExamplesTestsAccelerate): def test_controlnet_flux(self): with tempfile.TemporaryDirectory() as tmpdir: diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 2bb68220e268..cbbce2932ef8 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -263,6 +263,12 @@ def parse_args(input_args=None): help="Path to pretrained controlnet model or model identifier from huggingface.co/models." " If not specified controlnet weights are initialized from unet.", ) + parser.add_argument( + "--num_extra_conditioning_channels", + type=int, + default=0, + help="Number of extra conditioning channels for controlnet.", + ) parser.add_argument( "--revision", type=str, @@ -539,6 +545,9 @@ def parse_args(input_args=None): default=77, help="Maximum sequence length to use with with the T5 text encoder", ) + parser.add_argument( + "--dataset_preprocess_batch_size", type=int, default=1000, help="Batch size for preprocessing dataset." + ) parser.add_argument( "--validation_prompt", type=str, @@ -986,7 +995,9 @@ def main(args): controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) else: logger.info("Initializing controlnet weights from transformer") - controlnet = SD3ControlNetModel.from_transformer(transformer) + controlnet = SD3ControlNetModel.from_transformer( + transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels + ) transformer.requires_grad_(False) vae.requires_grad_(False) @@ -1123,7 +1134,12 @@ def compute_text_embeddings(batch, text_encoders, tokenizers): # fingerprint used by the cache for the other processes to load the result # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 new_fingerprint = Hasher.hash(args) - train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) + train_dataset = train_dataset.map( + compute_embeddings_fn, + batched=True, + batch_size=args.dataset_preprocess_batch_size, + new_fingerprint=new_fingerprint, + ) del text_encoder_one, text_encoder_two, text_encoder_three del tokenizer_one, tokenizer_two, tokenizer_three diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 887e8afd2106..79452bb85176 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from typing import Any, Dict, List, Optional, Tuple, Union import torch diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 90c253f783c6..5c547164c29a 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -60,7 +60,9 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ) batch_params = frozenset(["prompt", "negative_prompt"]) - def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm"): + def get_dummy_components( + self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False + ): torch.manual_seed(0) transformer = SD3Transformer2DModel( sample_size=32, @@ -74,6 +76,7 @@ def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional pooled_projection_dim=64, out_channels=8, qk_norm=qk_norm, + dual_attention_layers=() if not use_dual_attention else (0, 1), ) torch.manual_seed(0) @@ -88,7 +91,10 @@ def get_dummy_components(self, num_controlnet_layers: int = 3, qk_norm: Optional caption_projection_dim=32, pooled_projection_dim=64, out_channels=8, + qk_norm=qk_norm, + dual_attention_layers=() if not use_dual_attention else (0,), ) + clip_text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, @@ -173,8 +179,7 @@ def get_dummy_inputs(self, device, seed=0): return inputs - def test_controlnet_sd3(self): - components = self.get_dummy_components() + def run_pipe(self, components, use_sd35=False): sd_pipe = StableDiffusion3ControlNetPipeline(**components) sd_pipe = sd_pipe.to(torch_device, dtype=torch.float16) sd_pipe.set_progress_bar_config(disable=None) @@ -187,12 +192,23 @@ def test_controlnet_sd3(self): assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.5767, 0.7100, 0.5981, 0.5674, 0.5952, 0.4102, 0.5093, 0.5044, 0.6030]) + if not use_sd35: + expected_slice = np.array([0.5767, 0.7100, 0.5981, 0.5674, 0.5952, 0.4102, 0.5093, 0.5044, 0.6030]) + else: + expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328]) assert ( np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 ), f"Expected: {expected_slice}, got: {image_slice.flatten()}" + def test_controlnet_sd3(self): + components = self.get_dummy_components() + self.run_pipe(components) + + def test_controlnet_sd35(self): + components = self.get_dummy_components(num_controlnet_layers=1, qk_norm="rms_norm", use_dual_attention=True) + self.run_pipe(components, use_sd35=True) + @unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention") def test_xformers_attention_forwardGenerator_pass(self): pass