Skip to content

Commit 14ce0f9

Browse files
committed
Use ModelCompatChecker for ExtraNetworkCard
1 parent 9aafc83 commit 14ce0f9

File tree

2 files changed

+68
-6
lines changed

2 files changed

+68
-6
lines changed

StabilityMatrix.Avalonia/ViewModels/Inference/ExtraNetworkCardViewModel.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
using StabilityMatrix.Avalonia.Services;
1313
using StabilityMatrix.Avalonia.ViewModels.Base;
1414
using StabilityMatrix.Core.Attributes;
15-
using StabilityMatrix.Core.Extensions;
15+
using StabilityMatrix.Core.Helper;
1616
using StabilityMatrix.Core.Models;
1717
using StabilityMatrix.Core.Services;
1818

@@ -26,6 +26,8 @@ namespace StabilityMatrix.Avalonia.ViewModels.Inference;
2626
public partial class ExtraNetworkCardViewModel : DisposableLoadableViewModelBase
2727
{
2828
private readonly ISettingsManager settingsManager;
29+
private readonly ModelCompatChecker modelCompatChecker = new();
30+
2931
public const string ModuleKey = "ExtraNetwork";
3032

3133
/// <summary>
@@ -161,11 +163,7 @@ private bool FilterCompatibleLoras(HybridModelFile? lora)
161163
if (!settingsManager.Settings.FilterExtraNetworksByBaseModel)
162164
return true;
163165

164-
return SelectedBaseModel is null
165-
|| lora?.Local?.ConnectedModelInfo == null
166-
|| SelectedBaseModel.Local?.ConnectedModelInfo == null
167-
|| lora.Local?.ConnectedModelInfo?.BaseModel
168-
== SelectedBaseModel.Local?.ConnectedModelInfo?.BaseModel;
166+
return modelCompatChecker.IsLoraCompatibleWithBaseModel(lora, SelectedBaseModel) ?? true;
169167
}
170168

171169
internal class ExtraNetworkCardModel
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using StabilityMatrix.Core.Extensions;
2+
using StabilityMatrix.Core.Models;
3+
using StabilityMatrix.Core.Models.Api;
4+
5+
namespace StabilityMatrix.Core.Helper;
6+
7+
public class ModelCompatChecker
8+
{
9+
private readonly Dictionary<string, CivitBaseModelType> baseModelNamesToTypes =
10+
Enum.GetValues<CivitBaseModelType>().ToDictionary(x => x.GetStringValue());
11+
12+
public bool? IsLoraCompatibleWithBaseModel(HybridModelFile? lora, HybridModelFile? baseModel)
13+
{
14+
// Require connected info for both
15+
if (
16+
lora?.Local?.ConnectedModelInfo is not { } loraInfo
17+
|| baseModel?.Local?.ConnectedModelInfo is not { } baseModelInfo
18+
)
19+
return null;
20+
21+
if (
22+
loraInfo.BaseModel is null
23+
|| !baseModelNamesToTypes.TryGetValue(loraInfo.BaseModel, out var loraBaseModelType)
24+
)
25+
return null;
26+
27+
if (
28+
baseModelInfo.BaseModel is null
29+
|| !baseModelNamesToTypes.TryGetValue(baseModelInfo.BaseModel, out var baseModelType)
30+
)
31+
return null;
32+
33+
// Normalize both
34+
var normalizedLoraBaseModelType = NormalizeBaseModelType(loraBaseModelType);
35+
var normalizedBaseModelType = NormalizeBaseModelType(baseModelType);
36+
37+
// Ignore if either is "Other"
38+
if (
39+
normalizedLoraBaseModelType == CivitBaseModelType.Other
40+
|| normalizedBaseModelType == CivitBaseModelType.Other
41+
)
42+
return null;
43+
44+
return normalizedLoraBaseModelType == normalizedBaseModelType;
45+
}
46+
47+
// Normalize base model type
48+
private static CivitBaseModelType NormalizeBaseModelType(CivitBaseModelType baseModel)
49+
{
50+
return baseModel switch
51+
{
52+
CivitBaseModelType.Sdxl09 => CivitBaseModelType.Sdxl10,
53+
CivitBaseModelType.Sdxl10Lcm => CivitBaseModelType.Sdxl10,
54+
CivitBaseModelType.SdxlDistilled => CivitBaseModelType.Sdxl10,
55+
CivitBaseModelType.SdxlHyper => CivitBaseModelType.Sdxl10,
56+
CivitBaseModelType.SdxlLightning => CivitBaseModelType.Sdxl10,
57+
CivitBaseModelType.SdxlTurbo => CivitBaseModelType.Sdxl10,
58+
CivitBaseModelType.Pony => CivitBaseModelType.Sdxl10,
59+
CivitBaseModelType.NoobAi => CivitBaseModelType.Sdxl10,
60+
CivitBaseModelType.Illustrious => CivitBaseModelType.Sdxl10,
61+
_ => baseModel,
62+
};
63+
}
64+
}

0 commit comments

Comments
 (0)