Skip to content

Commit ba3e715

Browse files
committed
Fix design test error: Make MockInferenceModelManager override instead
1 parent 33295d7 commit ba3e715

File tree

3 files changed

+67
-126
lines changed

3 files changed

+67
-126
lines changed
Lines changed: 35 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,48 @@
1-
using System;
2-
using System.Linq;
3-
using System.Threading;
4-
using System.Threading.Tasks;
5-
using CommunityToolkit.Mvvm.ComponentModel;
6-
using DynamicData;
7-
using DynamicData.Binding;
8-
using StabilityMatrix.Avalonia.Models;
1+
using DynamicData;
2+
using Microsoft.Extensions.Logging;
3+
using StabilityMatrix.Avalonia.Models.TagCompletion;
94
using StabilityMatrix.Avalonia.Services;
10-
using StabilityMatrix.Core.Inference;
5+
using StabilityMatrix.Core.Api;
116
using StabilityMatrix.Core.Models;
12-
using StabilityMatrix.Core.Models.Api.Comfy;
13-
using StabilityMatrix.Core.Models.FileInterfaces;
7+
using StabilityMatrix.Core.Services;
148

159
namespace StabilityMatrix.Avalonia.DesignData;
1610

17-
public partial class MockInferenceClientManager : ObservableObject, IInferenceClientManager
11+
public class MockInferenceClientManager : InferenceClientManager
1812
{
19-
public ComfyClient? Client { get; set; }
20-
21-
public IObservableCollection<HybridModelFile> Models { get; } =
22-
new ObservableCollectionExtended<HybridModelFile>();
23-
24-
public IObservableCollection<HybridModelFile> VaeModels { get; } =
25-
new ObservableCollectionExtended<HybridModelFile>();
26-
27-
public IObservableCollection<HybridModelFile> ControlNetModels { get; } =
28-
new ObservableCollectionExtended<HybridModelFile>();
29-
30-
public IObservableCollection<HybridModelFile> LoraModels { get; } =
31-
new ObservableCollectionExtended<HybridModelFile>();
32-
33-
public IObservable<IChangeSet<HybridModelFile, string>> LoraModelsChangeSet { get; }
34-
35-
public IObservableCollection<HybridModelFile> PromptExpansionModels { get; } =
36-
new ObservableCollectionExtended<HybridModelFile>();
37-
38-
public IObservableCollection<ComfySampler> Samplers { get; } =
39-
new ObservableCollectionExtended<ComfySampler>(ComfySampler.Defaults);
40-
41-
public IObservableCollection<ComfyUpscaler> Upscalers { get; } =
42-
new ObservableCollectionExtended<ComfyUpscaler>(
43-
ComfyUpscaler.Defaults.Concat(ComfyUpscaler.DefaultDownloadableModels)
44-
);
45-
46-
public IObservableCollection<ComfyScheduler> Schedulers { get; } =
47-
new ObservableCollectionExtended<ComfyScheduler>(ComfyScheduler.Defaults);
48-
49-
public IObservableCollection<ComfyAuxPreprocessor> Preprocessors { get; } =
50-
new ObservableCollectionExtended<ComfyAuxPreprocessor>(ComfyAuxPreprocessor.Defaults);
51-
52-
public IObservableCollection<HybridModelFile> UltralyticsModels { get; } =
53-
new ObservableCollectionExtended<HybridModelFile>();
54-
55-
public IObservableCollection<HybridModelFile> SamModels { get; } =
56-
new ObservableCollectionExtended<HybridModelFile>();
57-
58-
public IObservableCollection<HybridModelFile> UnetModels { get; } =
59-
new ObservableCollectionExtended<HybridModelFile>();
60-
61-
public IObservableCollection<HybridModelFile> ClipModels { get; } =
62-
new ObservableCollectionExtended<HybridModelFile>();
63-
64-
public IObservableCollection<HybridModelFile> ClipVisionModels { get; } =
65-
new ObservableCollectionExtended<HybridModelFile>();
66-
67-
[ObservableProperty]
68-
[NotifyPropertyChangedFor(nameof(CanUserConnect))]
69-
private bool isConnected;
70-
71-
[ObservableProperty]
72-
[NotifyPropertyChangedFor(nameof(CanUserConnect))]
73-
private bool isConnecting;
74-
75-
/// <inheritdoc />
76-
public bool CanUserConnect => !IsConnected && !IsConnecting;
77-
78-
/// <inheritdoc />
79-
public bool CanUserDisconnect => IsConnected && !IsConnecting;
80-
81-
public MockInferenceClientManager()
13+
public MockInferenceClientManager(
14+
ILogger<InferenceClientManager> logger,
15+
IApiFactory apiFactory,
16+
IModelIndexService modelIndexService,
17+
ISettingsManager settingsManager,
18+
ICompletionProvider completionProvider
19+
)
20+
: base(logger, apiFactory, modelIndexService, settingsManager, completionProvider)
8221
{
83-
Models.AddRange(
84-
new[]
85-
{
86-
HybridModelFile.FromRemote("v1-5-pruned-emaonly.safetensors"),
87-
HybridModelFile.FromRemote("artshaper1.safetensors"),
88-
}
89-
);
22+
// Load our initial models
23+
ResetSharedProperties();
9024
}
9125

92-
/// <inheritdoc />
93-
public Task CopyImageToInputAsync(FilePath imageFile, CancellationToken cancellationToken = default)
94-
{
95-
return Task.CompletedTask;
96-
}
26+
public new bool IsConnected { get; set; }
9727

98-
/// <inheritdoc />
99-
public Task UploadInputImageAsync(ImageSource image, CancellationToken cancellationToken = default)
28+
protected override Task LoadSharedPropertiesAsync()
10029
{
101-
return Task.CompletedTask;
102-
}
30+
if (Models.Any(m => m.IsRemote))
31+
{
32+
return Task.CompletedTask;
33+
}
34+
35+
Models.Add(
36+
[
37+
HybridModelFile.FromRemote("v1-5-pruned-emaonly.safetensors"),
38+
HybridModelFile.FromRemote("art-shaper1.safetensors"),
39+
]
40+
);
10341

104-
/// <inheritdoc />
105-
public Task WriteImageToInputAsync(ImageSource imageSource, CancellationToken cancellationToken = default)
106-
{
10742
return Task.CompletedTask;
10843
}
10944

110-
public async Task ConnectAsync(CancellationToken cancellationToken = default)
45+
public override async Task ConnectAsync(CancellationToken cancellationToken = default)
11146
{
11247
IsConnecting = true;
11348
await Task.Delay(5000, cancellationToken);
@@ -116,20 +51,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
11651
IsConnected = true;
11752
}
11853

119-
/// <inheritdoc />
120-
public Task ConnectAsync(PackagePair packagePair, CancellationToken cancellationToken = default)
121-
{
122-
return Task.CompletedTask;
123-
}
124-
125-
public Task CloseAsync()
126-
{
127-
IsConnected = false;
128-
return Task.CompletedTask;
129-
}
130-
131-
public void Dispose()
54+
public override async Task ConnectAsync(
55+
PackagePair packagePair,
56+
CancellationToken cancellationToken = default
57+
)
13258
{
133-
GC.SuppressFinalize(this);
59+
await ConnectAsync(cancellationToken);
13460
}
13561
}

StabilityMatrix.Avalonia/DesignData/MockModelIndexService.cs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
using System;
2-
using System.Collections.Generic;
3-
using System.Linq;
4-
using System.Threading.Tasks;
5-
using Nito.Disposables.Internals;
1+
using Nito.Disposables.Internals;
62
using StabilityMatrix.Core.Models;
73
using StabilityMatrix.Core.Models.Database;
84
using StabilityMatrix.Core.Services;
@@ -12,7 +8,23 @@ namespace StabilityMatrix.Avalonia.DesignData;
128
public class MockModelIndexService : IModelIndexService
139
{
1410
/// <inheritdoc />
15-
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; } = new();
11+
public Dictionary<SharedFolderType, List<LocalModelFile>> ModelIndex { get; } =
12+
new()
13+
{
14+
[SharedFolderType.Lora] =
15+
[
16+
new LocalModelFile
17+
{
18+
RelativePath = "Lora/mock_model_1.safetensors",
19+
SharedFolderType = SharedFolderType.Lora,
20+
},
21+
new LocalModelFile
22+
{
23+
RelativePath = "Lora/mock_model_2.safetensors",
24+
SharedFolderType = SharedFolderType.Lora,
25+
},
26+
],
27+
};
1628

1729
/// <inheritdoc />
1830
public IReadOnlySet<string> ModelIndexBlake3Hashes =>
@@ -27,7 +39,7 @@ public Task RefreshIndex()
2739
/// <inheritdoc />
2840
public IEnumerable<LocalModelFile> FindByModelType(SharedFolderType types)
2941
{
30-
return Array.Empty<LocalModelFile>();
42+
return ModelIndex.Where(kvp => (kvp.Key & types) != 0).SelectMany(kvp => kvp.Value);
3143
}
3244

3345
/// <inheritdoc />

StabilityMatrix.Avalonia/Services/InferenceClientManager.cs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public partial class InferenceClientManager : ObservableObject, IInferenceClient
4747
private ComfyClient? client;
4848

4949
[MemberNotNullWhen(true, nameof(Client))]
50-
public bool IsConnected => Client is not null;
50+
public virtual bool IsConnected => Client is not null;
5151

5252
[ObservableProperty]
5353
[NotifyPropertyChangedFor(nameof(CanUserConnect))]
@@ -367,7 +367,7 @@ private void EnsureConnected()
367367
throw new InvalidOperationException("Client is not connected");
368368
}
369369

370-
private async Task LoadSharedPropertiesAsync()
370+
protected virtual async Task LoadSharedPropertiesAsync()
371371
{
372372
EnsureConnected();
373373

@@ -526,7 +526,7 @@ .. clipVisionModelNames.Select(HybridModelFile.FromRemote),
526526
/// <summary>
527527
/// Clears shared properties and sets them to local defaults
528528
/// </summary>
529-
private void ResetSharedProperties()
529+
protected void ResetSharedProperties()
530530
{
531531
// Load local models
532532
modelsSource.EditDiff(
@@ -784,12 +784,6 @@ private async Task ConnectAsyncImpl(Uri uri, CancellationToken cancellationToken
784784
}
785785
}
786786

787-
/// <inheritdoc />
788-
public Task ConnectAsync(CancellationToken cancellationToken = default)
789-
{
790-
return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken);
791-
}
792-
793787
private async Task MigrateLinksIfNeeded(PackagePair packagePair)
794788
{
795789
if (packagePair.InstalledPackage.FullPath is not { } packagePath)
@@ -827,7 +821,16 @@ private async Task MigrateLinksIfNeeded(PackagePair packagePair)
827821
}
828822

829823
/// <inheritdoc />
830-
public async Task ConnectAsync(PackagePair packagePair, CancellationToken cancellationToken = default)
824+
public virtual Task ConnectAsync(CancellationToken cancellationToken = default)
825+
{
826+
return ConnectAsyncImpl(new Uri("http://127.0.0.1:8188"), cancellationToken);
827+
}
828+
829+
/// <inheritdoc />
830+
public virtual async Task ConnectAsync(
831+
PackagePair packagePair,
832+
CancellationToken cancellationToken = default
833+
)
831834
{
832835
if (IsConnected)
833836
return;

0 commit comments

Comments
 (0)