Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions StabilityMatrix.Avalonia/App.axaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
<StyleInclude Source="Controls/Inference/UnetModelCard.axaml" />
<StyleInclude Source="Controls/Inference/DiscreteModelSamplingCard.axaml" />
<StyleInclude Source="Controls/Inference/RescaleCfgCard.axaml" />
<StyleInclude Source="Controls/Inference/TiledVAECard.axaml" />
<StyleInclude Source="Controls/Painting/PaintCanvas.axaml" />
<StyleInclude Source="Controls/MarkdownViewer.axaml" />
<StyleInclude Source="Controls/Inference/WanModelCard.axaml" />
Expand Down
79 changes: 79 additions & 0 deletions StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
<Styles
xmlns="https://github.com/avaloniaui"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:controls="using:StabilityMatrix.Avalonia.Controls"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
xmlns:vmInference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
x:DataType="vmInference:TiledVAECardViewModel">
<Design.PreviewWith>
<controls:TiledVAECard />
</Design.PreviewWith>

<Style Selector="controls|TiledVAECard">
<!-- Set Defaults -->
<Setter Property="Template">
<ControlTemplate>
<controls:Card x:Name="PART_Card">
<controls:Card.Styles>
<Style Selector="ui|NumberBox">
<Setter Property="Margin" Value="12,0,0,0" />
<Setter Property="MinWidth" Value="70" />
<Setter Property="HorizontalAlignment" Value="Stretch" />
<Setter Property="ValidationMode" Value="InvalidInputOverwritten" />
<Setter Property="SmallChange" Value="32" />
<Setter Property="LargeChange" Value="128" />
<Setter Property="SpinButtonPlacementMode" Value="Inline" />
</Style>
</controls:Card.Styles>
<StackPanel Spacing="8">
<!-- Tile Size -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Tile Size" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding TileSize, Mode=TwoWay}" />
</Grid>

<!-- Overlap -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Overlap" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding Overlap, Mode=TwoWay}" />
</Grid>

<!-- Temporal Size (for Video VAEs) -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Size" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding TemporalSize, Mode=TwoWay}" />
</Grid>

<!-- Temporal Overlap (for Video VAEs) -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Overlap" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding TemporalOverlap, Mode=TwoWay}"
SmallChange="4"
LargeChange="16" />
</Grid>
</StackPanel>
</controls:Card>
</ControlTemplate>
</Setter>
</Style>
</Styles>
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Avalonia;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using Injectio.Attributes;

namespace StabilityMatrix.Avalonia.Controls;

[RegisterTransient<TiledVAECard>]
public class TiledVAECard : TemplatedControlBase { }
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
[JsonDerivedType(typeof(PlasmaNoiseCardViewModel), PlasmaNoiseCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(NrsCardViewModel), NrsCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(CfzCudnnToggleCardViewModel), CfzCudnnToggleCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(TiledVAECardViewModel), TiledVAECardViewModel.ModuleKey)]
[JsonDerivedType(typeof(FreeUModule))]
[JsonDerivedType(typeof(HiresFixModule))]
[JsonDerivedType(typeof(FluxHiresFixModule))]
Expand All @@ -43,6 +44,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
[JsonDerivedType(typeof(PlasmaNoiseModule))]
[JsonDerivedType(typeof(NRSModule))]
[JsonDerivedType(typeof(CfzCudnnToggleModule))]
[JsonDerivedType(typeof(TiledVAEModule))]
public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Injectio.Attributes;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;

namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;

[ManagedService]
[RegisterTransient<TiledVAEModule>]
public class TiledVAEModule : ModuleBase
{
public TiledVAEModule(IServiceManager<ViewModelBase> vmFactory)
: base(vmFactory)
{
Title = "Tiled VAE Decode";
AddCards(vmFactory.Get<TiledVAECardViewModel>());
}

protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var card = GetCard<TiledVAECardViewModel>();

// Register a pre-output action that replaces standard VAE decode with tiled decode
e.PreOutputActions.Add(args =>
{
var builder = args.Builder;

// Only apply if primary is in latent space
if (builder.Connections.Primary?.IsT0 != true)
return;

var latent = builder.Connections.Primary.AsT0;
var vae = builder.Connections.GetDefaultVAE();

// Use tiled VAE decode instead of standard decode
var tiledDecode = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.TiledVAEDecode
{
Name = builder.Nodes.GetUniqueName("TiledVAEDecode"),
Samples = latent,
Vae = vae,
TileSize = card.TileSize,
Overlap = card.Overlap,
TemporalSize = card.TemporalSize,
TemporalOverlap = card.TemporalOverlap
}
);

// Update primary connection to the decoded image
builder.Connections.Primary = tiledDecode.Output;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ TabContext tabContext
typeof(RescaleCfgModule),
typeof(PlasmaNoiseModule),
typeof(NRSModule),
typeof(TiledVAEModule),
];
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using System.ComponentModel.DataAnnotations;
using CommunityToolkit.Mvvm.ComponentModel;
using Injectio.Attributes;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;

namespace StabilityMatrix.Avalonia.ViewModels.Inference;

[View(typeof(TiledVAECard))]
[ManagedService]
[RegisterTransient<TiledVAECardViewModel>]
public partial class TiledVAECardViewModel : LoadableViewModelBase
{
public const string ModuleKey = "TiledVAE";

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(64, 4096)]
private int tileSize = 512;

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(0, 4096)]
private int overlap = 64;

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(8, 4096)]
private int temporalSize = 64;

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(4, 4096)]
private int temporalOverlap = 8;
}
19 changes: 19 additions & 0 deletions StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ public record VAEDecode : ComfyTypedNodeBase<ImageNodeConnection>
public required VAENodeConnection Vae { get; init; }
}

[TypedNodeOptions(Name = "VAEDecodeTiled")]
public record TiledVAEDecode : ComfyTypedNodeBase<ImageNodeConnection>
{
public required LatentNodeConnection Samples { get; init; }
public required VAENodeConnection Vae { get; init; }

[Range(64, 4096)]
public int TileSize { get; init; } = 512;

[Range(0, 4096)]
public int Overlap { get; init; } = 64;

[Range(8, 4096)]
public int TemporalSize { get; init; } = 64;

[Range(4, 4096)]
public int TemporalOverlap { get; init; } = 8;
}

public record KSampler : ComfyTypedNodeBase<LatentNodeConnection>
{
public required ModelNodeConnection Model { get; init; }
Expand Down
Loading