diff --git a/src/components/ModelCatalog.tsx b/src/components/ModelCatalog.tsx index e65eb0f4144643..35083fc6e44b0b 100644 --- a/src/components/ModelCatalog.tsx +++ b/src/components/ModelCatalog.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import ModelInfo from "./models/ModelInfo"; import ModelBadges from "./models/ModelBadges"; import { authorData } from "./models/data"; @@ -19,6 +19,22 @@ const ModelCatalog = ({ models }: { models: WorkersAIModelsSchema[] }) => { capabilities: [], }); + useEffect(() => { + const params = new URLSearchParams(window.location.search); + + const search = params.get("search") ?? ""; + const authors = params.getAll("authors"); + const tasks = params.getAll("tasks"); + const capabilities = params.getAll("capabilities"); + + setFilters({ + search, + authors, + tasks, + capabilities, + }); + }, []); + const mapped = models.map((model) => ({ model: { ...model, @@ -43,21 +59,21 @@ const ModelCatalog = ({ models }: { models: WorkersAIModelsSchema[] }) => { const authors = [...new Set(models.map((model) => model.name.split("/")[1]))]; const capabilities = [ ...new Set( - models - .map((model) => - model.properties - .flatMap(({ property_id, value }) => { - if (property_id === "lora" && value === "true") { - return "LoRA"; - } + models.flatMap((model) => + model.properties + .flatMap(({ property_id, value }) => { + if (property_id === "lora" && value === "true") { + return "LoRA"; + } + + if (property_id === "function_calling" && value === "true") { + return "Function calling"; + } - if (property_id === "function_calling" && value === "true") { - return "Function calling"; - } - }) - .filter((p) => Boolean(p)), - ) - .flat(), + return []; + }) + .filter((p) => Boolean(p)), + ), ), ]; @@ -102,7 +118,7 @@ const ModelCatalog = ({ models }: { models: WorkersAIModelsSchema[] }) => {
- ▼ Model Types + ▼ Tasks {tasks.map((task) => ( @@ -111,7 +127,8 @@ const ModelCatalog = ({ models }: { models: WorkersAIModelsSchema[] }) => { type="checkbox" className="mr-2" value={task} - onClick={(e) => { + checked={filters.tasks.includes(task)} + onChange={(e) => { const target = e.target as HTMLInputElement; if (target.checked) { @@ -142,8 +159,9 @@ const ModelCatalog = ({ models }: { models: WorkersAIModelsSchema[] }) => { { + onChange={(e) => { const target = e.target as HTMLInputElement; if (target.checked) { @@ -177,7 +195,8 @@ const ModelCatalog = ({ models }: { models: WorkersAIModelsSchema[] }) => { type="checkbox" className="mr-2" value={author} - onClick={(e) => { + checked={filters.authors.includes(author)} + onChange={(e) => { const target = e.target as HTMLInputElement; if (target.checked) { diff --git a/src/components/models/ModelBadges.tsx b/src/components/models/ModelBadges.tsx index 5392cd5cede048..9327d377e69620 100644 --- a/src/components/models/ModelBadges.tsx +++ b/src/components/models/ModelBadges.tsx @@ -33,7 +33,7 @@ const ModelBadges = ({ model }: { model: WorkersAIModelsSchema }) => {