Skip to content

Commit 36fb141

Browse files
committed
Add Feature - Tiled VAE module for Inference
- Add TiledVAEModule with support for VAEDecodeTiled Comfy node - Add TiledVAECardViewModel and TiledVAECard UI with configurable parameters - Register TiledVAEModule as available module in SamplerCardViewModel - Add JsonDerivedType attributes for serialization support
1 parent 7984ccc commit 36fb141

File tree

8 files changed

+208
-0
lines changed

8 files changed

+208
-0
lines changed

StabilityMatrix.Avalonia/App.axaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
<StyleInclude Source="Controls/Inference/UnetModelCard.axaml" />
9696
<StyleInclude Source="Controls/Inference/DiscreteModelSamplingCard.axaml" />
9797
<StyleInclude Source="Controls/Inference/RescaleCfgCard.axaml" />
98+
<StyleInclude Source="Controls/Inference/TiledVAECard.axaml" />
9899
<StyleInclude Source="Controls/Painting/PaintCanvas.axaml" />
99100
<StyleInclude Source="Controls/MarkdownViewer.axaml" />
100101
<StyleInclude Source="Controls/Inference/WanModelCard.axaml" />
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
<Styles
2+
xmlns="https://github.com/avaloniaui"
3+
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
4+
xmlns:controls="using:StabilityMatrix.Avalonia.Controls"
5+
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
6+
xmlns:vmInference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
7+
x:DataType="vmInference:TiledVAECardViewModel">
8+
<Design.PreviewWith>
9+
<controls:TiledVAECard />
10+
</Design.PreviewWith>
11+
12+
<Style Selector="controls|TiledVAECard">
13+
<!-- Set Defaults -->
14+
<Setter Property="Template">
15+
<ControlTemplate>
16+
<controls:Card x:Name="PART_Card">
17+
<controls:Card.Styles>
18+
<Style Selector="ui|NumberBox">
19+
<Setter Property="Margin" Value="12,4,0,4" />
20+
<Setter Property="MinWidth" Value="70" />
21+
<Setter Property="HorizontalAlignment" Value="Stretch" />
22+
<Setter Property="ValidationMode" Value="InvalidInputOverwritten" />
23+
<Setter Property="SmallChange" Value="32" />
24+
<Setter Property="LargeChange" Value="128" />
25+
<Setter Property="SpinButtonPlacementMode" Value="Inline" />
26+
</Style>
27+
</controls:Card.Styles>
28+
<StackPanel Spacing="8">
29+
<!-- Tile Size -->
30+
<Grid ColumnDefinitions="Auto,*">
31+
<TextBlock
32+
Grid.Column="0"
33+
VerticalAlignment="Center"
34+
Text="Tile Size" />
35+
<ui:NumberBox
36+
Grid.Column="1"
37+
Margin="12,0,0,0"
38+
Value="{Binding TileSize, Mode=TwoWay}" />
39+
</Grid>
40+
41+
<!-- Overlap -->
42+
<Grid ColumnDefinitions="Auto,*">
43+
<TextBlock
44+
Grid.Column="0"
45+
VerticalAlignment="Center"
46+
Text="Overlap" />
47+
<ui:NumberBox
48+
Grid.Column="1"
49+
Margin="12,0,0,0"
50+
Value="{Binding Overlap, Mode=TwoWay}" />
51+
</Grid>
52+
53+
<!-- Temporal Size (for Video VAEs) -->
54+
<Grid ColumnDefinitions="Auto,*">
55+
<TextBlock
56+
Grid.Column="0"
57+
VerticalAlignment="Center"
58+
Text="Temporal Size" />
59+
<ui:NumberBox
60+
Grid.Column="1"
61+
Margin="12,0,0,0"
62+
Value="{Binding TemporalSize, Mode=TwoWay}" />
63+
</Grid>
64+
65+
<!-- Temporal Overlap (for Video VAEs) -->
66+
<Grid ColumnDefinitions="Auto,*">
67+
<TextBlock
68+
Grid.Column="0"
69+
VerticalAlignment="Center"
70+
Text="Temporal Overlap" />
71+
<ui:NumberBox
72+
Grid.Column="1"
73+
Margin="12,0,0,0"
74+
Value="{Binding TemporalOverlap, Mode=TwoWay}" />
75+
</Grid>
76+
</StackPanel>
77+
</controls:Card>
78+
</ControlTemplate>
79+
</Setter>
80+
</Style>
81+
</Styles>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using Avalonia;
2+
using Avalonia.Controls;
3+
using Avalonia.Controls.Primitives;
4+
using Injectio.Attributes;
5+
6+
namespace StabilityMatrix.Avalonia.Controls;
7+
8+
[RegisterTransient<TiledVAECard>]
9+
public class TiledVAECard : TemplatedControlBase { }

StabilityMatrix.Avalonia/ViewModels/Base/LoadableViewModelBase.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
2727
[JsonDerivedType(typeof(PlasmaNoiseCardViewModel), PlasmaNoiseCardViewModel.ModuleKey)]
2828
[JsonDerivedType(typeof(NrsCardViewModel), NrsCardViewModel.ModuleKey)]
2929
[JsonDerivedType(typeof(CfzCudnnToggleCardViewModel), CfzCudnnToggleCardViewModel.ModuleKey)]
30+
[JsonDerivedType(typeof(TiledVAECardViewModel), TiledVAECardViewModel.ModuleKey)]
3031
[JsonDerivedType(typeof(FreeUModule))]
3132
[JsonDerivedType(typeof(HiresFixModule))]
3233
[JsonDerivedType(typeof(FluxHiresFixModule))]
@@ -43,6 +44,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
4344
[JsonDerivedType(typeof(PlasmaNoiseModule))]
4445
[JsonDerivedType(typeof(NRSModule))]
4546
[JsonDerivedType(typeof(CfzCudnnToggleModule))]
47+
[JsonDerivedType(typeof(TiledVAEModule))]
4648
public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
4749
{
4850
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
using Injectio.Attributes;
2+
using StabilityMatrix.Avalonia.Models.Inference;
3+
using StabilityMatrix.Avalonia.Services;
4+
using StabilityMatrix.Avalonia.ViewModels.Base;
5+
using StabilityMatrix.Core.Attributes;
6+
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
7+
8+
namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;
9+
10+
[ManagedService]
11+
[RegisterTransient<TiledVAEModule>]
12+
public class TiledVAEModule : ModuleBase
13+
{
14+
public TiledVAEModule(IServiceManager<ViewModelBase> vmFactory)
15+
: base(vmFactory)
16+
{
17+
Title = "Tiled VAE";
18+
AddCards(vmFactory.Get<TiledVAECardViewModel>());
19+
}
20+
21+
protected override void OnApplyStep(ModuleApplyStepEventArgs e)
22+
{
23+
var card = GetCard<TiledVAECardViewModel>();
24+
25+
// Register a pre-output action that replaces standard VAE decode with tiled decode
26+
e.PreOutputActions.Add(args =>
27+
{
28+
var builder = args.Builder;
29+
30+
// Only apply if primary is in latent space
31+
if (builder.Connections.Primary?.IsT0 != true)
32+
return;
33+
34+
var latent = builder.Connections.Primary.AsT0;
35+
var vae = builder.Connections.GetDefaultVAE();
36+
37+
// Use tiled VAE decode instead of standard decode
38+
var tiledDecode = builder.Nodes.AddTypedNode(
39+
new ComfyNodeBuilder.TiledVAEDecode
40+
{
41+
Name = builder.Nodes.GetUniqueName("TiledVAEDecode"),
42+
Samples = latent,
43+
Vae = vae,
44+
TileSize = card.TileSize,
45+
Overlap = card.Overlap,
46+
TemporalSize = card.TemporalSize,
47+
TemporalOverlap = card.TemporalOverlap
48+
}
49+
);
50+
51+
// Update primary connection to the decoded image
52+
builder.Connections.Primary = tiledDecode.Output;
53+
});
54+
}
55+
}

StabilityMatrix.Avalonia/ViewModels/Inference/SamplerCardViewModel.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ TabContext tabContext
155155
typeof(RescaleCfgModule),
156156
typeof(PlasmaNoiseModule),
157157
typeof(NRSModule),
158+
typeof(TiledVAEModule),
158159
];
159160
});
160161
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System.ComponentModel.DataAnnotations;
2+
using CommunityToolkit.Mvvm.ComponentModel;
3+
using Injectio.Attributes;
4+
using StabilityMatrix.Avalonia.Controls;
5+
using StabilityMatrix.Avalonia.ViewModels.Base;
6+
using StabilityMatrix.Core.Attributes;
7+
8+
namespace StabilityMatrix.Avalonia.ViewModels.Inference;
9+
10+
[View(typeof(TiledVAECard))]
11+
[ManagedService]
12+
[RegisterTransient<TiledVAECardViewModel>]
13+
public partial class TiledVAECardViewModel : LoadableViewModelBase
14+
{
15+
public const string ModuleKey = "TiledVAE";
16+
17+
[ObservableProperty]
18+
[NotifyDataErrorInfo]
19+
[Required]
20+
[Range(64, 4096)]
21+
private int tileSize = 512;
22+
23+
[ObservableProperty]
24+
[NotifyDataErrorInfo]
25+
[Required]
26+
[Range(0, 4096)]
27+
private int overlap = 64;
28+
29+
[ObservableProperty]
30+
[NotifyDataErrorInfo]
31+
[Required]
32+
[Range(8, 4096)]
33+
private int temporalSize = 64;
34+
35+
[ObservableProperty]
36+
[NotifyDataErrorInfo]
37+
[Required]
38+
[Range(4, 4096)]
39+
private int temporalOverlap = 8;
40+
}

StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,25 @@ public record VAEDecode : ComfyTypedNodeBase<ImageNodeConnection>
6262
public required VAENodeConnection Vae { get; init; }
6363
}
6464

65+
[TypedNodeOptions(Name = "VAEDecodeTiled")]
66+
public record TiledVAEDecode : ComfyTypedNodeBase<ImageNodeConnection>
67+
{
68+
public required LatentNodeConnection Samples { get; init; }
69+
public required VAENodeConnection Vae { get; init; }
70+
71+
[Range(64, 4096)]
72+
public int TileSize { get; init; } = 512;
73+
74+
[Range(0, 4096)]
75+
public int Overlap { get; init; } = 64;
76+
77+
[Range(8, 4096)]
78+
public int TemporalSize { get; init; } = 64;
79+
80+
[Range(4, 4096)]
81+
public int TemporalOverlap { get; init; } = 8;
82+
}
83+
6584
public record KSampler : ComfyTypedNodeBase<LatentNodeConnection>
6685
{
6786
public required ModelNodeConnection Model { get; init; }

0 commit comments

Comments
 (0)