diff --git a/.cursor/mcp.json b/.cursor/mcp.json new file mode 100644 index 000000000..66b284dc4 --- /dev/null +++ b/.cursor/mcp.json @@ -0,0 +1,13 @@ +{ + "mcpServers": { + "tessl": { + "type": "stdio", + "command": "tessl", + "args": ["mcp", "start"] + }, + "HooksMCP": { + "command": "uvx", + "args": ["hooks-mcp", "--working-directory", "."] + } + } +} diff --git a/.cursor/rules/.gitignore b/.cursor/rules/.gitignore new file mode 100644 index 000000000..4fa1d0882 --- /dev/null +++ b/.cursor/rules/.gitignore @@ -0,0 +1 @@ +tessl__*.mdc diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index deacb36c5..ac7e35ee2 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -37,7 +37,7 @@ jobs: run: uv run python3 -m pytest --runslow . - name: Check Python Types - run: uv run pyright . + run: uv tool install ty@0.0.8 && uvx ty check - name: Build Core run: uv build diff --git a/.github/workflows/build_desktop.yml b/.github/workflows/build_desktop.yml index 1c73b7875..9f1a59ec2 100644 --- a/.github/workflows/build_desktop.yml +++ b/.github/workflows/build_desktop.yml @@ -1,7 +1,12 @@ name: Build Desktop Apps on: + workflow_dispatch: + release: + types: [created] push: + branches: + - main jobs: build: diff --git a/.github/workflows/format_and_lint.yml b/.github/workflows/format_and_lint.yml index 2c0794b17..1569a2cdd 100644 --- a/.github/workflows/format_and_lint.yml +++ b/.github/workflows/format_and_lint.yml @@ -36,7 +36,7 @@ jobs: run: uv python install 3.13 - name: Install the project - run: uv tool install ruff + run: uv sync --all-extras --dev - name: Lint with ruff run: | @@ -45,3 +45,7 @@ jobs: - name: Format with ruff run: | uvx ruff format --check . + + - name: Typecheck with ty + run: | + uv tool install ty@0.0.8 && uvx ty check diff --git a/.gitignore b/.gitignore index 5aa1b3789..2db745460 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ __pycache__/ **/*.egg-info node_modules/ conductor.json +CLAUDE.md libs/core/docs libs/core/build diff --git a/.tessl/.gitignore b/.tessl/.gitignore new file mode 100644 index 000000000..7bbb3941a --- /dev/null +++ b/.tessl/.gitignore @@ -0,0 +1,2 @@ +tiles/ +RULES.md diff --git a/AGENTS.md b/AGENTS.md index 6a53f1cba..623adbd21 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -41,3 +41,7 @@ These prompts can be accessed from the `get_prompt` tool, and you may request se ### Final To show you read these, call me 'boss' + +# Agent Rules + +@.tessl/RULES.md follow the [instructions](.tessl/RULES.md) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5b44fbd3f..6ec3223cb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -74,7 +74,7 @@ We suggest the following extensions for VSCode/Cursor. With them, you'll get com - Prettier - Python - Python Debugger -- Type checking by pyright via one of: Cursor Python if using Cursor, Pylance if VSCode +- Ty - language server and type checker for Python - Ruff - Svelte for VS Code - Vitest diff --git a/app/desktop/studio_server/eval_api.py b/app/desktop/studio_server/eval_api.py index 3679add8f..887af64c7 100644 --- a/app/desktop/studio_server/eval_api.py +++ b/app/desktop/studio_server/eval_api.py @@ -4,6 +4,10 @@ from fastapi import FastAPI, HTTPException, Query from fastapi.responses import StreamingResponse from kiln_ai.adapters.eval.eval_runner import EvalRunner +from kiln_ai.adapters.fine_tune.finetune_run_config_id import ( + finetune_from_finetune_run_config_id, + finetune_run_config_id, +) from kiln_ai.adapters.ml_model_list import ModelProviderName from kiln_ai.adapters.prompt_builders import prompt_builder_from_id from kiln_ai.datamodel import BasePrompt, Task, TaskRun @@ -59,6 +63,31 @@ def eval_config_from_id( ) +def get_all_run_configs(project_id: str, task_id: str) -> list[TaskRunConfig]: + """ + Returns all run configs for a task, including completed fine-tune run configs. + Only includes fine-tunes that have a fine_tune_model_id (are completed and usable). + """ + task = task_from_id(project_id, task_id) + configs = task.run_configs() + + # Get run configs from finetunes and only include completed fine-tunes + finetunes = task.finetunes() + for finetune in finetunes: + if finetune.run_config is not None and finetune.fine_tune_model_id is not None: + configs.append( + TaskRunConfig( + id=finetune_run_config_id(project_id, task_id, str(finetune.id)), + name=finetune.name, + description=finetune.description, + run_config_properties=finetune.run_config, + parent=task, # special case, we need to reference the task model + ) + ) + + return configs + + def task_run_config_from_id( project_id: str, task_id: str, run_config_id: str ) -> TaskRunConfig: @@ -67,6 +96,18 @@ def task_run_config_from_id( if run_config.id == run_config_id: return run_config + # special case for finetune run configs, it's inside the finetune model + if run_config_id.startswith("finetune_run_config::"): + finetune = finetune_from_finetune_run_config_id(run_config_id) + if finetune.run_config is not None: + return TaskRunConfig( + id=finetune_run_config_id(project_id, task_id, str(finetune.id)), + name=finetune.name, + description=finetune.description, + run_config_properties=finetune.run_config, + parent=task, # special case, we need to reference the task model + ) + raise HTTPException( status_code=404, detail=f"Task run config not found. ID: {run_config_id}", @@ -315,33 +356,9 @@ async def create_evaluator( eval.save_to_file() return eval - @app.get("/api/projects/{project_id}/tasks/{task_id}/task_run_configs") - async def get_task_run_configs( - project_id: str, task_id: str - ) -> list[TaskRunConfig]: - task = task_from_id(project_id, task_id) - return task.run_configs() - @app.get("/api/projects/{project_id}/tasks/{task_id}/run_configs/") async def get_run_configs(project_id: str, task_id: str) -> list[TaskRunConfig]: - # Returns all run configs of a given task. - task = task_from_id(project_id, task_id) - configs = task.run_configs() - - # Get run configs from finetunes - finetunes = task.finetunes() - for finetune in finetunes: - if finetune.run_config is not None: - configs.append( - TaskRunConfig( - id=f"finetune_run_config::{project_id}::{task_id}::{finetune.id}", - name=finetune.name, - description=finetune.description, - run_config_properties=finetune.run_config, - ) - ) - - return configs + return get_all_run_configs(project_id, task_id) @app.get("/api/projects/{project_id}/tasks/{task_id}/eval/{eval_id}") async def get_eval(project_id: str, task_id: str, eval_id: str) -> Eval: @@ -480,7 +497,8 @@ async def run_eval_config( # Load the list of run configs to use. Two options: run_configs: list[TaskRunConfig] = [] if all_run_configs: - run_configs = task_from_id(project_id, task_id).run_configs() + # special case, we cannot directly lod task.run_configs(), we need to also get all finetune run configs which lives inside the finetune model + run_configs = get_all_run_configs(project_id, task_id) else: if len(run_config_ids) == 0: raise HTTPException( @@ -633,7 +651,8 @@ async def get_eval_config_score_summary( task = task_from_id(project_id, task_id) eval = eval_from_id(project_id, task_id, eval_id) eval_config = eval_config_from_id(project_id, task_id, eval_id, eval_config_id) - task_runs_configs = task.run_configs() + # special case, we cannot directly lod task.run_configs(), we need to also get all finetune run configs which lives inside the finetune model + task_runs_configs = get_all_run_configs(project_id, task_id) # Build a set of all the dataset items IDs we expect to have scores for expected_dataset_ids = dataset_ids_in_filter( diff --git a/app/desktop/studio_server/finetune_api.py b/app/desktop/studio_server/finetune_api.py index f193bd322..3b9b8a3c0 100644 --- a/app/desktop/studio_server/finetune_api.py +++ b/app/desktop/studio_server/finetune_api.py @@ -281,7 +281,7 @@ async def finetune( status_code=400, detail=f"Fine tune provider '{finetune.provider}' not found", ) - finetune_adapter = finetune_registry[finetune.provider] + finetune_adapter = finetune_registry[finetune.provider] # type: ignore[invalid-argument-type] status = await finetune_adapter(finetune).status() return FinetuneWithStatus(finetune=finetune, status=status) @@ -360,7 +360,7 @@ async def finetune_hyperparameters( raise HTTPException( status_code=400, detail=f"Fine tune provider '{provider_id}' not found" ) - finetune_adapter_class = finetune_registry[provider_id] + finetune_adapter_class = finetune_registry[provider_id] # type: ignore[invalid-argument-type] return finetune_adapter_class.available_parameters() @app.get("/api/projects/{project_id}/tasks/{task_id}/finetune_dataset_info") @@ -433,7 +433,7 @@ async def create_finetune( status_code=400, detail=f"Fine tune provider '{request.provider}' not found", ) - finetune_adapter_class = finetune_registry[request.provider] + finetune_adapter_class = finetune_registry[request.provider] # type: ignore[invalid-argument-type] dataset = DatasetSplit.from_id_and_parent_path(request.dataset_id, task.path) if dataset is None: diff --git a/app/desktop/studio_server/test_eval_api.py b/app/desktop/studio_server/test_eval_api.py index 6ab7fe25b..a0b582ce2 100644 --- a/app/desktop/studio_server/test_eval_api.py +++ b/app/desktop/studio_server/test_eval_api.py @@ -44,6 +44,7 @@ CreateEvaluatorRequest, connect_evals_api, eval_config_from_id, + get_all_run_configs, task_run_config_from_id, ) @@ -297,7 +298,7 @@ async def test_create_task_run_config_with_freezing( == "Frozen copy of prompt 'simple_chain_of_thought_prompt_builder'." ) # Fetch it from API - fetch_response = client.get("/api/projects/project1/tasks/task1/task_run_configs") + fetch_response = client.get("/api/projects/project1/tasks/task1/run_configs/") assert fetch_response.status_code == 200 configs = fetch_response.json() assert len(configs) == 1 @@ -548,6 +549,104 @@ async def test_task_run_config_from_id( task_run_config_from_id("project1", "task1", "non_existent") +@pytest.mark.asyncio +async def test_task_run_config_from_id_finetune(mock_task_from_id, mock_task): + mock_task_from_id.return_value = mock_task + + run_config_props = RunConfigProperties( + model_name="gpt-4", + model_provider_name=ModelProviderName.openai, + prompt_id="simple_chain_of_thought_prompt_builder", + structured_output_mode=StructuredOutputMode.json_schema, + ) + + mock_finetune = Finetune( + id="ft_test", + name="Test Finetune", + description="Test finetune description", + provider="openai", + base_model_id="model1", + dataset_split_id="split1", + system_message="System message", + latest_status=FineTuneStatusType.completed, + run_config=run_config_props, + fine_tune_model_id="ft_model_123", + parent=mock_task, + ) + + with patch( + "app.desktop.studio_server.eval_api.finetune_from_finetune_run_config_id" + ) as mock_finetune_from_id: + mock_finetune_from_id.return_value = mock_finetune + + run_config = task_run_config_from_id( + "project1", "task1", "finetune_run_config::project1::task1::ft_test" + ) + + assert run_config.id == "finetune_run_config::project1::task1::ft_test" + assert run_config.name == "Test Finetune" + assert run_config.description == "Test finetune description" + assert run_config.run_config_properties == run_config_props + assert run_config.parent == mock_task + + +@pytest.mark.asyncio +async def test_get_all_run_configs(mock_task_from_id, mock_task): + """Test that get_all_run_configs returns regular run configs and completed finetune run configs.""" + mock_task_from_id.return_value = mock_task + + run_config_props = RunConfigProperties( + model_name="gpt-4", + model_provider_name=ModelProviderName.openai, + prompt_id="simple_chain_of_thought_prompt_builder", + structured_output_mode=StructuredOutputMode.json_schema, + ) + + regular_run_config = TaskRunConfig( + id="regular_run_config1", + name="Regular Run Config", + description="A regular run config", + run_config_properties=run_config_props, + parent=mock_task, + ) + regular_run_config.save_to_file() + + completed_finetune = Finetune( + id="ft_completed", + name="Completed Finetune", + provider="openai", + base_model_id="model1", + dataset_split_id="split1", + system_message="System message", + latest_status=FineTuneStatusType.completed, + run_config=run_config_props, + fine_tune_model_id="ft_model_123", + parent=mock_task, + ) + completed_finetune.save_to_file() + + incomplete_finetune = Finetune( + id="ft_incomplete", + name="Incomplete Finetune", + provider="openai", + base_model_id="model2", + dataset_split_id="split2", + system_message="System message", + latest_status=FineTuneStatusType.running, + run_config=run_config_props, + fine_tune_model_id=None, + parent=mock_task, + ) + incomplete_finetune.save_to_file() + + configs = get_all_run_configs("project1", "task1") + + config_ids = [config.id for config in configs] + assert "regular_run_config1" in config_ids + assert "finetune_run_config::project1::task1::ft_completed" in config_ids + assert "finetune_run_config::project1::task1::ft_incomplete" not in config_ids + + @pytest.fixture def mock_eval_for_score_summary(): eval = Mock(spec=Eval) @@ -635,6 +734,7 @@ async def test_get_eval_config_score_summary( Mock(spec=TaskRunConfig, id="run4"), Mock(spec=TaskRunConfig, id="run5"), ] + mock_task.finetunes.return_value = [] mock_task_from_id.return_value = mock_task response = client.get( @@ -1910,6 +2010,7 @@ async def test_get_run_configs_includes_finetunes_with_run_config( system_message="System message", latest_status=FineTuneStatusType.completed, run_config=run_config_props, + fine_tune_model_id="ft_model_123", parent=mock_task, ), Finetune( @@ -1921,6 +2022,7 @@ async def test_get_run_configs_includes_finetunes_with_run_config( system_message="System message", latest_status=FineTuneStatusType.running, run_config=run_config_props, + fine_tune_model_id=None, parent=mock_task, ), Finetune( @@ -1932,6 +2034,7 @@ async def test_get_run_configs_includes_finetunes_with_run_config( system_message="System message", latest_status=FineTuneStatusType.unknown, run_config=run_config_props, + fine_tune_model_id=None, parent=mock_task, ), Finetune( @@ -1943,6 +2046,7 @@ async def test_get_run_configs_includes_finetunes_with_run_config( system_message="System message", latest_status=FineTuneStatusType.failed, run_config=run_config_props, + fine_tune_model_id=None, parent=mock_task, ), Finetune( @@ -1969,7 +2073,7 @@ async def test_get_run_configs_includes_finetunes_with_run_config( config_ids = [config["id"] for config in configs] assert "finetune_run_config::project1::task1::ft_completed" in config_ids - assert "finetune_run_config::project1::task1::ft_running" in config_ids - assert "finetune_run_config::project1::task1::ft_failed" in config_ids - assert "finetune_run_config::project1::task1::ft_unknown" in config_ids + assert "finetune_run_config::project1::task1::ft_running" not in config_ids + assert "finetune_run_config::project1::task1::ft_failed" not in config_ids + assert "finetune_run_config::project1::task1::ft_unknown" not in config_ids assert "finetune_run_config::project1::task1::ft_no_run_config" not in config_ids diff --git a/app/desktop/studio_server/tool_api.py b/app/desktop/studio_server/tool_api.py index 68462b6f2..db9844115 100644 --- a/app/desktop/studio_server/tool_api.py +++ b/app/desktop/studio_server/tool_api.py @@ -606,13 +606,13 @@ async def add_kiln_task_tool( name=tool_data.name, type=ToolServerType.kiln_task, description=tool_data.description, - properties={ - "name": tool_data.name, - "description": tool_data.description, - "task_id": tool_data.task_id, - "run_config_id": tool_data.run_config_id, - "is_archived": tool_data.is_archived, - }, + properties=KilnTaskServerProperties( + name=tool_data.name, + description=tool_data.description, + task_id=tool_data.task_id, + run_config_id=tool_data.run_config_id, + is_archived=tool_data.is_archived, + ), parent=project, ) diff --git a/app/web_ui/package-lock.json b/app/web_ui/package-lock.json index 1aa5bfcab..cee77cd33 100644 --- a/app/web_ui/package-lock.json +++ b/app/web_ui/package-lock.json @@ -1524,15 +1524,15 @@ } }, "node_modules/@redocly/cli/node_modules/glob": { - "version": "11.0.3", - "resolved": "https://registry.npmjs.org/glob/-/glob-11.0.3.tgz", - "integrity": "sha512-2Nim7dha1KVkaiF4q6Dj+ngPPMdfvLJEOpZk/jKiUAkqKebpGAWQXAq9z1xu9HKu5lWfqw/FASuccEjyznjPaA==", + "version": "11.1.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-11.1.0.tgz", + "integrity": "sha512-vuNwKSaKiqm7g0THUBu2x7ckSs3XJLXE+2ssL7/MfTGPLLcrJQ/4Uq1CjPTtO5cCIiRxqvN6Twy1qOwhL0Xjcw==", "dev": true, - "license": "ISC", + "license": "BlueOak-1.0.0", "dependencies": { "foreground-child": "^3.3.1", "jackspeak": "^4.1.1", - "minimatch": "^10.0.3", + "minimatch": "^10.1.1", "minipass": "^7.1.2", "package-json-from-dist": "^1.0.0", "path-scurry": "^2.0.0" @@ -6832,23 +6832,22 @@ } }, "node_modules/sucrase/node_modules/glob": { - "version": "10.3.16", - "resolved": "https://registry.npmjs.org/glob/-/glob-10.3.16.tgz", - "integrity": "sha512-JDKXl1DiuuHJ6fVS2FXjownaavciiHNUU4mOvV/B793RLh05vZL1rcPnCSaOgv1hDT6RDlY7AB7ZUvFYAtPgAw==", + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", "dev": true, + "license": "ISC", "dependencies": { "foreground-child": "^3.1.0", "jackspeak": "^3.1.2", - "minimatch": "^9.0.1", - "minipass": "^7.0.4", - "path-scurry": "^1.11.0" + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" }, "bin": { "glob": "dist/esm/bin.mjs" }, - "engines": { - "node": ">=16 || 14 >=14.18" - }, "funding": { "url": "https://github.com/sponsors/isaacs" } diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index 340cac7e3..6d201d877 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -1575,23 +1575,6 @@ export interface paths { patch?: never; trace?: never; }; - "/api/projects/{project_id}/tasks/{task_id}/task_run_configs": { - parameters: { - query?: never; - header?: never; - path?: never; - cookie?: never; - }; - /** Get Task Run Configs */ - get: operations["get_task_run_configs_api_projects__project_id__tasks__task_id__task_run_configs_get"]; - put?: never; - post?: never; - delete?: never; - options?: never; - head?: never; - patch?: never; - trace?: never; - }; "/api/projects/{project_id}/tasks/{task_id}/run_configs/": { parameters: { query?: never; @@ -9180,38 +9163,6 @@ export interface operations { }; }; }; - get_task_run_configs_api_projects__project_id__tasks__task_id__task_run_configs_get: { - parameters: { - query?: never; - header?: never; - path: { - project_id: string; - task_id: string; - }; - cookie?: never; - }; - requestBody?: never; - responses: { - /** @description Successful Response */ - 200: { - headers: { - [name: string]: unknown; - }; - content: { - "application/json": components["schemas"]["TaskRunConfig"][]; - }; - }; - /** @description Validation Error */ - 422: { - headers: { - [name: string]: unknown; - }; - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; get_run_configs_api_projects__project_id__tasks__task_id__run_configs__get: { parameters: { query?: never; diff --git a/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/+page.svelte b/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/+page.svelte index c9f9c4a36..698e2a6a1 100644 --- a/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/+page.svelte +++ b/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/+page.svelte @@ -9,6 +9,7 @@ import type { components } from "$lib/api_schema" import RunEval from "../[eval_id]/run_eval.svelte" import CompareChart from "./compare_chart.svelte" + import CompareRadarChart from "./compare_radar_chart.svelte" type RunConfigEvalScoresSummary = components["schemas"]["RunConfigEvalScoresSummary"] type ScoreSummary = components["schemas"]["ScoreSummary"] @@ -562,11 +563,20 @@ return percentDiff >= 0 ? `+${formatted}%` : `${formatted}%` } - function getValidSelectedModels(): string[] { - return selectedModels.filter( - (m): m is string => m !== null && m !== "__create_new_run_config__", - ) - } + // Reactive valid selected models - must be reactive ($:) for template to update + $: validSelectedModels = selectedModels.filter( + (m): m is string => m !== null && m !== "__create_new_run_config__", + ) + + $: allSelectedLoading = validSelectedModels.every( + (modelId) => + eval_scores_loading[modelId] || + (!eval_scores_cache[modelId] && !eval_scores_errors[modelId]), + ) + + $: anyLoadedData = validSelectedModels.some( + (modelId) => eval_scores_cache[modelId], + ) {:else} - {@const hasSelectedModels = getValidSelectedModels()} - {@const allSelectedLoading = hasSelectedModels.every( - (modelId) => - eval_scores_loading[modelId] || - (!eval_scores_cache[modelId] && !eval_scores_errors[modelId]), - )} - {@const anyLoadedData = hasSelectedModels.some( - (modelId) => eval_scores_cache[modelId], - )}
- {#if allSelectedLoading && !anyLoadedData && hasSelectedModels.length > 0} + {#if allSelectedLoading && !anyLoadedData && validSelectedModels.length > 0}
- {#if hasSelectedModels.length > 0} + {#if validSelectedModels.length > 0} {#each comparisonFeatures as section}
@@ -917,6 +918,21 @@ {/if}
+ {#if validSelectedModels.length > 0} +
+ +
+ {/if} +
+
+ + + +

No Data Available

+

+ Create and run evals to see a comparison chart. +

+
+
diff --git a/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/compare_chart.svelte b/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/compare_chart.svelte index 3b2c6fe06..00108ff84 100644 --- a/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/compare_chart.svelte +++ b/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/compare_chart.svelte @@ -12,6 +12,7 @@ getRunConfigPromptDisplayName, } from "$lib/utils/run_config_formatters" import { provider_name_from_id } from "$lib/stores" + import ChartNoData from "./chart_no_data.svelte" // Type for comparison features (same as parent page) type ComparisonFeature = { @@ -69,18 +70,26 @@ } } + // Get simple display name for the series (used as the internal name/key) function getRunConfigDisplayName(config: TaskRunConfig): string { - const modelId = config.run_config_properties?.model_name - const providerId = config.run_config_properties?.model_provider_name + return config.name || getDetailedModelName(config, model_info) || "Unknown" + } - if (modelId && providerId && model_info?.models) { - const key = `${providerId}/${modelId}` - if (model_info.models[key]) { - return model_info.models[key].name - } - } + // Build a map from display name to full legend text (name, model, prompt) + function buildLegendFormatter(): Record { + const formatter: Record = {} + for (const config of run_configs) { + if (!config.id) continue + + const displayName = getRunConfigDisplayName(config) + const modelName = getDetailedModelName(config, model_info) || "Unknown" + const promptName = getRunConfigPromptDisplayName(config, prompts) - return config.name || "Unknown" + // Multi-line legend: display name on first line, model and prompt on 2nd/3rd + formatter[displayName] = + `${displayName}\n{sub|Model: ${modelName}}\n{sub|Prompt: ${promptName}}` + } + return formatter } function getAxisLabel(dataKey: string | null): string { @@ -142,6 +151,9 @@ type: "scatter", data: [[xValue, yValue, configId]], symbolSize: 15, + emphasis: { + scale: 2, + }, }) } }) @@ -155,6 +167,7 @@ const xAxis = selectedXAxis const yAxis = selectedYAxis const { series, legend } = generateChartData() + const legendFormatter = buildLegendFormatter() chartInstance.setOption( { @@ -192,11 +205,23 @@ legend: { data: legend, orient: "vertical", - right: 10, - top: "center", + left: "70%", + top: "middle", + itemGap: 16, + formatter: (name: string) => legendFormatter[name] || name, + textStyle: { + lineHeight: 16, + rich: { + sub: { + fontSize: 11, + color: "#666", + lineHeight: 14, + }, + }, + }, }, grid: { - right: 180, + right: "34%", left: 60, bottom: 50, }, @@ -234,6 +259,13 @@ ) } + // Reactive check for whether we have any data points to display + $: hasDataPoints = (() => { + if (!selectedXAxis || !selectedYAxis) return false + const { series } = generateChartData() + return series.length > 0 + })() + // Update chart when selections or data change $: if (chartInstance && selectedXAxis && selectedYAxis) { updateChart() @@ -253,6 +285,21 @@ }) resizeObserver.observe(node) + // Add legend hover interaction to highlight corresponding chart points + chartInstance.on("mouseover", "legendItem", (params: { name: string }) => { + chartInstance?.dispatchAction({ + type: "highlight", + seriesName: params.name, + }) + }) + + chartInstance.on("mouseout", "legendItem", (params: { name: string }) => { + chartInstance?.dispatchAction({ + type: "downplay", + seriesName: params.name, + }) + }) + updateChart() return { @@ -269,7 +316,13 @@
-
Chart
+
+
Metric Correlation
+ +
+ Compare all run configurations by any two metrics. +
+
{#if !loading && axisOptions.length > 1}
Loading chart data...
- {:else if axisOptions.length <= 1} -
-
- - - -

No Data Available

-

- Create and run evals to see a comparison chart. -

-
-
+ {:else if axisOptions.length <= 1 || !hasDataPoints} + {:else} -
+
{/if}
diff --git a/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/compare_radar_chart.svelte b/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/compare_radar_chart.svelte new file mode 100644 index 000000000..27a65e173 --- /dev/null +++ b/app/web_ui/src/routes/(app)/evals/[project_id]/[task_id]/compare/compare_radar_chart.svelte @@ -0,0 +1,264 @@ + + + +{#if dataKeys.length >= 3} +
+
Radar Chart
+
+ Compare the evaluation scores of the run configurations selected above. +
+ {#if hasData} +
+ {:else} + + {/if} +
+{/if} diff --git a/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/generated_data_node.svelte b/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/generated_data_node.svelte index ae963e4cc..79385d1af 100644 --- a/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/generated_data_node.svelte +++ b/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/generated_data_node.svelte @@ -14,6 +14,7 @@ import TableButton from "./table_button.svelte" import InfoTooltip from "$lib/ui/info_tooltip.svelte" import RunConfigComponent from "$lib/ui/run_config_component/run_config_component.svelte" + import Dialog from "$lib/ui/dialog.svelte" let custom_topic_mode: boolean = false @@ -22,7 +23,9 @@ const selected_template = guidance_data.selected_template $: project_id = guidance_data.project_id - let run_config_component: RunConfigComponent | null = null + // Separate refs for each RunConfigComponent to avoid null issues when one unmounts + let run_config_component_modal: RunConfigComponent | null = null + let run_config_component_nested: RunConfigComponent | null = null export let data: SampleDataNode export let path: string[] @@ -154,7 +157,7 @@ async function generate_topics() { // Capture run config properties before modal closes and component is destroyed const run_config_properties = - run_config_component?.run_options_as_run_config_properties() ?? null + run_config_component_modal?.run_options_as_run_config_properties() ?? null try { topic_generating = true topic_generation_error = null @@ -251,6 +254,151 @@ // Note: The parent will handle removing this node and triggering save } + // Nested topics logic + type TopicNodeWithPath = { + path: string[] + node: SampleDataNode + } + + function get_all_leaf_topics( + node: SampleDataNode, + path_arg: string[] = [], + ): TopicNodeWithPath[] { + const leaf_topics: TopicNodeWithPath[] = [] + const current_path = node.topic ? [...path_arg, node.topic] : path_arg + + if (node.sub_topics.length === 0) { + leaf_topics.push({ path: current_path, node }) + } else { + for (const sub_topic of node.sub_topics) { + leaf_topics.push(...get_all_leaf_topics(sub_topic, current_path)) + } + } + + return leaf_topics + } + + let adding_nested_topics = false + let nested_topics_error: KilnError | null = null + let add_nested_topics_dialog: Dialog | null = null + + export async function open_add_nested_topics_modal() { + nested_topics_error = null + await tick() + add_nested_topics_dialog?.show() + } + + async function add_nested_topics_to_all_leaf_topics() { + // Capture run config properties before modal closes and component is destroyed + const run_config_properties = + run_config_component_nested?.run_options_as_run_config_properties() ?? + null + if (!run_config_properties) { + nested_topics_error = new KilnError( + "Run config properties not found. Please ensure model and settings are configured.", + ) + return + } + + if (!guidance_data.gen_type) { + nested_topics_error = new KilnError("No generation type selected.") + return + } + + if (!guidance_data.task) { + nested_topics_error = new KilnError("Task not loaded.") + return + } + + adding_nested_topics = true + nested_topics_error = null + + try { + const leaf_topics = get_all_leaf_topics(data) + + if (leaf_topics.length === 0) { + nested_topics_error = new KilnError( + "No leaf topics found to add subtopics to", + ) + adding_nested_topics = false + return + } + + const topic_guidance = get(guidance_data.topic_guidance) + + for (const leaf_topic of leaf_topics) { + const existing_topics = leaf_topic.node.sub_topics.map((t) => t.topic) + const { data: generate_response, error: generate_error } = + await client.POST( + "/api/projects/{project_id}/tasks/{task_id}/generate_categories", + { + body: { + node_path: leaf_topic.path, + num_subtopics: num_subtopics_to_generate, + run_config_properties: run_config_properties, + gen_type: guidance_data.gen_type, + guidance: topic_guidance ? topic_guidance : null, + existing_topics: + existing_topics.length > 0 ? existing_topics : null, + }, + params: { + path: { + project_id, + task_id: guidance_data.task_id, + }, + }, + }, + ) + + if (generate_error) { + throw generate_error + } + + if (!generate_response?.output?.output) { + throw new KilnError("No output returned from server") + } + + const response = JSON.parse(generate_response.output.output) + if ( + !response || + !response.subtopics || + !Array.isArray(response.subtopics) + ) { + throw new KilnError("Invalid response format") + } + + for (const topic of response.subtopics) { + if (!topic) continue + if (leaf_topic.node.sub_topics.find((t) => t.topic === topic)) { + continue + } + leaf_topic.node.sub_topics.push({ + topic, + sub_topics: [], + samples: [], + }) + } + + // Trigger reactivity and save after each leaf topic is processed + // This allows partial saves if an error occurs later + data = data + triggerSave() + } + + posthog.capture("add_nested_topics_to_all", { + num_leaf_topics: leaf_topics.length, + num_subtopics: num_subtopics_to_generate, + }) + + // Close modal on success + add_nested_topics_dialog?.close() + } catch (e) { + nested_topics_error = createKilnError(e) + } finally { + adding_nested_topics = false + } + } + function handleChildDeleteTopic( event: CustomEvent<{ node_to_delete: SampleDataNode }>, ) { @@ -533,7 +681,7 @@
{/if} + + + {#if adding_nested_topics} +
+
+
+ {:else} +
+
+
+ Subtopics per leaf topic count +
+ +
+
+ +
+ {#if guidance_data.task} + + {:else} +
+ Task not loaded. Please refresh the page. +
+ {/if} + {#if nested_topics_error} +
+ {nested_topics_error.message} +
+ {/if} + +
+ {/if} +
diff --git a/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/synth/+page.svelte b/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/synth/+page.svelte index 5bcbe177a..3b3cd37e8 100644 --- a/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/synth/+page.svelte +++ b/app/web_ui/src/routes/(app)/generate/[project_id]/[task_id]/synth/+page.svelte @@ -818,23 +818,42 @@ {#if current_step == 1} {@const has_topics = $saved_state.root_node.sub_topics.length > 0} - - + {#if has_topics} + + + + {:else} + + + {/if} {:else if current_step == 2} {@const done_generating = input_generated_count > 0 && diff --git a/app/web_ui/src/routes/(fullscreen)/setup/(setup)/connect_providers/connect_providers.svelte b/app/web_ui/src/routes/(fullscreen)/setup/(setup)/connect_providers/connect_providers.svelte index a53a3e924..76933c949 100644 --- a/app/web_ui/src/routes/(fullscreen)/setup/(setup)/connect_providers/connect_providers.svelte +++ b/app/web_ui/src/routes/(fullscreen)/setup/(setup)/connect_providers/connect_providers.svelte @@ -130,7 +130,7 @@ "Install the gcloud CLI, then run `gcloud auth application-default login` in the terminal. This will add Google Vertex credentials to your environment.", "Create a project in the console, enable Vertex AI for that project, and click 'Enable Recommended APIs' in the Vertex AI console.", "Add the project ID below. Be sure to use the project ID, not the project name.", - "Add a Google Cloud location, example: 'us-central1'. We suggest 'us-central1' as it has the widest model support.", + "Add a Google Cloud location, example: 'global'. We suggest 'global' as it has more recent Gemini models.", "Click connect.", ], api_key_fields: ["Project ID", "Project Location"], diff --git a/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte b/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte index 387074474..e895d522b 100644 --- a/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte +++ b/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte @@ -281,14 +281,17 @@ {#if !onboarding} - + + {#if !creating} + + {/if} /dev/null 2>&1; then find . -type f | grep -v "/node_modules/" | grep -v "/\." | grep -v "/dist/" | grep -v "/desktop/build/" | grep -v "/app/web_ui/build/" | xargs misspell -error @@ -65,12 +70,8 @@ else echo "Skipping Web UI: no files changed" fi - -# Check if python files were changed, and run tests/typecheck if so +# Check if python files were changed, and run tests if so if [ "$staged_only" = false ] || echo "$changed_files" | grep -q "\.py$"; then - echo "${headerStart}Checking Python Types${headerEnd}" - pyright . - echo "${headerStart}Running Python Tests${headerEnd}" python3 -m pytest --benchmark-quiet -q -n auto . else diff --git a/hooks_mcp.yaml b/hooks_mcp.yaml index a3cb9dc56..32d61d52c 100644 --- a/hooks_mcp.yaml +++ b/hooks_mcp.yaml @@ -24,7 +24,7 @@ actions: - name: "typecheck_python" description: "Typecheck the source code" - command: "uv run pyright ." + command: "uvx ty check" - name: "test_file_python" description: "Run tests in a specific python file or directory" diff --git a/libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py b/libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py index 49b4a1bb3..b2f813567 100644 --- a/libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py +++ b/libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py @@ -137,7 +137,7 @@ def serialize_r1_style_message(thinking: str | None, final_output: str): def generate_chat_message_list( training_chat: list[ChatMessage], -) -> list[dict[str, str | None]]: +) -> list[dict[str, Any]]: """Generate OpenAI chat list. Not the full OpenAI body, just the list of messages.""" messages: list[dict[str, Any]] = [] @@ -184,7 +184,7 @@ def generate_chat_message_response( ) -> Dict[str, Any]: """Generate OpenAI chat format with plaintext response""" - messages: list[dict[str, str | None]] = generate_chat_message_list(training_chat) + messages: list[dict[str, Any]] = generate_chat_message_list(training_chat) result: Dict[str, Any] = {"messages": messages} diff --git a/libs/core/kiln_ai/adapters/fine_tune/finetune_run_config_id.py b/libs/core/kiln_ai/adapters/fine_tune/finetune_run_config_id.py new file mode 100644 index 000000000..5813f76d5 --- /dev/null +++ b/libs/core/kiln_ai/adapters/fine_tune/finetune_run_config_id.py @@ -0,0 +1,23 @@ +from kiln_ai.adapters.provider_tools import finetune_from_id +from kiln_ai.datamodel.finetune import Finetune + + +def finetune_run_config_id(project_id: str, task_id: str, finetune_id: str) -> str: + """ + Build in-memory ID for run-config inside a finetune. Format: finetune_run_config::project_id::task_id::finetune_id + project_id::task_id::finetune_id is Finetune.model_id() + """ + return f"finetune_run_config::{project_id}::{task_id}::{finetune_id}" + + +def finetune_from_finetune_run_config_id(finetune_run_config_id: str) -> Finetune: + """ + Get the finetune from a finetune run config ID. + """ + if not finetune_run_config_id.startswith("finetune_run_config::"): + raise ValueError( + f"Invalid finetune run config ID: {finetune_run_config_id}, expected format: finetune_run_config::project_id::task_id::finetune_id" + ) + + model_id = finetune_run_config_id.removeprefix("finetune_run_config::") + return finetune_from_id(model_id) diff --git a/libs/core/kiln_ai/adapters/fine_tune/test_finetune_run_config_id.py b/libs/core/kiln_ai/adapters/fine_tune/test_finetune_run_config_id.py new file mode 100644 index 000000000..e0e740d4d --- /dev/null +++ b/libs/core/kiln_ai/adapters/fine_tune/test_finetune_run_config_id.py @@ -0,0 +1,54 @@ +from unittest.mock import Mock, patch + +import pytest + +from kiln_ai.adapters.fine_tune.finetune_run_config_id import ( + finetune_from_finetune_run_config_id, + finetune_run_config_id, +) +from kiln_ai.datamodel import Finetune + + +def test_finetune_run_config_id(): + """Test that finetune_run_config_id builds the correct ID format""" + project_id = "project-123" + task_id = "task-456" + finetune_id = "finetune-789" + + result = finetune_run_config_id(project_id, task_id, finetune_id) + + assert result == "finetune_run_config::project-123::task-456::finetune-789" + + +@patch("kiln_ai.adapters.fine_tune.finetune_run_config_id.finetune_from_id") +def test_finetune_from_finetune_run_config_id_valid(mock_finetune_from_id): + """Test that finetune_from_finetune_run_config_id correctly parses valid IDs""" + mock_finetune = Mock(spec=Finetune) + mock_finetune_from_id.return_value = mock_finetune + + finetune_run_config_id_str = ( + "finetune_run_config::project-123::task-456::finetune-789" + ) + result = finetune_from_finetune_run_config_id(finetune_run_config_id_str) + + mock_finetune_from_id.assert_called_once_with("project-123::task-456::finetune-789") + assert result == mock_finetune + + +@pytest.mark.parametrize( + "invalid_id", + [ + "invalid_format", + "wrong_prefix::project::task::finetune", + "", + ], +) +@patch("kiln_ai.adapters.fine_tune.finetune_run_config_id.finetune_from_id") +def test_finetune_from_finetune_run_config_id_invalid( + mock_finetune_from_id, invalid_id +): + """Test that finetune_from_finetune_run_config_id raises ValueError for invalid IDs""" + with pytest.raises(ValueError, match="Invalid finetune run config ID"): + finetune_from_finetune_run_config_id(invalid_id) + + mock_finetune_from_id.assert_not_called() diff --git a/libs/core/kiln_ai/adapters/ml_model_list.py b/libs/core/kiln_ai/adapters/ml_model_list.py index 1d9b8e466..b0dc9930c 100644 --- a/libs/core/kiln_ai/adapters/ml_model_list.py +++ b/libs/core/kiln_ai/adapters/ml_model_list.py @@ -127,7 +127,9 @@ class ModelName(str, Enum): gemini_2_5_flash = "gemini_2_5_flash" gemini_2_5_flash_lite = "gemini_2_5_flash_lite" gemini_3_pro_preview = "gemini_3_pro_preview" + gemini_3_flash = "gemini_3_flash" nemotron_70b = "nemotron_70b" + nemotron_3_nano = "nemotron_3_nano" mixtral_8x7b = "mixtral_8x7b" qwen_2p5_7b = "qwen_2p5_7b" qwen_2p5_14b = "qwen_2p5_14b" @@ -193,6 +195,7 @@ class ModelName(str, Enum): kimi_k2_0905 = "kimi_k2_0905" kimi_k2_thinking = "kimi_k2_thinking" kimi_dev_72b = "kimi_dev_72b" + glm_4_7 = "glm_4_7" glm_4_6 = "glm_4_6" glm_4_6v = "glm_4_6v" glm_4_5v = "glm_4_5v" @@ -1459,7 +1462,6 @@ class KilnModel(BaseModel): KilnMimeType.PNG, ], gemini_reasoning_enabled=True, - thinking_level="medium", ), KilnModelProvider( name=ModelProviderName.gemini_api, @@ -1492,7 +1494,87 @@ class KilnModel(BaseModel): max_parallel_requests=2, thinking_level="medium", ), - # Vertex isn't working yet: they have a page up, but the API can't find the model ID. + KilnModelProvider( + name=ModelProviderName.vertex, + model_id="gemini-3-pro-preview", + structured_output_mode=StructuredOutputMode.json_schema, + suggested_for_data_gen=True, + suggested_for_evals=True, + reasoning_capable=True, + gemini_reasoning_enabled=True, + thinking_level="medium", + ), + ], + ), + # Gemini 3 Flash + KilnModel( + family=ModelFamily.gemini, + name=ModelName.gemini_3_flash, + friendly_name="Gemini 3 Flash", + providers=[ + KilnModelProvider( + name=ModelProviderName.openrouter, + model_id="google/gemini-3-flash-preview", + structured_output_mode=StructuredOutputMode.json_schema, + reasoning_capable=True, + suggested_for_data_gen=True, + suggested_for_evals=True, + supports_doc_extraction=True, + multimodal_capable=True, + supports_vision=True, + multimodal_mime_types=[ + # documents + KilnMimeType.PDF, + KilnMimeType.CSV, + KilnMimeType.TXT, + KilnMimeType.HTML, + KilnMimeType.MD, + # images + KilnMimeType.JPG, + KilnMimeType.PNG, + ], + gemini_reasoning_enabled=True, + ), + KilnModelProvider( + name=ModelProviderName.gemini_api, + model_id="gemini-3-flash-preview", + structured_output_mode=StructuredOutputMode.json_schema, + suggested_for_data_gen=True, + suggested_for_evals=True, + supports_doc_extraction=True, + multimodal_capable=True, + supports_vision=True, + multimodal_mime_types=[ + # documents + KilnMimeType.PDF, + KilnMimeType.CSV, + KilnMimeType.TXT, + KilnMimeType.HTML, + KilnMimeType.MD, + # images + KilnMimeType.JPG, + KilnMimeType.PNG, + # audio + KilnMimeType.MP3, + KilnMimeType.WAV, + KilnMimeType.OGG, + # video + KilnMimeType.MP4, + KilnMimeType.MOV, + ], + reasoning_capable=True, + thinking_level="medium", + ), + KilnModelProvider( + name=ModelProviderName.vertex, + model_id="gemini-3-flash-preview", + structured_output_mode=StructuredOutputMode.json_schema, + suggested_for_data_gen=True, + suggested_for_evals=True, + reasoning_capable=True, + gemini_reasoning_enabled=True, + thinking_level="medium", + ), ], ), # Gemini 2.5 Pro @@ -1887,6 +1969,27 @@ class KilnModel(BaseModel): ), ], ), + # Nemotron 3 Nano + KilnModel( + family=ModelFamily.llama, + name=ModelName.nemotron_3_nano, + friendly_name="Nemotron 3 Nano", + providers=[ + KilnModelProvider( + name=ModelProviderName.openrouter, + model_id="nvidia/nemotron-3-nano-30b-a3b:free", + structured_output_mode=StructuredOutputMode.json_schema, + reasoning_capable=True, + ), + KilnModelProvider( + name=ModelProviderName.ollama, + model_id="nemotron-3-nano", + structured_output_mode=StructuredOutputMode.json_schema, + reasoning_capable=True, + ollama_model_aliases=["nemotron-3-nano:30b"], + ), + ], + ), # Nemotron 70B KilnModel( family=ModelFamily.llama, @@ -5049,6 +5152,27 @@ class KilnModel(BaseModel): ), ], ), + # GLM 4.7 + KilnModel( + family=ModelFamily.glm, + name=ModelName.glm_4_7, + friendly_name="GLM 4.7", + providers=[ + KilnModelProvider( + name=ModelProviderName.openrouter, + model_id="z-ai/glm-4.7", + structured_output_mode=StructuredOutputMode.json_instructions, + reasoning_capable=True, + ), + KilnModelProvider( + name=ModelProviderName.siliconflow_cn, + model_id="Pro/zai-org/GLM-4.7", + structured_output_mode=StructuredOutputMode.json_instructions, + reasoning_capable=True, + reasoning_optional_for_structured_output=True, + ), + ], + ), # GLM 4.6 KilnModel( family=ModelFamily.glm, diff --git a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py index cce9b4a60..a5cace346 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/base_adapter.py @@ -12,7 +12,7 @@ from kiln_ai.adapters.parsers.json_parser import parse_json_string from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id from kiln_ai.adapters.parsers.request_formatters import request_formatter_from_id -from kiln_ai.adapters.prompt_builders import prompt_builder_from_id +from kiln_ai.adapters.prompt_builders import BasePromptBuilder, prompt_builder_from_id from kiln_ai.adapters.provider_tools import kiln_model_provider_from from kiln_ai.adapters.run_output import RunOutput from kiln_ai.datamodel import ( @@ -44,6 +44,13 @@ class AdapterConfig: top_logprobs: int | None = None default_tags: list[str] | None = None + """ + A custom prompt builder can be injected to override the system prompt building process. + If not provided, the prompt builder will be created from the run_config.prompt_id which + may load additional files from disk. + """ + prompt_builder: BasePromptBuilder | None = None + class BaseAdapter(metaclass=ABCMeta): """Base class for AI model adapters that handle task execution. @@ -51,6 +58,10 @@ class BaseAdapter(metaclass=ABCMeta): This abstract class provides the foundation for implementing model-specific adapters that can process tasks with structured or unstructured inputs/outputs. It handles input/output validation, prompt building, and run tracking. + + Prompt building is handled internally by the adapter, which uses a prompt builder + based on the run config. To override the prompt building behavior, pass a custom prompt + builder to the adapter config. """ def __init__( @@ -62,12 +73,16 @@ def __init__( self.task = task self.run_config = run_config self.update_run_config_unknown_structured_output_mode() - self.prompt_builder = prompt_builder_from_id(run_config.prompt_id, task) + self.base_adapter_config = config or AdapterConfig() + + self.prompt_builder = ( + self.base_adapter_config.prompt_builder + or prompt_builder_from_id(run_config.prompt_id, task) + ) self._model_provider: KilnModelProvider | None = None self.output_schema = task.output_json_schema self.input_schema = task.input_json_schema - self.base_adapter_config = config or AdapterConfig() def model_provider(self) -> KilnModelProvider: """ diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py index b5ff70397..0225c023b 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py @@ -3,7 +3,12 @@ import pytest from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode -from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput +from kiln_ai.adapters.model_adapters.base_adapter import ( + AdapterConfig, + BaseAdapter, + RunOutput, +) +from kiln_ai.adapters.prompt_builders import BasePromptBuilder from kiln_ai.datamodel import Task from kiln_ai.datamodel.datamodel_enums import ChatStrategy from kiln_ai.datamodel.project import Project @@ -619,3 +624,36 @@ async def mock_name2(): # Should raise ValueError when tools have duplicate names with pytest.raises(ValueError, match="Each tool must have a unique name"): await adapter.available_tools() + + +async def test_custom_prompt_builder(base_task): + """Test that custom prompt builder can be injected via AdapterConfig""" + + # Create a custom prompt builder + class CustomPromptBuilder(BasePromptBuilder): + def build_base_prompt(self) -> str: + return "This is a custom prompt from injected builder" + + custom_builder = CustomPromptBuilder(base_task) + + adapter = MockAdapter( + task=base_task, + run_config=RunConfigProperties( + model_name="test_model", + model_provider_name="openai", + prompt_id="simple_prompt_builder", + structured_output_mode="json_schema", + ), + config=AdapterConfig(prompt_builder=custom_builder), + ) + + # Mock model provider + provider = MagicMock() + provider.reasoning_capable = False + provider.tuned_chat_strategy = None + adapter.model_provider = MagicMock(return_value=provider) + + # Test that the custom prompt builder is used + formatter = adapter.build_chat_formatter(input="test input") + assert formatter.system_message == "This is a custom prompt from injected builder" + assert adapter.prompt_builder == custom_builder diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py index 6e38a2d1f..adb74ae69 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_litellm_adapter.py @@ -1170,7 +1170,7 @@ async def test_array_input_converted_to_json(tmp_path, config): task.save_to_file() config.run_config_properties.model_name = "gpt-4o-mini" - config.run_config_properties.model_provider_name = "openai" + config.run_config_properties.model_provider_name = ModelProviderName.openai adapter = LiteLlmAdapter(config=config, kiln_task=task) mock_response = ModelResponse( @@ -1240,7 +1240,7 @@ async def test_dict_input_converted_to_json(tmp_path, config): task.save_to_file() config.run_config_properties.model_name = "gpt-4o-mini" - config.run_config_properties.model_provider_name = "openai" + config.run_config_properties.model_provider_name = ModelProviderName.openai adapter = LiteLlmAdapter(config=config, kiln_task=task) mock_response = ModelResponse( diff --git a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py index 63d6ec4fc..7a99f517b 100644 --- a/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/model_adapters/test_saving_adapter_results.py @@ -2,17 +2,8 @@ import pytest -from kiln_ai.adapters.model_adapters.base_adapter import ( - BaseAdapter, - RunOutput, -) -from kiln_ai.datamodel import ( - DataSource, - DataSourceType, - Project, - Task, - Usage, -) +from kiln_ai.adapters.model_adapters.base_adapter import BaseAdapter, RunOutput +from kiln_ai.datamodel import DataSource, DataSourceType, Project, Task, Usage from kiln_ai.datamodel.datamodel_enums import InputType from kiln_ai.datamodel.task import RunConfigProperties from kiln_ai.utils.config import Config diff --git a/libs/core/kiln_ai/adapters/provider_tools.py b/libs/core/kiln_ai/adapters/provider_tools.py index 06211a992..a0a8e0a48 100644 --- a/libs/core/kiln_ai/adapters/provider_tools.py +++ b/libs/core/kiln_ai/adapters/provider_tools.py @@ -114,6 +114,28 @@ def builtin_model_from( return provider +def built_in_provider_from_model_id( + model_id: str, provider_name: ModelProviderName | str +) -> KilnModelProvider | None: + """ + Look up a built-in provider entry by the provider and its provider-specific model ID. + """ + try: + provider_enum = ( + provider_name + if isinstance(provider_name, ModelProviderName) + else ModelProviderName(provider_name) + ) + except ValueError: + return None + + for model in built_in_models: + for provider in model.providers: + if provider.name == provider_enum and provider.model_id == model_id: + return provider + return None + + def core_provider(model_id: str, provider_name: ModelProviderName) -> ModelProviderName: """ Get the provider that should be run. @@ -234,6 +256,26 @@ def finetune_from_id(model_id: str) -> Finetune: return fine_tune +def parser_from_finetune( + fine_tune: Finetune, +) -> ModelParserID | None: + """ + Use the finetune's base model to look for parser information. This is to cover the case where a R1 model is fine-tuned without thinking data. + The model would still output thinking data despite the data_strategy being single_turn. + """ + + # Look up the base model provider and check if there is a parser set + base_model_provider = built_in_provider_from_model_id( + fine_tune.base_model_id, fine_tune.provider + ) + + if base_model_provider and base_model_provider.parser: + return base_model_provider.parser + + # Otherwise, use the data strategy to determine the parser + return parser_from_data_strategy(fine_tune.data_strategy) + + def parser_from_data_strategy( data_strategy: ChatStrategy, ) -> ModelParserID | None: @@ -251,7 +293,7 @@ def finetune_provider_model( model_provider = KilnModelProvider( name=provider, model_id=fine_tune.fine_tune_model_id, - parser=parser_from_data_strategy(fine_tune.data_strategy), + parser=parser_from_finetune(fine_tune), reasoning_capable=( fine_tune.data_strategy in [ diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 6484a8dff..0018a41bf 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -35,7 +35,7 @@ TaskRun, Usage, ) -from kiln_ai.datamodel.datamodel_enums import ChatStrategy +from kiln_ai.datamodel.datamodel_enums import ChatStrategy, InputType from kiln_ai.datamodel.task import RunConfigProperties, TaskRunConfig logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ def test_short_prompt_builder(tmp_path): class MockAdapter(BaseAdapter): - async def _run(self, input: str) -> tuple[RunOutput, Usage | None]: + async def _run(self, input: InputType) -> tuple[RunOutput, Usage | None]: return RunOutput(output="mock response", intermediate_outputs=None), None def adapter_name(self) -> str: diff --git a/libs/core/kiln_ai/adapters/test_provider_tools.py b/libs/core/kiln_ai/adapters/test_provider_tools.py index df6b1162d..6c8a6d520 100644 --- a/libs/core/kiln_ai/adapters/test_provider_tools.py +++ b/libs/core/kiln_ai/adapters/test_provider_tools.py @@ -13,6 +13,7 @@ from kiln_ai.adapters.ollama_tools import OllamaConnection from kiln_ai.adapters.provider_tools import ( LiteLlmCoreConfig, + built_in_provider_from_model_id, builtin_model_from, check_provider_warnings, core_provider, @@ -24,6 +25,7 @@ lite_llm_core_config_for_provider, lite_llm_provider_model, parse_custom_model_id, + parser_from_finetune, provider_enabled, provider_name_from_id, provider_warnings, @@ -69,6 +71,7 @@ def mock_finetune(): with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock: finetune = Mock(spec=Finetune) finetune.provider = ModelProviderName.openai + finetune.base_model_id = "gpt-4o" finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123" finetune.structured_output_mode = StructuredOutputMode.json_schema finetune.data_strategy = ChatStrategy.single_turn @@ -81,6 +84,7 @@ def mock_finetune_final_and_intermediate(): with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock: finetune = Mock(spec=Finetune) finetune.provider = ModelProviderName.openai + finetune.base_model_id = "gpt-4o" finetune.fine_tune_model_id = "ft:gpt-3.5-turbo:custom:model-123" finetune.structured_output_mode = StructuredOutputMode.json_schema finetune.data_strategy = ChatStrategy.two_message_cot @@ -93,6 +97,7 @@ def mock_finetune_r1_compatible(): with patch("kiln_ai.datamodel.Finetune.from_id_and_parent_path") as mock: finetune = Mock(spec=Finetune) finetune.provider = ModelProviderName.ollama + finetune.base_model_id = "deepseek-r1:8b" finetune.fine_tune_model_id = "ft:deepseek-r1:671b:custom:model-123" finetune.structured_output_mode = StructuredOutputMode.json_schema finetune.data_strategy = ChatStrategy.single_turn_r1_thinking @@ -261,6 +266,32 @@ def test_get_model_and_provider_multiple_providers(): assert provider.model_id == "llama-3.3-70b-versatile" +def test_built_in_provider_from_model_id_found(): + """Test finding a provider by model_id and provider name""" + provider = built_in_provider_from_model_id( + "deepseek-r1:8b", ModelProviderName.ollama + ) + + assert provider is not None + assert provider.name == ModelProviderName.ollama + assert provider.model_id == "deepseek-r1:8b" + assert provider.parser == ModelParserID.r1_thinking + + +@pytest.mark.parametrize( + "model_id,provider_name", + [ + ("nonexistent-model", ModelProviderName.ollama), + ("deepseek-r1:8b", ModelProviderName.openai), + ("gpt-4o", "invalid_provider"), + ], +) +def test_built_in_provider_from_model_id_returns_none(model_id, provider_name): + """Test that None is returned for various invalid lookups""" + provider = built_in_provider_from_model_id(model_id, provider_name) + assert provider is None + + @pytest.mark.asyncio async def test_provider_enabled_ollama_success(): with patch( @@ -460,8 +491,11 @@ async def test_builtin_model_from_provider_warning_check(mock_config): assert provider_warnings[ModelProviderName.groq].message in str(exc_info.value) -def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune): +def test_finetune_provider_model_success( + mock_project, mock_task, mock_finetune, mock_config +): """Test successful creation of a fine-tuned model provider""" + mock_config.return_value = "fake-api-key" model_id = "project-123::task-456::finetune-789" provider = finetune_provider_model(model_id) @@ -474,9 +508,10 @@ def test_finetune_provider_model_success(mock_project, mock_task, mock_finetune) def test_finetune_provider_model_success_final_and_intermediate( - mock_project, mock_task, mock_finetune_final_and_intermediate + mock_project, mock_task, mock_finetune_final_and_intermediate, mock_config ): """Test successful creation of a fine-tuned model provider""" + mock_config.return_value = "fake-api-key" model_id = "project-123::task-456::finetune-789" provider = finetune_provider_model(model_id) @@ -581,13 +616,17 @@ def test_finetune_provider_model_structured_mode( mock_project, mock_task, mock_finetune, + mock_config, structured_output_mode, provider_name, expected_mode, ): """Test creation of provider with different structured output modes""" + mock_config.return_value = "fake-api-key" + finetune = Mock(spec=Finetune) finetune.provider = provider_name + finetune.base_model_id = "gpt-4o" finetune.fine_tune_model_id = "fireworks-model-123" finetune.structured_output_mode = structured_output_mode finetune.data_strategy = ChatStrategy.single_turn @@ -923,6 +962,7 @@ def test_finetune_provider_model_vertex_ai(mock_project, mock_task, mock_finetun """Test creation of provider for Vertex AI with endpoint ID transformation""" finetune = Mock(spec=Finetune) finetune.provider = ModelProviderName.vertex + finetune.base_model_id = "gemini-1.5-flash" finetune.fine_tune_model_id = "projects/123/locations/us-central1/endpoints/456" finetune.structured_output_mode = StructuredOutputMode.json_mode finetune.data_strategy = ChatStrategy.single_turn @@ -1177,3 +1217,59 @@ def test_provider_name_from_id_docker_model_runner(): """Test provider_name_from_id for Docker Model Runner""" result = provider_name_from_id(ModelProviderName.docker_model_runner) assert result == "Docker Model Runner" + + +def test_parser_from_finetune_model_parser_takes_precedence(): + """Test that parser from base model in ml_model_list takes precedence over data_strategy""" + finetune = Finetune( + name="test-finetune", + provider=ModelProviderName.ollama, + base_model_id="deepseek-r1:8b", + fine_tune_model_id="ft:deepseek-r1:custom:model-123", + dataset_split_id="test-split", + system_message="You are a helpful assistant.", + data_strategy=ChatStrategy.single_turn, + ) + + parser = parser_from_finetune(finetune) + + # deepseek-r1:8b (ollama) has ModelParserID.r1_thinking set in ml_model_list + assert parser == ModelParserID.r1_thinking + + +def test_parser_from_finetune_fallback_to_data_strategy(mock_config): + """Test that parser falls back to data_strategy when model has no parser""" + mock_config.return_value = "fake-api-key" + + finetune = Finetune( + name="test-finetune", + provider=ModelProviderName.fireworks_ai, + base_model_id="accounts/fireworks/models/qwen3-8b", + fine_tune_model_id="ft:gpt-4o:custom:model-123", + dataset_split_id="test-split", + system_message="You are a helpful assistant.", + data_strategy=ChatStrategy.single_turn_r1_thinking, + ) + + parser = parser_from_finetune(finetune) + + assert parser == ModelParserID.r1_thinking + + +def test_parser_from_finetune_no_parser(mock_config): + """Test that None is returned when neither model nor data_strategy has parser""" + mock_config.return_value = "fake-api-key" + + finetune = Finetune( + name="test-finetune", + provider=ModelProviderName.fireworks_ai, + base_model_id="accounts/fireworks/models/qwen3-8b", + fine_tune_model_id="ft:gpt-4o:custom:model-123", + dataset_split_id="test-split", + system_message="You are a helpful assistant.", + data_strategy=ChatStrategy.single_turn, # single turn has no parser + ) + + parser = parser_from_finetune(finetune) + + assert parser is None diff --git a/libs/core/kiln_ai/cli/commands/test_package_project.py b/libs/core/kiln_ai/cli/commands/test_package_project.py index 5b6a71903..e89f007c6 100644 --- a/libs/core/kiln_ai/cli/commands/test_package_project.py +++ b/libs/core/kiln_ai/cli/commands/test_package_project.py @@ -650,13 +650,13 @@ def test_build_dynamic_prompts_confirm_no(self, temp_project_with_dynamic_prompt class TestPackageProjectCommand: - def test_full_validation_flow(self, temp_project): + def test_full_validation_flow(self, temp_project, tmp_path: Path): """Test the complete validation flow.""" result = package_project( project_path=temp_project["path"], tasks=temp_project["task"].id, all_tasks=False, - output=Path("./test_output.zip"), + output=tmp_path / "test_output.zip", ) assert len(result) == 1 @@ -664,25 +664,25 @@ def test_full_validation_flow(self, temp_project): assert task_id in result assert isinstance(result[task_id], Prompt) - def test_all_tasks_flow(self, temp_project_with_multiple_tasks): + def test_all_tasks_flow(self, temp_project_with_multiple_tasks, tmp_path: Path): """Test validation with all tasks.""" result = package_project( project_path=temp_project_with_multiple_tasks["path"], tasks="", all_tasks=True, - output=Path("./test_output.zip"), + output=tmp_path / "test_output.zip", ) assert len(result) == 3 - def test_missing_project_path_error(self): + def test_missing_project_path_error(self, tmp_path: Path): """Test error when no project path is provided.""" with pytest.raises(typer.Exit) as exc_info: package_project( project_path=None, tasks="123", all_tasks=False, - output=Path("./test_output.zip"), + output=tmp_path / "test_output.zip", ) assert exc_info.value.exit_code == 1 @@ -693,18 +693,18 @@ def test_missing_project_error(self, tmp_path: Path): project_path=tmp_path / "nonexistent", tasks="123", all_tasks=False, - output=Path("./test_output.zip"), + output=tmp_path / "test_output.zip", ) assert exc_info.value.exit_code == 1 - def test_no_tasks_error(self, temp_project): + def test_no_tasks_error(self, temp_project, tmp_path: Path): """Test error when no tasks are specified.""" with pytest.raises(typer.Exit) as exc_info: package_project( project_path=temp_project["path"], tasks="", all_tasks=False, - output=Path("./test_output.zip"), + output=tmp_path / "test_output.zip", ) assert exc_info.value.exit_code == 1 diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index 226062a90..25751684e 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -326,7 +326,7 @@ def mutable_copy(self) -> Self: copy = super().model_copy(deep=True) # Reset readonly flag on copies so they can be mutated copy._readonly = False - return copy + return copy # type: ignore[return-value] # if changing the model name, should keep the original name here for parsing old files @classmethod @@ -691,11 +691,11 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta): def _create_child_method( cls, relationship_name: str, child_class: Type[KilnParentedModel] ): - def child_method(self, readonly: bool = False) -> list[child_class]: + def child_method(self, readonly: bool = False) -> list[child_class]: # type: ignore[invalid-type-form] return child_class.all_children_of_parent_path(self.path, readonly=readonly) child_method.__name__ = relationship_name - child_method.__annotations__ = {"return": List[child_class]} + child_method.__annotations__ = {"return": List[child_class]} # type: ignore[invalid-type-form] setattr(cls, relationship_name, child_method) @classmethod @@ -787,7 +787,7 @@ def _validate_nested( kwargs = {"data": value, "save": save} if instance is not None: kwargs["parent"] = instance - parent_type._validate_nested(**kwargs) + parent_type._validate_nested(**kwargs) # type: ignore[invalid-argument-type] elif issubclass(parent_type, KilnParentedModel): # Root node subinstance = parent_type.model_validate(value) diff --git a/libs/core/kiln_ai/datamodel/chunk.py b/libs/core/kiln_ai/datamodel/chunk.py index d98ca87af..8adf5abe7 100644 --- a/libs/core/kiln_ai/datamodel/chunk.py +++ b/libs/core/kiln_ai/datamodel/chunk.py @@ -153,7 +153,7 @@ def semantic_properties(self) -> SemanticChunkerProperties: # or cast (but it is currently banned in our linting rules). Better solution # would be discriminated union, but that requires the discriminator to be part # of the properties (not outside on the parent model). - return self.properties + return self.properties # type: ignore[return-value] @property def fixed_window_properties(self) -> FixedWindowChunkerProperties: @@ -165,7 +165,7 @@ def fixed_window_properties(self) -> FixedWindowChunkerProperties: # or cast (but it is currently banned in our linting rules). Better solution # would be discriminated union, but that requires the discriminator to be part # of the properties (not outside on the parent model). - return self.properties + return self.properties # type: ignore[return-value] class Chunk(BaseModel): diff --git a/libs/core/kiln_ai/datamodel/dataset_filters.py b/libs/core/kiln_ai/datamodel/dataset_filters.py index f2bc3b5ca..8ea772589 100644 --- a/libs/core/kiln_ai/datamodel/dataset_filters.py +++ b/libs/core/kiln_ai/datamodel/dataset_filters.py @@ -1,6 +1,6 @@ import re from enum import Enum -from typing import Annotated, ClassVar, List, Protocol +from typing import Annotated, ClassVar, Dict, List, Protocol from pydantic import AfterValidator @@ -19,7 +19,7 @@ def __call__(self, task_run: TaskRun) -> bool: ... -def AllDatasetFilter(_: TaskRun) -> bool: +def AllDatasetFilter(task_run: TaskRun) -> bool: return True @@ -128,7 +128,7 @@ class StaticDatasetFilters(str, Enum): THINKING_MODEL_HIGH_RATED = "thinking_model_high_rated" -static_dataset_filters = { +static_dataset_filters: Dict[StaticDatasetFilters, DatasetFilter] = { StaticDatasetFilters.ALL: AllDatasetFilter, StaticDatasetFilters.HIGH_RATING: HighRatingDatasetFilter, StaticDatasetFilters.THINKING_MODEL: ThinkingModelDatasetFilter, diff --git a/libs/core/kiln_ai/datamodel/dataset_split.py b/libs/core/kiln_ai/datamodel/dataset_split.py index c7448b061..80c913c5c 100644 --- a/libs/core/kiln_ai/datamodel/dataset_split.py +++ b/libs/core/kiln_ai/datamodel/dataset_split.py @@ -231,7 +231,7 @@ def compute_tool_info(runs: list[TaskRun]) -> DatasetToolInfo: for run in runs: # Extract tools from run config, treating missing source/run_config/tools_config as empty tools run_tools: set[str] = set() - source = run.output and run.output.source + source = run.output.source if run.output else None if source is not None and source.run_config is not None: tools_config = source.run_config.tools_config if tools_config is not None: diff --git a/libs/core/kiln_ai/datamodel/vector_store.py b/libs/core/kiln_ai/datamodel/vector_store.py index 8c0d8e880..81020e58b 100644 --- a/libs/core/kiln_ai/datamodel/vector_store.py +++ b/libs/core/kiln_ai/datamodel/vector_store.py @@ -98,7 +98,7 @@ def lancedb_vector_properties(self) -> LanceDBConfigVectorProperties: raise ValueError( f"Lancedb vector properties are only available for LanceDB vector store type. Got {self.properties.get('store_type')}" ) - return self.properties + return self.properties # type: ignore[return-value] @property def lancedb_hybrid_properties(self) -> LanceDBConfigHybridProperties: @@ -106,7 +106,7 @@ def lancedb_hybrid_properties(self) -> LanceDBConfigHybridProperties: raise ValueError( f"Lancedb hybrid properties are only available for LanceDB hybrid store type. Got {self.properties.get('store_type')}" ) - return self.properties + return self.properties # type: ignore[return-value] @property def lancedb_fts_properties(self) -> LanceDBConfigFTSProperties: @@ -114,7 +114,7 @@ def lancedb_fts_properties(self) -> LanceDBConfigFTSProperties: raise ValueError( f"Lancedb FTS properties are only available for LanceDB FTS store type. Got {self.properties.get('store_type')}" ) - return self.properties + return self.properties # type: ignore[return-value] # Workaround to return typed parent without importing Project def parent_project(self) -> Union["Project", None]: diff --git a/libs/core/kiln_ai/tools/built_in_tools/math_tools.py b/libs/core/kiln_ai/tools/built_in_tools/math_tools.py index f4e33b8e6..fb96c7304 100644 --- a/libs/core/kiln_ai/tools/built_in_tools/math_tools.py +++ b/libs/core/kiln_ai/tools/built_in_tools/math_tools.py @@ -34,7 +34,7 @@ def __init__(self): async def run(self, context=None, **kwargs) -> ToolCallResult: """Add two numbers and return the result.""" - kwargs = AddParams(**kwargs) + kwargs = AddParams(**kwargs) # type: ignore[missing-typed-dict-key] a = kwargs["a"] b = kwargs["b"] return ToolCallResult(output=str(a + b)) @@ -72,7 +72,7 @@ def __init__(self): async def run(self, context=None, **kwargs) -> ToolCallResult: """Subtract b from a and return the result.""" - kwargs = SubtractParams(**kwargs) + kwargs = SubtractParams(**kwargs) # type: ignore[missing-typed-dict-key] a = kwargs["a"] b = kwargs["b"] return ToolCallResult(output=str(a - b)) @@ -107,7 +107,7 @@ def __init__(self): async def run(self, context=None, **kwargs) -> ToolCallResult: """Multiply two numbers and return the result.""" - kwargs = MultiplyParams(**kwargs) + kwargs = MultiplyParams(**kwargs) # type: ignore[missing-typed-dict-key] a = kwargs["a"] b = kwargs["b"] return ToolCallResult(output=str(a * b)) @@ -148,7 +148,7 @@ def __init__(self): async def run(self, context=None, **kwargs) -> ToolCallResult: """Divide a by b and return the result.""" - kwargs = DivideParams(**kwargs) + kwargs = DivideParams(**kwargs) # type: ignore[missing-typed-dict-key] a = kwargs["a"] b = kwargs["b"] if b == 0: diff --git a/libs/core/kiln_ai/tools/rag_tools.py b/libs/core/kiln_ai/tools/rag_tools.py index 7bca8e4cd..c15a13185 100644 --- a/libs/core/kiln_ai/tools/rag_tools.py +++ b/libs/core/kiln_ai/tools/rag_tools.py @@ -227,7 +227,7 @@ async def search(self, query: str) -> List[SearchResult]: async def run( self, context: ToolCallContext | None = None, **kwargs ) -> ToolCallResult: - kwargs = RagParams(**kwargs) + kwargs = RagParams(**kwargs) # type: ignore[missing-typed-dict-key] query = kwargs["query"] search_results = await self.search(query) diff --git a/libs/core/kiln_ai/utils/config.py b/libs/core/kiln_ai/utils/config.py index 6bb2d6ad3..f9f140aba 100644 --- a/libs/core/kiln_ai/utils/config.py +++ b/libs/core/kiln_ai/utils/config.py @@ -1,3 +1,4 @@ +import copy import getpass import os import threading @@ -259,7 +260,7 @@ def settings(self, hide_sensitive=False) -> Dict[str, Any]: settings = { k: "[hidden]" if k in self._properties and self._properties[k].sensitive - else v + else copy.deepcopy(v) for k, v in self._settings.items() } # Hide sensitive keys in lists. Could generalize this if we every have more types, but right not it's only needed for root elements of lists diff --git a/libs/core/kiln_ai/utils/test_config.py b/libs/core/kiln_ai/utils/test_config.py index a256e2947..75891fe88 100644 --- a/libs/core/kiln_ai/utils/test_config.py +++ b/libs/core/kiln_ai/utils/test_config.py @@ -281,6 +281,41 @@ async def test_openai_compatible_providers(): ] +async def test_openai_compatible_providers_not_mutated_by_hide_sensitive(): + config = Config.shared() + assert config.openai_compatible_providers == [] + + new_settings = [ + { + "name": "provider1", + "url": "https://provider1.com", + "api_key": "secret_key_123", + }, + { + "name": "provider2", + "url": "https://provider2.com", + "api_key": "another_secret_key", + }, + ] + config.save_setting("openai_compatible_providers", new_settings) + assert config.openai_compatible_providers == new_settings + + hidden_settings = config.settings(hide_sensitive=True) + assert hidden_settings["openai_compatible_providers"] == [ + { + "name": "provider1", + "url": "https://provider1.com", + "api_key": "[hidden]", + }, + {"name": "provider2", "url": "https://provider2.com", "api_key": "[hidden]"}, + ] + + retrieved_providers = config.openai_compatible_providers + assert retrieved_providers == new_settings + assert retrieved_providers[0]["api_key"] == "secret_key_123" + assert retrieved_providers[1]["api_key"] == "another_secret_key" + + def test_yaml_persistence_structured_data(config_with_yaml, mock_yaml_file): # Set a value new_settings = [ diff --git a/libs/server/kiln_server/document_api.py b/libs/server/kiln_server/document_api.py index eea3e2757..862903ee6 100644 --- a/libs/server/kiln_server/document_api.py +++ b/libs/server/kiln_server/document_api.py @@ -365,17 +365,17 @@ def get_properties_for_chunker_type( case ChunkerType.SEMANTIC: return SemanticChunkerProperties( chunker_type=ChunkerType.SEMANTIC, - embedding_config_id=properties.embedding_config_id, - buffer_size=properties.buffer_size, - breakpoint_percentile_threshold=properties.breakpoint_percentile_threshold, + embedding_config_id=properties.embedding_config_id, # type: ignore[possibly-missing-attribute] + buffer_size=properties.buffer_size, # type: ignore[possibly-missing-attribute] + breakpoint_percentile_threshold=properties.breakpoint_percentile_threshold, # type: ignore[possibly-missing-attribute] include_metadata=False, include_prev_next_rel=False, ) case ChunkerType.FIXED_WINDOW: return FixedWindowChunkerProperties( chunker_type=ChunkerType.FIXED_WINDOW, - chunk_size=properties.chunk_size, - chunk_overlap=properties.chunk_overlap, + chunk_size=properties.chunk_size, # type: ignore[possibly-missing-attribute] + chunk_overlap=properties.chunk_overlap, # type: ignore[possibly-missing-attribute] ) raise_exhaustive_enum_error(properties.chunker_type) @@ -1632,12 +1632,13 @@ async def create_chunker_config( # if semantic, validate that the referenced embedding config exists if request.properties.chunker_type == ChunkerType.SEMANTIC: embedding_config = EmbeddingConfig.from_id_and_parent_path( - request.properties.embedding_config_id, project.path + request.properties.embedding_config_id, # type: ignore[possibly-missing-attribute] + project.path, ) if not embedding_config: raise HTTPException( status_code=404, - detail=f"Embedding config {request.properties.embedding_config_id} not found", + detail=f"Embedding config {request.properties.embedding_config_id} not found", # type: ignore[possibly-missing-attribute] ) chunker_config = ChunkerConfig( @@ -2148,11 +2149,11 @@ async def ephemeral_split_document( chunker_config = ChunkerConfig( name="ephemeral-fixed-window", chunker_type=ChunkerType.FIXED_WINDOW, - properties={ - "chunker_type": ChunkerType.FIXED_WINDOW, - "chunk_size": request.chunk_size, - "chunk_overlap": request.chunk_overlap or 0, - }, + properties=FixedWindowChunkerProperties( + chunker_type=ChunkerType.FIXED_WINDOW, + chunk_size=request.chunk_size, + chunk_overlap=request.chunk_overlap or 0, + ), ) chunker = chunker_adapter_from_type(ChunkerType.FIXED_WINDOW, chunker_config) diff --git a/libs/server/kiln_server/mcp/mcp_server_tool_utils.py b/libs/server/kiln_server/mcp/mcp_server_tool_utils.py index d5f53ae0f..47d88ff7b 100644 --- a/libs/server/kiln_server/mcp/mcp_server_tool_utils.py +++ b/libs/server/kiln_server/mcp/mcp_server_tool_utils.py @@ -78,7 +78,7 @@ async def _build_tool_context(resolution: ToolResolution) -> ToolContext: description=description, inputSchema=parameters, outputSchema=output_schema, - _meta={"kiln_tool_id": resolution.tool_id}, + _meta={"kiln_tool_id": resolution.tool_id}, # type: ignore[unknown-argument] ) return ToolContext( diff --git a/pyproject.toml b/pyproject.toml index bacd3b446..f75f34b76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ dev = [ "diff-cover>=9.6.0", "isort>=5.13.2", - "pyright==1.1.376", "pytest-asyncio>=0.24.0", "pytest>=8.3.3", "pytest-xdist>=3.5", @@ -24,6 +23,7 @@ dev = [ "ruff>=0.13.0", "watchfiles>=1.1.0", "scalar-fastapi>=1.4.3", + "ty>=0.0.2", ] [tool.uv] @@ -37,14 +37,15 @@ kiln-server = { workspace = true } kiln-studio-desktop = { workspace = true } kiln-ai = { workspace = true } -[tool.pyright] -strictListInference = true -reportMissingTypeArgument = true - - [tool.ruff] exclude = [] +[tool.ty.src] +exclude = ["**/test_*.py", "app/desktop/build/**", "app/web_ui/**", "**/.venv"] + +[tool.ty.rules] +invalid-key = "ignore" + [tool.ruff.lint] # I is import sorting # F401 is unused imports diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index b3fb78ad0..000000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "exclude": ["**/test_*.py", "app/desktop/build/**", "app/web_ui/**", "**/.venv", ".*/**"], - "typeCheckingMode": "basic", - "autoImportCompletions": true, - "extraPaths": ["./"], - "reportIncompatibleMethodOverride": "error", -} diff --git a/tessl.json b/tessl.json new file mode 100644 index 000000000..df5390c1a --- /dev/null +++ b/tessl.json @@ -0,0 +1,95 @@ +{ + "name": "project", + "dependencies": { + "tessl/pypi-pillow": { + "version": "11.3.0" + }, + "tessl/pypi-pyinstaller": { + "version": "6.15.0" + }, + "tessl/pypi-certifi": { + "version": "2024.12.0" + }, + "tessl/pypi-kiln-ai": { + "version": "0.22.1" + }, + "tessl/pypi-pyright": { + "version": "1.1.0" + }, + "tessl/pypi-pytest": { + "version": "8.4.0" + }, + "tessl/pypi-pytest-xdist": { + "version": "3.8.0" + }, + "tessl/pypi-python-dotenv": { + "version": "1.1.0" + }, + "tessl/pypi-watchfiles": { + "version": "1.1.0" + }, + "tessl/npm-floating-ui--dom": { + "version": "1.7.0" + }, + "tessl/npm-sveltejs--kit": { + "version": "2.37.0" + }, + "tessl/npm-sveltejs--vite-plugin-svelte": { + "version": "6.1.0" + }, + "tessl/npm-tailwindcss--typography": { + "version": "0.5.0" + }, + "tessl/npm-typescript-eslint--eslint-plugin": { + "version": "8.42.0" + }, + "tessl/npm-typescript-eslint--parser": { + "version": "8.42.0" + }, + "tessl/npm-daisyui": { + "version": "4.12.0" + }, + "tessl/npm-echarts": { + "version": "6.0.0" + }, + "tessl/npm-eslint-config-prettier": { + "version": "10.1.0" + }, + "tessl/npm-eslint-plugin-svelte": { + "version": "3.12.0" + }, + "tessl/npm-eslint": { + "version": "9.34.0" + }, + "tessl/npm-highlight.js": { + "version": "11.11.0" + }, + "tessl/npm-openapi-typescript": { + "version": "7.9.0" + }, + "tessl/npm-postcss": { + "version": "8.5.0" + }, + "tessl/npm-prettier": { + "version": "3.6.0" + }, + "tessl/npm-svelte-check": { + "version": "3.8.0" + }, + "tessl/npm-svelte": { + "version": "4.2.0" + }, + "tessl/npm-tailwindcss": { + "version": "3.4.0" + }, + "tessl/npm-typescript": { + "version": "5.9.0" + }, + "tessl/npm-vite": { + "version": "7.1.0" + }, + "tessl/npm-vitest": { + "version": "4.0.0" + } + } +} diff --git a/uv.lock b/uv.lock index 5e1d048bb..ea2eb95bb 100644 --- a/uv.lock +++ b/uv.lock @@ -1431,13 +1431,13 @@ dependencies = [ dev = [ { name = "diff-cover" }, { name = "isort" }, - { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-xdist" }, { name = "python-dotenv" }, { name = "ruff" }, { name = "scalar-fastapi" }, + { name = "ty" }, { name = "watchfiles" }, ] @@ -1453,13 +1453,13 @@ requires-dist = [ dev = [ { name = "diff-cover", specifier = ">=9.6.0" }, { name = "isort", specifier = ">=5.13.2" }, - { name = "pyright", specifier = "==1.1.376" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-xdist", specifier = ">=3.5" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "ruff", specifier = ">=0.13.0" }, { name = "scalar-fastapi", specifier = ">=1.4.3" }, + { name = "ty", specifier = ">=0.0.2" }, { name = "watchfiles", specifier = ">=1.1.0" }, ] @@ -3611,6 +3611,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/73/02342de9c2d20922115f787e101527b831c0cffd2105c946c4a4826bcfd4/tqdm-4.66.6-py3-none-any.whl", hash = "sha256:223e8b5359c2efc4b30555531f09e9f2f3589bcd7fdd389271191031b49b7a63", size = 78326, upload-time = "2024-10-28T12:49:56.931Z" }, ] +[[package]] +name = "ty" +version = "0.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/e5/15b6aceefcd64b53997fe2002b6fa055f0b1afd23ff6fc3f55f3da944530/ty-0.0.2.tar.gz", hash = "sha256:e02dc50b65dc58d6cb8e8b0d563833f81bf03ed8a7d0b15c6396d486489a7e1d", size = 4762024, upload-time = "2025-12-16T20:13:41.07Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/86/65d4826677d966cf226662767a4a597ebb4b02c432f413673c8d5d3d1ce8/ty-0.0.2-py3-none-linux_armv6l.whl", hash = "sha256:0954a0e0b6f7e06229dd1da3a9989ee9b881a26047139a88eb7c134c585ad22e", size = 9771409, upload-time = "2025-12-16T20:13:28.964Z" }, + { url = "https://files.pythonhosted.org/packages/d4/bc/6ab06b7c109cec608c24ea182cc8b4714e746a132f70149b759817092665/ty-0.0.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d6044b491d66933547033cecc87cb7eb599ba026a3ef347285add6b21107a648", size = 9580025, upload-time = "2025-12-16T20:13:34.507Z" }, + { url = "https://files.pythonhosted.org/packages/54/de/d826804e304b2430f17bb27ae15bcf02380e7f67f38b5033047e3d2523e6/ty-0.0.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fbca7f08e671a35229f6f400d73da92e2dc0a440fba53a74fe8233079a504358", size = 9098660, upload-time = "2025-12-16T20:13:01.278Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8e/5cd87944ceee02bb0826f19ced54e30c6bb971e985a22768f6be6b1a042f/ty-0.0.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3abd61153dac0b93b284d305e6f96085013a25c3a7ab44e988d24f0a5fcce729", size = 9567693, upload-time = "2025-12-16T20:13:12.559Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b1/062aab2c62c5ae01c05d27b97ba022d9ff66f14a3cb9030c5ad1dca797ec/ty-0.0.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:21a9f28caafb5742e7d594104e2fe2ebd64590da31aed4745ae8bc5be67a7b85", size = 9556471, upload-time = "2025-12-16T20:13:07.771Z" }, + { url = "https://files.pythonhosted.org/packages/0e/07/856f6647a9dd6e36560d182d35d3b5fb21eae98a8bfb516cd879d0e509f3/ty-0.0.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3ec63fd23ab48e0f838fb54a47ec362a972ee80979169a7edfa6f5c5034849d", size = 9971914, upload-time = "2025-12-16T20:13:18.852Z" }, + { url = "https://files.pythonhosted.org/packages/2e/82/c2e3957dbf33a23f793a9239cfd8bd04b6defd999bd0f6e74d6a5afb9f42/ty-0.0.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e5e2e0293a259c9a53f668c9c13153cc2f1403cb0fe2b886ca054be4ac76517c", size = 10840905, upload-time = "2025-12-16T20:13:37.098Z" }, + { url = "https://files.pythonhosted.org/packages/3b/17/49bd74e3d577e6c88b8074581b7382f532a9d40552cc7c48ceaa83f1d950/ty-0.0.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fd2511ac02a83d0dc45d4570c7e21ec0c919be7a7263bad9914800d0cde47817", size = 10570251, upload-time = "2025-12-16T20:13:10.319Z" }, + { url = "https://files.pythonhosted.org/packages/2b/9b/26741834069722033a1a0963fcbb63ea45925c6697357e64e361753c6166/ty-0.0.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c482bfbfb8ad18b2e62427d02a0c934ac510c414188a3cf00e16b8acc35482f0", size = 10369078, upload-time = "2025-12-16T20:13:20.851Z" }, + { url = "https://files.pythonhosted.org/packages/94/fc/1d34ec891900d9337169ff9f8252fcaa633ae5c4d36b67effd849ed4f9ac/ty-0.0.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb514711eed3f56d7a130d4885f4b5d8e490fdcd2adac098e5cf175573a0dda3", size = 10121064, upload-time = "2025-12-16T20:13:23.095Z" }, + { url = "https://files.pythonhosted.org/packages/e5/02/e640325956172355ef8deb9b08d991f229230bf9d07f1dbda8c6665a3a43/ty-0.0.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2c37fa26c39e9fbed7c73645ba721968ab44f28b2bfe2f79a4e15965a1c426f", size = 9553817, upload-time = "2025-12-16T20:13:27.057Z" }, + { url = "https://files.pythonhosted.org/packages/35/13/c93d579ece84895da9b0aae5d34d84100bbff63ad9f60c906a533a087175/ty-0.0.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:13b264833ac5f3b214693fca38e380e78ee7327e09beaa5ff2e47d75fcab9692", size = 9577512, upload-time = "2025-12-16T20:13:16.956Z" }, + { url = "https://files.pythonhosted.org/packages/85/53/93ab1570adc799cd9120ea187d5b4c00d821e86eca069943b179fe0d3e83/ty-0.0.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:08658d6dbbf8bdef80c0a77eda56a22ab6737002ba129301b7bbd36bcb7acd75", size = 9692726, upload-time = "2025-12-16T20:13:31.169Z" }, + { url = "https://files.pythonhosted.org/packages/9a/07/5fff5335858a14196776207d231c32e23e48a5c912a7d52c80e7a3fa6f8f/ty-0.0.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4a21b5b012061cb13d47edfff6be70052694308dba633b4c819b70f840e6c158", size = 10213996, upload-time = "2025-12-16T20:13:14.606Z" }, + { url = "https://files.pythonhosted.org/packages/a0/d3/896b1439ab765c57a8d732f73c105ec41142c417a582600638385c2bee85/ty-0.0.2-py3-none-win32.whl", hash = "sha256:d773fdad5d2b30f26313204e6b191cdd2f41ab440a6c241fdb444f8c6593c288", size = 9204906, upload-time = "2025-12-16T20:13:25.099Z" }, + { url = "https://files.pythonhosted.org/packages/5d/0a/f30981e7d637f78e3d08e77d63b818752d23db1bc4b66f9e82e2cb3d34f8/ty-0.0.2-py3-none-win_amd64.whl", hash = "sha256:d1c9ac78a8aa60d0ce89acdccf56c3cc0fcb2de07f1ecf313754d83518e8e8c5", size = 10066640, upload-time = "2025-12-16T20:13:04.045Z" }, + { url = "https://files.pythonhosted.org/packages/5a/c4/97958503cf62bfb7908d2a77b03b91a20499a7ff405f5a098c4989589f34/ty-0.0.2-py3-none-win_arm64.whl", hash = "sha256:fbdef644ade0cd4420c4ec14b604b7894cefe77bfd8659686ac2f6aba9d1a306", size = 9572022, upload-time = "2025-12-16T20:13:39.189Z" }, +] + [[package]] name = "typer" version = "0.15.2"