Skip to content

Commit 95f44ff

Browse files
fix: Make the name extraction work for both ckpts and folders
1 parent f9c3c07 commit 95f44ff

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddCheckpoint.tsx

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
1414
import BaseModelSelect from '../shared/BaseModelSelect';
1515
import CheckpointConfigsSelect from '../shared/CheckpointConfigsSelect';
1616
import ModelVariantSelect from '../shared/ModelVariantSelect';
17+
import { getModelName } from './util';
1718

1819
type AdvancedAddCheckpointProps = {
1920
model_path?: string;
@@ -28,7 +29,7 @@ export default function AdvancedAddCheckpoint(
2829

2930
const advancedAddCheckpointForm = useForm<CheckpointModelConfig>({
3031
initialValues: {
31-
model_name: model_path?.match(/[^\\/]+$/)?.[0]?.split('.')[0] ?? '',
32+
model_name: model_path ? getModelName(model_path) : '',
3233
base_model: 'sd-1',
3334
model_type: 'main',
3435
path: model_path ? model_path : '',
@@ -102,10 +103,7 @@ export default function AdvancedAddCheckpoint(
102103
{...advancedAddCheckpointForm.getInputProps('path')}
103104
onBlur={(e) => {
104105
if (advancedAddCheckpointForm.values['model_name'] === '') {
105-
const modelName = e.currentTarget.value
106-
.match(/[^\\/]+$/)?.[0]
107-
?.split('.')[0];
108-
106+
const modelName = getModelName(e.currentTarget.value);
109107
if (modelName) {
110108
advancedAddCheckpointForm.setFieldValue(
111109
'model_name',

invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/AddModelsPanel/AdvancedAddDiffusers.tsx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { DiffusersModelConfig } from 'services/api/types';
1111
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
1212
import BaseModelSelect from '../shared/BaseModelSelect';
1313
import ModelVariantSelect from '../shared/ModelVariantSelect';
14+
import { getModelName } from './util';
1415

1516
type AdvancedAddDiffusersProps = {
1617
model_path?: string;
@@ -25,7 +26,7 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
2526

2627
const advancedAddDiffusersForm = useForm<DiffusersModelConfig>({
2728
initialValues: {
28-
model_name: model_path?.match(/[^\\/]+$/)?.[0] ?? '',
29+
model_name: model_path ? getModelName(model_path, false) : '',
2930
base_model: 'sd-1',
3031
model_type: 'main',
3132
path: model_path ? model_path : '',
@@ -94,7 +95,7 @@ export default function AdvancedAddDiffusers(props: AdvancedAddDiffusersProps) {
9495
{...advancedAddDiffusersForm.getInputProps('path')}
9596
onBlur={(e) => {
9697
if (advancedAddDiffusersForm.values['model_name'] === '') {
97-
const modelName = e.currentTarget.value.match(/[^\\/]+$/)?.[0];
98+
const modelName = getModelName(e.currentTarget.value, false);
9899
if (modelName) {
99100
advancedAddDiffusersForm.setFieldValue(
100101
'model_name',
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export function getModelName(filepath: string, isCheckpoint: boolean = true) {
2+
let regex;
3+
if (isCheckpoint) {
4+
regex = new RegExp('[^\\\\/]+(?=\\.)');
5+
} else {
6+
regex = new RegExp('[^\\\\/]+(?=[\\\\/]?$)');
7+
}
8+
9+
const match = filepath.match(regex);
10+
if (match) {
11+
return match[0];
12+
} else {
13+
return '';
14+
}
15+
}

0 commit comments

Comments
 (0)