Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions StabilityMatrix.Avalonia/App.axaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
<StyleInclude Source="Controls/Inference/UnetModelCard.axaml" />
<StyleInclude Source="Controls/Inference/DiscreteModelSamplingCard.axaml" />
<StyleInclude Source="Controls/Inference/RescaleCfgCard.axaml" />
<StyleInclude Source="Controls/Inference/TiledVAECard.axaml" />
<StyleInclude Source="Controls/Painting/PaintCanvas.axaml" />
<StyleInclude Source="Controls/MarkdownViewer.axaml" />
<StyleInclude Source="Controls/Inference/WanModelCard.axaml" />
Expand Down
81 changes: 81 additions & 0 deletions StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
<Styles
xmlns="https://github.com/avaloniaui"
xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"
xmlns:controls="using:StabilityMatrix.Avalonia.Controls"
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
xmlns:vmInference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
x:DataType="vmInference:TiledVAECardViewModel">
<Design.PreviewWith>
<controls:TiledVAECard />
</Design.PreviewWith>

<Style Selector="controls|TiledVAECard">
<!-- Set Defaults -->
<Setter Property="Template">
<ControlTemplate>
<controls:Card x:Name="PART_Card">
<controls:Card.Styles>
<Style Selector="ui|NumberBox">
<Setter Property="Margin" Value="12,4,0,4" />
<Setter Property="MinWidth" Value="70" />
<Setter Property="HorizontalAlignment" Value="Stretch" />
<Setter Property="ValidationMode" Value="InvalidInputOverwritten" />
<Setter Property="SmallChange" Value="32" />
<Setter Property="LargeChange" Value="128" />
<Setter Property="SpinButtonPlacementMode" Value="Inline" />
</Style>
</controls:Card.Styles>
<StackPanel Spacing="8">
<!-- Tile Size -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Tile Size" />
<ui:NumberBox
Grid.Column="1"
Margin="12,0,0,0"
Value="{Binding TileSize, Mode=TwoWay}" />
</Grid>

<!-- Overlap -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Overlap" />
<ui:NumberBox
Grid.Column="1"
Margin="12,0,0,0"
Value="{Binding Overlap, Mode=TwoWay}" />
</Grid>

<!-- Temporal Size (for Video VAEs) -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Size" />
<ui:NumberBox
Grid.Column="1"
Margin="12,0,0,0"
Value="{Binding TemporalSize, Mode=TwoWay}" />
</Grid>

<!-- Temporal Overlap (for Video VAEs) -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Overlap" />
<ui:NumberBox
Grid.Column="1"
Margin="12,0,0,0"
Value="{Binding TemporalOverlap, Mode=TwoWay}" />
</Grid>
</StackPanel>
</controls:Card>
</ControlTemplate>
</Setter>
</Style>
</Styles>
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Avalonia;
using Avalonia.Controls;
using Avalonia.Controls.Primitives;
using Injectio.Attributes;

namespace StabilityMatrix.Avalonia.Controls;

[RegisterTransient<TiledVAECard>]
public class TiledVAECard : TemplatedControlBase { }
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
[JsonDerivedType(typeof(PlasmaNoiseCardViewModel), PlasmaNoiseCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(NrsCardViewModel), NrsCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(CfzCudnnToggleCardViewModel), CfzCudnnToggleCardViewModel.ModuleKey)]
[JsonDerivedType(typeof(TiledVAECardViewModel), TiledVAECardViewModel.ModuleKey)]
[JsonDerivedType(typeof(FreeUModule))]
[JsonDerivedType(typeof(HiresFixModule))]
[JsonDerivedType(typeof(FluxHiresFixModule))]
Expand All @@ -43,6 +44,7 @@ namespace StabilityMatrix.Avalonia.ViewModels.Base;
[JsonDerivedType(typeof(PlasmaNoiseModule))]
[JsonDerivedType(typeof(NRSModule))]
[JsonDerivedType(typeof(CfzCudnnToggleModule))]
[JsonDerivedType(typeof(TiledVAEModule))]
public abstract class LoadableViewModelBase : ViewModelBase, IJsonLoadableState
{
private static readonly Logger Logger = LogManager.GetCurrentClassLogger();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Injectio.Attributes;
using StabilityMatrix.Avalonia.Models.Inference;
using StabilityMatrix.Avalonia.Services;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;

namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;

[ManagedService]
[RegisterTransient<TiledVAEModule>]
public class TiledVAEModule : ModuleBase
{
public TiledVAEModule(IServiceManager<ViewModelBase> vmFactory)
: base(vmFactory)
{
Title = "Tiled VAE Decode";
AddCards(vmFactory.Get<TiledVAECardViewModel>());
}

protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var card = GetCard<TiledVAECardViewModel>();

// Register a pre-output action that replaces standard VAE decode with tiled decode
e.PreOutputActions.Add(args =>
{
var builder = args.Builder;

// Only apply if primary is in latent space
if (builder.Connections.Primary?.IsT0 != true)
return;

var latent = builder.Connections.Primary.AsT0;
var vae = builder.Connections.GetDefaultVAE();

// Use tiled VAE decode instead of standard decode
var tiledDecode = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.TiledVAEDecode
{
Name = builder.Nodes.GetUniqueName("TiledVAEDecode"),
Samples = latent,
Vae = vae,
TileSize = card.TileSize,
Overlap = card.Overlap,
TemporalSize = card.TemporalSize,
TemporalOverlap = card.TemporalOverlap
}
);

// Update primary connection to the decoded image
builder.Connections.Primary = tiledDecode.Output;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ TabContext tabContext
typeof(RescaleCfgModule),
typeof(PlasmaNoiseModule),
typeof(NRSModule),
typeof(TiledVAEModule),
];
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using System.ComponentModel.DataAnnotations;
using CommunityToolkit.Mvvm.ComponentModel;
using Injectio.Attributes;
using StabilityMatrix.Avalonia.Controls;
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;

namespace StabilityMatrix.Avalonia.ViewModels.Inference;

[View(typeof(TiledVAECard))]
[ManagedService]
[RegisterTransient<TiledVAECardViewModel>]
public partial class TiledVAECardViewModel : LoadableViewModelBase
{
public const string ModuleKey = "TiledVAE";

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(64, 4096)]
private int tileSize = 512;

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(0, 4096)]
private int overlap = 64;

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(8, 4096)]
private int temporalSize = 64;

[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(4, 4096)]
private int temporalOverlap = 8;
}
19 changes: 19 additions & 0 deletions StabilityMatrix.Core/Models/Api/Comfy/Nodes/ComfyNodeBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ public record VAEDecode : ComfyTypedNodeBase<ImageNodeConnection>
public required VAENodeConnection Vae { get; init; }
}

[TypedNodeOptions(Name = "VAEDecodeTiled")]
public record TiledVAEDecode : ComfyTypedNodeBase<ImageNodeConnection>
{
public required LatentNodeConnection Samples { get; init; }
public required VAENodeConnection Vae { get; init; }

[Range(64, 4096)]
public int TileSize { get; init; } = 512;

[Range(0, 4096)]
public int Overlap { get; init; } = 64;

[Range(8, 4096)]
public int TemporalSize { get; init; } = 64;

[Range(4, 4096)]
public int TemporalOverlap { get; init; } = 8;
}

public record KSampler : ComfyTypedNodeBase<LatentNodeConnection>
{
public required ModelNodeConnection Model { get; init; }
Expand Down
Loading