Skip to content

Commit 5ffef17

Browse files
committed
Add InferenceClientManager.LoraModelsChangeSet
1 parent 65f5398 commit 5ffef17

File tree

3 files changed

+62
-38
lines changed

3 files changed

+62
-38
lines changed

StabilityMatrix.Avalonia/DesignData/MockInferenceClientManager.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ public partial class MockInferenceClientManager : ObservableObject, IInferenceCl
3030
public IObservableCollection<HybridModelFile> LoraModels { get; } =
3131
new ObservableCollectionExtended<HybridModelFile>();
3232

33+
public IObservable<IChangeSet<HybridModelFile, string>> LoraModelsChangeSet { get; }
34+
3335
public IObservableCollection<HybridModelFile> PromptExpansionModels { get; } =
3436
new ObservableCollectionExtended<HybridModelFile>();
3537

StabilityMatrix.Avalonia/Services/IInferenceClientManager.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Diagnostics.CodeAnalysis;
44
using System.Threading;
55
using System.Threading.Tasks;
6+
using DynamicData;
67
using DynamicData.Binding;
78
using StabilityMatrix.Avalonia.Models;
89
using StabilityMatrix.Core.Inference;
@@ -51,6 +52,7 @@ public interface IInferenceClientManager : IDisposable, INotifyPropertyChanged,
5152
IObservableCollection<HybridModelFile> UnetModels { get; }
5253
IObservableCollection<HybridModelFile> ClipModels { get; }
5354
IObservableCollection<HybridModelFile> ClipVisionModels { get; }
55+
IObservable<IChangeSet<HybridModelFile, string>> LoraModelsChangeSet { get; }
5456

5557
Task CopyImageToInputAsync(FilePath imageFile, CancellationToken cancellationToken = default);
5658

StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,25 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
7373

7474
private readonly SourceCache<HybridModelFile, string> controlNetModelsSource = new(p => p.GetId());
7575

76-
private readonly SourceCache<HybridModelFile, string> downloadableControlNetModelsSource =
77-
new(p => p.GetId());
76+
private readonly SourceCache<HybridModelFile, string> downloadableControlNetModelsSource = new(p =>
77+
p.GetId()
78+
);
7879

7980
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
8081
new ObservableCollectionExtended<HybridModelFile>();
8182

82-
public readonly SourceCache<HybridModelFile, string> LoraModelsSource = new(p => p.GetId());
83+
private readonly SourceCache<HybridModelFile, string> loraModelsSource = new(p => p.GetId());
84+
85+
public IObservable<IChangeSet<HybridModelFile, string>> LoraModelsChangeSet { get; }
8386

8487
public IObservableCollection<HybridModelFile> LoraModels { get; } =
8588
new ObservableCollectionExtended<HybridModelFile>();
8689

8790
private readonly SourceCache<HybridModelFile, string> promptExpansionModelsSource = new(p => p.GetId());
8891

89-
private readonly SourceCache<HybridModelFile, string> downloadablePromptExpansionModelsSource =
90-
new(p => p.GetId());
92+
private readonly SourceCache<HybridModelFile, string> downloadablePromptExpansionModelsSource = new(p =>
93+
p.GetId()
94+
);
9195

9296
public IObservableCollection<HybridModelFile> PromptExpansionModels { get; } =
9397
new ObservableCollectionExtended<HybridModelFile>();
@@ -121,8 +125,9 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
121125

122126
private readonly SourceCache<HybridModelFile, string> ultralyticsModelsSource = new(p => p.GetId());
123127

124-
private readonly SourceCache<HybridModelFile, string> downloadableUltralyticsModelsSource =
125-
new(p => p.GetId());
128+
private readonly SourceCache<HybridModelFile, string> downloadableUltralyticsModelsSource = new(p =>
129+
p.GetId()
130+
);
126131

127132
public IObservableCollection<HybridModelFile> SamModels { get; } =
128133
new ObservableCollectionExtended<HybridModelFile>();
@@ -143,8 +148,9 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
143148
new ObservableCollectionExtended<HybridModelFile>();
144149

145150
private readonly SourceCache<HybridModelFile, string> clipVisionModelsSource = new(p => p.GetId());
146-
private readonly SourceCache<HybridModelFile, string> downloadableClipVisionModelsSource =
147-
new(p => p.GetId());
151+
private readonly SourceCache<HybridModelFile, string> downloadableClipVisionModelsSource = new(p =>
152+
p.GetId()
153+
);
148154

149155
public IObservableCollection<HybridModelFile> ClipVisionModels { get; } =
150156
new ObservableCollectionExtended<HybridModelFile>();
@@ -188,7 +194,23 @@ ICompletionProvider completionProvider
188194
.ObserveOn(SynchronizationContext.Current)
189195
.Subscribe();
190196

191-
LoraModelsSource
197+
LoraModelsChangeSet = loraModelsSource
198+
.Connect()
199+
.DeferUntilLoaded()
200+
// Adding .RefCount() if multiple consumers might subscribe to this
201+
// LoraModelsChangeSet property. It keeps the upstream connection active as long
202+
// as there's at least one subscriber. This is usually a good idea when exposing streams.
203+
.RefCount();
204+
205+
LoraModelsChangeSet
206+
.SortAndBind(
207+
LoraModels,
208+
SortExpressionComparer<HybridModelFile>.Ascending(f => f.Type).ThenByAscending(f => f.SortKey)
209+
)
210+
.ObserveOn(SynchronizationContext.Current)
211+
.Subscribe();
212+
213+
loraModelsSource
192214
.Connect()
193215
.DeferUntilLoaded()
194216
.SortAndBind(
@@ -331,8 +353,8 @@ ICompletionProvider completionProvider
331353
if (IsConnected)
332354
{
333355
LoadSharedPropertiesAsync()
334-
.SafeFireAndForget(
335-
onException: ex => logger.LogError(ex, "Error loading shared properties")
356+
.SafeFireAndForget(onException: ex =>
357+
logger.LogError(ex, "Error loading shared properties")
336358
);
337359
}
338360
};
@@ -370,7 +392,7 @@ await Client.GetNodeOptionNamesAsync("ControlNetLoader", "control_net_name") is
370392
// Get Lora model names
371393
if (await Client.GetNodeOptionNamesAsync("LoraLoader", "lora_name") is { } loraModelNames)
372394
{
373-
LoraModelsSource.EditDiff(
395+
loraModelsSource.EditDiff(
374396
loraModelNames.Select(HybridModelFile.FromRemote),
375397
HybridModelFile.Comparer
376398
);
@@ -385,7 +407,7 @@ await Client.GetOptionalNodeOptionNamesAsync("UltralyticsDetectorProvider", "mod
385407
IEnumerable<HybridModelFile> models =
386408
[
387409
HybridModelFile.None,
388-
..ultralyticsModelNames.Select(HybridModelFile.FromRemote)
410+
.. ultralyticsModelNames.Select(HybridModelFile.FromRemote),
389411
];
390412
ultralyticsModelsSource.EditDiff(models, HybridModelFile.Comparer);
391413
}
@@ -396,7 +418,7 @@ await Client.GetOptionalNodeOptionNamesAsync("UltralyticsDetectorProvider", "mod
396418
IEnumerable<HybridModelFile> models =
397419
[
398420
HybridModelFile.None,
399-
..samModelNames.Select(HybridModelFile.FromRemote)
421+
.. samModelNames.Select(HybridModelFile.FromRemote),
400422
];
401423
samModelsSource.EditDiff(models, HybridModelFile.Comparer);
402424
}
@@ -484,7 +506,7 @@ await Client.GetRequiredNodeOptionNamesFromOptionalNodeAsync("UnetLoaderGGUF", "
484506
IEnumerable<HybridModelFile> models =
485507
[
486508
HybridModelFile.None,
487-
..clipModelNames.Select(HybridModelFile.FromRemote)
509+
.. clipModelNames.Select(HybridModelFile.FromRemote),
488510
];
489511
clipModelsSource.EditDiff(models, HybridModelFile.Comparer);
490512
}
@@ -495,7 +517,7 @@ await Client.GetRequiredNodeOptionNamesFromOptionalNodeAsync("UnetLoaderGGUF", "
495517
IEnumerable<HybridModelFile> models =
496518
[
497519
HybridModelFile.None,
498-
..clipVisionModelNames.Select(HybridModelFile.FromRemote)
520+
.. clipVisionModelNames.Select(HybridModelFile.FromRemote),
499521
];
500522
clipVisionModelsSource.EditDiff(models, HybridModelFile.Comparer);
501523
}
@@ -521,13 +543,13 @@ private void ResetSharedProperties()
521543
);
522544

523545
// Downloadable ControlNet models
524-
var downloadableControlNets = RemoteModels.ControlNetModels.Where(
525-
u => !controlNetModelsSource.Lookup(u.GetId()).HasValue
546+
var downloadableControlNets = RemoteModels.ControlNetModels.Where(u =>
547+
!controlNetModelsSource.Lookup(u.GetId()).HasValue
526548
);
527549
downloadableControlNetModelsSource.EditDiff(downloadableControlNets, HybridModelFile.Comparer);
528550

529551
// Load local Lora / LyCORIS models
530-
LoraModelsSource.EditDiff(
552+
loraModelsSource.EditDiff(
531553
modelIndexService
532554
.FindByModelType(SharedFolderType.Lora | SharedFolderType.LyCORIS)
533555
.Select(HybridModelFile.FromLocal),
@@ -544,8 +566,8 @@ private void ResetSharedProperties()
544566

545567
// Downloadable PromptExpansion models
546568
downloadablePromptExpansionModelsSource.EditDiff(
547-
RemoteModels.PromptExpansionModels.Where(
548-
u => !promptExpansionModelsSource.Lookup(u.GetId()).HasValue
569+
RemoteModels.PromptExpansionModels.Where(u =>
570+
!promptExpansionModelsSource.Lookup(u.GetId()).HasValue
549571
),
550572
HybridModelFile.Comparer
551573
);
@@ -560,29 +582,27 @@ private void ResetSharedProperties()
560582
IEnumerable<HybridModelFile> ultralyticsModels =
561583
[
562584
HybridModelFile.None,
563-
..modelIndexService
564-
.FindByModelType(SharedFolderType.Ultralytics)
565-
.Select(HybridModelFile.FromLocal)
585+
.. modelIndexService
586+
.FindByModelType(SharedFolderType.Ultralytics)
587+
.Select(HybridModelFile.FromLocal),
566588
];
567589
ultralyticsModelsSource.EditDiff(ultralyticsModels, HybridModelFile.Comparer);
568590

569-
var downloadableUltralyticsModels = RemoteModels.UltralyticsModelFiles.Where(
570-
u => !ultralyticsModelsSource.Lookup(u.GetId()).HasValue
591+
var downloadableUltralyticsModels = RemoteModels.UltralyticsModelFiles.Where(u =>
592+
!ultralyticsModelsSource.Lookup(u.GetId()).HasValue
571593
);
572594
downloadableUltralyticsModelsSource.EditDiff(downloadableUltralyticsModels, HybridModelFile.Comparer);
573595

574596
// Load SAM models
575597
IEnumerable<HybridModelFile> samModels =
576598
[
577599
HybridModelFile.None,
578-
..modelIndexService
579-
.FindByModelType(SharedFolderType.Sams)
580-
.Select(HybridModelFile.FromLocal)
600+
.. modelIndexService.FindByModelType(SharedFolderType.Sams).Select(HybridModelFile.FromLocal),
581601
];
582602
samModelsSource.EditDiff(samModels, HybridModelFile.Comparer);
583603

584-
var downloadableSamModels = RemoteModels.SamModelFiles.Where(
585-
u => !samModelsSource.Lookup(u.GetId()).HasValue
604+
var downloadableSamModels = RemoteModels.SamModelFiles.Where(u =>
605+
!samModelsSource.Lookup(u.GetId()).HasValue
586606
);
587607
downloadableSamModelsSource.EditDiff(downloadableSamModels, HybridModelFile.Comparer);
588608

@@ -600,8 +620,8 @@ private void ResetSharedProperties()
600620
HybridModelFile.Comparer
601621
);
602622

603-
var downloadableClipModels = RemoteModels.ClipModelFiles.Where(
604-
u => !clipModelsSource.Lookup(u.GetId()).HasValue
623+
var downloadableClipModels = RemoteModels.ClipModelFiles.Where(u =>
624+
!clipModelsSource.Lookup(u.GetId()).HasValue
605625
);
606626
downloadableClipModelsSource.EditDiff(downloadableClipModels, HybridModelFile.Comparer);
607627

@@ -610,8 +630,8 @@ private void ResetSharedProperties()
610630
HybridModelFile.Comparer
611631
);
612632

613-
var downloadableClipVisionModels = RemoteModels.ClipVisionModelFiles.Where(
614-
u => !clipVisionModelsSource.Lookup(u.GetId()).HasValue
633+
var downloadableClipVisionModels = RemoteModels.ClipVisionModelFiles.Where(u =>
634+
!clipVisionModelsSource.Lookup(u.GetId()).HasValue
615635
);
616636
downloadableClipVisionModelsSource.EditDiff(downloadableClipVisionModels, HybridModelFile.Comparer);
617637

@@ -632,8 +652,8 @@ private void ResetSharedProperties()
632652
);
633653

634654
// Remote upscalers
635-
var remoteUpscalers = ComfyUpscaler.DefaultDownloadableModels.Where(
636-
u => !modelUpscalersSource.Lookup(u.Name).HasValue
655+
var remoteUpscalers = ComfyUpscaler.DefaultDownloadableModels.Where(u =>
656+
!modelUpscalersSource.Lookup(u.Name).HasValue
637657
);
638658
downloadableUpscalersSource.EditDiff(remoteUpscalers, ComfyUpscaler.Comparer);
639659

0 commit comments

Comments
 (0)