Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class ImageInput:
def from_extent(e: Extent):
return ImageInput(ExtentInput(e, e, e, e))


@dataclass
class LoraInput:
name: str
Expand All @@ -64,7 +63,12 @@ class CheckpointInput:
self_attention_guidance: bool = False
dynamic_caching: bool = False
tiled_vae: bool = False

magcache_enabled: bool = False
magcache_thresh: float = 0.24
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the parameters also from here and settings.py.

They can be dug up from git if ever needed, but I'd rather not have them in the code without a good reason.

And if more configurability is desired at some point, a single "strength" value or some preset would be a more user friendly way to control it anyway.

magcache_retention_ratio: float = 0.1
magcache_K: int = 5
magcache_start_step: int = 0
magcache_end_step: int = -1

@dataclass
class SamplingInput:
Expand Down
1 change: 1 addition & 0 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class ClientFeatures(NamedTuple):
max_control_layers: int = 1000
wave_speed: bool = False
gguf: bool = False
magcache: bool = False


class Client(ABC):
Expand Down
5 changes: 5 additions & 0 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ async def connect(url=default_url, access_token=""):
languages=await _list_languages(client),
wave_speed="ApplyFBCacheOnModel" in nodes,
gguf="UnetLoaderGGUF" in nodes,
magcache="MagCache" in nodes,
)

# Check for required and optional model resources
Expand Down Expand Up @@ -489,6 +490,10 @@ def performance_settings(self):
max_pixel_count=settings.max_pixel_count,
tiled_vae=settings.tiled_vae,
dynamic_caching=settings.dynamic_caching and self.features.wave_speed,
magcache_enabled=settings.magcache_enabled,
magcache_thresh=settings.magcache_thresh,
magcache_retention_ratio=settings.magcache_retention_ratio,
magcache_K=settings.magcache_K,
)

async def upload_loras(self, work: WorkflowInput, local_job_id: str):
Expand Down
22 changes: 22 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,28 @@ def inpaint_image(self, model: Output, image: Output, mask: Output):
return self.add(
"INPAINT_InpaintWithModel", 1, inpaint_model=model, image=image, mask=mask, seed=834729
)
def apply_magcache(
self,
model: Output,
model_type: str = "flux",
magcache_thresh: float = 0.24,
retention_ratio: float = 0.1,
magcache_K: int = 5,
start_step: int = 0,
end_step: int = -1,
):
"""Apply MagCache acceleration to Flux models"""
return self.add(
"MagCache",
1,
model=model,
model_type=model_type,
magcache_thresh=magcache_thresh,
retention_ratio=retention_ratio,
magcache_K=magcache_K,
start_step=start_step,
end_step=end_step,
)

def crop_mask(self, mask: Output, bounds: Bounds):
return self.add(
Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class CustomNode(NamedTuple):
"16ec6f344f8cecbbf006d374043f85af22b7a51d",
["ApplyFBCacheOnModel"],
),
CustomNode(
"MagCache",
"ComfyUI-MagCache",
"https://github.com/Zehong-Ma/ComfyUI-MagCache",
"7d4e982bf7955498afca891c7094c48a70985537",
["MagCache"],
),
]


Expand Down
33 changes: 32 additions & 1 deletion ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ class PerformanceSettings:
max_pixel_count: int = 6
dynamic_caching: bool = False
tiled_vae: bool = False

magcache_enabled: bool = False
magcache_thresh: float = 0.24
magcache_retention_ratio: float = 0.1
magcache_K: int = 5

class Setting:
def __init__(self, name: str, default, desc="", help="", items=None):
Expand Down Expand Up @@ -292,6 +295,34 @@ class Settings(QObject):
_("Re-use outputs of previous steps (First Block Cache) to speed up generation."),
)

magcache_enabled: bool
_magcache_enabled = Setting(
_("MagCache Acceleration"),
False,
_("Accelerate Flux model inference using MagCache technology."),
)

magcache_thresh: float
_magcache_thresh = Setting(
_("MagCache Threshold"),
0.24,
_("Threshold value for MagCache activation (lower values = more aggressive caching)."),
)

magcache_retention_ratio: float
_magcache_retention_ratio = Setting(
_("Retention Ratio"),
0.1,
_("Ratio of cached values to retain across timesteps."),
)

magcache_K: int
_magcache_K = Setting(
_("MagCache K"),
5,
_("Number of cached timesteps to consider."),
)

tiled_vae: bool
_tiled_vae = Setting(
_("Tiled VAE"),
Expand Down
50 changes: 40 additions & 10 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,17 +619,25 @@ def __init__(self):
self._dynamic_caching.value_changed.connect(self.write)
self._layout.addWidget(self._dynamic_caching)

self._magcache_enabled = SwitchSetting(
Settings._magcache_enabled,
text=(_("Enabled"), _("Disabled")),
parent=self
)
self._magcache_enabled.value_changed.connect(self.write)
self._layout.addWidget(self._magcache_enabled)

self._layout.addStretch()

def _change_performance_preset(self, index):
self.write()
is_custom = settings.performance_preset is PerformancePreset.custom
self._advanced.setEnabled(is_custom)
if (
settings.performance_preset is PerformancePreset.auto
and root.connection.state is ConnectionState.connected
):

if (settings.performance_preset is PerformancePreset.auto and
root.connection.state is ConnectionState.connected):
apply_performance_preset(settings, root.connection.client.device_info)

if not is_custom:
self.read()

Expand All @@ -640,12 +648,30 @@ def update_client_info(self):
_("Device")
+ f": [{client.device_info.type.upper()}] {client.device_info.name} ({client.device_info.vram} GB)"
)
self._dynamic_caching.enabled = client.features.wave_speed
self._dynamic_caching.setToolTip(
_("The {node_name} node is not installed.").format(node_name="Comfy-WaveSpeed")
if not client.features.wave_speed
else ""
)

has_wave_speed = hasattr(client.features, 'wave_speed') and client.features.wave_speed
Copy link
Owner

@Acly Acly Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the hasattr is not needed

(see also below)

self._dynamic_caching.enabled = has_wave_speed
if not has_wave_speed:
self._dynamic_caching.setToolTip(
_("The {node_name} node is not installed.").format(node_name="Comfy-WaveSpeed")
)
else:
self._dynamic_caching.setToolTip("")

has_magcache = hasattr(client.features, 'magcache') and client.features.magcache
self._magcache_enabled.enabled = has_magcache
if not has_magcache:
self._magcache_enabled.setToolTip(
_("The {node_name} node is not installed.").format(node_name="MagCache")
)
else:
self._magcache_enabled.setToolTip(_("Accelerate Flux model inference using MagCache. Parameters are automatically configured based on model architecture."))
else:
self._device_info.setText(_("Not connected"))
self._dynamic_caching.enabled = False
self._magcache_enabled.enabled = False
self._dynamic_caching.setToolTip(_("Not connected to server"))
self._magcache_enabled.setToolTip(_("Not connected to server"))

def _read(self):
self._history_size.value = settings.history_size
Expand All @@ -660,6 +686,8 @@ def _read(self):
self._max_pixel_count.value = settings.max_pixel_count
self._tiled_vae.value = settings.tiled_vae
self._dynamic_caching.value = settings.dynamic_caching
self._magcache_enabled.value = settings.magcache_enabled

self.update_client_info()

def _write(self):
Expand All @@ -673,6 +701,7 @@ def _write(self):
self._performance_preset.currentIndex()
]
settings.dynamic_caching = self._dynamic_caching.value
settings.magcache_enabled = self._magcache_enabled.value


class AboutSettings(SettingsTab):
Expand Down Expand Up @@ -943,4 +972,5 @@ def _open_settings_folder(self):
QDesktopServices.openUrl(QUrl.fromLocalFile(str(util.user_data_dir)))

def _close(self):
settings.save()
_ = self.close()
41 changes: 41 additions & 0 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def _sampler_params(sampling: SamplingInput, strength: float | None = None):
return params




def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, models: ClientModels):
arch = checkpoint.version
model_info = models.checkpoints.get(checkpoint.checkpoint)
Expand Down Expand Up @@ -153,6 +155,29 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
if arch.supports_attention_guidance and checkpoint.self_attention_guidance:
model = w.apply_self_attention_guidance(model)

# Apply MagCache as a model patch if enabled
if checkpoint.magcache_enabled:
if arch in [Arch.flux, Arch.flux_k]:
print(f"Applying MagCache patch for {arch}")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove all the prints


model_type = "flux_kontext" if arch is Arch.flux_k else "flux"

try:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to wrap in try/except, this should not fail (or do so loudly if the code is broken)

model = w.apply_magcache(
model,
model_type=model_type,
magcache_thresh=checkpoint.magcache_thresh,
retention_ratio=checkpoint.magcache_retention_ratio,
magcache_K=checkpoint.magcache_K,
start_step=checkpoint.magcache_start_step,
end_step=checkpoint.magcache_end_step,
)
print(f"MagCache patch applied successfully with {model_type} settings")
except Exception as e:
print(f"Failed to apply MagCache patch: {e}")
else:
print(f"MagCache not supported for architecture: {arch}")

return model, Clip(clip, arch), vae


Expand Down Expand Up @@ -751,7 +776,10 @@ def generate(
models: ModelDict,
):
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)


model = apply_ip_adapter(w, model, cond.control, models)

model_orig = copy(model)
model, regions = apply_attention_mask(w, model, cond, clip, extent.initial)
model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models)
Expand Down Expand Up @@ -864,6 +892,7 @@ def inpaint(

model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)
model = w.differential_diffusion(model)

model_orig = copy(model)

upscale_extent = ScaledExtent( # after crop to the masked region
Expand Down Expand Up @@ -992,7 +1021,10 @@ def refine(
models: ModelDict,
):
model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all)


model = apply_ip_adapter(w, model, cond.control, models)

model, regions = apply_attention_mask(w, model, cond, clip, extent.initial)
model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models)
in_image = w.load_image(image)
Expand Down Expand Up @@ -1368,6 +1400,15 @@ def prepare(
face_weight = median_or_zero(c.strength for c in all_control if c.mode is ControlMode.face)
if face_weight > 0:
i.models.loras.append(LoraInput(model_set.lora["face"], 0.65 * face_weight))

if perf.magcache_enabled and arch in [Arch.flux, Arch.flux_k]:
i.models.magcache_enabled = True
i.models.magcache_thresh = perf.magcache_thresh
i.models.magcache_retention_ratio = perf.magcache_retention_ratio
i.models.magcache_K = perf.magcache_K
print(f"MagCache settings added to WorkflowInput for {arch}")
else:
i.models.magcache_enabled = False

if kind is WorkflowKind.generate:
assert isinstance(canvas, Extent)
Expand Down
Loading