diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py index 06b4f48de..8f7496015 100644 --- a/ai_diffusion/api.py +++ b/ai_diffusion/api.py @@ -40,7 +40,6 @@ class ImageInput: def from_extent(e: Extent): return ImageInput(ExtentInput(e, e, e, e)) - @dataclass class LoraInput: name: str @@ -64,7 +63,8 @@ class CheckpointInput: self_attention_guidance: bool = False dynamic_caching: bool = False tiled_vae: bool = False - + magcache_enabled: bool = False + magcache_thresh: float = 0.24 @dataclass class SamplingInput: diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index 574a6d37a..375393762 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -273,6 +273,7 @@ class ClientFeatures(NamedTuple): max_control_layers: int = 1000 wave_speed: bool = False gguf: bool = False + magcache: bool = False class Client(ABC): diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 03b0f872e..6bdab2e33 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -130,6 +130,7 @@ async def connect(url=default_url, access_token=""): languages=await _list_languages(client), wave_speed="ApplyFBCacheOnModel" in nodes, gguf="UnetLoaderGGUF" in nodes, + magcache="MagCache" in nodes, ) # Check for required and optional model resources @@ -489,6 +490,8 @@ def performance_settings(self): max_pixel_count=settings.max_pixel_count, tiled_vae=settings.tiled_vae, dynamic_caching=settings.dynamic_caching and self.features.wave_speed, + magcache_enabled=settings.magcache_enabled, + magcache_thresh=settings.magcache_thresh, ) async def upload_loras(self, work: WorkflowInput, local_job_id: str): diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 3cfd7b74c..386340946 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -865,6 +865,28 @@ def inpaint_image(self, model: Output, image: Output, mask: Output): return self.add( "INPAINT_InpaintWithModel", 1, inpaint_model=model, image=image, mask=mask, seed=834729 ) + def apply_magcache( + self, + model: Output, + model_type: str = "flux", + magcache_thresh: float = 0.24, + retention_ratio: float = 0.1, + magcache_K: int = 5, + start_step: int = 0, + end_step: int = -1, + ): + """Apply MagCache acceleration to Flux models""" + return self.add( + "MagCache", + 1, + model=model, + model_type=model_type, + magcache_thresh=magcache_thresh, + retention_ratio=retention_ratio, + magcache_K=magcache_K, + start_step=start_step, + end_step=end_step, + ) def crop_mask(self, mask: Output, bounds: Bounds): return self.add( diff --git a/ai_diffusion/resources.py b/ai_diffusion/resources.py index f52bf5b62..222650d8e 100644 --- a/ai_diffusion/resources.py +++ b/ai_diffusion/resources.py @@ -68,6 +68,13 @@ class CustomNode(NamedTuple): "16ec6f344f8cecbbf006d374043f85af22b7a51d", ["ApplyFBCacheOnModel"], ), + CustomNode( + "MagCache", + "ComfyUI-MagCache", + "https://github.com/Zehong-Ma/ComfyUI-MagCache", + "7d4e982bf7955498afca891c7094c48a70985537", + ["MagCache"], + ), ] diff --git a/ai_diffusion/settings.py b/ai_diffusion/settings.py index f35fee942..049bbaa89 100644 --- a/ai_diffusion/settings.py +++ b/ai_diffusion/settings.py @@ -82,7 +82,8 @@ class PerformanceSettings: max_pixel_count: int = 6 dynamic_caching: bool = False tiled_vae: bool = False - + magcache_enabled: bool = False + magcache_thresh: float = 0.24 class Setting: def __init__(self, name: str, default, desc="", help="", items=None): @@ -292,6 +293,20 @@ class Settings(QObject): _("Re-use outputs of previous steps (First Block Cache) to speed up generation."), ) + magcache_enabled: bool + _magcache_enabled = Setting( + _("MagCache Acceleration"), + False, + _("Accelerate Flux model inference using MagCache technology."), + ) + + magcache_thresh: float + _magcache_thresh = Setting( + _("MagCache Strength"), + 0.24, + _("Strength value for MagCache activation (lower values = more powerful)."), + ) + tiled_vae: bool _tiled_vae = Setting( _("Tiled VAE"), @@ -405,6 +420,10 @@ def apply_performance_preset(self, preset: PerformancePreset): for k, v in self._performance_presets[preset]._asdict().items(): self._values[k] = v + def configure_magcache_for_arch(self, arch_name: str): + """Configure MagCache settings for specific architecture.""" + pass + def _migrate_legacy_settings(self, path: Path): if path == self.default_path: legacy_path = Path(__file__).parent / "settings.json" @@ -416,4 +435,4 @@ def _migrate_legacy_settings(self, path: Path): log.warning(f"Failed to migrate settings from {legacy_path} to {path}: {e}") -settings = Settings() +settings = Settings() \ No newline at end of file diff --git a/ai_diffusion/ui/settings.py b/ai_diffusion/ui/settings.py index c1bfd6225..4ca4b8c88 100644 --- a/ai_diffusion/ui/settings.py +++ b/ai_diffusion/ui/settings.py @@ -619,17 +619,31 @@ def __init__(self): self._dynamic_caching.value_changed.connect(self.write) self._layout.addWidget(self._dynamic_caching) + self._magcache_enabled = SwitchSetting( + Settings._magcache_enabled, + text=(_("Enabled"), _("Disabled")), + parent=self + ) + self._magcache_enabled.value_changed.connect(self.write) + self._layout.addWidget(self._magcache_enabled) + + self._magcache_thresh = SliderSetting( + Settings._magcache_thresh, self, 0.1, 0.5, "{:.2f}" + ) + self._magcache_thresh.value_changed.connect(self.write) + self._layout.addWidget(self._magcache_thresh) + self._layout.addStretch() def _change_performance_preset(self, index): self.write() is_custom = settings.performance_preset is PerformancePreset.custom self._advanced.setEnabled(is_custom) - if ( - settings.performance_preset is PerformancePreset.auto - and root.connection.state is ConnectionState.connected - ): + + if (settings.performance_preset is PerformancePreset.auto and + root.connection.state is ConnectionState.connected): apply_performance_preset(settings, root.connection.client.device_info) + if not is_custom: self.read() @@ -640,12 +654,35 @@ def update_client_info(self): _("Device") + f": [{client.device_info.type.upper()}] {client.device_info.name} ({client.device_info.vram} GB)" ) + self._dynamic_caching.enabled = client.features.wave_speed self._dynamic_caching.setToolTip( _("The {node_name} node is not installed.").format(node_name="Comfy-WaveSpeed") if not client.features.wave_speed else "" ) + + self._magcache_enabled.enabled = client.features.magcache + self._magcache_enabled.setToolTip( + _("The {node_name} node is not installed.").format(node_name="MagCache") + if not client.features.magcache + else "" + ) + + self._magcache_thresh.enabled = client.features.magcache + self._magcache_thresh.setToolTip( + _("The {node_name} node is not installed.").format(node_name="MagCache") + if not client.features.magcache + else "" + ) + else: + self._device_info.setText(_("Not connected")) + self._dynamic_caching.enabled = False + self._magcache_enabled.enabled = False + self._magcache_thresh.enabled = False + self._dynamic_caching.setToolTip(_("Not connected to server")) + self._magcache_enabled.setToolTip(_("Not connected to server")) + self._magcache_thresh.setToolTip(_("Not connected to server")) def _read(self): self._history_size.value = settings.history_size @@ -660,6 +697,9 @@ def _read(self): self._max_pixel_count.value = settings.max_pixel_count self._tiled_vae.value = settings.tiled_vae self._dynamic_caching.value = settings.dynamic_caching + self._magcache_enabled.value = settings.magcache_enabled + self._magcache_thresh.value = settings.magcache_thresh + self.update_client_info() def _write(self): @@ -673,6 +713,8 @@ def _write(self): self._performance_preset.currentIndex() ] settings.dynamic_caching = self._dynamic_caching.value + settings.magcache_enabled = self._magcache_enabled.value + settings.magcache_thresh = self._magcache_thresh.value class AboutSettings(SettingsTab): @@ -943,4 +985,5 @@ def _open_settings_folder(self): QDesktopServices.openUrl(QUrl.fromLocalFile(str(util.user_data_dir))) def _close(self): + settings.save() _ = self.close() diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 6d6abb87a..96a66794e 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -87,6 +87,8 @@ def _sampler_params(sampling: SamplingInput, strength: float | None = None): return params + + def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, models: ClientModels): arch = checkpoint.version model_info = models.checkpoints.get(checkpoint.checkpoint) @@ -153,6 +155,17 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod if arch.supports_attention_guidance and checkpoint.self_attention_guidance: model = w.apply_self_attention_guidance(model) + if checkpoint.magcache_enabled and arch in [Arch.flux, Arch.flux_k]: + model_type = "flux_kontext" if arch is Arch.flux_k else "flux" + try: + model = w.apply_magcache( + model, + model_type=model_type, + magcache_thresh=checkpoint.magcache_thresh, + ) + except Exception: + pass + return model, Clip(clip, arch), vae @@ -751,7 +764,10 @@ def generate( models: ModelDict, ): model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + + model = apply_ip_adapter(w, model, cond.control, models) + model_orig = copy(model) model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models) @@ -864,6 +880,7 @@ def inpaint( model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) model = w.differential_diffusion(model) + model_orig = copy(model) upscale_extent = ScaledExtent( # after crop to the masked region @@ -992,7 +1009,10 @@ def refine( models: ModelDict, ): model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) + + model = apply_ip_adapter(w, model, cond.control, models) + model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models) in_image = w.load_image(image) @@ -1368,6 +1388,10 @@ def prepare( face_weight = median_or_zero(c.strength for c in all_control if c.mode is ControlMode.face) if face_weight > 0: i.models.loras.append(LoraInput(model_set.lora["face"], 0.65 * face_weight)) + + i.models.magcache_enabled = perf.magcache_enabled and arch in [Arch.flux, Arch.flux_k] + if i.models.magcache_enabled: + i.models.magcache_thresh = perf.magcache_thresh if kind is WorkflowKind.generate: assert isinstance(canvas, Extent)