diff --git a/packages/frontend/src/lib/select/InferenceRuntimeSelect.svelte b/packages/frontend/src/lib/select/InferenceRuntimeSelect.svelte new file mode 100644 index 000000000..5cd4907a9 --- /dev/null +++ b/packages/frontend/src/lib/select/InferenceRuntimeSelect.svelte @@ -0,0 +1,33 @@ + + + ({ + value: type, + label: type, + }))} /> diff --git a/packages/frontend/src/pages/CreateService.svelte b/packages/frontend/src/pages/CreateService.svelte index 148db7406..0d7bb23cd 100644 --- a/packages/frontend/src/pages/CreateService.svelte +++ b/packages/frontend/src/pages/CreateService.svelte @@ -16,6 +16,8 @@ import { containerProviderConnections } from '/@/stores/containerProviderConnect import ContainerProviderConnectionSelect from '/@/lib/select/ContainerProviderConnectionSelect.svelte'; import ContainerConnectionWrapper from '/@/lib/notification/ContainerConnectionWrapper.svelte'; import TrackedTasks from '/@/lib/progress/TrackedTasks.svelte'; +import InferenceRuntimeSelect from '/@/lib/select/InferenceRuntimeSelect.svelte'; +import { InferenceType } from '@shared/models/IInference'; interface Props { // The tracking id is a unique identifier provided by the @@ -25,9 +27,15 @@ interface Props { let { trackingId }: Props = $props(); +// The runtime to use +let runtime: InferenceType = $state(InferenceType.LLAMA_CPP); + // List of the models available locally let localModels: ModelInfo[] = $derived($modelsInfo.filter(model => model.file)); +// List of the models filtered by runtime +let filteredModels: Array = $derived(localModels.filter(model => model.backend === runtime)); + // The container provider connection to use let containerProviderConnection: ContainerProviderConnectionInfo | undefined = $state(undefined); @@ -51,9 +59,14 @@ let available: boolean = $derived(!!containerId && $inferenceServers.some(server let loading = $derived(trackingId !== undefined && !errorMsg); $effect(() => { + // remove any incompatible model + if (model?.backend !== runtime) { + model = undefined; + } + // Select default model - if (!model && localModels.length > 0) { - model = localModels[0]; + if (!model && filteredModels.length > 0) { + model = filteredModels[0]; } // Select default connection @@ -129,6 +142,7 @@ function populateModelFromTasks(trackedTasks: Task[]): void { if (!mModel) return; model = mModel; + runtime = mModel.backend as InferenceType; } onMount(() => { @@ -145,6 +159,9 @@ onMount(() => { const queryModelId = router.location.query.get('model-id'); if (queryModelId !== undefined && typeof queryModelId === 'string') { model = localModels.find(mModel => mModel.id === queryModelId); + if (model) { + runtime = model.backend as InferenceType; + } } }); @@ -182,6 +199,11 @@ export function goToUpPage(): void { + + Inference Runtime + + {#if startedContainerProviderConnectionInfo.length > 1} Model - - {#if localModels.length === 0} + + {#if filteredModels.length === 0}