Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class ImageInput:
def from_extent(e: Extent):
return ImageInput(ExtentInput(e, e, e, e))


@dataclass
class LoraInput:
name: str
Expand All @@ -64,7 +63,8 @@ class CheckpointInput:
self_attention_guidance: bool = False
dynamic_caching: bool = False
tiled_vae: bool = False

magcache_enabled: bool = False
magcache_thresh: float = 0.24

@dataclass
class SamplingInput:
Expand Down
1 change: 1 addition & 0 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class ClientFeatures(NamedTuple):
max_control_layers: int = 1000
wave_speed: bool = False
gguf: bool = False
magcache: bool = False


class Client(ABC):
Expand Down
3 changes: 3 additions & 0 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ async def connect(url=default_url, access_token=""):
languages=await _list_languages(client),
wave_speed="ApplyFBCacheOnModel" in nodes,
gguf="UnetLoaderGGUF" in nodes,
magcache="MagCache" in nodes,
)

# Check for required and optional model resources
Expand Down Expand Up @@ -489,6 +490,8 @@ def performance_settings(self):
max_pixel_count=settings.max_pixel_count,
tiled_vae=settings.tiled_vae,
dynamic_caching=settings.dynamic_caching and self.features.wave_speed,
magcache_enabled=settings.magcache_enabled,
magcache_thresh=settings.magcache_thresh,
)

async def upload_loras(self, work: WorkflowInput, local_job_id: str):
Expand Down
22 changes: 22 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,28 @@ def inpaint_image(self, model: Output, image: Output, mask: Output):
return self.add(
"INPAINT_InpaintWithModel", 1, inpaint_model=model, image=image, mask=mask, seed=834729
)
def apply_magcache(
self,
model: Output,
model_type: str = "flux",
magcache_thresh: float = 0.24,
retention_ratio: float = 0.1,
magcache_K: int = 5,
start_step: int = 0,
end_step: int = -1,
):
"""Apply MagCache acceleration to Flux models"""
return self.add(
"MagCache",
1,
model=model,
model_type=model_type,
magcache_thresh=magcache_thresh,
retention_ratio=retention_ratio,
magcache_K=magcache_K,
start_step=start_step,
end_step=end_step,
)

def crop_mask(self, mask: Output, bounds: Bounds):
return self.add(
Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class CustomNode(NamedTuple):
"16ec6f344f8cecbbf006d374043f85af22b7a51d",
["ApplyFBCacheOnModel"],
),
CustomNode(
"MagCache",
"ComfyUI-MagCache",
"https://github.com/Zehong-Ma/ComfyUI-MagCache",
"7d4e982bf7955498afca891c7094c48a70985537",
["MagCache"],
),
]


Expand Down
23 changes: 21 additions & 2 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class PerformanceSettings:
max_pixel_count: int = 6
dynamic_caching: bool = False
tiled_vae: bool = False

magcache_enabled: bool = False
magcache_thresh: float = 0.24

class Setting:
def __init__(self, name: str, default, desc="", help="", items=None):
Expand Down Expand Up @@ -292,6 +293,20 @@ class Settings(QObject):
_("Re-use outputs of previous steps (First Block Cache) to speed up generation."),
)

magcache_enabled: bool
_magcache_enabled = Setting(
_("MagCache Acceleration"),
False,
_("Accelerate Flux model inference using MagCache technology."),
)

magcache_thresh: float
_magcache_thresh = Setting(
_("MagCache Strength"),
0.24,
_("Strength value for MagCache activation (lower values = more powerful)."),
)

tiled_vae: bool
_tiled_vae = Setting(
_("Tiled VAE"),
Expand Down Expand Up @@ -405,6 +420,10 @@ def apply_performance_preset(self, preset: PerformancePreset):
for k, v in self._performance_presets[preset]._asdict().items():
self._values[k] = v

def configure_magcache_for_arch(self, arch_name: str):
"""Configure MagCache settings for specific architecture."""
pass

def _migrate_legacy_settings(self, path: Path):
if path == self.default_path:
legacy_path = Path(__file__).parent / "settings.json"
Expand All @@ -416,4 +435,4 @@ def _migrate_legacy_settings(self, path: Path):
log.warning(f"Failed to migrate settings from {legacy_path} to {path}: {e}")


settings = Settings()
settings = Settings()
51 changes: 47 additions & 4 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,17 +619,31 @@ def __init__(self):
self._dynamic_caching.value_changed.connect(self.write)
self._layout.addWidget(self._dynamic_caching)

self._magcache_enabled = SwitchSetting(
Settings._magcache_enabled,
text=(_("Enabled"), _("Disabled")),
parent=self
)
self._magcache_enabled.value_changed.connect(self.write)
self._layout.addWidget(self._magcache_enabled)

self._magcache_thresh = SliderSetting(
Settings._magcache_thresh, self, 0.1, 0.5, "{:.2f}"
)
self._magcache_thresh.value_changed.connect(self.write)
self._layout.addWidget(self._magcache_thresh)

self._layout.addStretch()

def _change_performance_preset(self, index):
self.write()
is_custom = settings.performance_preset is PerformancePreset.custom
self._advanced.setEnabled(is_custom)
if (
settings.performance_preset is PerformancePreset.auto
and root.connection.state is ConnectionState.connected
):

if (settings.performance_preset is PerformancePreset.auto and
root.connection.state is ConnectionState.connected):
apply_performance_preset(settings, root.connection.client.device_info)

if not is_custom:
self.read()

Expand All @@ -640,12 +654,35 @@ def update_client_info(self):
_("Device")
+ f": [{client.device_info.type.upper()}] {client.device_info.name} ({client.device_info.vram} GB)"
)

self._dynamic_caching.enabled = client.features.wave_speed
self._dynamic_caching.setToolTip(
_("The {node_name} node is not installed.").format(node_name="Comfy-WaveSpeed")
if not client.features.wave_speed
else ""
)

self._magcache_enabled.enabled = client.features.magcache
self._magcache_enabled.setToolTip(
_("The {node_name} node is not installed.").format(node_name="MagCache")
if not client.features.magcache
else ""
)

self._magcache_thresh.enabled = client.features.magcache
self._magcache_thresh.setToolTip(
_("The {node_name} node is not installed.").format(node_name="MagCache")
if not client.features.magcache
else ""
)
else:
self._device_info.setText(_("Not connected"))
self._dynamic_caching.enabled = False
self._magcache_enabled.enabled = False
self._magcache_thresh.enabled = False
self._dynamic_caching.setToolTip(_("Not connected to server"))
self._magcache_enabled.setToolTip(_("Not connected to server"))
self._magcache_thresh.setToolTip(_("Not connected to server"))

def _read(self):
self._history_size.value = settings.history_size
Expand All @@ -660,6 +697,9 @@ def _read(self):
self._max_pixel_count.value = settings.max_pixel_count
self._tiled_vae.value = settings.tiled_vae
self._dynamic_caching.value = settings.dynamic_caching
self._magcache_enabled.value = settings.magcache_enabled
self._magcache_thresh.value = settings.magcache_thresh

self.update_client_info()

def _write(self):
Expand All @@ -673,6 +713,8 @@ def _write(self):
self._performance_preset.currentIndex()
]
settings.dynamic_caching = self._dynamic_caching.value
settings.magcache_enabled = self._magcache_enabled.value
settings.magcache_thresh = self._magcache_thresh.value


class AboutSettings(SettingsTab):
Expand Down Expand Up @@ -943,4 +985,5 @@ def _open_settings_folder(self):
QDesktopServices.openUrl(QUrl.fromLocalFile(str(util.user_data_dir)))

def _close(self):
settings.save()
_ = self.close()
24 changes: 24 additions & 0 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def _sampler_params(sampling: SamplingInput, strength: float | None = None):
return params




def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, models: ClientModels):
arch = checkpoint.version
model_info = models.checkpoints.get(checkpoint.checkpoint)
Expand Down Expand Up @@ -153,6 +155,17 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
if arch.supports_attention_guidance and checkpoint.self_attention_guidance:
model = w.apply_self_attention_guidance(model)

if checkpoint.magcache_enabled and arch in [Arch.flux, Arch.flux_k]:
model_type = "flux_kontext" if arch is Arch.flux_k else "flux"
try:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to wrap in try/except, this should not fail (or do so loudly if the code is broken)

model = w.apply_magcache(
model,
model_type=model_type,
magcache_thresh=checkpoint.magcache_thresh,
)
except Exception:
pass

return model, Clip(clip, arch), vae


Expand Down Expand Up @@ -751,7 +764,10 @@ def generate(
models: ModelDict,
):
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)


model = apply_ip_adapter(w, model, cond.control, models)

model_orig = copy(model)
model, regions = apply_attention_mask(w, model, cond, clip, extent.initial)
model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models)
Expand Down Expand Up @@ -864,6 +880,7 @@ def inpaint(

model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)
model = w.differential_diffusion(model)

model_orig = copy(model)

upscale_extent = ScaledExtent( # after crop to the masked region
Expand Down Expand Up @@ -992,7 +1009,10 @@ def refine(
models: ModelDict,
):
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)


model = apply_ip_adapter(w, model, cond.control, models)

model, regions = apply_attention_mask(w, model, cond, clip, extent.initial)
model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models)
in_image = w.load_image(image)
Expand Down Expand Up @@ -1368,6 +1388,10 @@ def prepare(
face_weight = median_or_zero(c.strength for c in all_control if c.mode is ControlMode.face)
if face_weight > 0:
i.models.loras.append(LoraInput(model_set.lora["face"], 0.65 * face_weight))

i.models.magcache_enabled = perf.magcache_enabled and arch in [Arch.flux, Arch.flux_k]
if i.models.magcache_enabled:
i.models.magcache_thresh = perf.magcache_thresh

if kind is WorkflowKind.generate:
assert isinstance(canvas, Extent)
Expand Down
Loading