Skip to content

Commit b76d2cd

Browse files
psychedelicioushipsterusername
authored andcommitted
fix(ui): handle base model compat when recalling parameters
We had a one-behind issue with recalling metadata items that had a model. For example, when recalling LoRAs, we check against the current main model to decide whether or not the requested LoRA is compatible and may be recalled. When recalling all params, we are often also recalling the main model, but the compat logic didn't compare against this new main model. The logic is updated to check against the new main model, if one is being set. Closes #5512
1 parent 022b32c commit b76d2cd

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import {
4747
vaeSelected,
4848
widthChanged,
4949
} from 'features/parameters/store/generationSlice';
50+
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
5051
import {
5152
isParameterCFGRescaleMultiplier,
5253
isParameterCFGScale,
@@ -480,7 +481,7 @@ export const useRecallParameters = () => {
480481
const { data: loraModels } = useGetLoRAModelsQuery(undefined);
481482

482483
const prepareLoRAMetadataItem = useCallback(
483-
(loraMetadataItem: LoRAMetadataItem) => {
484+
(loraMetadataItem: LoRAMetadataItem, newModel?: ParameterModel) => {
484485
if (!isParameterLoRAModel(loraMetadataItem.lora)) {
485486
return { lora: null, error: 'Invalid LoRA model' };
486487
}
@@ -499,7 +500,7 @@ export const useRecallParameters = () => {
499500
}
500501

501502
const isCompatibleBaseModel =
502-
matchingLoRA?.base_model === model?.base_model;
503+
matchingLoRA?.base_model === (newModel ?? model)?.base_model;
503504

504505
if (!isCompatibleBaseModel) {
505506
return {
@@ -510,7 +511,7 @@ export const useRecallParameters = () => {
510511

511512
return { lora: matchingLoRA, error: null };
512513
},
513-
[loraModels, model?.base_model]
514+
[loraModels, model]
514515
);
515516

516517
const recallLoRA = useCallback(
@@ -538,7 +539,10 @@ export const useRecallParameters = () => {
538539
const { data: controlNetModels } = useGetControlNetModelsQuery(undefined);
539540

540541
const prepareControlNetMetadataItem = useCallback(
541-
(controlnetMetadataItem: ControlNetMetadataItem) => {
542+
(
543+
controlnetMetadataItem: ControlNetMetadataItem,
544+
newModel?: ParameterModel
545+
) => {
542546
if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) {
543547
return { controlnet: null, error: 'Invalid ControlNet model' };
544548
}
@@ -565,7 +569,7 @@ export const useRecallParameters = () => {
565569
}
566570

567571
const isCompatibleBaseModel =
568-
matchingControlNetModel?.base_model === model?.base_model;
572+
matchingControlNetModel?.base_model === (newModel ?? model)?.base_model;
569573

570574
if (!isCompatibleBaseModel) {
571575
return {
@@ -600,7 +604,7 @@ export const useRecallParameters = () => {
600604

601605
return { controlnet, error: null };
602606
},
603-
[controlNetModels, model?.base_model]
607+
[controlNetModels, model]
604608
);
605609

606610
const recallControlNet = useCallback(
@@ -631,7 +635,10 @@ export const useRecallParameters = () => {
631635
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery(undefined);
632636

633637
const prepareT2IAdapterMetadataItem = useCallback(
634-
(t2iAdapterMetadataItem: T2IAdapterMetadataItem) => {
638+
(
639+
t2iAdapterMetadataItem: T2IAdapterMetadataItem,
640+
newModel?: ParameterModel
641+
) => {
635642
if (
636643
!isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)
637644
) {
@@ -659,7 +666,7 @@ export const useRecallParameters = () => {
659666
}
660667

661668
const isCompatibleBaseModel =
662-
matchingT2IAdapterModel?.base_model === model?.base_model;
669+
matchingT2IAdapterModel?.base_model === (newModel ?? model)?.base_model;
663670

664671
if (!isCompatibleBaseModel) {
665672
return {
@@ -690,7 +697,7 @@ export const useRecallParameters = () => {
690697

691698
return { t2iAdapter, error: null };
692699
},
693-
[model?.base_model, t2iAdapterModels]
700+
[model, t2iAdapterModels]
694701
);
695702

696703
const recallT2IAdapter = useCallback(
@@ -721,7 +728,10 @@ export const useRecallParameters = () => {
721728
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery(undefined);
722729

723730
const prepareIPAdapterMetadataItem = useCallback(
724-
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
731+
(
732+
ipAdapterMetadataItem: IPAdapterMetadataItem,
733+
newModel?: ParameterModel
734+
) => {
725735
if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
726736
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
727737
}
@@ -746,7 +756,7 @@ export const useRecallParameters = () => {
746756
}
747757

748758
const isCompatibleBaseModel =
749-
matchingIPAdapterModel?.base_model === model?.base_model;
759+
matchingIPAdapterModel?.base_model === (newModel ?? model)?.base_model;
750760

751761
if (!isCompatibleBaseModel) {
752762
return {
@@ -768,7 +778,7 @@ export const useRecallParameters = () => {
768778

769779
return { ipAdapter, error: null };
770780
},
771-
[ipAdapterModels, model?.base_model]
781+
[ipAdapterModels, model]
772782
);
773783

774784
const recallIPAdapter = useCallback(
@@ -840,6 +850,13 @@ export const useRecallParameters = () => {
840850
t2iAdapters,
841851
} = metadata;
842852

853+
let newModel: ParameterModel | undefined = undefined;
854+
855+
if (isParameterModel(model)) {
856+
newModel = model;
857+
dispatch(modelSelected(model));
858+
}
859+
843860
if (isParameterCFGScale(cfg_scale)) {
844861
dispatch(setCfgScale(cfg_scale));
845862
}
@@ -848,10 +865,6 @@ export const useRecallParameters = () => {
848865
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
849866
}
850867

851-
if (isParameterModel(model)) {
852-
dispatch(modelSelected(model));
853-
}
854-
855868
if (isParameterPositivePrompt(positive_prompt)) {
856869
dispatch(setPositivePrompt(positive_prompt));
857870
}
@@ -953,29 +966,29 @@ export const useRecallParameters = () => {
953966

954967
dispatch(lorasCleared());
955968
loras?.forEach((lora) => {
956-
const result = prepareLoRAMetadataItem(lora);
969+
const result = prepareLoRAMetadataItem(lora, newModel);
957970
if (result.lora) {
958971
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
959972
}
960973
});
961974

962975
dispatch(controlAdaptersReset());
963976
controlnets?.forEach((controlnet) => {
964-
const result = prepareControlNetMetadataItem(controlnet);
977+
const result = prepareControlNetMetadataItem(controlnet, newModel);
965978
if (result.controlnet) {
966979
dispatch(controlAdapterRecalled(result.controlnet));
967980
}
968981
});
969982

970983
ipAdapters?.forEach((ipAdapter) => {
971-
const result = prepareIPAdapterMetadataItem(ipAdapter);
984+
const result = prepareIPAdapterMetadataItem(ipAdapter, newModel);
972985
if (result.ipAdapter) {
973986
dispatch(controlAdapterRecalled(result.ipAdapter));
974987
}
975988
});
976989

977990
t2iAdapters?.forEach((t2iAdapter) => {
978-
const result = prepareT2IAdapterMetadataItem(t2iAdapter);
991+
const result = prepareT2IAdapterMetadataItem(t2iAdapter, newModel);
979992
if (result.t2iAdapter) {
980993
dispatch(controlAdapterRecalled(result.t2iAdapter));
981994
}

0 commit comments

Comments
 (0)