|
| 1 | +import os |
| 2 | +import torch |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from typing import List, Tuple, Optional |
| 5 | + |
| 6 | +from diffsynth_engine.configs.controlnet import ControlType |
| 7 | + |
| 8 | + |
| 9 | +@dataclass |
| 10 | +class BaseConfig: |
| 11 | + model_path: str | os.PathLike | List[str | os.PathLike] |
| 12 | + model_dtype: torch.dtype |
| 13 | + batch_cfg: bool = False |
| 14 | + vae_tiled: bool = False |
| 15 | + vae_tile_size: int | Tuple[int, int] = 256 |
| 16 | + vae_tile_stride: int | Tuple[int, int] = 256 |
| 17 | + device: str = "cuda" |
| 18 | + offload_mode: Optional[str] = None |
| 19 | + |
| 20 | + |
| 21 | +@dataclass |
| 22 | +class AttentionConfig: |
| 23 | + dit_attn_impl: str = "auto" |
| 24 | + # Sparge Attention |
| 25 | + sparge_smooth_k: bool = True |
| 26 | + sparge_cdfthreshd: float = 0.6 |
| 27 | + sparge_simthreshd1: float = 0.98 |
| 28 | + sparge_pvthreshd: float = 50.0 |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class OptimizationConfig: |
| 33 | + use_fp8_linear: bool = False |
| 34 | + use_fbcache: bool = False |
| 35 | + fbcache_relative_l1_threshold: float = 0.05 |
| 36 | + |
| 37 | + |
| 38 | +@dataclass |
| 39 | +class ParallelConfig: |
| 40 | + parallelism: int = 1 |
| 41 | + use_cfg_parallel: bool = False |
| 42 | + cfg_degree: Optional[int] = None |
| 43 | + sp_ulysses_degree: Optional[int] = None |
| 44 | + sp_ring_degree: Optional[int] = None |
| 45 | + tp_degree: Optional[int] = None |
| 46 | + use_fsdp: bool = False |
| 47 | + |
| 48 | + |
| 49 | +@dataclass |
| 50 | +class SDPipelineConfig(BaseConfig): |
| 51 | + model_path: str | os.PathLike | List[str | os.PathLike] |
| 52 | + clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 53 | + vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 54 | + model_dtype: torch.dtype = torch.float16 |
| 55 | + clip_dtype: torch.dtype = torch.float16 |
| 56 | + vae_dtype: torch.dtype = torch.float32 |
| 57 | + |
| 58 | + @classmethod |
| 59 | + def basic_config( |
| 60 | + cls, |
| 61 | + model_path: str | os.PathLike | List[str | os.PathLike], |
| 62 | + device: str = "cuda", |
| 63 | + offload_mode: Optional[str] = None, |
| 64 | + ) -> "SDPipelineConfig": |
| 65 | + return cls( |
| 66 | + model_path=model_path, |
| 67 | + device=device, |
| 68 | + offload_mode=offload_mode, |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +@dataclass |
| 73 | +class SDXLPipelineConfig(BaseConfig): |
| 74 | + model_path: str | os.PathLike | List[str | os.PathLike] |
| 75 | + clip_l_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 76 | + clip_g_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 77 | + vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 78 | + model_dtype: torch.dtype = torch.float16 |
| 79 | + clip_l_dtype: torch.dtype = torch.float16 |
| 80 | + clip_g_dtype: torch.dtype = torch.float16 |
| 81 | + vae_dtype: torch.dtype = torch.float32 |
| 82 | + |
| 83 | + @classmethod |
| 84 | + def basic_config( |
| 85 | + cls, |
| 86 | + model_path: str | os.PathLike | List[str | os.PathLike], |
| 87 | + device: str = "cuda", |
| 88 | + offload_mode: Optional[str] = None, |
| 89 | + ) -> "SDXLPipelineConfig": |
| 90 | + return cls( |
| 91 | + model_path=model_path, |
| 92 | + device=device, |
| 93 | + offload_mode=offload_mode, |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +@dataclass |
| 98 | +class FluxPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig): |
| 99 | + model_path: str | os.PathLike | List[str | os.PathLike] |
| 100 | + clip_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 101 | + t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 102 | + vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 103 | + model_dtype: torch.dtype = torch.bfloat16 |
| 104 | + clip_dtype: torch.dtype = torch.bfloat16 |
| 105 | + t5_dtype: torch.dtype = torch.bfloat16 |
| 106 | + vae_dtype: torch.dtype = torch.bfloat16 |
| 107 | + |
| 108 | + load_text_encoder: bool = True |
| 109 | + control_type: ControlType = ControlType.normal |
| 110 | + |
| 111 | + @classmethod |
| 112 | + def basic_config( |
| 113 | + cls, |
| 114 | + model_path: str | os.PathLike | List[str | os.PathLike], |
| 115 | + device: str = "cuda", |
| 116 | + parallelism: int = 1, |
| 117 | + offload_mode: Optional[str] = None, |
| 118 | + ) -> "FluxPipelineConfig": |
| 119 | + return cls( |
| 120 | + model_path=model_path, |
| 121 | + device=device, |
| 122 | + parallelism=parallelism, |
| 123 | + use_fsdp=True, |
| 124 | + offload_mode=offload_mode, |
| 125 | + ) |
| 126 | + |
| 127 | + def __post_init__(self): |
| 128 | + init_parallel_config(self) |
| 129 | + |
| 130 | + |
| 131 | +@dataclass |
| 132 | +class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig): |
| 133 | + model_path: str | os.PathLike | List[str | os.PathLike] |
| 134 | + t5_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 135 | + vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 136 | + image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None |
| 137 | + model_dtype: torch.dtype = torch.bfloat16 |
| 138 | + t5_dtype: torch.dtype = torch.bfloat16 |
| 139 | + vae_dtype: torch.dtype = torch.bfloat16 |
| 140 | + image_encoder_dtype: torch.dtype = torch.bfloat16 |
| 141 | + |
| 142 | + shift: Optional[float] = field(default=None, init=False) # RecifitedFlowScheduler shift factor, set by model type |
| 143 | + |
| 144 | + # override BaseConfig |
| 145 | + vae_tiled: bool = True |
| 146 | + vae_tile_size: Tuple[int, int] = (34, 34) |
| 147 | + vae_tile_stride: Tuple[int, int] = (18, 16) |
| 148 | + |
| 149 | + @classmethod |
| 150 | + def basic_config( |
| 151 | + cls, |
| 152 | + model_path: str | os.PathLike | List[str | os.PathLike], |
| 153 | + image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None, |
| 154 | + device: str = "cuda", |
| 155 | + parallelism: int = 1, |
| 156 | + offload_mode: Optional[str] = None, |
| 157 | + ) -> "WanPipelineConfig": |
| 158 | + return cls( |
| 159 | + model_path=model_path, |
| 160 | + image_encoder_path=image_encoder_path, |
| 161 | + device=device, |
| 162 | + parallelism=parallelism, |
| 163 | + use_cfg_parallel=True, |
| 164 | + use_fsdp=True, |
| 165 | + offload_mode=offload_mode, |
| 166 | + ) |
| 167 | + |
| 168 | + def __post_init__(self): |
| 169 | + init_parallel_config(self) |
| 170 | + |
| 171 | + |
| 172 | +def init_parallel_config(config: FluxPipelineConfig | WanPipelineConfig): |
| 173 | + assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8" |
| 174 | + config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg |
| 175 | + |
| 176 | + if config.use_cfg_parallel is True and config.cfg_degree is not None: |
| 177 | + raise ValueError("use_cfg_parallel and cfg_degree should not be specified together") |
| 178 | + config.cfg_degree = (2 if config.use_cfg_parallel else 1) if config.cfg_degree is None else config.cfg_degree |
| 179 | + |
| 180 | + if config.tp_degree is not None: |
| 181 | + assert config.sp_ulysses_degree is None and config.sp_ring_degree is None, ( |
| 182 | + "not allowed to enable sequence parallel and tensor parallel together; " |
| 183 | + "either set sp_ulysses_degree=None, sp_ring_degree=None or set tp_degree=None during pipeline initialization" |
| 184 | + ) |
| 185 | + assert config.use_fsdp is False, ( |
| 186 | + "not allowed to enable fully sharded data parallel and tensor parallel together; " |
| 187 | + "either set use_fsdp=False or set tp_degree=None during pipeline initialization" |
| 188 | + ) |
| 189 | + assert config.parallelism == config.cfg_degree * config.tp_degree, ( |
| 190 | + f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * tp_degree ({config.tp_degree})" |
| 191 | + ) |
| 192 | + config.sp_ulysses_degree = 1 |
| 193 | + config.sp_ring_degree = 1 |
| 194 | + elif config.sp_ulysses_degree is None and config.sp_ring_degree is None: |
| 195 | + # use ulysses if not specified |
| 196 | + config.sp_ulysses_degree = config.parallelism // config.cfg_degree |
| 197 | + config.sp_ring_degree = 1 |
| 198 | + config.tp_degree = 1 |
| 199 | + elif config.sp_ulysses_degree is not None and config.sp_ring_degree is not None: |
| 200 | + assert config.parallelism == config.cfg_degree * config.sp_ulysses_degree * config.sp_ring_degree, ( |
| 201 | + f"parallelism ({config.parallelism}) must be equal to cfg_degree ({config.cfg_degree}) * " |
| 202 | + f"sp_ulysses_degree ({config.sp_ulysses_degree}) * sp_ring_degree ({config.sp_ring_degree})" |
| 203 | + ) |
| 204 | + config.tp_degree = 1 |
| 205 | + else: |
| 206 | + raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together") |
0 commit comments