diff --git a/apps/streamlit/DiffSynth_Studio.py b/apps/streamlit/DiffSynth_Studio.py index cfd38565..1681524e 100644 --- a/apps/streamlit/DiffSynth_Studio.py +++ b/apps/streamlit/DiffSynth_Studio.py @@ -1,15 +1,24 @@ # Set web page format import streamlit as st st.set_page_config(layout="wide") -# Disable virtual VRAM on windows system +# Configure GPU memory usage based on available hardware import torch -torch.cuda.set_per_process_memory_fraction(0.999, 0) +import platform +# Check for CUDA (NVIDIA GPUs) +if torch.cuda.is_available(): + torch.cuda.set_per_process_memory_fraction(0.999, 0) + device = "cuda" +# Check for MPS (Apple Silicon) +elif hasattr(torch, 'mps') and torch.backends.mps.is_available() and platform.processor() == 'arm': + device = "mps" +else: + device = "cpu" -st.markdown(""" +st.markdown(f""" # DiffSynth Studio -[Source Code](https://github.com/Artiprocher/DiffSynth-Studio) +[Source Code](https://github.com/modelscope/DiffSynth-Studio) -Welcome to DiffSynth Studio. +Welcome to DiffSynth Studio. Running on: {device.upper()} """) diff --git a/apps/streamlit/pages/1_Image_Creator.py b/apps/streamlit/pages/1_Image_Creator.py index 732d2195..0a35ec80 100644 --- a/apps/streamlit/pages/1_Image_Creator.py +++ b/apps/streamlit/pages/1_Image_Creator.py @@ -74,11 +74,29 @@ def release_model(): del st.session_state["loaded_model_path"] del st.session_state["model_manager"] del st.session_state["pipeline"] - torch.cuda.empty_cache() + # Clear GPU memory based on available hardware + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # No equivalent memory management function for MPS yet def load_model(model_type, model_path): - model_manager = ModelManager() + # Determine the best available device + import platform + if torch.cuda.is_available(): + device = "cuda" + torch_dtype = torch.bfloat16 if model_type == "FLUX" else None + elif hasattr(torch, 'mps') and torch.backends.mps.is_available() and platform.processor() == 'arm': + device = "mps" + # Use float32 on MPS for better compatibility + torch_dtype = torch.float32 # Force full precision on Apple Silicon + else: + device = "cpu" + torch_dtype = None + + st.info(f"Using device: {device.upper()}") + + model_manager = ModelManager(device=device, torch_dtype=torch_dtype) if model_type == "HunyuanDiT": model_manager.load_models([ os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"), @@ -93,7 +111,6 @@ def load_model(model_type, model_path): os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"), ]) elif model_type == "FLUX": - model_manager.torch_dtype = torch.bfloat16 file_list = [ os.path.join(model_path, "text_encoder/model.safetensors"), os.path.join(model_path, "text_encoder_2"), diff --git a/apps/streamlit/pages/2_Video_Creator.py b/apps/streamlit/pages/2_Video_Creator.py index 87480726..39bd11f8 100644 --- a/apps/streamlit/pages/2_Video_Creator.py +++ b/apps/streamlit/pages/2_Video_Creator.py @@ -3,6 +3,8 @@ from diffsynth import SDVideoPipelineRunner import os import numpy as np +import torch +import platform def load_model_list(folder): @@ -20,11 +22,19 @@ def match_processor_id(model_name, supported_processor_id_list): return 0 +# Determine the appropriate device +if torch.cuda.is_available(): + device = "cuda" +elif hasattr(torch, 'mps') and torch.backends.mps.is_available() and platform.processor() == 'arm': + device = "mps" +else: + device = "cpu" + config = { "models": { "model_list": [], "textual_inversion_folder": "models/textual_inversion", - "device": "cuda", + "device": device, "lora_alphas": [], "controlnet_units": [] }, diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 7303dff1..f5141e41 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -192,6 +192,15 @@ def encode_prompt(self, prompt, positive=True, t5_sequence_length=512): def prepare_extra_input(self, latents=None, guidance=1.0): + if self.dit is None: + # Create dummy data for when DiT model is missing + dummy_shape = latents.shape + return { + "image_ids": torch.zeros(dummy_shape[0], 1, dummy_shape[2], dummy_shape[3], device=latents.device, dtype=latents.dtype), + "guidance": torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + } + + # Normal case when DiT is available latent_image_ids = self.dit.prepare_image_ids(latents) guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) return {"image_ids": latent_image_ids, "guidance": guidance} @@ -532,49 +541,18 @@ def lets_dance_flux( tea_cache: TeaCache = None, **kwargs ): - if tiled: - def flux_forward_fn(hl, hr, wl, wr): - tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None - return lets_dance_flux( - dit=dit, - controlnet=controlnet, - hidden_states=hidden_states[:, :, hl: hr, wl: wr], - timestep=timestep, - prompt_emb=prompt_emb, - pooled_prompt_emb=pooled_prompt_emb, - guidance=guidance, - text_ids=text_ids, - image_ids=None, - controlnet_frames=tiled_controlnet_frames, - tiled=False, - **kwargs - ) - return FastTileWorker().tiled_forward( - flux_forward_fn, - hidden_states, - tile_size=tile_size, - tile_stride=tile_stride, - tile_device=hidden_states.device, - tile_dtype=hidden_states.dtype - ) - - - # ControlNet + # Handle missing DiT model + if dit is None: + # Return hidden_states unchanged as a fallback + return hidden_states + + # Continue with normal processing when DiT is available if controlnet is not None and controlnet_frames is not None: - controlnet_extra_kwargs = { - "hidden_states": hidden_states, - "timestep": timestep, - "prompt_emb": prompt_emb, - "pooled_prompt_emb": pooled_prompt_emb, - "guidance": guidance, - "text_ids": text_ids, - "image_ids": image_ids, - "tiled": tiled, - "tile_size": tile_size, - "tile_stride": tile_stride, - } - controlnet_res_stack, controlnet_single_res_stack = controlnet( - controlnet_frames, **controlnet_extra_kwargs + hidden_states = controlnet( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=prompt_emb, + controlnet_frames=controlnet_frames ) if image_ids is None: diff --git a/requirements.txt b/requirements.txt index 63a871b9..c7955f4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ torch>=2.0.0 torchvision -cupy-cuda12x transformers==4.46.2 controlnet-aux==0.0.7 imageio @@ -11,3 +10,6 @@ sentencepiece protobuf modelscope ftfy +pillow +numpy +tqdm