Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions StabilityMatrix.Avalonia/Controls/Inference/TiledVAECard.axaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
xmlns:ui="clr-namespace:FluentAvalonia.UI.Controls;assembly=FluentAvalonia"
xmlns:vmInference="clr-namespace:StabilityMatrix.Avalonia.ViewModels.Inference"
x:DataType="vmInference:TiledVAECardViewModel">

<Design.PreviewWith>
<controls:TiledVAECard />
</Design.PreviewWith>

<Style Selector="controls|TiledVAECard">
<!-- Set Defaults -->
<Setter Property="Template">
<ControlTemplate>
<controls:Card x:Name="PART_Card">

<controls:Card.Styles>
<Style Selector="ui|NumberBox">
<Setter Property="Margin" Value="12,0,0,0" />
Expand All @@ -25,8 +26,10 @@
<Setter Property="SpinButtonPlacementMode" Value="Inline" />
</Style>
</controls:Card.Styles>

<StackPanel Spacing="8">
<!-- Tile Size -->

<!-- Tile Size -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
Expand All @@ -37,7 +40,7 @@
Value="{Binding TileSize, Mode=TwoWay}" />
</Grid>

<!-- Overlap -->
<!-- Overlap -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
Expand All @@ -48,29 +51,40 @@
Value="{Binding Overlap, Mode=TwoWay}" />
</Grid>

<!-- Temporal Size (for Video VAEs) -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Size" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding TemporalSize, Mode=TwoWay}" />
</Grid>
<!-- Enable Temporal Tiling -->
<ToggleSwitch
Header="Enable Temporal Tiling"
IsChecked="{Binding UseCustomTemporalTiling}" />

<!-- Temporal Controls (Visible only when enabled) -->
<StackPanel IsVisible="{Binding UseCustomTemporalTiling}" Spacing="8">

<!-- Temporal Size -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Size" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding TemporalSize, Mode=TwoWay}" />
</Grid>

<!-- Temporal Overlap -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Overlap" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding TemporalOverlap, Mode=TwoWay}"
SmallChange="4"
LargeChange="16" />
</Grid>

</StackPanel>

<!-- Temporal Overlap (for Video VAEs) -->
<Grid ColumnDefinitions="Auto,*">
<TextBlock
Grid.Column="0"
VerticalAlignment="Center"
Text="Temporal Overlap" />
<ui:NumberBox
Grid.Column="1"
Value="{Binding TemporalOverlap, Mode=TwoWay}"
SmallChange="4"
LargeChange="16" />
</Grid>
</StackPanel>
</controls:Card>
</ControlTemplate>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ RunningPackageService runningPackageService

BatchSizeCardViewModel = vmFactory.Get<BatchSizeCardViewModel>();

VideoOutputSettingsCardViewModel = vmFactory.Get<VideoOutputSettingsCardViewModel>(
vm => vm.Fps = 16.0d
VideoOutputSettingsCardViewModel = vmFactory.Get<VideoOutputSettingsCardViewModel>(vm =>
vm.Fps = 16.0d
);

StackCardViewModel = vmFactory.Get<StackCardViewModel>();
Expand All @@ -94,7 +94,7 @@ protected override void BuildPrompt(BuildPromptEventArgs args)
builder.Connections.Seed = args.SeedOverride switch
{
{ } seed => Convert.ToUInt64(seed),
_ => Convert.ToUInt64(SeedCardViewModel.Seed)
_ => Convert.ToUInt64(SeedCardViewModel.Seed),
};

// Load models
Expand All @@ -115,7 +115,6 @@ protected override void BuildPrompt(BuildPromptEventArgs args)

SamplerCardViewModel.ApplyStep(args);

// Animated webp output
VideoOutputSettingsCardViewModel.ApplyStep(args);
}

Expand Down Expand Up @@ -165,13 +164,13 @@ CancellationToken cancellationToken
OutputNodeNames = buildPromptArgs.Builder.Connections.OutputNodeNames.ToArray(),
Parameters = SaveStateToParameters(new GenerationParameters()) with
{
Seed = Convert.ToUInt64(seed)
Seed = Convert.ToUInt64(seed),
},
Project = inferenceProject,
FilesToTransfer = buildPromptArgs.FilesToTransfer,
BatchIndex = i,
// Only clear output images on the first batch
ClearOutputImages = i == 0
ClearOutputImages = i == 0,
};

batchArgs.Add(generationArgs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using StabilityMatrix.Avalonia.ViewModels.Base;
using StabilityMatrix.Core.Attributes;
using StabilityMatrix.Core.Models.Api.Comfy.Nodes;
using NLog;


namespace StabilityMatrix.Avalonia.ViewModels.Inference.Modules;

Expand All @@ -14,6 +16,7 @@ public class TiledVAEModule : ModuleBase
public TiledVAEModule(IServiceManager<ViewModelBase> vmFactory)
: base(vmFactory)
{
this.logger = logger;
Title = "Tiled VAE Decode";
AddCards(vmFactory.Get<TiledVAECardViewModel>());
}
Expand All @@ -22,7 +25,6 @@ protected override void OnApplyStep(ModuleApplyStepEventArgs e)
{
var card = GetCard<TiledVAECardViewModel>();

// Register a pre-output action that replaces standard VAE decode with tiled decode
e.PreOutputActions.Add(args =>
{
var builder = args.Builder;
Expand All @@ -34,22 +36,29 @@ protected override void OnApplyStep(ModuleApplyStepEventArgs e)
var latent = builder.Connections.Primary.AsT0;
var vae = builder.Connections.GetDefaultVAE();

// Use tiled VAE decode instead of standard decode
var tiledDecode = builder.Nodes.AddTypedNode(
logger.LogDebug("TiledVAE: Injecting TiledVAEDecode");
logger.LogDebug(
"UseCustomTemporalTiling value at runtime: {value}",
card.UseCustomTemporalTiling
);
Comment on lines +39 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger still does not appear to be injected into this class - this will not compile


var node = builder.Nodes.AddTypedNode(
new ComfyNodeBuilder.TiledVAEDecode
{
Name = builder.Nodes.GetUniqueName("TiledVAEDecode"),
Samples = latent,
Vae = vae,
TileSize = card.TileSize,
Overlap = card.Overlap,
TemporalSize = card.TemporalSize,
TemporalOverlap = card.TemporalOverlap

// Temporal tiling (WAN requires temporal tiling)
TemporalSize = card.UseCustomTemporalTiling ? card.TemporalSize : 64,
TemporalOverlap = card.UseCustomTemporalTiling ? card.TemporalOverlap : 8,
}
);

// Update primary connection to the decoded image
builder.Connections.Primary = tiledDecode.Output;
builder.Connections.Primary = node.Output;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,32 @@ public partial class TiledVAECardViewModel : LoadableViewModelBase
{
public const string ModuleKey = "TiledVAE";

// Spatial tile size (valid for Wan)
[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(64, 4096)]
private int tileSize = 512;

// Spatial overlap
[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(0, 4096)]
private int overlap = 64;

// Toggle: Use custom temporal tiling settings
[ObservableProperty]
private bool useCustomTemporalTiling = false;

// Temporal tile size (must be >= 8)
[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
[Range(8, 4096)]
private int temporalSize = 64;

// Temporal overlap (must be >= 4)
[ObservableProperty]
[NotifyDataErrorInfo]
[Required]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ public partial class InferenceSettingsViewModel : PageViewModelBase
[ObservableProperty]
private bool filterExtraNetworksByBaseModel;

[ObservableProperty]
private bool enableTiledVae;


private List<string> ignoredFileNameFormatVars =
[
"author",
Expand Down Expand Up @@ -191,6 +195,12 @@ ISettingsManager settingsManager
settings => settings.InferenceDimensionStepChange,
true
);
settingsManager.RelayPropertyFor(
this,
vm => vm.EnableTiledVae,
settings => settings.EnableTiledVae,
true
);

FavoriteDimensions
.ToObservableChangeSet()
Expand Down
32 changes: 16 additions & 16 deletions StabilityMatrix.Core/Inference/ComfyClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,15 @@ public class ComfyClient : InferenceClientBase
private readonly IComfyApi comfyApi;
private bool isDisposed;

private readonly JsonSerializerOptions jsonSerializerOptions =
new()
private readonly JsonSerializerOptions jsonSerializerOptions = new()
{
PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower,
Converters =
{
PropertyNamingPolicy = JsonNamingPolicies.SnakeCaseLower,
Converters =
{
new NodeConnectionBaseJsonConverter(),
new OneOfJsonConverter<string, StringNodeConnection>()
}
};
new NodeConnectionBaseJsonConverter(),
new OneOfJsonConverter<string, StringNodeConnection>(),
},
};

// ReSharper disable once MemberCanBePrivate.Global
public string ClientId { get; } = Guid.NewGuid().ToString();
Expand Down Expand Up @@ -111,20 +110,20 @@ public ComfyClient(IApiFactory apiFactory, Uri baseAddress)
{
Scheme = "ws",
Path = "/ws",
Query = $"clientId={ClientId}"
Query = $"clientId={ClientId}",
}.Uri;

webSocketClient = new WebsocketClient(wsUri)
{
Name = nameof(ComfyClient),
ReconnectTimeout = TimeSpan.FromSeconds(30)
ReconnectTimeout = TimeSpan.FromSeconds(30),
};

webSocketClient.DisconnectionHappened.Subscribe(
info => Logger.Info("Websocket Disconnected, ({Type})", info.Type)
webSocketClient.DisconnectionHappened.Subscribe(info =>
Logger.Info("Websocket Disconnected, ({Type})", info.Type)
);
webSocketClient.ReconnectionHappened.Subscribe(
info => Logger.Info("Websocket Reconnected, ({Type})", info.Type)
webSocketClient.ReconnectionHappened.Subscribe(info =>
Logger.Info("Websocket Reconnected, ({Type})", info.Type)
);

webSocketClient.MessageReceived.Subscribe(OnMessageReceived);
Expand Down Expand Up @@ -287,7 +286,7 @@ private void HandleBinaryMessage(byte[] data)
Array.Reverse(typeBytes);
}*/

PreviewImageReceived?.Invoke(this, new ComfyWebSocketImageData { ImageBytes = data[8..], });
PreviewImageReceived?.Invoke(this, new ComfyWebSocketImageData { ImageBytes = data[8..] });
}

public override async Task ConnectAsync(CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -332,6 +331,7 @@ public async Task<ComfyTask> QueuePromptAsync(
)
{
var request = new ComfyPromptRequest { ClientId = ClientId, Prompt = nodes };

var result = await comfyApi.PostPrompt(request, cancellationToken).ConfigureAwait(false);

// Add task to dictionary and set it as the current task
Expand Down
Loading