diff --git a/.gitignore b/.gitignore index a7cf771599..0c4b3800f5 100644 --- a/.gitignore +++ b/.gitignore @@ -21,13 +21,6 @@ myenv/ **/error-*.log - -# hosting/ -! hosting/docker-compose/oss/.env.oss.dev.example -! hosting/docker-compose/oss/.env.oss.gh.example -! hosting/docker-compose/ee/.env.ee.dev.example -! hosting/docker-compose/ee/.env.ee.gh.example - # examples/ examples/**/config.toml examples/**/agenta.py diff --git a/.gitleaks.toml b/.gitleaks.toml new file mode 100644 index 0000000000..d1991dc726 --- /dev/null +++ b/.gitleaks.toml @@ -0,0 +1,32 @@ +title = "Agenta Gitleaks Configuration" +version = 2 + +[extend] +useDefault = true + +[allowlist] +paths = [ + # ---------------------------------------------------------------- PUBLIC DOCS + '''^website/docs/reference/api/.*\.mdx''', + '''^core/docs/docs/reference/api/.*\.mdx''', + '''^docs/docs/reference/api/.*\.mdx''', + '''^docs/.docusaurus/.*''', + # -------------------------------------------------------------- WEB ARTIFACTS + '''^.*/\.pnpm-store/.*''', + '''^.*/public/__env\.js$''', + '''^.*/\.next/.*''', + # -------------------------------------------------------------- ALL ENV FILES + '''^.*\.env.*$''', + # ---------------------------------------------------------------------------- +] +regexes = [ + # ------------------------------------------------------------ FALSE POSITIVES + '''is_completion=True''', + '''YOUR_API_KEY''', + '''_SECRET_KEY''', + # ---------------------------------------------------------------------------- +] + +# USEFUL GITLEAKS COMMANDS +# gitleaks --config .gitleaks.toml --exit-code 1 --verbose git +# gitleaks --config .gitleaks.toml --exit-code 1 --verbose detect --no-git diff --git a/.gitleaksignore b/.gitleaksignore new file mode 100644 index 0000000000..097b6b8358 --- /dev/null +++ b/.gitleaksignore @@ -0,0 +1,196 @@ +# LEGACY / REVOKED / BENIGN CREDENTIALS, FROM PAST COMMITS, FROM BEFORE CLEANUP +docs/docusaurus.config.ts:generic-api-key:236 +api/oss/tests/manual/tracing/windowing.http:generic-api-key:3 +sdk/tests/legacy/baggage/config.toml:generic-api-key:4 +sdk/tests/legacy/debugging/simple-app/config.toml:generic-api-key:4 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/auth/admin.http:generic-api-key:3 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/annotations/crud.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/evaluators/crud.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/testsets/crud.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/02_span_id.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/03_parent_id.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/01_trace_id.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/07_end_time.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/06_start_time.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/05_span_name.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/04_span_kind.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/08_status_code.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/11_links.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/00_user_id.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/12_references.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/09_status_message.http:generic-api-key:2 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/filtering/10_attributes.http:generic-api-key:2 +854f1ca002740cd51252f81660701a3b6f9d6a8a:agenta-cli/debugging/simple-app/config.toml:generic-api-key:4 +51020488ce57a4b964c05cc0c41cecb4eb67692c:agenta-cli/debugging/simple-app/config.toml:generic-api-key:4 +6a4288ba3b4a2f95f24ed372bce7ac0679b5b868:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain_openinference.py:generic-api-key:12 +6a4288ba3b4a2f95f24ed372bce7ac0679b5b868:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain_openllmetery.py:generic-api-key:12 +6a4288ba3b4a2f95f24ed372bce7ac0679b5b868:agenta-cli/tests/observability_sdk/sanity_check/app_local.py:generic-api-key:8 +0f9c9ac3afcb8df950a743206715ab5ebe8808eb:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain_openllmetery.py:generic-api-key:12 +0f9c9ac3afcb8df950a743206715ab5ebe8808eb:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain_openinference.py:generic-api-key:12 +0f9c9ac3afcb8df950a743206715ab5ebe8808eb:agenta-cli/tests/observability_sdk/sanity_check/app_local.py:generic-api-key:8 +8cd7319eb87e87723a310555a820433f77ab01fd:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain_openinference.py:generic-api-key:11 +8cd7319eb87e87723a310555a820433f77ab01fd:agenta-cli/tests/observability_sdk/sanity_check/app_local.py:generic-api-key:7 +50c0e27be4960b5f06b5edbed6af912d79ea0f27:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain_openinference.py:generic-api-key:11 +50c0e27be4960b5f06b5edbed6af912d79ea0f27:agenta-cli/tests/observability_sdk/sanity_check/app_local.py:generic-api-key:7 +03b90aadcd58abd101454da5e3b925dde8e6cd43:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain.py:generic-api-key:11 +a1dbd3f3eafbe326a246a16fe70e02350cefdf2f:agenta-cli/tests/observability_sdk/integrations/langchain/simple_chain.py:generic-api-key:11 +86c2a2430e3ddbc544361343b7e9ea0152e53ab7:api/oss/tests/workflows/observability/test_otlp_ingestion.py:generic-api-key:21 +dc4370980d17ba1643a5c081670404f942ebfc57:agenta-cli/tests/management_sdk/manual_tests/apps_with_new_sdk/config.toml:generic-api-key:4 +850314eb630ca7fdcf756c7ffe36a6adad5cc845:agenta-cli/tests/management_sdk/manual_tests/apps_with_new_sdk/config.toml:generic-api-key:4 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/crud.http:generic-api-key:3 +3db7c34a8f206fb4a19e525ac5a964185d502c4a:api/oss/tests/manual/tracing/windowing.http:generic-api-key:3 +7ee494e8cdad4f54073be483e373e7a5bf273ea5:agenta-cli/tests/baggage/config.toml:generic-api-key:4 +a5e3197cd247c5468d8739ef9de811cd2a1cbc2f:agenta-cli/tests/baggage/config.toml:generic-api-key:4 +e84abaed2592e50d660d180d7fd373376b544f14:hosting/kubernetes/oss/secret.yml:kubernetes-secret-yaml:2 +e84abaed2592e50d660d180d7fd373376b544f14:hosting/kubernetes/oss/secret.yml:generic-api-key:12 +e84abaed2592e50d660d180d7fd373376b544f14:hosting/helm/oss/values.yaml:generic-api-key:88 +e84abaed2592e50d660d180d7fd373376b544f14:hosting/helm/oss/values.yaml:generic-api-key:89 +e84abaed2592e50d660d180d7fd373376b544f14:hosting/kubernetes/oss/secret.yml:generic-api-key:8 +e84abaed2592e50d660d180d7fd373376b544f14:hosting/kubernetes/oss/secret.yml:generic-api-key:9 +81a2b05aa4624cfc39587e5384bf7c106e547449:.github/workflows/frontend-test.yml:openai-api-key:27 +4857d8f04896e27d707e2967bb361eb1a0b129db:.github/workflows/frontend-test.yml:openai-api-key:27 +8465021df57fca629f14c269d3f37d18d6fdcd11:services/completion-new-sdk-prompt/docker-compose.yml:openai-api-key:10 +406e68077c51da204b1f63f193a2defe6031c966:agenta-web/cypress.config.ts:openai-api-key:10 +450a435754557bfa1d3d3e372f4b47e4eb63f93e:agenta-web/cypress.config.ts:openai-api-key:10 +066e345ad9ba7318fc59b191cf33af2e81634aa8:agenta-web/cypress/support/commands/evaluations.ts:openai-api-key:106 +3533b30e483378a8ecb900c603a3c54ffc9cc390:agenta-web/cypress/support/commands/evaluations.ts:openai-api-key:106 +41e5d327e87083f55850c6933611cdc79ea9d204:agenta-backend/agenta_backend/tests/variants_evaluators_router/conftest.py:openai-api-key:25 +9968400e3095fdc1fb219f45c0d73db13c6de499:agenta-backend/agenta_backend/tests/variants_evaluators_router/conftest.py:openai-api-key:25 +a8efa140a02295ef6accbd02bc7c4c4eeb75e435:agenta-backend/agenta_backend/tests/variants_evaluators_router/conftest.py:openai-api-key:17 +d343b2a5b12387fc6b99d508b5e776f7689736c1:agenta-backend/agenta_backend/tests/variants_evaluators_router/conftest.py:openai-api-key:17 +5f37c440e203cf56d7f08a8efdd7ca372c646beb:docs/docs/prompt-management/05-adding-custom-providers.mdx:generic-api-key:81 +73644b725b5409be78d1aeecf7f5ff6a24ab3643:docs/docusaurus.config.ts:generic-api-key:220 +41c85fef68f9f8c2e4576956369ef600223193c8:website/docusaurus.config.ts:generic-api-key:184 +179d78e547e2eb92737cdd0ba7fd3eeb1f4bc5ce:website/docusaurus.config.ts:generic-api-key:184 +faf49eadbd38fd6771c4687fea78528ad73741b6:api/oss/tests/manual/annotations/crud.http:generic-api-key:2 +f86dddabb759924689022d2451d97efe218848c9:api/oss/tests/manual/evaluations/crud.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/auth/admin.http:generic-api-key:3 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/annotations/crud.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/evaluators/crud.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/testsets/crud.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/00_user_id.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/02_span_id.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/03_parent_id.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/01_trace_id.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/04_span_kind.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/06_start_time.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/08_status_code.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/05_span_name.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/07_end_time.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/09_status_message.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/12_references.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/11_links.http:generic-api-key:2 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/filtering/10_attributes.http:generic-api-key:2 +4888444e93a8438334a9dfb81c7979500d0ab4bf:api/oss/tests/manual/testsets/crud.http:generic-api-key:2 +5289a1c740cff9dec0047e7dc05902edbc471649:api/oss/tests/manual/tracing/filtering/12_references.http:generic-api-key:2 +96b9056a6ff1160f11dcc302321d4a29a7f1b8dd:api/oss/tests/manual/workflows/artifacts.http:generic-api-key:2 +6a4af8b70816f18ae69056df96b54c622e7ef494:api/oss/tests/manual/feedback/crud.http:generic-api-key:2 +7b0eeb2ae0cfa80e9a79cee814ed10c9b57ee9d3:api/oss/tests/manual/annotators/crud.http:generic-api-key:2 +876ba2f78358d43cc6dafe518e38ef404b6462f0:api/oss/tests/manual/annotations/crud.http:generic-api-key:2 +05f8741ea3349096e60a1686c1cef3585a6d34d7:api/oss/tests/manual/tracing/filtering/00_user_id.http:generic-api-key:2 +f47bb6f3b65c50664b354f33081d131289fa47cd:api/oss/tests/manual/tracing/filtering/11_links.http:generic-api-key:2 +18bfd66e6bc309ada998457a32b9a4ca689015a2:api/oss/tests/manual/tracing/filtering/10_attributes.http:generic-api-key:2 +0dcfe02574a545c8400b1b5385d7662143ec2544:api/oss/tests/manual/tracing/filtering/08_status_code.http:generic-api-key:2 +260822ac28ec5c08f9b4c2b04e895d46fcbfb164:api/oss/tests/manual/tracing/filtering/09_status_message.http:generic-api-key:2 +c16ca8eca0a2743c541457a80f88fe0ec71151cb:api/oss/tests/manual/tracing/filtering_parent_id.http:generic-api-key:2 +fe9a8b1c7518160bc3c9f80eff3ba1076a2a5030:api/oss/tests/manual/tracing/filtering_end_time.http:generic-api-key:2 +f158907f559e92fb91672c696715a90aef5470ab:api/oss/tests/manual/tracing/filtering_span_id.http:generic-api-key:2 +fe9a8b1c7518160bc3c9f80eff3ba1076a2a5030:api/oss/tests/manual/tracing/filtering_start_time.http:generic-api-key:2 +86dda27e6458ea8f7e64bfb4a9f63946c8fc82ce:api/oss/tests/manual/tracing/filtering_trace_id.http:generic-api-key:2 +7afc6f26080c6c37219995089aed409e50ef6279:api/oss/tests/manual/tracing/filtering_span_name.http:generic-api-key:2 +71a49c35758dab163bbbe700f55f6f50e6bdf9a5:api/oss/tests/manual/tracing/filtering_span_kind.http:generic-api-key:2 +b14e377b19cc4f77db9d0a2b51f72b88b6f54c6c:api/oss/tests/manual/auth/admin.http:generic-api-key:3 +0c6acff0523bd4e594e43caf63c4342e319476b8:hosting/kubernetes/oss/secret.yml:kubernetes-secret-yaml:2 +97c08e2f4ad87c2aacf6760da60eb01ec8d5d329:cloud/tests/conftest.py:generic-api-key:114 +6526d232893d18496af47a05c7b99e7c0c1fe510:docs/docs/prompt-management/05-adding-custom-providers.mdx:generic-api-key:71 +0c6acff0523bd4e594e43caf63c4342e319476b8:hosting/kubernetes/oss/secret.yml:generic-api-key:12 +bf7e1824839cea10432731174549faeb9bad3545:hosting/helm/oss/values.yaml:generic-api-key:83 +bf7e1824839cea10432731174549faeb9bad3545:hosting/helm/oss/values.yaml:generic-api-key:84 +0c6acff0523bd4e594e43caf63c4342e319476b8:hosting/kubernetes/oss/secret.yml:generic-api-key:8 +0c6acff0523bd4e594e43caf63c4342e319476b8:hosting/kubernetes/oss/secret.yml:generic-api-key:9 +fd477298c83aa220b01c6704058382c1ded1fdca:core/agenta-cli/debugging/simple-app/config.toml:generic-api-key:4 +fd477298c83aa220b01c6704058382c1ded1fdca:core/agenta-cli/tests/baggage/config.toml:generic-api-key:4 +f92a341a7e45fc051a08da1fa619137a192c89ae:api/ee/tests/manual/tracing.http:generic-api-key:5 +eba1ed50e6846a323d456b6da510f42d8c8bbe9a:api/ee/tests/manual/billing.http:generic-api-key:4 +9d741648f9ec1719c6f7f0fcb16cbf116458916c:api/oss/tests/manual/annotations/crud.http:generic-api-key:3 +9d741648f9ec1719c6f7f0fcb16cbf116458916c:api/oss/tests/manual/tracing/crud.http:generic-api-key:3 +273d2f5a1b37ef9420c4e40303b8fc6233362571:api/ee/tests/manual/billing.http:generic-api-key:4 +c078d4b1395ea2856891424f82e80f4fe60d7136:api/ee/tests/manual/billing.http:generic-api-key:4 +2793a4b2f065d7b588fa6733b74f68c0748473a5:api/oss/tests/manual/tracing/crud.http:generic-api-key:3 +ef6f83612a7cfd552147b49928feb8a5d4429c0d:api/oss/tests/manual/tracing/filtering.http:generic-api-key:3 +b14e377b19cc4f77db9d0a2b51f72b88b6f54c6c:api/oss/tests/manual/annotations/crud.http:generic-api-key:3 +b14e377b19cc4f77db9d0a2b51f72b88b6f54c6c:api/oss/tests/manual/tracing/crud.http:generic-api-key:3 +19ccc3f1f292edca26e840428ebc6224cbaef78a:api/ee/tests/manual/annotations/crud.http:generic-api-key:3 +bf42b5eaa7e805a249f52d65a6882d6ade2828f3:api/ee/tests/manual/tracing/windowing.http:generic-api-key:5 +bb0c1b3fb0032b6dbebe659d745d5cb90aa306ce:api/ee/tests/manual/tracing.http:generic-api-key:5 +16622c30916fae1b284b1b7150e4b7c57413ad17:api/ee/tests/manual/evaluations/sdk/test.py:generic-api-key:16 +75ed5549eeb4685c5234c1ec577721920cc0ec9c:api/ee/tests/manual/evaluations/sdk/test.py:generic-api-key:16 +4e743f16edcb3ff4e13b1400b9ff8175b072a5e1:api/ee/tests/manual/evaluations/sdk/test.py:generic-api-key:16 +b587813ed56832b2df7fb7560775ee0b75f03674:api/ee/tests/manual/evaluations/sdk/test.py:generic-api-key:12 +3abc3f4d2051c4df2f64c6d88608bd9bf1ae265f:api/ee/tests/manual/evaluations/sdk/test.py:generic-api-key:9 +35442f703897393a3d2a5e9aa7a42985787bb24f:api/ee/tests/manual/evaluations/sdk/test.py:generic-api-key:16 +12f36507e801d41e2388889777c195557e7a6e5c:api/ee/tests/manual/evaluations/sdk/test_serve.py:generic-api-key:16 +12f36507e801d41e2388889777c195557e7a6e5c:api/ee/tests/manual/evaluations/sdk/test_handlers.py:generic-api-key:16 +e6d87a97aa4750ace564ac28eafda0123c21e017:api/oss/tests/workflows/observability/test_otlp_ingestion.py:generic-api-key:21 +fd477298c83aa220b01c6704058382c1ded1fdca:core/docs/docusaurus.config.ts:generic-api-key:232 +c8d8f465b61764195de460164e6c27e0fe4b2b9a:docs/docs/self-host/05-advanced-configuration.mdx:generic-api-key:37 +a268b8a81a700704e28d82c5cb9af31dde32146b:ee/docker/docker-compose.demo.prod.yml:generic-api-key:25 +56829b2eccdec425954243d0ce5e4fcac9d05e9c:ee/docker/docker-compose.cloud.dev.yml:generic-api-key:25 +91678b6a27c326e0002205f79fd8999a7591e38f:ee/docker/docker-compose.demo.prod.yml:generic-api-key:25 +91678b6a27c326e0002205f79fd8999a7591e38f:ee/docker/docker-compose.demo.dev.yml:generic-api-key:25 +56829b2eccdec425954243d0ce5e4fcac9d05e9c:ee/docker/docker-compose.cloud.dev.yml:generic-api-key:22 +ad6b459dfc5ac1e5c140fcf3e03e247ba31383ae:ee/docker/docker-compose.demo.prod.yml:generic-api-key:22 +594f33f5b7eb665edb38208666d27d6de6365946:ee/docker/docker-compose.demo.prod.yml:generic-api-key:73 +594f33f5b7eb665edb38208666d27d6de6365946:ee/docker/docker-compose.demo.dev.yml:generic-api-key:22 +f6ef6aa32d569ee025bdb3ce9f515521a4095494:cloud/agenta-backend/agenta_backend/cloud/__init__.py:generic-api-key:155 +b3e5fae0e270f2c92a65360123c980d725c5f226:ee/agenta-backend/agenta_backend/ee/__init__.py:generic-api-key:107 +c98a5da1a33d2c0986e3c66329eaa5237fbccf3d:hosting/docker-compose/ee/aws/docker-compose.oss.prod.yml:generic-api-key:76 +4ee55c08b2fed661eaf90876f96c329d7c7eeb6b:cloud/docker/docker-compose.oss.stage.yml:generic-api-key:46 +169c54d84f5d1931601550a9d3aa76874ef73ec5:cloud/docker/docker-compose.oss.stage.yml:generic-api-key:47 +f725cfc9247743bdc44f84edee109ff36193d741:cloud/docker/docker-compose.oss.stage.yml:generic-api-key:46 +cbd8ac00ecdc2c8124b989a274c6f835c09f8474:cloud/docker/docker-compose.oss.prod.yml:generic-api-key:46 +2fb6b0f94f8bf711255e0901c03787b73f3d650f:cloud/docker/docker-compose.oss.prod.yml:generic-api-key:46 +a5e7781869b2a4bf22dd5e22fd5e5ae2ec8d02ea:cloud/docker/newrelic-infra.yml:generic-api-key:1 +2fb6b0f94f8bf711255e0901c03787b73f3d650f:cloud/docker/docker-compose.oss.prod.yml:generic-api-key:171 +f6ef6aa32d569ee025bdb3ce9f515521a4095494:cloud/agenta-backend/agenta_backend/cloud/__init__.py:generic-api-key:144 +b3e5fae0e270f2c92a65360123c980d725c5f226:ee/agenta-backend/agenta_backend/ee/__init__.py:generic-api-key:96 +a268b8a81a700704e28d82c5cb9af31dde32146b:ee/docker/docker-compose.demo.prod.yml:generic-api-key:22 +9c55f5572904ae07b73f73ee365e833d0637633a:ee/docker/docker-compose.demo.prod.yml:generic-api-key:22 +ad74134f522cde71f860cb59b6363a8fdf0a64c6:ee/setup_agenta_web.sh:generic-api-key:26 +5a10aacebd0ed4f2e613eb9176e95836aea34f15:ee/setup_agenta_web.sh:generic-api-key:26 +637068dd09eff7b30b776061863027a9b9aa1deb:ee/setup_agenta_web.sh:generic-api-key:26 +590578c803d94d8ccb1a6ca977471f3d44b43fc3:hosting/helm/oss/templates/config/app-configmap.yaml:generic-api-key:45 +1d8f08b267675726441fcaaae24572bb635c5eac:api/oss/src/utils/env.py:generic-api-key:53 +55f27e52327062382beb299b162f94895268d766:web/oss/public/__ENV.js:generic-api-key:1 +c98a5da1a33d2c0986e3c66329eaa5237fbccf3d:hosting/docker-compose/ee/aws/docker-compose.oss.prod.yml:generic-api-key:73 +bf0cd42bffc2581b1df6f56fa6e4b20ff9b68c33:hosting/docker-compose/ee/aws/docker-compose.oss.aws.yml:generic-api-key:61 +52cd40cefd3121eea2e21205e8208712b093529a:core/hosting/docker-compose/ee/docker-compose.dev.yml:generic-api-key:18 +6efb8a0c9620a316cf81fe961b1407e93aa2efa7:core/hosting/docker-compose/oss/docker-compose.gh.yml:generic-api-key:17 +6efb8a0c9620a316cf81fe961b1407e93aa2efa7:core/hosting/docker-compose/oss/docker-compose.dev.yml:generic-api-key:17 +44d1669c1a53b3ca47e3689dda5500e6f742f525:core/hosting/docker-compose/ee/docker-compose.dev.yml:generic-api-key:18 +58a4230a8f2e63b2836f81bcf2341ba12003189e:core/hosting/docker-compose/oss/docker-compose.yml:generic-api-key:30 +fd477298c83aa220b01c6704058382c1ded1fdca:core/agenta-cli/agenta/cli/helper.py:generic-api-key:19 +fd477298c83aa220b01c6704058382c1ded1fdca:core/agenta-web/prod.gh.Dockerfile:generic-api-key:7 +fd477298c83aa220b01c6704058382c1ded1fdca:core/docker-compose.gh.yml:generic-api-key:26 +fd477298c83aa220b01c6704058382c1ded1fdca:core/docker-compose.gh.yml:generic-api-key:95 +fd477298c83aa220b01c6704058382c1ded1fdca:core/docker-compose.yml:generic-api-key:30 +fd477298c83aa220b01c6704058382c1ded1fdca:core/docker-compose.yml:generic-api-key:111 +fd477298c83aa220b01c6704058382c1ded1fdca:core/docker-compose.prod.yml:generic-api-key:102 +4ee55c08b2fed661eaf90876f96c329d7c7eeb6b:cloud/docker/docker-compose.oss.stage.yml:generic-api-key:43 +169c54d84f5d1931601550a9d3aa76874ef73ec5:cloud/docker/docker-compose.oss.stage.yml:generic-api-key:44 +9e7831e7500364776cb3e9eac41448907ef92dcd:cloud/docker/docker-compose.test.yml:generic-api-key:28 +9e7831e7500364776cb3e9eac41448907ef92dcd:cloud/docker/docker-compose.test.yml:generic-api-key:83 +fe72ad1a8d14e1f3bce547ef224fc75e3df8f4ff:cloud/docker/docker-compose.cloud.test.yml:generic-api-key:28 +fe72ad1a8d14e1f3bce547ef224fc75e3df8f4ff:cloud/docker/docker-compose.cloud.test.yml:generic-api-key:83 +6198bbc532d8e984ef94276b58a7ab8dc65a279f:cloud/docker/docker-compose.oss.stage.yml:generic-api-key:43 +4b4cdbdf4b8ad4a9342fdb939b7e30f88420fccd:cloud/docker/docker-compose.oss.prod.yml:generic-api-key:44 +a268b8a81a700704e28d82c5cb9af31dde32146b:ee/docker/docker-compose.demo.prod.yml:sendgrid-api-token:26 +11ad273f1039e9263cf8d2f61338a121d59b9cc7:ee/docker/docker-compose.cloud.prod.yml:sendgrid-api-token:22 +11ad273f1039e9263cf8d2f61338a121d59b9cc7:ee/docker/docker-compose.cloud.dev.yml:sendgrid-api-token:21 +11ad273f1039e9263cf8d2f61338a121d59b9cc7:ee/docker/docker-compose.demo.prod.yml:sendgrid-api-token:26 +11ad273f1039e9263cf8d2f61338a121d59b9cc7:ee/docker/docker-compose.demo.dev.yml:sendgrid-api-token:26 +b7bc21c67bbae3c06c372bc585c4917a80613a14:cloud/agenta-backend/agenta_backend/cloud/routers/payment_router.py:stripe-access-token:13 +d8b4af2ae8c1084dbdd30fca59aa84e8ece047db:examples/python/annotation-example.py:openai-api-key:19 +a268b8a81a700704e28d82c5cb9af31dde32146b:ee/docker/docker-compose.demo.prod.yml:openai-api-key:24 +02d9f665aed89e8d69e06acdc7d01d699ee5b0dd:ee/docker/docker-compose.demo.prod.yml:openai-api-key:24 +c8d8f465b61764195de460164e6c27e0fe4b2b9a:docs/docs/self-host/05-advanced-configuration.mdx:generic-api-key:46 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8363b88652..f498cd623a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,6 +3,12 @@ repos: hooks: - id: gitleaks-pre-commit name: gitleaks git (staged only) - entry: echo "Aloha" + entry: bash -c 'gitleaks --config .gitleaks.toml --exit-code 1 --verbose git --staged' language: system pass_filenames: false + - id: gitleaks-pre-push + name: gitleaks git (pre-push, scan diff) + entry: bash -c 'gitleaks --config .gitleaks.toml --exit-code 1 --verbose git --log-opts "$(git merge-base HEAD "origin/$(git rev-parse --abbrev-ref HEAD)" 2>/dev/null || git merge-base HEAD origin/main)..HEAD"' + language: system + stages: [pre-push] + pass_filenames: false diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5cfd9357c1..5810643c7c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,4 +40,9 @@ We had many zombie issues and PRs (assigned but inactive) in the past. We want t - An issue may only be assigned to one person for up to one week (three days for very simple issues). If the issue remains unsolved after a week, it will be unassigned and made available to others. - Any pull request (PR) left inactive by the author for over a week will be closed. The author can reopen it if they wish to continue. -We look forward to seeing your contributions to Agenta! \ No newline at end of file +We look forward to seeing your contributions to Agenta! + +## Contributor License Agreement +If you want to contribute, we need you to sign a Contributor License Agreement. We need this to avoid potential intellectual property problems in the future. You can sign the agreement by clicking a button. Here is how it works: + +After you open a PR, a bot will automatically comment asking you to sign the agreement. Click on the link in the comment, login with your Github account, and sign the agreement. \ No newline at end of file diff --git a/LICENSE b/LICENSE index 79b3725428..1fff9c4444 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,16 @@ -The MIT License +Copyright (c) 2023–2025 +Agentatech UG (haftungsbeschränkt), doing business as “Agenta” + +Portions of this software are licensed as follows: -Copyright (c) Agentatech UG (haftungsbeschränkt) +- All content that resides under any "ee/" directory of this repository, if +such directories exist, are licensed under the license defined in "ee/LICENSE". +- All third party components incorporated into the Agenta Software are licensed +under the original license provided by the owner of the applicable component. +- Content outside of the above mentioned directories or restrictions above is +available under the "MIT Expat" license as defined below. + +The MIT License Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 082ea25f29..6f9ef58382 100644 --- a/README.md +++ b/README.md @@ -266,4 +266,4 @@ This project follows the [all-contributors](https://github.com/all-contributors/ ## Disabling Anonymized Tracking -By default, Agenta automatically reports anonymized basic usage statistics. This helps us understand how Agenta is used and track its overall usage and growth. This data does not include any sensitive information. To disable anonymized telemetry set `AGENTA_TELEMETRY_ENABLED` to `false` in your `.env` file. \ No newline at end of file +By default, Agenta automatically reports anonymized basic usage statistics. This helps us understand how Agenta is used and track its overall usage and growth. This data does not include any sensitive information. To disable anonymized telemetry set `AGENTA_TELEMETRY_ENABLED` to `false` in your `.env` file. diff --git a/SECURITY.md b/SECURITY.md index f45e7a7624..fabf40c910 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,19 +1,76 @@ # Security Policy + ## Reporting a Vulnerability If you believe you have found a security vulnerability in any Agenta repository, please report it to us through coordinated disclosure. -Please do not report security vulnerabilities through public GitHub issues, discussions, or pull requests. +**Do not** report security vulnerabilities via public GitHub issues, pull requests, or discussions. + +Instead, please send an email to **security@agenta.ai**. + +--- + +## Information to Include + +Please include as much of the following as you can to help us reproduce and resolve the issue: + +- Type of issue (e.g., buffer overflow, SQL injection, cross-site scripting). +- Full paths of source files related to the issue. +- The location of the affected source code (tag, branch, commit SHA, or direct URL). +- Any special configuration or environment required to reproduce. +- Step-by-step instructions to reproduce. +- Proof-of-concept or exploit code (if possible). +- Expected vs actual behaviour and potential impact. +- Your contact details and disclosure timeline preference. + +--- + +## Our Process + +- **Acknowledgement**: We will acknowledge receipt within **3 business days**. +- **Triage**: We aim to complete an initial triage within **7 calendar days** and will share severity and next steps. +- **Remediation & Disclosure**: For critical vulnerabilities we aim to release a fix or mitigation within **30 days**. For other issues, typically within **90 days**. We will coordinate any public disclosure with you. +- We will provide status updates as needed during remediation. + +--- + +## Safe Harbor + +We respect and protect good-faith security research. If you follow this policy: + +- We will not initiate legal action against you for good-faith testing conducted as part of coordinated disclosure. +- Do not access, modify, or exfiltrate data beyond what is necessary to demonstrate the issue. +- Do not disrupt production services or attempt destructive actions. + +--- + +## Scope Exclusions + +The following are **out of scope**: + +- Third-party services not operated by Agenta. +- Physical security attacks or social engineering of personnel. +- Low-risk informational issues without security impact (e.g., generic version banners). +- Denial-of-service attacks (**we will not accept DoS testing against production**). + +--- + +## Recognition & Credits + +If you report a valid vulnerability and want public recognition, tell us how you wish to be credited (full name, handle, company, or anonymous). Recognition is discretionary and will be coordinated with you. + +--- + +## Emergency / Out-of-band + +If email is unavailable and you need an immediate or urgent channel, contact our general line: **team@agenta.ai** (monitored during business hours). For truly critical emergencies, include “EMERGENCY / SECURITY” in the subject line of your email. -Instead, please send an email to team@agenta.ai. +--- -Please include as much of the information listed below as you can to help us better understand and resolve the issue: +## Contact retention & privacy - The type of issue (e.g., buffer overflow, SQL injection, or cross-site scripting) - Full paths of source file(s) related to the manifestation of the issue - The location of the affected source code (tag/branch/commit or direct URL) - Any special configuration required to reproduce the issue - Step-by-step instructions to reproduce the issue - Proof-of-concept or exploit code (if possible) - Impact of the issue, including how an attacker might exploit the issue +- Report metadata will be retained for incident tracking and compliance. +- Personal data you provide will be handled according to our privacy policy. +- We will only share reporter data internally on a need-to-know basis. +--- diff --git a/api/ee/LICENSE b/api/ee/LICENSE new file mode 100644 index 0000000000..ae7a2f38f4 --- /dev/null +++ b/api/ee/LICENSE @@ -0,0 +1,37 @@ +Agenta Enterprise License (the “Enterprise License”) +Copyright (c) 2023–2025 +Agentatech UG (haftungsbeschränkt), doing business as “Agenta” (“Agenta”) + +With regard to the Agenta Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have agreed to, +and are in compliance with, the Agenta Subscription Terms of Service, available +at https://agenta.ai/terms (the “Enterprise Terms”), or other +agreement governing the use of the Software, as agreed by you and Agenta, +and otherwise have a valid Agenta Enterprise License. + +Subject to the foregoing sentence, you are free to modify this Software and +publish patches to the Software. You agree that Agenta and/or its licensors +(as applicable) retain all right, title and interest in and to all such +modifications and/or patches, and all such modifications and/or patches may +only be used, copied, modified, displayed, distributed, or otherwise exploited +with a valid Agenta Enterprise License. Notwithstanding the foregoing, you may +copy and modify the Software for development and testing purposes, without +requiring a subscription. You agree that Agenta and/or its licensors (as +applicable) retain all right, title and interest in and to all such +modifications. You are not granted any other rights beyond what is expressly +stated herein. Subject to the foregoing, it is forbidden to copy, merge, +publish, distribute, sublicense, and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Agenta Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/hosting/gcp/credentials.json b/api/ee/__init__.py similarity index 100% rename from hosting/gcp/credentials.json rename to api/ee/__init__.py diff --git a/api/ee/databases/__init__.py b/api/ee/databases/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/databases/postgres/init-db-ee.sql b/api/ee/databases/postgres/init-db-ee.sql new file mode 100644 index 0000000000..e949c33926 --- /dev/null +++ b/api/ee/databases/postgres/init-db-ee.sql @@ -0,0 +1,39 @@ +-- Ensure we are connected to the default postgres database before creating new databases +\c postgres + +-- Create the 'username' role with a password if it doesn't exist +SELECT 'CREATE ROLE username WITH LOGIN PASSWORD ''password''' +WHERE NOT EXISTS (SELECT FROM pg_roles WHERE rolname = 'username')\gexec + +-- Create the 'agenta_ee_core' database if it doesn't exist +SELECT 'CREATE DATABASE agenta_ee_core' +WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'agenta_ee_core')\gexec + +-- Create the 'agenta_ee_tracing' database if it doesn't exist +SELECT 'CREATE DATABASE agenta_ee_tracing' +WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'agenta_ee_tracing')\gexec + +-- Create the 'agenta_ee_supertokens' database if it doesn't exist +SELECT 'CREATE DATABASE agenta_ee_supertokens' +WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'agenta_ee_supertokens')\gexec + +-- Grant necessary permissions to 'username' for both databases +GRANT ALL PRIVILEGES ON DATABASE agenta_ee_core TO username; +GRANT ALL PRIVILEGES ON DATABASE agenta_ee_tracing TO username; +GRANT ALL PRIVILEGES ON DATABASE agenta_ee_supertokens TO username; + + +-- Switch to 'agenta_ee_core' and grant schema permissions +\c agenta_ee_core +GRANT ALL ON SCHEMA public TO username; + +-- Switch to 'agenta_ee_tracing' and grant schema permissions +\c agenta_ee_tracing +GRANT ALL ON SCHEMA public TO username; + +-- Switch to 'agenta_ee_supertokens' and grant schema permissions +\c agenta_ee_supertokens +GRANT ALL ON SCHEMA public TO username; + +-- Return to postgres +\c postgres diff --git a/api/ee/databases/postgres/migrations/__init__.py b/api/ee/databases/postgres/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/databases/postgres/migrations/core/README.md b/api/ee/databases/postgres/migrations/core/README.md new file mode 100644 index 0000000000..8d8552e3c3 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/README.md @@ -0,0 +1,35 @@ +# Migrations with Alembic + +Generic single-database configuration with an async dbapi. + +## Autogenerate Migrations + +One of Alembic's key features is its ability to auto-generate migration scripts. By analyzing the current database state and comparing it with the application's table metadata, Alembic can automatically generate the necessary migration scripts using the `--autogenerate` flag in the alembic revision command. + +Note that autogenerate sometimes does not detect all database changes and it is always necessary to manually review (and correct if needed) the candidate migrations that autogenerate produces. + +### Making migrations + +To make migrations after creating a new table schema or modifying a current column in a table, run the following commands: + +```bash +docker exec -e PYTHONPATH=/app -w /app/ee/databases/postgres/migrations/core agenta-ee-dev-api-1 alembic -c alembic.ini revision --autogenerate -m "migration message" +``` + +The above command will create a script that contains the changes that was made to the database schema. Kindly update "migration message" with a message that is clear to indicate what change was made. Here are some examples: + +- added username column in users table +- renamed template_uri to template_repository_uri +- etc + +### Applying Migrations + +```bash +docker exec -e PYTHONPATH=/app -w /app/ee/databases/postgres/migrations/core agenta-ee-dev-api-1 alembic -c alembic.ini upgrade head +``` + +The above command will be used to apply the changes in the script created to the database table(s). If you'd like to revert the migration, run the following command: + +```bash +docker exec -e PYTHONPATH=/app -w /app/ee/databases/postgres/migrations/core agenta-ee-dev-api-1 alembic -c alembic.ini downgrade head +``` diff --git a/api/ee/databases/postgres/migrations/core/alembic.ini b/api/ee/databases/postgres/migrations/core/alembic.ini new file mode 100644 index 0000000000..1888be8152 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/alembic.ini @@ -0,0 +1,112 @@ +# A generic, single database configuration. + +[alembic] +script_location = /app/ee/databases/postgres/migrations/core + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to postgres/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:postgres/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/api_keys.py b/api/ee/databases/postgres/migrations/core/data_migrations/api_keys.py new file mode 100644 index 0000000000..769b6b8157 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/api_keys.py @@ -0,0 +1,282 @@ +import uuid +import traceback +from typing import Optional + +import click +from sqlalchemy.future import select +from sqlalchemy import Connection, update, func, or_, insert, delete + +from oss.src.models.db_models import APIKeyDB +from ee.src.models.db_models import ProjectDB +from ee.src.models.extended.deprecated_models import DeprecatedAPIKeyDB + + +BATCH_SIZE = 200 + + +def get_project_id_from_workspace_id( + session: Connection, workspace_id: str +) -> Optional[str]: + statement = select(ProjectDB).filter_by( + workspace_id=uuid.UUID(workspace_id), is_default=True + ) + project = session.execute(statement).fetchone() + return str(project.id) if project is not None else None + + +def get_workspace_id_from_project_id( + session: Connection, project_id: str +) -> Optional[str]: + statement = select(ProjectDB).filter_by(id=uuid.UUID(project_id)) + project = session.execute(statement).fetchone() + return str(project.workspace_id) if project is not None else None + + +def update_api_key_to_make_use_of_project_id(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + SKIPPED_RECORDS = 0 + + # Count total rows with user_id & workspace_id isnot NULL and project_id is NULL + stmt = ( + select(func.count()) + .select_from(DeprecatedAPIKeyDB) + .filter( + DeprecatedAPIKeyDB.user_id.isnot(None), + DeprecatedAPIKeyDB.workspace_id.isnot(None), + DeprecatedAPIKeyDB.project_id.is_(None), + ) + ) + result = session.execute(stmt).scalar() + TOTAL_API_KEYS_WITH_USER_AND_WORKSPACE_ID = result if result is not None else 0 + print( + f"Total rows in api_keys table with user_id and workspace_id not been NULL is {TOTAL_API_KEYS_WITH_USER_AND_WORKSPACE_ID}" + ) + + while True: + # Fetch a batch of api_keys with user_id & workspace_id not been NULL + records = session.execute( + select(DeprecatedAPIKeyDB) + .filter( + or_( + DeprecatedAPIKeyDB.user_id.isnot(None), + DeprecatedAPIKeyDB.user_id != "None", + ), + or_( + DeprecatedAPIKeyDB.workspace_id != "None", + DeprecatedAPIKeyDB.workspace_id.isnot(None), + ), + DeprecatedAPIKeyDB.project_id.is_(None), + ) + .offset(offset) + .limit(BATCH_SIZE) + ).fetchall() + batch_migrated = len(records) + if not records: + break + + # Process and update records in the batch + for record in records: + print( + "Record (has workspace_id?, workspace id, user id, id, types [workspace_id & user_id]) --- ", + hasattr(record, "workspace_id"), + record.workspace_id, + record.user_id, + record.id, + type(record.workspace_id), + type(record.user_id), + ) + if ( + hasattr(record, "workspace_id") + and record.workspace_id + not in [ + "None", + "", + ] + and record.user_id not in ["None", ""] + ): + project_id = get_project_id_from_workspace_id( + session=session, workspace_id=str(record.workspace_id) + ) + if project_id is None: + SKIPPED_RECORDS += 1 + print( + f"Could not retrieve project_id from workspace_id for APIKey with ID {str(record.id)}." + ) + + batch_migrated -= 1 + print( + "Subtracting record from part of batch. Now, Skipping record..." + ) + continue + + # Add the new object to the session. + insert_statement = insert(APIKeyDB).values( + prefix=record.prefix, + hashed_key=record.hashed_key, + created_by_id=uuid.UUID(record.user_id), + project_id=uuid.UUID(project_id), + rate_limit=record.rate_limit, + hidden=record.hidden, + expiration_date=record.expiration_date, + created_at=record.created_at, + updated_at=record.updated_at, + ) + session.execute(insert_statement) + else: + SKIPPED_RECORDS += 1 + print( + f"No workspace_id found for APIKey with ID {str(record.id)}. Skipping record..." + ) + + batch_migrated -= 1 + print( + "Subtracting record from part of batch. Now, Skipping record..." + ) + continue + + # Update migration progress tracking + TOTAL_MIGRATED += batch_migrated + offset += BATCH_SIZE + remaining_records = ( + TOTAL_API_KEYS_WITH_USER_AND_WORKSPACE_ID - TOTAL_MIGRATED + ) + click.echo( + click.style( + f"Processed {batch_migrated} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}.", + fg="yellow", + ) + ) + + # Break if all records have been processed + if remaining_records <= 0: + break + + # Count total rows with user_id and/or workspace_id been NULL + stmt = ( + select(func.count()) + .select_from(DeprecatedAPIKeyDB) + .filter(DeprecatedAPIKeyDB.project_id.is_(None)) + ) + result = session.execute(stmt).scalar() + TOTAL_API_KEYS_WITH_NO_USER_AND_WORKSPACE_ID = ( + result if result is not None else 0 + ) + if TOTAL_API_KEYS_WITH_NO_USER_AND_WORKSPACE_ID >= 1: + session.execute( + delete(DeprecatedAPIKeyDB).where( + DeprecatedAPIKeyDB.project_id.is_(None) + ) + ) + + print( + f"Total rows in api_keys table with user_id and workspace_id been NULL is {TOTAL_API_KEYS_WITH_NO_USER_AND_WORKSPACE_ID} and have been deleted." + ) + except Exception as e: + session.rollback() + click.echo( + click.style( + f"ERROR updating api_keys to make use of project_id: {traceback.format_exc()}", + fg="red", + ) + ) + raise e + + +def revert_api_key_to_make_use_of_workspace_id(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + SKIPPED_RECORDS = 0 + + # Count total rows with created_by_id & project_id isnot NULL + stmt = ( + select(func.count()) + .select_from(DeprecatedAPIKeyDB) + .filter( + DeprecatedAPIKeyDB.created_by_id.isnot(None), + DeprecatedAPIKeyDB.project_id.isnot(None), + DeprecatedAPIKeyDB.workspace_id.is_(None), + ) + ) + result = session.execute(stmt).scalar() + TOTAL_API_KEYS_WITH_USER_AND_PROJECT_ID = result if result is not None else 0 + print( + f"Total rows in api_keys table with created_by_id and project_id not been NULL is {TOTAL_API_KEYS_WITH_USER_AND_PROJECT_ID}" + ) + + while True: + # Fetch a batch of api_keys with created_by_id & project_id isnot NULL + records = session.execute( + select(DeprecatedAPIKeyDB) + .filter( + DeprecatedAPIKeyDB.created_by_id.isnot(None), + DeprecatedAPIKeyDB.project_id.isnot(None), + DeprecatedAPIKeyDB.workspace_id.is_(None), + ) + .offset(offset) + .limit(BATCH_SIZE) + ).fetchall() + + if not records or len(records) <= 0: + break # Exit if no more records to process + + # Process and update records in the batch + for record in records: + workspace_id = get_workspace_id_from_project_id( + session=session, project_id=str(record.project_id) + ) + if workspace_id is None: + SKIPPED_RECORDS += 1 + print( + f"Could not retrieve workspace_id from project_id for APIKey with ID {str(record.id)}. Skipping record..." + ) + continue + + session.execute( + update(DeprecatedAPIKeyDB) + .where(DeprecatedAPIKeyDB.id == record.id) + .values( + user_id=str(record.created_by_id), + workspace_id=workspace_id, + ) + ) + + # Update migration progress tracking + batch_migrated = len(records) + TOTAL_MIGRATED += batch_migrated + offset += BATCH_SIZE + remaining_records = TOTAL_API_KEYS_WITH_USER_AND_PROJECT_ID - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {batch_migrated} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}.", + fg="yellow", + ) + ) + + # Count total rows with created_by_id and/or project_id been NULL + stmt = ( + select(func.count()) + .select_from(DeprecatedAPIKeyDB) + .filter( + or_( + DeprecatedAPIKeyDB.created_by_id.is_(None), + DeprecatedAPIKeyDB.project_id.is_(None), + ), + ) + ) + result = session.execute(stmt).scalar() + TOTAL_API_KEYS_WITH_NO_USER_AND_PROJECT_ID = result if result is not None else 0 + print( + f"Total rows in api_keys table with created_by_id and project_id been NULL is {TOTAL_API_KEYS_WITH_NO_USER_AND_PROJECT_ID}" + ) + except Exception as e: + session.rollback() + click.echo( + click.style( + f"ERROR reverting api_keys to make use of workspace_id: {traceback.format_exc()}", + fg="red", + ) + ) + raise e diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/applications.py b/api/ee/databases/postgres/migrations/core/data_migrations/applications.py new file mode 100644 index 0000000000..95353642ec --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/applications.py @@ -0,0 +1,124 @@ +import uuid +import traceback +from typing import Optional + + +import click +from sqlalchemy.future import select +from sqlalchemy import delete, Connection, update, func + +from oss.src.models.deprecated_models import ( # type: ignore + DeprecatedEvaluatorConfigDBwApp as DeprecatedEvaluatorConfigDB, + DeprecatedAppDB, +) + + +BATCH_SIZE = 200 + + +def get_app_db(session: Connection, app_id: str) -> Optional[DeprecatedAppDB]: + query = session.execute(select(DeprecatedAppDB).filter_by(id=uuid.UUID(app_id))) + return query.fetchone() # type: ignore + + +def update_evaluators_with_app_name(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + SKIPPED_RECORDS = 0 + + # Count total rows with a non-null app_id + total_query = ( + select(func.count()) + .select_from(DeprecatedEvaluatorConfigDB) + .filter(DeprecatedEvaluatorConfigDB.app_id.isnot(None)) + ) + result = session.execute(total_query).scalar() + TOTAL_EVALUATOR_CONFIGS = result if result is not None else 0 + print( + f"Total rows in evaluator_configs table with app_id: {TOTAL_EVALUATOR_CONFIGS}" + ) + + while True: + # Fetch a batch of evaluator_configs with non-null app_id + records = session.execute( + select(DeprecatedEvaluatorConfigDB) + .filter(DeprecatedEvaluatorConfigDB.app_id.isnot(None)) + .offset(offset) + .limit(BATCH_SIZE) + ).fetchall() + if not records: + break + + # Process and update records in the batch + for record in records: + if hasattr(record, "app_id") and record.app_id is not None: + evaluator_config_app = get_app_db( + session=session, app_id=str(record.app_id) + ) + if evaluator_config_app is not None: + # Update the name with the app_name as a prefix + new_name = f"{record.name} ({evaluator_config_app.app_name})" + session.execute( + update(DeprecatedEvaluatorConfigDB) + .where(DeprecatedEvaluatorConfigDB.id == record.id) + .values(name=new_name) + ) + else: + print( + f"Skipping... No application found for evaluator_config {str(record.id)}." + ) + SKIPPED_RECORDS += 1 + else: + print( + f"Skipping... evaluator_config {str(record.id)} have app_id that is NULL." + ) + SKIPPED_RECORDS += 1 + + session.commit() + + # Update progress tracking + batch_migrated = len(records) + TOTAL_MIGRATED += batch_migrated + offset += BATCH_SIZE + remaining_records = TOTAL_EVALUATOR_CONFIGS - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {batch_migrated} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Break if all records have been processed + if remaining_records <= 0: + break + + # Delete deprecated evaluator configs with app_id as None + stmt = ( + select(func.count()) + .select_from(DeprecatedEvaluatorConfigDB) + .filter(DeprecatedEvaluatorConfigDB.app_id.is_(None)) + ) + result = session.execute(stmt).scalar() + TOTAL_EVALUATOR_CONFIGS_WITH_NO_APPID = result if result is not None else 0 + print( + f"Total rows in evaluator_configs table with no app_id: {TOTAL_EVALUATOR_CONFIGS_WITH_NO_APPID}. Deleting these rows..." + ) + + session.execute( + delete(DeprecatedEvaluatorConfigDB).where( + DeprecatedEvaluatorConfigDB.app_id.is_(None) + ) + ) + session.commit() + print("Successfully deleted rows in evaluator_configs with no app_id.") + + except Exception as e: + session.rollback() + click.echo( + click.style( + f"ERROR updating evaluator config names: {traceback.format_exc()}", + fg="red", + ) + ) + raise e diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/demos.py b/api/ee/databases/postgres/migrations/core/data_migrations/demos.py new file mode 100644 index 0000000000..06e2403fd2 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/demos.py @@ -0,0 +1,576 @@ +from os import getenv +from uuid import UUID +from json import loads +from functools import wraps +from traceback import format_exc +from typing import List, Optional + +from click import echo, style +from pydantic import BaseModel + + +from sqlalchemy import Connection, delete, insert +from sqlalchemy.future import select + +from oss.src.models.db_models import UserDB +from ee.src.models.db_models import ( + ProjectDB, + OrganizationMemberDB, + WorkspaceMemberDB, + ProjectMemberDB, +) + + +BATCH_SIZE = 100 +DEMOS = "AGENTA_DEMOS" +DEMO_ROLE = "viewer" +OWNER_ROLE = "owner" + + +class Demo(BaseModel): + organization_id: UUID + workspace_id: UUID + project_id: UUID + + +class User(BaseModel): + user_id: UUID + + +class Member(BaseModel): + user_id: UUID + + organization_id: Optional[UUID] = None + workspace_id: Optional[UUID] = None + project_id: Optional[UUID] = None + + role: Optional[str] = None + + +def with_rollback(): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as exc: + session = kwargs.get("session") + + session.rollback() + + log_error(format_exc()) + + raise exc + + return wrapper + + return decorator + + +def log_info(message) -> None: + echo(style(f"{message}", fg="green"), color=True) + + +def log_error(message) -> None: + echo(style(f"ERROR: {message}", fg="red"), color=True) + + +def fetch_project( + session: Connection, + project_id: UUID, +) -> ProjectDB: + result = session.execute( + select( + ProjectDB.id, + ProjectDB.workspace_id, + ProjectDB.organization_id, + ).where( + ProjectDB.id == project_id, + ) + ).first() + + project = ProjectDB( + id=result.id, + workspace_id=result.workspace_id, + organization_id=result.organization_id, + ) + + return project + + +def list_all_demos(session: Connection) -> List[Demo]: + demos = [] + + try: + demo_project_ids = loads(getenv(DEMOS) or "[]") + + for project_id in demo_project_ids: + project = fetch_project( + session, + project_id, + ) + + try: + demos.append( + Demo( + organization_id=project.organization_id, + workspace_id=project.workspace_id, + project_id=project_id, + ) + ) + + except: # pylint: disable=bare-except + pass + + except: # pylint: disable=bare-except + pass + + return demos + + +def list_all_users( + session: Connection, +) -> List[User]: + user_ids = session.execute(select(UserDB.id)).scalars().all() + + all_users = [User(user_id=user_id) for user_id in user_ids] + + return all_users + + +def fetch_organization_members( + session: Connection, + organization_id: UUID, +) -> List[Member]: + result = session.execute( + select( + OrganizationMemberDB.user_id, + OrganizationMemberDB.organization_id, + ).where( + OrganizationMemberDB.organization_id == organization_id, + ) + ).all() + + organization_members = [ + Member( + user_id=row.user_id, + organization_id=row.organization_id, + ) + for row in result + ] + + return organization_members + + +def get_new_organization_members( + users: List[User], + members: List[Member], +) -> List[Member]: + user_ids = {user.user_id for user in users} + member_user_ids = {member.user_id for member in members} + + new_user_ids = user_ids - member_user_ids + + new_members = [Member(user_id=user_id) for user_id in new_user_ids] + + return new_members + + +def add_new_members_to_organization( + session: Connection, + organization_id: UUID, + new_members: List[Member], +) -> None: + for i in range(0, len(new_members), BATCH_SIZE): + batch = new_members[i : i + BATCH_SIZE] + + values = [ + { + "user_id": member.user_id, + "organization_id": organization_id, + } + for member in batch + ] + + session.execute(insert(OrganizationMemberDB).values(values)) + + +def remove_all_members_from_organization( + session: Connection, + organization_id: UUID, +) -> None: + session.execute( + delete(OrganizationMemberDB).where( + OrganizationMemberDB.organization_id == organization_id, + ) + ) + + +def fetch_workspace_members( + session: Connection, + workspace_id: UUID, +) -> List[Member]: + result = session.execute( + select( + WorkspaceMemberDB.user_id, + WorkspaceMemberDB.workspace_id, + WorkspaceMemberDB.role, + ).where( + WorkspaceMemberDB.workspace_id == workspace_id, + ) + ).all() + + members = [ + Member( + user_id=row.user_id, + workspace_id=row.workspace_id, + role=row.role, + ) + for row in result + ] + + return members + + +def get_faulty_workspace_members( + members: List[Member], +) -> List[Member]: + member_user_ids = { + member.user_id + for member in members + if member.role not in [DEMO_ROLE, OWNER_ROLE] + } + + new_members = [Member(user_id=user_id) for user_id in member_user_ids] + + return new_members + + +def remove_faulty_workspace_members( + session: Connection, + workspace_id: UUID, + faulty_members: List[Member], +) -> None: + faulty_user_ids = [member.user_id for member in faulty_members] + + for i in range(0, len(faulty_user_ids), BATCH_SIZE): + batch = faulty_user_ids[i : i + BATCH_SIZE] + + session.execute( + delete(WorkspaceMemberDB) + .where(WorkspaceMemberDB.workspace_id == workspace_id) + .where(WorkspaceMemberDB.user_id.in_(batch)) + ) + + +def get_new_workspace_members( + users: List[User], + members: List[Member], +) -> List[Member]: + user_ids = {user.user_id for user in users} + member_user_ids = { + member.user_id for member in members if member.role in [DEMO_ROLE, OWNER_ROLE] + } + + new_user_ids = user_ids - member_user_ids + + new_members = [Member(user_id=user_id) for user_id in new_user_ids] + + return new_members + + +def add_new_members_to_workspace( + session: Connection, + workspace_id: UUID, + new_members: List[Member], +) -> None: + for i in range(0, len(new_members), BATCH_SIZE): + batch = new_members[i : i + BATCH_SIZE] + + values = [ + { + "user_id": member.user_id, + "workspace_id": workspace_id, + "role": DEMO_ROLE, + } + for member in batch + ] + + session.execute(insert(WorkspaceMemberDB).values(values)) + + +def remove_all_members_from_workspace( + session: Connection, + workspace_id: UUID, +) -> None: + session.execute( + delete(WorkspaceMemberDB).where( + WorkspaceMemberDB.workspace_id == workspace_id, + ) + ) + + +def fetch_project_members( + session: Connection, + project_id: UUID, +) -> List[Member]: + result = session.execute( + select( + ProjectMemberDB.user_id, + ProjectMemberDB.project_id, + ProjectMemberDB.role, + ).where( + ProjectMemberDB.project_id == project_id, + ) + ).all() + + members = [ + Member( + user_id=row.user_id, + project_id=row.project_id, + role=row.role, + ) + for row in result + ] + + return members + + +def get_faulty_project_members( + members: List[Member], +) -> List[Member]: + member_user_ids = { + member.user_id + for member in members + if member.role not in [DEMO_ROLE, OWNER_ROLE] + } + + new_members = [Member(user_id=user_id) for user_id in member_user_ids] + + return new_members + + +def remove_faulty_project_members( + session: Connection, + project_id: UUID, + faulty_members: List[Member], +) -> None: + faulty_user_ids = [member.user_id for member in faulty_members] + + for i in range(0, len(faulty_user_ids), BATCH_SIZE): + batch = faulty_user_ids[i : i + BATCH_SIZE] + + session.execute( + delete(ProjectMemberDB) + .where(ProjectMemberDB.project_id == project_id) + .where(ProjectMemberDB.user_id.in_(batch)) + ) + + +def get_new_project_members( + users: List[User], + members: List[Member], +) -> List[Member]: + user_ids = {user.user_id for user in users} + member_user_ids = { + member.user_id for member in members if member.role in [DEMO_ROLE, OWNER_ROLE] + } + + new_user_ids = user_ids - member_user_ids + + new_members = [Member(user_id=user_id) for user_id in new_user_ids] + + return new_members + + +def add_new_members_to_project( + session: Connection, + project_id: UUID, + new_members: List[Member], +) -> None: + for i in range(0, len(new_members), BATCH_SIZE): + batch = new_members[i : i + BATCH_SIZE] + + values = [ + { + "user_id": member.user_id, + "project_id": project_id, + "role": DEMO_ROLE, + "is_demo": True, + } + for member in batch + ] + + session.execute(insert(ProjectMemberDB).values(values)) + + +def remove_all_members_from_project( + session: Connection, + project_id: UUID, +) -> None: + session.execute( + delete(ProjectMemberDB).where( + ProjectMemberDB.project_id == project_id, + ) + ) + + +@with_rollback() +def add_users_to_demos(session: Connection) -> None: + log_info("Populating demos.") + + all_demos = list_all_demos(session) + + log_info(f"Found {len(all_demos)} demos.") + + all_users = list_all_users(session) + + log_info(f"Found {len(all_users)} users.") + + for i, demo in enumerate(all_demos): + log_info(f"Populating demo #{i}.") + + # DEMO ORGANIZATIONS + organization_members = fetch_organization_members( + session, + demo.organization_id, + ) + + log_info(f"Found {len(organization_members)} organization members.") + + new_organization_members = get_new_organization_members( + all_users, + organization_members, + ) + + log_info(f"Missing {len(new_organization_members)} organization members.") + + add_new_members_to_organization( + session, + demo.organization_id, + new_organization_members, + ) + + log_info(f"Added {len(new_organization_members)} organization members.") + # ------------------ + + # DEMO WORKSPACES + workspace_members = fetch_workspace_members( + session, + demo.workspace_id, + ) + + log_info(f"Found {len(workspace_members)} workspace members.") + + faulty_workspace_members = get_faulty_workspace_members( + workspace_members, + ) + + log_info(f"Found {len(faulty_workspace_members)} faulty workspace members.") + + remove_faulty_workspace_members( + session, + demo.workspace_id, + faulty_workspace_members, + ) + + log_info(f"Removed {len(faulty_workspace_members)} faulty workspace members.") + + new_workspace_members = get_new_workspace_members( + all_users, + workspace_members, + ) + + log_info(f"Missing {len(new_workspace_members)} workspace members.") + + add_new_members_to_workspace( + session, + demo.workspace_id, + new_workspace_members, + ) + + log_info(f"Added {len(new_workspace_members)} workspace members.") + # --------------- + + # DEMO PROJECTS + project_members = fetch_project_members( + session, + demo.project_id, + ) + + log_info(f"Found {len(project_members)} project members.") + + faulty_project_members = get_faulty_project_members( + project_members, + ) + + log_info(f"Found {len(faulty_project_members)} faulty project members.") + + remove_faulty_project_members( + session, + demo.project_id, + faulty_project_members, + ) + + log_info(f"Removed {len(faulty_project_members)} faulty project members.") + + new_project_members = get_new_project_members( + all_users, + project_members, + ) + + log_info(f"Missing {len(new_project_members)} project members.") + + add_new_members_to_project( + session, + demo.project_id, + new_project_members, + ) + + log_info(f"Added {len(new_project_members)} project members.") + # ------------- + + log_info(f"Done with demo #{i}.") + + log_info("Done with demos.") + + +@with_rollback() +def remove_users_from_demos(session: Connection) -> None: + log_info("Cleaning up demos.") + + all_demos = list_all_demos(session) + + for i, demo in enumerate(all_demos): + log_info(f"Cleaning up demo #{i}.") + + # DEMO PROJECTS + remove_all_members_from_project( + session, + demo.project_id, + ) + # ------------- + + log_info("Removed project members.") + + # DEMO WORKSPACES + remove_all_members_from_workspace( + session, + demo.workspace_id, + ) + # --------------- + + log_info("Removed workspace members.") + + # DEMO ORGANIZATIONS + remove_all_members_from_organization( + session, + demo.organization_id, + ) + # ------------------ + + log_info("Removed organization members.") + + log_info(f"Done with demo #{i}.") + + log_info("Done with demos.") diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/evaluators.py b/api/ee/databases/postgres/migrations/core/data_migrations/evaluators.py new file mode 100644 index 0000000000..c6b82d338c --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/evaluators.py @@ -0,0 +1,195 @@ +import uuid +import asyncio +import traceback +from typing import Optional + +import click +from sqlalchemy.future import select +from sqlalchemy import func +from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine + + +from ee.src.models.db_models import WorkspaceMemberDB as WorkspaceMemberDBE +from oss.src.models.db_models import ProjectDB as ProjectDBE +from oss.src.dbs.postgres.workflows.dbes import ( + WorkflowArtifactDBE, + WorkflowVariantDBE, + WorkflowRevisionDBE, +) +from oss.src.dbs.postgres.git.dao import GitDAO +from oss.src.core.evaluators.service import SimpleEvaluatorsService, EvaluatorsService +from oss.src.models.deprecated_models import ( + DeprecatedAutoEvaluatorConfigDBwProject as DeprecatedEvaluatorConfigDBwProject, +) +from oss.src.core.workflows.service import WorkflowsService +from oss.src.core.tracing.service import TracingService +from oss.src.apis.fastapi.tracing.router import TracingRouter +from oss.src.dbs.postgres.tracing.dao import TracingDAO + + +# Define constants +DEFAULT_BATCH_SIZE = 200 + +# Initialize plug-ins for migration +tracing_service = TracingService( + tracing_dao=TracingDAO(), +) +tracing = TracingRouter( + tracing_service=tracing_service, +) +evaluators_service = EvaluatorsService( + workflows_service=WorkflowsService( + workflows_dao=GitDAO( + ArtifactDBE=WorkflowArtifactDBE, + VariantDBE=WorkflowVariantDBE, + RevisionDBE=WorkflowRevisionDBE, + ), + ) +) +simple_evaluators_service = SimpleEvaluatorsService( + evaluators_service=evaluators_service, +) + + +async def _fetch_project_owner( + *, + project_id: uuid.UUID, + connection: AsyncConnection, +) -> Optional[uuid.UUID]: + """Fetch the owner user ID for a given project.""" + workspace_owner_query = ( + select(WorkspaceMemberDBE.user_id) + .select_from(WorkspaceMemberDBE, ProjectDBE) + .where( + WorkspaceMemberDBE.workspace_id == ProjectDBE.workspace_id, + WorkspaceMemberDBE.role == "owner", + ProjectDBE.id == project_id, + ) + ) + result = await connection.execute(workspace_owner_query) + owner = result.scalar_one_or_none() + return owner + + +async def migration_old_evaluator_configs_to_new_evaluator_configs( + connection: AsyncConnection, +): + """Migrate old evaluator configurations to new workflow-based system.""" + try: + offset = 0 + total_migrated = 0 + skipped_records = 0 + + # Count total rows with a non-null project_id + total_query = ( + select(func.count()) + .select_from(DeprecatedEvaluatorConfigDBwProject) + .filter(DeprecatedEvaluatorConfigDBwProject.project_id.isnot(None)) + ) + result = await connection.execute(total_query) + total_rows = result.scalar() + total_evaluators = total_rows or 0 + + click.echo( + click.style( + f"Total rows in evaluator_configs with project_id: {total_evaluators}", + fg="yellow", + ) + ) + + while offset < total_evaluators: + # STEP 1: Fetch evaluator configurations with non-null project_id + result = await connection.execute( + select(DeprecatedEvaluatorConfigDBwProject) + .filter(DeprecatedEvaluatorConfigDBwProject.project_id.isnot(None)) + .offset(offset) + .limit(DEFAULT_BATCH_SIZE) + ) + evaluator_configs_rows = result.fetchall() + + if not evaluator_configs_rows: + break + + # Process and transfer records to evaluator workflows + for old_evaluator in evaluator_configs_rows: + try: + # STEP 2: Get owner from project_id + owner = await _fetch_project_owner( + project_id=old_evaluator.project_id, # type: ignore + connection=connection, + ) + if not owner: + skipped_records += 1 + click.echo( + click.style( + f"Skipping record with ID {old_evaluator.id} due to missing owner in workspace member table", + fg="yellow", + ) + ) + continue + + # STEP 3: Migrate records using transfer_* util function + new_evaluator = await simple_evaluators_service.transfer( + project_id=old_evaluator.project_id, + user_id=owner, + evaluator_id=old_evaluator.id, + ) + if not new_evaluator: + skipped_records += 1 + click.echo( + click.style( + f"Skipping record with ID {old_evaluator.id} due to old evaluator not existing in database table", + fg="yellow", + ) + ) + continue + + except Exception as e: + click.echo( + click.style( + f"Failed to migrate evaluator {old_evaluator.id}: {str(e)}", + fg="red", + ) + ) + click.echo(click.style(traceback.format_exc(), fg="red")) + skipped_records += 1 + continue + + # Update progress tracking for current batch + batch_migrated = len(evaluator_configs_rows) + offset += DEFAULT_BATCH_SIZE + total_migrated += batch_migrated + + click.echo( + click.style( + f"Processed {batch_migrated} records in this batch.", + fg="yellow", + ) + ) + + # Update progress tracking for all batches + remaining_records = total_evaluators - total_migrated + click.echo(click.style(f"Total migrated: {total_migrated}", fg="yellow")) + click.echo(click.style(f"Skipped records: {skipped_records}", fg="yellow")) + click.echo( + click.style(f"Records left to migrate: {remaining_records}", fg="yellow") + ) + + except Exception as e: + click.echo(f"Error occurred: {e}") + click.echo(click.style(traceback.format_exc(), fg="red")) + + +def run_migration(sqlalchemy_url: str): + import concurrent.futures + + async def _start(): + connection = create_async_engine(url=sqlalchemy_url) + async with connection.connect() as connection: + await migration_old_evaluator_configs_to_new_evaluator_configs( + connection=connection + ) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, _start()) + future.result() diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/export_records.py b/api/ee/databases/postgres/migrations/core/data_migrations/export_records.py new file mode 100644 index 0000000000..f6aa6e3a0d --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/export_records.py @@ -0,0 +1,175 @@ +import traceback +import click +from sqlalchemy.future import select +from sqlalchemy import Connection, insert, func +from ee.src.models.db_models import OrganizationMemberDB # type: ignore +from ee.src.models.extended.deprecated_models import UserOrganizationDB # type: ignore + +BATCH_SIZE = 200 + + +def transfer_records_from_user_organization_to_organization_members( + session: Connection, +): + try: + offset = 0 + TOTAL_MIGRATED = 0 + + # Count total rows in user_organizations table + total_query = select(func.count()).select_from(UserOrganizationDB) + result = session.execute(total_query).scalar() + TOTAL_USERS_ORGANIZATIONS = result if result is not None else 0 + print(f"Total rows in UserOrganizationDB table: {TOTAL_USERS_ORGANIZATIONS}") + + while True: + # Fetch a batch of records from user_organizations with ordering + users_in_organizations = session.execute( + select(UserOrganizationDB).offset(offset).limit(BATCH_SIZE) + ).fetchall() + + actual_batch_size = len(users_in_organizations) + if actual_batch_size == 0: + break + + for user_organization in users_in_organizations: + # Check if the record already exists in OrganizationMemberDB + existing_record = session.execute( + select(OrganizationMemberDB).where( + OrganizationMemberDB.user_id == user_organization.user_id, + OrganizationMemberDB.organization_id + == user_organization.organization_id, + ) + ).fetchone() + if existing_record: + # Log that a duplicate was found + click.echo( + click.style( + f"Duplicate record found for user_id {user_organization.user_id} and organization_id {user_organization.organization_id}. Skipping.", + fg="yellow", + ) + ) + continue # Skip inserting this record + + # Insert a new record in OrganizationMemberDB + insert_statement = insert(OrganizationMemberDB).values( + user_id=user_organization.user_id, + organization_id=user_organization.organization_id, + ) + session.execute(insert_statement) + + # Commit the batch + session.commit() + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_USERS_ORGANIZATIONS - TOTAL_MIGRATED + + click.echo( + click.style( + f"Processed {actual_batch_size} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Check if there are still remaining records + remaining_records_query = select(func.count()).select_from(UserOrganizationDB) + remaining_count = session.execute(remaining_records_query).scalar() + records_left_count = remaining_count if remaining_count is not None else 0 + if records_left_count > 0: + click.echo( + click.style( + f"There are still {remaining_count} records left in UserOrganizationDB that were not migrated.", + fg="red", + ) + ) + + click.echo( + click.style( + "\nSuccessfully migrated records and handled duplicates in user_organization table to organization_members.", + fg="green", + ), + color=True, + ) + except Exception as e: + # Handle exceptions and rollback if necessary + session.rollback() + click.echo( + click.style( + f"\nAn ERROR occurred while transferring records: {traceback.format_exc()}", + fg="red", + ), + color=True, + ) + raise e + + +def transfer_records_from_organization_members_to_user_organization( + session: Connection, +): + try: + offset = 0 + TOTAL_MIGRATED = 0 + + # Count total rows in OrganizationMemberDB + total_query = select(func.count()).select_from(OrganizationMemberDB) + result = session.execute(total_query).scalar() + TOTAL_ORGANIZATIONS_MEMBERS = result if result is not None else 0 + print( + f"Total rows in OrganizationMemberDB table: {TOTAL_ORGANIZATIONS_MEMBERS}" + ) + + while True: + # Retrieve a batch of records from OrganizationMemberDB + members_in_organizations = session.execute( + select(OrganizationMemberDB).offset(offset).limit(BATCH_SIZE) + ).fetchall() + actual_batch_size = len(members_in_organizations) + if not members_in_organizations: + break + + # Process each record in the current batch + for user_organization in members_in_organizations: + # Create a new record in UserOrganizationDB + insert_statement = insert(UserOrganizationDB).values( + user_id=user_organization.user_id, + organization_id=user_organization.organization_id, + ) + session.execute(insert_statement) + + # Commit the batch + session.commit() + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_ORGANIZATIONS_MEMBERS - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {actual_batch_size} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Break the loop if all records are migrated + if remaining_records <= 0: + break + + click.echo( + click.style( + "\nSuccessfully migrated records in organization_members table to user_organizations table.", + fg="green", + ), + color=True, + ) + except Exception as e: + # Handle exceptions and rollback if necessary + session.rollback() + click.echo( + click.style( + f"\nAn ERROR occurred while transferring records from organization_members to user_organizations: {traceback.format_exc()}", + fg="red", + ), + color=True, + ) + raise e diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/invitations.py b/api/ee/databases/postgres/migrations/core/data_migrations/invitations.py new file mode 100644 index 0000000000..802f2ef4fe --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/invitations.py @@ -0,0 +1,192 @@ +import os +import uuid +import traceback + +import click +from sqlalchemy.future import select +from sqlalchemy.orm import joinedload +from sqlalchemy import delete, Connection, insert, func + +from oss.src.models.db_models import UserDB, InvitationDB, ProjectDB +from ee.src.models.extended.deprecated_models import OldInvitationDB + + +BATCH_SIZE = 200 + + +def transfer_invitations_from_old_table_to_new_table(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + SKIPPED_INVITATIONS = 0 + + # Count total rows in OldInvitationDB table + count_query = select(func.count()).select_from(OldInvitationDB) + result = session.execute(count_query).scalar() + TOTAL_INVITATIONS = result if result is not None else 0 + print(f"Total rows in OldInvitationDB table is {TOTAL_INVITATIONS}") + + while True: + # Retrieve a batch of old invitations + query = session.execute( + select(OldInvitationDB).offset(offset).limit(BATCH_SIZE) + ) + old_invitations = query.fetchall() + actual_batch_size = len(old_invitations) + if not old_invitations: + break + + for old_invitation in old_invitations: + user = session.execute( + select(UserDB).where(UserDB.email == old_invitation.email) + ).fetchone() + + project = session.execute( + select(ProjectDB).where( + ProjectDB.workspace_id == uuid.UUID(old_invitation.workspace_id) + ) + ).fetchone() + if user and project: + print( + f"Found user {user.username} in workspace invitation ({str(old_invitation.id)})" + ) + print( + f"Found project {str(project.id)} that will be used to transfer workspace invitation into." + ) + # Map fields from OldInvitationDB to InvitationDB + statement = insert(InvitationDB).values( + id=old_invitation.id, + token=old_invitation.token, + email=old_invitation.email, + used=old_invitation.used, + role=old_invitation.workspace_roles[0], + user_id=user.id, + project_id=project.id, + expiration_date=old_invitation.expiration_date, + ) + + # Add the new invitation to the session + session.execute(statement) + + # Remove old invitation + session.execute( + delete(OldInvitationDB).where( + OldInvitationDB.id == old_invitation.id + ) + ) + else: + print( + f"Skipping unused workspace invitation {str(old_invitation.id)}. No matching user or project." + ) + SKIPPED_INVITATIONS += 1 + + # Commit the changes for the current batch + session.commit() + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_INVITATIONS - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {actual_batch_size} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Stop the loop when all records have been processed + if remaining_records <= 0: + break + + click.echo( + click.style( + f"\nSuccessfully transferred workspaces invitations to projects invitations table. Skipped {SKIPPED_INVITATIONS} records.", + fg="green", + ), + color=True, + ) + + except Exception as e: + session.rollback() + click.echo( + click.style( + f"\nAn ERROR occurred while transferring workspaces invitations: {traceback.format_exc()}", + fg="red", + ), + color=True, + ) + raise e + + +def revert_invitations_transfer_from_new_table_to_old_table(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + + # Count total rows in invitations table + stmt = select(func.count()).select_from(InvitationDB) + result = session.execute(stmt).scalar() + TOTAL_INVITATIONS = result if result is not None else 0 + print(f"Total rows in project_invitations table is {TOTAL_INVITATIONS}") + + while True: + # Retrieve a batch of project invitations + project_invitations = session.execute( + select(InvitationDB) + .offset(offset) + .limit(BATCH_SIZE) + .options(joinedload(InvitationDB.project)) + ).fetchall() + if not project_invitations: + break + + for project_invitation in project_invitations: + # Map fields from InvitationDB to OldInvitationDB + statement = insert(OldInvitationDB).values( + id=project_invitation.id, + token=project_invitation.token, + email=project_invitation.email, + used=project_invitation.used, + organization_id=str(project_invitation.project.workspace_id), + workspace_id=str(project_invitation.project.workspace_id), + workspace_roles=[project_invitation.role], + expiration_date=project_invitation.expiration_date, + ) + session.execute(statement) + + # Remove previous invitation (that references project_id) + session.execute( + delete(InvitationDB).where(InvitationDB.id == project_invitation.id) + ) + + # Commit the changes for the current batch + session.commit() + + # Update migration progress + TOTAL_MIGRATED += BATCH_SIZE + offset += BATCH_SIZE + click.echo( + click.style( + f"Processed {offset} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {TOTAL_INVITATIONS - TOTAL_MIGRATED}", + fg="yellow", + ) + ) + + click.echo( + click.style( + "\nSuccessfully transferred projects invitations to the workspaces invitations table.", + fg="green", + ), + color=True, + ) + + except Exception as e: + session.rollback() + click.echo( + click.style( + f"\nAn ERROR occurred while transferring projects invitations: {traceback.format_exc()}", + fg="red", + ), + color=True, + ) + raise e diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/projects.py b/api/ee/databases/postgres/migrations/core/data_migrations/projects.py new file mode 100644 index 0000000000..293b05f52a --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/projects.py @@ -0,0 +1,501 @@ +import uuid +import traceback +from typing import Dict, Optional +from collections import defaultdict + +import click +from sqlalchemy.future import select +from sqlalchemy import Connection, update, func, or_ + +from ee.src.models.extended.deprecated_transfer_models import ( # type: ignore + ProjectDB, + AppDB, + AppVariantDB, + AppVariantRevisionsDB, + VariantBaseDB, + DeploymentDB, + AppEnvironmentDB, + AppEnvironmentRevisionDB, + EvaluationScenarioDB, + EvaluationDB, + EvaluatorConfigDB, + HumanEvaluationDB, + HumanEvaluationScenarioDB, + TestSetDB, +) + + +MODELS = [ + AppDB, # have workspace_id + AppVariantDB, # have workspace_id + AppVariantRevisionsDB, # doesn't have, but can make use of variant_id to get workspace_id + VariantBaseDB, # have workspace_id + DeploymentDB, # have workspace_id + AppEnvironmentDB, # have workspace_id + AppEnvironmentRevisionDB, # have workspace_id + EvaluationScenarioDB, # have workspace_id + EvaluationDB, # have workspace_id + EvaluatorConfigDB, # have workspace_id + HumanEvaluationDB, # have workspace_id + HumanEvaluationScenarioDB, # have workspace_id + TestSetDB, # have workspace_id +] + + +def get_workspace_project_by_id( + session: Connection, workspace_id: str +) -> Optional[str]: + workspace_project = session.execute( + select(ProjectDB).filter_by( + is_default=True, workspace_id=uuid.UUID(workspace_id) + ) + ).fetchone() + return str(workspace_project.id) if workspace_project is not None else None + + +def get_variant_by_id(session: Connection, variant_id: str) -> Optional[AppVariantDB]: + query = session.execute(select(AppVariantDB).filter_by(id=uuid.UUID(variant_id))) + return query.fetchone() # type: ignore + + +def get_app_by_id(session: Connection, app_id: str) -> Optional[AppDB]: + query = session.execute(select(AppDB).filter_by(id=uuid.UUID(app_id))) + return query.fetchone() # type: ignore + + +def get_evaluation_by_id( + session: Connection, evaluation_id: str +) -> Optional[EvaluationDB]: + query = session.execute(select(EvaluationDB).filter_by(id=uuid.UUID(evaluation_id))) + return query.fetchone() # type: ignore + + +def get_workspace_project_id(session: Connection, workspace_id: str) -> Optional[str]: + query = session.execute( + select(ProjectDB).filter_by( + workspace_id=uuid.UUID(workspace_id), is_default=True + ) + ) + workspace_project = query.fetchone() + return str(workspace_project.id) if workspace_project is not None else None + + +def repair_evaluation_scenario_to_have_project_id(session: Connection): + offset = 0 + BATCH_SIZE = 200 + TOTAL_MIGRATED = 0 + + # Count total rows for evaluation_scenarios with project_id = None + count_query = ( + select(func.count()) + .select_from(EvaluationScenarioDB) + .filter(EvaluationScenarioDB.project_id.is_(None)) + ) + result = session.execute(count_query).scalar() + TOTAL_ROWS_OF_TABLE = result if result is not None else 0 + print( + f"\nTotal rows in {EvaluationScenarioDB.__tablename__} table with no workspace_id: {TOTAL_ROWS_OF_TABLE}. Repairing rows to make use of workspace_id from either variant_id or evaluation_id..." + ) + + while True: + # Fetch records where project_id is None + records = session.execute( + select(EvaluationScenarioDB) + .filter( + EvaluationScenarioDB.project_id.is_(None), + or_( + EvaluationScenarioDB.variant_id.isnot(None), + EvaluationScenarioDB.evaluation_id.isnot(None), + ), + ) + .limit(BATCH_SIZE) + ).fetchall() + + # If no more records are returned, break the loop + if not records or len(records) == 0: + break + + # Update records with default project_id + for record in records: + workspace_id = None + + if hasattr(record, "variant_id") and record.variant_id is not None: + variant = get_variant_by_id( + session=session, variant_id=str(record.variant_id) + ) + if variant is None: + print( + f"ES {str(record.id)} did not return any variant to retrieve the workspace_id. Now, trying evaluation..." + ) + else: + workspace_id = str(variant.workspace_id) + + if ( + workspace_id is None + and hasattr(record, "evaluation_id") + and record.evaluation_id is not None + ): + evaluation = get_evaluation_by_id( + session=session, evaluation_id=str(record.evaluation_id) + ) + if evaluation is None: + print( + f"ES {str(record.id)} did not return any evaluation or variant to retrieve the workspace_id. Skipping record..." + ) + continue # Skip this record as no valid workspace_id found + + workspace_id = str(evaluation.workspace_id) + + # Update model record workspace_id field if a valid project_id was found + if workspace_id is not None: + workspace_project_id = get_workspace_project_by_id( + session=session, workspace_id=workspace_id + ) + session.execute( + update(EvaluationScenarioDB) + .where(EvaluationScenarioDB.id == record.id) + .values(project_id=uuid.UUID(workspace_project_id)) + ) + else: + print( + f"Evaluation scenario {str(record.id)} did not find a variant_id {record.variant_id} and evaluation {record.evaluation_id} to make use of." + ) + + session.commit() + + # Update migration progress + batch_migrated = len(records) + TOTAL_MIGRATED += batch_migrated + offset += batch_migrated + remaining_records = TOTAL_ROWS_OF_TABLE - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {batch_migrated} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Break if all records have been processed + records_with_no_variant_and_workspace_count_query = ( + select(func.count()) + .select_from(EvaluationScenarioDB) + .filter( + EvaluationScenarioDB.project_id.is_(None), + EvaluationScenarioDB.evaluation_id.is_(None), + EvaluationScenarioDB.variant_id.is_(None), + ) + ) + result = session.execute( + records_with_no_variant_and_workspace_count_query + ).scalar() + UNREPAIRABLE_DATA = result if result is not None else 0 + click.echo( + click.style( + f"Total malformed records with no variant_id & evaluation_id: {UNREPAIRABLE_DATA}", + fg="yellow", + ) + ) + + # Final reporting + click.echo( + click.style( + f"Migration to repair evaluation_scenario to have project_id completed.", + fg="green", + ) + ) + + +def repair_evaluator_configs_to_have_project_id(session: Connection): + offset = 0 + BATCH_SIZE = 200 + TOTAL_MIGRATED = 0 + SKIPPED_RECORDS = 0 + + # Count total rows for evaluator_configs with workspace_id = None + count_query = ( + select(func.count()) + .select_from(EvaluatorConfigDB) + .filter(EvaluatorConfigDB.project_id.is_(None)) + ) + result = session.execute(count_query).scalar() + TOTAL_ROWS_OF_TABLE = result if result is not None else 0 + print( + f"\nTotal rows in {EvaluatorConfigDB.__tablename__} table with no workspace_id: {TOTAL_ROWS_OF_TABLE}. Repairing rows to make use of workspace_id from app..." + ) + + while True: + # Fetch records where project_id is None + records = session.execute( + select(EvaluatorConfigDB) + .filter(EvaluatorConfigDB.project_id.is_(None)) + .limit(BATCH_SIZE) + ).fetchall() + + # Update records with default project_id + for record in records: + workspace_id = None + + if hasattr(record, "app_id") and ( + record.app_id is None or record.app_id == "" + ): + print(f"Evaluator config {str(record.id)} have no app_id. Skipping...") + SKIPPED_RECORDS += 1 + continue + + if hasattr(record, "app_id") and record.app_id is not None: + app_db = get_app_by_id(session=session, app_id=str(record.app_id)) + if app_db is None: + print( + f"Evaluator config {str(record.id)} have an app_id, but no application was found with the ID. Skipping..." + ) + SKIPPED_RECORDS += 1 + continue + + workspace_id = str(app_db.workspace_id) + + # Update model record workspace_id field if a valid project_id was found + if workspace_id is not None: + workspace_project_id = get_workspace_project_by_id( + session=session, workspace_id=workspace_id + ) + session.execute( + update(EvaluatorConfigDB) + .where(EvaluatorConfigDB.id == record.id) + .values(project_id=uuid.UUID(workspace_project_id)) + ) + else: + print( + f"Evaluator config {str(record.id)} did not find a workspace_id to make use of." + ) + + session.commit() + + # Update migration progress + batch_migrated = len(records) + TOTAL_MIGRATED += batch_migrated + offset += batch_migrated + remaining_records = TOTAL_ROWS_OF_TABLE - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {batch_migrated} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Break if all records have been processed + if batch_migrated <= 0: + break + + records_with_no_project_id = ( + select(func.count()) + .select_from(EvaluatorConfigDB) + .filter(EvaluatorConfigDB.project_id.is_(None)) + ) + result = session.execute(records_with_no_project_id).scalar() + TOTAL_ROWS_OF_RECORDS_WITH_NO_PROJECT_ID = result if result is not None else 0 + + # Final reporting + click.echo( + click.style( + f"Migration to repair evaluator_configs to have project_id completed. Total records with no project_id: {TOTAL_ROWS_OF_RECORDS_WITH_NO_PROJECT_ID}", + fg="green", + ) + ) + + +def add_project_id_to_db_entities(session: Connection): + try: + for model in MODELS: + offset = 0 + BATCH_SIZE = 200 + TOTAL_MIGRATED = 0 + SKIPPED_RECORDS: Dict[str, int] = defaultdict(int) + + def update_skipped_records_counter(model_tablename: str): + if SKIPPED_RECORDS.get(model_tablename, None) is None: + SKIPPED_RECORDS[model_tablename] = 1 + else: + SKIPPED_RECORDS[model_tablename] += 1 + + # Count total rows for tables with project_id = None + count_query = ( + select(func.count()) + .select_from(model) + .filter(model.project_id.is_(None)) + ) + result = session.execute(count_query).scalar() + TOTAL_ROWS_OF_TABLE = result if result is not None else 0 + print(f"Total rows in {model.__tablename__} table is {TOTAL_ROWS_OF_TABLE}") + + if hasattr(model, "workspace_id"): + query = select(model).filter( + model.project_id.is_(None), model.workspace_id.isnot(None) + ) + else: + # this will only be applied for AppVariantRevisionsDB model + query = select(model).filter(model.project_id.is_(None)) + + while True: + # Fetch records where project_id is None and workspace_id is not None + records = session.execute(query.limit(BATCH_SIZE)).fetchall() + actual_batch_size = len(records) + + # Add debugging logs for each batch + click.echo( + click.style( + f"Fetching {actual_batch_size} records starting from offset {offset} in {model.__tablename__}.", + fg="blue", + ) + ) + + # Update records with default project_id + for record in records: + if hasattr(record, "workspace_id"): + workspace_project_id = get_workspace_project_id( + session=session, workspace_id=str(record.workspace_id) + ) + elif ( + hasattr(record, "variant_id") and record.variant_id is not None + ) and not hasattr( + record, "workspace_id" + ): # this will only be applied for AppVariantRevisionsDB model + variant = get_variant_by_id( + session=session, variant_id=str(record.variant_id) + ) + if variant is not None: + workspace_project_id = get_workspace_project_id( + session=session, workspace_id=str(variant.workspace_id) + ) + else: + print( + f"Skipping record... {str(record.id)} in {model.__tablename__} table did not return any variant {str(record.variant_id)}." + ) + update_skipped_records_counter( + model_tablename=model.__tablename__ + ) + workspace_project_id = None + else: + print( + f"Skipping record... {str(record.id)} in {model.__tablename__} table due to no variant_id / workspace_id" + ) + actual_batch_size -= 1 # remove malformed record from records + update_skipped_records_counter( + model_tablename=model.__tablename__ + ) + workspace_project_id = None + + if workspace_project_id is not None: + # Update model record project_id field + session.execute( + update(model) + .where(model.id == record.id) + .values(project_id=uuid.UUID(workspace_project_id)) + ) + + session.commit() + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_ROWS_OF_TABLE - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {actual_batch_size} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Stop the loop when all records have been processed + if actual_batch_size <= 0: + break + + # Run migration to 'repair' evaluation_scenario to make use of workspace_id from either evalution or variant to get project_id + repair_evaluation_scenario_to_have_project_id(session=session) + + # Run migration to 'repair' evaluator_configs to make use of workspace_id from app to get project_id + repair_evaluator_configs_to_have_project_id(session=session) + + click.echo( + click.style( + f"Migration for adding project_id to all records listed in {[model.__tablename__ for model in MODELS]} tables are completed. Skipped records: {SKIPPED_RECORDS}", + fg="green", + ) + ) + + except Exception as e: + session.rollback() + click.echo( + click.style( + f"ERROR adding project_id to db entities: {traceback.format_exc()}", + fg="red", + ) + ) + raise e + + +def remove_project_id_from_db_entities(session: Connection): + try: + for model in MODELS: + offset = 0 + BATCH_SIZE = 200 + TOTAL_MIGRATED = 0 + + # Count total rows for tables where project_id is not None + count_query = ( + select(func.count()) + .select_from(model) + .where(model.project_id.isnot(None)) + ) + result = session.execute(count_query).scalar() + TOTAL_ROWS_OF_TABLE = result if result is not None else 0 + print(f"Total rows in {model.__tablename__} table is {TOTAL_ROWS_OF_TABLE}") + + while True: + # Retrieve records from model where its project_id is not None + records = session.execute( + select(model) + .where(model.project_id.isnot(None)) + .offset(offset) + .limit(BATCH_SIZE) + ).fetchall() + actual_batch_size = len(records) + if not records: + break + + # Update records project_id column with None + for record in records: + record.project_id = None + + session.commit() + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_ROWS_OF_TABLE - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {actual_batch_size} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records}", + fg="yellow", + ) + ) + + # Stop the loop when all records have been processed + if remaining_records <= 0: + break + + click.echo( + click.style( + f"Migration for removing project_id to all records listed in {[model.__tablename__ for model in MODELS]} tables are completed.", + fg="green", + ) + ) + + except Exception as e: + session.rollback() + click.echo( + click.style( + f"ERROR removing project_id to db entities: {traceback.format_exc()}", + fg="red", + ) + ) + raise e diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/testsets.py b/api/ee/databases/postgres/migrations/core/data_migrations/testsets.py new file mode 100644 index 0000000000..add9acf809 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/testsets.py @@ -0,0 +1,191 @@ +import uuid +import asyncio +import traceback +from typing import Optional + +import click +from sqlalchemy.future import select +from sqlalchemy import func +from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine + +from ee.src.models.db_models import WorkspaceMemberDB as WorkspaceMemberDBE +from oss.src.models.db_models import ProjectDB as ProjectDBE +from oss.src.dbs.postgres.testcases.dbes import ( + TestcaseBlobDBE, +) +from oss.src.dbs.postgres.blobs.dao import BlobsDAO +from oss.src.dbs.postgres.testsets.dbes import ( + TestsetArtifactDBE, + TestsetVariantDBE, + TestsetRevisionDBE, +) +from oss.src.dbs.postgres.git.dao import GitDAO +from oss.src.core.testcases.service import TestcasesService +from oss.src.models.deprecated_models import DeprecatedTestSetDB +from oss.src.core.testsets.service import TestsetsService, SimpleTestsetsService + + +# Define constants +DEFAULT_BATCH_SIZE = 200 + +# Initialize plug-ins for migration +testcases_dao = BlobsDAO( + BlobDBE=TestcaseBlobDBE, +) +testsets_dao = GitDAO( + ArtifactDBE=TestsetArtifactDBE, + VariantDBE=TestsetVariantDBE, + RevisionDBE=TestsetRevisionDBE, +) +testcases_service = TestcasesService( + testcases_dao=testcases_dao, +) +testsets_service = TestsetsService( + testsets_dao=testsets_dao, + testcases_service=testcases_service, +) +simple_testsets_service = SimpleTestsetsService( + testsets_service=testsets_service, +) + + +async def _fetch_project_owner( + *, + project_id: uuid.UUID, + connection: AsyncConnection, +) -> Optional[uuid.UUID]: + """Fetch the owner user ID for a given project.""" + workspace_owner_query = ( + select(WorkspaceMemberDBE.user_id) + .select_from(WorkspaceMemberDBE, ProjectDBE) + .where( + WorkspaceMemberDBE.workspace_id == ProjectDBE.workspace_id, + WorkspaceMemberDBE.role == "owner", + ProjectDBE.id == project_id, + ) + ) + result = await connection.execute(workspace_owner_query) + owner = result.scalar_one_or_none() + return owner + + +async def migration_old_testsets_to_new_testsets( + connection: AsyncConnection, +): + """Migrate old testsets to new testsets system.""" + try: + offset = 0 + total_migrated = 0 + skipped_records = 0 + + # Count total rows with a non-null project_id + total_query = ( + select(func.count()) + .select_from(DeprecatedTestSetDB) + .filter(DeprecatedTestSetDB.project_id.isnot(None)) + ) + result = await connection.execute(total_query) + total_rows = result.scalar() + total_testsets = total_rows or 0 + + click.echo( + click.style( + f"Total rows in testsets with project_id: {total_testsets}", + fg="yellow", + ) + ) + + while offset < total_testsets: + # STEP 1: Fetch evaluator configurations with non-null project_id + result = await connection.execute( + select(DeprecatedTestSetDB) + .filter(DeprecatedTestSetDB.project_id.isnot(None)) + .offset(offset) + .limit(DEFAULT_BATCH_SIZE) + ) + testsets_rows = result.fetchall() + + if not testsets_rows: + break + + # Process and transfer records to testset workflows + for testset in testsets_rows: + try: + # STEP 2: Get owner from project_id + owner = await _fetch_project_owner( + project_id=testset.project_id, # type: ignore + connection=connection, + ) + if not owner: + skipped_records += 1 + click.echo( + click.style( + f"Skipping record with ID {testset.id} due to missing owner in workspace member table", + fg="yellow", + ) + ) + continue + + # STEP 3: Migrate records using transfer_* util function + new_testset = await simple_testsets_service.transfer( + project_id=testset.project_id, + user_id=owner, + testset_id=testset.id, + ) + if not new_testset: + skipped_records += 1 + click.echo( + click.style( + f"Skipping record with ID {testset.id} due to old testset not existing in database table", + fg="yellow", + ) + ) + continue + + except Exception as e: + click.echo( + click.style( + f"Failed to migrate testset {testset.id}: {str(e)}", + fg="red", + ) + ) + click.echo(click.style(traceback.format_exc(), fg="red")) + skipped_records += 1 + continue + + # Update progress tracking for current batch + batch_migrated = len(testsets_rows) + offset += DEFAULT_BATCH_SIZE + total_migrated += batch_migrated + + click.echo( + click.style( + f"Processed {batch_migrated} records in this batch.", + fg="yellow", + ) + ) + + # Update progress tracking for all batches + remaining_records = total_testsets - total_migrated + click.echo(click.style(f"Total migrated: {total_migrated}", fg="yellow")) + click.echo(click.style(f"Skipped records: {skipped_records}", fg="yellow")) + click.echo( + click.style(f"Records left to migrate: {remaining_records}", fg="yellow") + ) + + except Exception as e: + click.echo(f"Error occurred: {e}") + click.echo(click.style(traceback.format_exc(), fg="red")) + + +def run_migration(sqlalchemy_url: str): + import concurrent.futures + + async def _start(): + connection = create_async_engine(url=sqlalchemy_url) + async with connection.connect() as connection: + await migration_old_testsets_to_new_testsets(connection=connection) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, _start()) + future.result() diff --git a/api/ee/databases/postgres/migrations/core/data_migrations/workspaces.py b/api/ee/databases/postgres/migrations/core/data_migrations/workspaces.py new file mode 100644 index 0000000000..2c5a241acc --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/data_migrations/workspaces.py @@ -0,0 +1,255 @@ +import traceback + +import click +from sqlalchemy.future import select +from sqlalchemy import delete, Connection, insert, func + +from ee.src.models.db_models import ProjectDB, WorkspaceDB +from ee.src.models.db_models import ( + WorkspaceMemberDB, + ProjectMemberDB, +) + +BATCH_SIZE = 200 + + +def get_or_create_workspace_default_project( + session: Connection, workspace: WorkspaceDB +) -> None: + project = session.execute( + select(ProjectDB).filter_by( + is_default=True, + workspace_id=workspace.id, + ) + ).fetchone() + + if project is None: + statement = insert(ProjectDB).values( + project_name="Default Project", + is_default=True, + workspace_id=workspace.id, + organization_id=workspace.organization_id, + ) + session.execute(statement) + + +def create_default_project_for_workspaces(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + + # Count total rows in workspaces table + stmt = select(func.count()).select_from(WorkspaceDB) + result = session.execute(stmt).scalar() + TOTAL_WORKSPACES = result if result is not None else 0 + print(f"Total rows in workspaces table is {TOTAL_WORKSPACES}") + + while True: + # Retrieve a batch of workspaces without a project + workspaces = session.execute( + select(WorkspaceDB).offset(offset).limit(BATCH_SIZE) + ).fetchall() + actual_batch_size = len(workspaces) + if not workspaces: + break + + for workspace in workspaces: + # Create a new default project for each workspace + get_or_create_workspace_default_project( + session=session, workspace=workspace # type: ignore + ) + + # Commit the changes for the current batch + session.commit() + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_WORKSPACES - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {offset} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records} ", + fg="yellow", + ) + ) + + # Stop the loop when all records have been processed + if remaining_records <= 0: + break + + click.echo( + click.style( + "\nSuccessfully created default projects for workspaces.", + fg="green", + ), + color=True, + ) + + except Exception as e: + session.rollback() + click.echo( + click.style( + f"\nAn ERROR occurred while creating default projects: {traceback.format_exc()}", + fg="red", + ), + color=True, + ) + raise e + + +def create_default_project_memberships(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + SKIPPED_RECORDS = 0 + + # Count total rows in workspaces_members table + stmt = select(func.count()).select_from(WorkspaceMemberDB) + result = session.execute(stmt).scalar() + TOTAL_WORKSPACES_MEMBERS = result if result is not None else 0 + print(f"Total rows in workspaces_members table is {TOTAL_WORKSPACES_MEMBERS}") + + while True: + # Retrieve a batch of workspace members + workspace_members = session.execute( + select(WorkspaceMemberDB).offset(offset).limit(BATCH_SIZE) + ).fetchall() + actual_batch_size = len(workspace_members) + if not workspace_members: + break + + for workspace_member in workspace_members: + # Find the default project for the member's workspace + project_query = session.execute( + select(ProjectDB) + .where( + ProjectDB.workspace_id == workspace_member.workspace_id, + ProjectDB.is_default == True, + ) + .limit(1) + ) + default_project = project_query.fetchone() + if default_project: + # Create a new project membership for each workspace member + statement = insert(ProjectMemberDB).values( + user_id=workspace_member.user_id, + project_id=getattr(default_project, "id"), + role=workspace_member.role, + ) + session.execute(statement) + else: + print( + f"Skipping record... Did not find any default project for workspace {str(workspace_member.workspace_id)}" + ) + SKIPPED_RECORDS += 1 + + # Commit the changes for the current batch + session.commit() + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_WORKSPACES_MEMBERS - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {offset} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records} ", + fg="yellow", + ) + ) + + # Stop the loop when all records have been processed + if remaining_records <= 0: + break + + click.echo( + click.style( + f"\nSuccessfully created default project memberships for workspace members. Skipped {SKIPPED_RECORDS} records.", + fg="green", + ), + color=True, + ) + + except Exception as e: + session.rollback() + click.echo( + click.style( + f"\nAn ERROR occurred while creating project memberships: {traceback.format_exc()}", + fg="red", + ), + color=True, + ) + raise e + + +def remove_default_projects_from_workspaces(session: Connection): + try: + offset = 0 + TOTAL_MIGRATED = 0 + + # Count total rows in projects table + stmt = ( + select(func.count()) + .select_from(ProjectDB) + .where(ProjectDB.is_default == True) + ) + result = session.execute(stmt).scalar() + TOTAL_PROJECTS = result if result is not None else 0 + print(f"Total rows in projects table is {TOTAL_PROJECTS}") + + while True: + # Retrieve a batch of workspaces with a default project + projects_to_delete = session.execute( + select(ProjectDB) + .where(ProjectDB.is_default == True) + .offset(offset) + .limit(BATCH_SIZE) # type: ignore + ).fetchall() + actual_batch_size = len(projects_to_delete) + if not projects_to_delete: + break + + for project in projects_to_delete: + if project is not None and len(project) >= 1: + # Remove associated project memberships + session.execute( + delete(ProjectMemberDB).where( + ProjectMemberDB.project_id == project.id + ) + ) + + # Remove the default project itself + session.execute(delete(ProjectDB).where(ProjectDB.id == project.id)) + + # Update migration progress + TOTAL_MIGRATED += actual_batch_size + offset += actual_batch_size + remaining_records = TOTAL_PROJECTS - TOTAL_MIGRATED + click.echo( + click.style( + f"Processed {offset} records in this batch. Total records migrated: {TOTAL_MIGRATED}. Records left to migrate: {remaining_records} ", + fg="yellow", + ) + ) + + # Stop the loop when all records have been processed + if remaining_records <= 0: + break + + click.echo( + click.style( + "\nSuccessfully removed default projects and associated memberships from existing workspaces.", + fg="green", + ), + color=True, + ) + except Exception as e: + # Handle exceptions and rollback if necessary + session.rollback() + click.echo( + click.style( + f"\nAn ERROR occurred while removing default projects and memberships: {traceback.format_exc()}", + fg="red", + ), + color=True, + ) + raise e diff --git a/api/ee/databases/postgres/migrations/core/env.py b/api/ee/databases/postgres/migrations/core/env.py new file mode 100644 index 0000000000..e5e251f801 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/env.py @@ -0,0 +1,126 @@ +import os +import asyncio +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.engine import Connection, create_engine +from sqlalchemy.ext.asyncio import async_engine_from_config, create_async_engine + +from alembic import context + +from oss.src.dbs.postgres.shared.engine import engine + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config +config.set_main_option("sqlalchemy.url", engine.postgres_uri_core) # type: ignore + + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +from oss.src.dbs.postgres.shared.base import Base + +import oss.src.dbs.postgres.secrets.dbes +import oss.src.dbs.postgres.observability.dbes +import oss.src.dbs.postgres.tracing.dbes +import oss.src.dbs.postgres.testcases.dbes +import oss.src.dbs.postgres.testsets.dbes +import oss.src.dbs.postgres.queries.dbes +import oss.src.dbs.postgres.workflows.dbes +import oss.src.dbs.postgres.evaluations.dbes + +import ee.src.dbs.postgres.meters.dbes +import ee.src.dbs.postgres.subscriptions.dbes + + +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + connection = create_engine( + url=config.get_main_option("sqlalchemy.url"), + pool_size=10, # Maintain 10 connections in the pool + pool_timeout=43200, # Timeout of 12 hours + pool_recycle=43200, # Timeout of 12 hours + pool_pre_ping=True, + echo_pool=True, + pool_use_lifo=True, + ) + context.configure( + connection=connection, + transaction_per_migration=True, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure( + transaction_per_migration=True, + connection=connection, + target_metadata=target_metadata, + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = create_async_engine( + url=config.get_main_option("sqlalchemy.url"), + pool_size=10, # Maintain 10 connections in the pool + pool_timeout=43200, # Timeout of 12 hours + pool_recycle=43200, # Timeout of 12 hours + pool_pre_ping=True, + echo_pool=True, + pool_use_lifo=True, + ) + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/api/ee/databases/postgres/migrations/core/script.py.mako b/api/ee/databases/postgres/migrations/core/script.py.mako new file mode 100644 index 0000000000..fbc4b07dce --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/api/ee/databases/postgres/migrations/core/temp/80910d2fa9a4_migrate_old_testsets_to_new_.py b/api/ee/databases/postgres/migrations/core/temp/80910d2fa9a4_migrate_old_testsets_to_new_.py new file mode 100644 index 0000000000..43be6c1579 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/temp/80910d2fa9a4_migrate_old_testsets_to_new_.py @@ -0,0 +1,32 @@ +"""migrate old testsets to new testsets data structure + +Revision ID: 80910d2fa9a4 +Revises: ... +Create Date: 2025-07-25 07:35:57.319449 + +""" + +from typing import Sequence, Union + +from alembic import context +from ee.databases.postgres.migrations.core.data_migrations.testsets import ( + run_migration, +) + +# revision identifiers, used by Alembic. +revision: str = "80910d2fa9a4" +down_revision: Union[str, None] = "..." +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + run_migration(sqlalchemy_url=context.config.get_main_option("sqlalchemy.url")) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/temp/bd7937ee784d_migrate_old_evaluators_to_new_.py b/api/ee/databases/postgres/migrations/core/temp/bd7937ee784d_migrate_old_evaluators_to_new_.py new file mode 100644 index 0000000000..da71b370bb --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/temp/bd7937ee784d_migrate_old_evaluators_to_new_.py @@ -0,0 +1,32 @@ +"""migrate old evaluators to new evaluators data structure + +Revision ID: bd7937ee784d +Revises: ... +Create Date: 2025-07-25 07:35:57.319449 + +""" + +from typing import Sequence, Union + +from alembic import context +from ee.databases.postgres.migrations.core.data_migrations.evaluators import ( + run_migration, +) + +# revision identifiers, used by Alembic. +revision: str = "bd7937ee784d" +down_revision: Union[str, None] = "..." +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + run_migration(sqlalchemy_url=context.config.get_main_option("sqlalchemy.url")) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/utils.py b/api/ee/databases/postgres/migrations/core/utils.py new file mode 100644 index 0000000000..206e46db64 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/utils.py @@ -0,0 +1,196 @@ +import asyncio +import logging +import traceback + +import click +import asyncpg +from alembic import command +from sqlalchemy import Engine +from alembic.config import Config +from sqlalchemy import inspect, text +from alembic.script import ScriptDirectory +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine + +from oss.src.utils.env import env + + +# Initializer logger +logger = logging.getLogger("alembic.env") + +# Initialize alembic config +alembic_cfg = Config(env.ALEMBIC_CFG_PATH_CORE) +script = ScriptDirectory.from_config(alembic_cfg) + +logger.info("license: ee") +logger.info("migrations: entities") +logger.info("ALEMBIC_CFG_PATH_CORE: %s", env.ALEMBIC_CFG_PATH_CORE) +logger.info("alembic_cfg: %s", alembic_cfg) +logger.info("script: %s", script) + + +def is_initial_setup(engine) -> bool: + """ + Check if the database is in its initial state by verifying the existence of required tables. + + This function inspects the current state of the database and determines if it needs initial setup by checking for the presence of a predefined set of required tables. + + Args: + engine (sqlalchemy.engine.base.Engine): The SQLAlchemy engine used to connect to the database. + + Returns: + bool: True if the database is in its initial state (i.e., not all required tables exist), False otherwise. + """ + + inspector = inspect(engine) + required_tables = [ + "users", + "app_db", + "deployments", + "bases", + "app_variants", + "ids_mapping", + ] # NOTE: The tables here were picked at random. Having all the tables in the database in the list \ + # will not change the behaviour of this function, so best to leave things as it is! + existing_tables = inspector.get_table_names() + + # Check if all required tables exist in the database + all_tables_exist = all(table in existing_tables for table in required_tables) + + return not all_tables_exist + + +async def get_current_migration_head_from_db(engine: AsyncEngine): + """ + Checks the alembic_version table to get the current migration head that has been applied. + + Args: + engine (Engine): The engine that connects to an sqlalchemy pool + + Returns: + the current migration head (where 'head' is the revision stored in the migration script) + """ + + async with engine.connect() as connection: + try: + result = await connection.execute(text("SELECT version_num FROM alembic_version")) # type: ignore + except (asyncpg.exceptions.UndefinedTableError, ProgrammingError): + # Note: If the alembic_version table does not exist, it will result in raising an UndefinedTableError exception. + # We need to suppress the error and return a list with the alembic_version table name to inform the user that there is a pending migration \ + # to make Alembic start tracking the migration changes. + # -------------------------------------------------------------------------------------- + # This effect (the exception raising) happens for both users (first-time and returning) + return "alembic_version" + + migration_heads = [row[0] for row in result.fetchall()] + assert ( + len(migration_heads) == 1 + ), "There can only be one migration head stored in the database." + return migration_heads[0] + + +async def get_pending_migration_head(): + """ + Gets the migration head that have not been applied. + + Returns: + the pending migration head + """ + + pending_migration_head = [] + + engine = create_async_engine(url=env.POSTGRES_URI_CORE) + try: + current_migration_script_head = script.get_current_head() + migration_head_from_db = await get_current_migration_head_from_db(engine=engine) + + if current_migration_script_head != migration_head_from_db: + pending_migration_head.append(current_migration_script_head) + if "alembic_version" == migration_head_from_db: + pending_migration_head.append("alembic_version") + finally: + await engine.dispose() + + return pending_migration_head + + +def run_alembic_migration(): + """ + Applies migration for first-time users and also checks the environment variable "AGENTA_AUTO_MIGRATIONS" to determine whether to apply migrations for returning users. + """ + + try: + pending_migration_head = asyncio.run(get_pending_migration_head()) + FIRST_TIME_USER = True if "alembic_version" in pending_migration_head else False + + if FIRST_TIME_USER or env.AGENTA_AUTO_MIGRATIONS: + command.upgrade(alembic_cfg, "head") + click.echo( + click.style( + "\nMigration applied successfully. The container will now exit.", + fg="green", + ), + color=True, + ) + else: + click.echo( + click.style( + "\nAll migrations are up-to-date. The container will now exit.", + fg="yellow", + ), + color=True, + ) + except Exception as e: + click.echo( + click.style( + f"\nAn ERROR occurred while applying migration: {traceback.format_exc()}\nThe container will now exit.", + fg="red", + ), + color=True, + ) + raise e + + +async def check_for_new_migrations(): + """ + Checks for new migrations and notify the user. + """ + + pending_migration_head = await get_pending_migration_head() + if len(pending_migration_head) >= 1 and isinstance(pending_migration_head[0], str): + click.echo( + click.style( + f"\nWe have detected that there are pending database migrations {pending_migration_head} that need to be applied to keep the application up to date. To ensure the application functions correctly with the latest updates, please follow the guide here => https://docs.agenta.ai/self-host/migration/applying-schema-migration\n", + fg="yellow", + ), + color=True, + ) + return + + +def unique_constraint_exists( + engine: Engine, table_name: str, constraint_name: str +) -> bool: + """ + The function checks if a unique constraint with a specific name exists on a table in a PostgreSQL + database. + + Args: + - engine (Engine): instance of a database engine that represents a connection to a database. + - table_name (str): name of the table to check the existence of the unique constraint. + - constraint_name (str): name of the unique constraint to check for existence. + + Returns: + - returns a boolean value indicating whether a unique constraint with the specified `constraint_name` exists in the table. + """ + + with engine.connect() as conn: + result = conn.execute( + text( + f""" + SELECT conname FROM pg_constraint + WHERE conname = '{constraint_name}' AND conrelid = '{table_name}'::regclass; + """ + ) + ) + return result.fetchone() is not None diff --git a/api/ee/databases/postgres/migrations/core/versions/0698355c7641_add_tables_for_testsets.py b/api/ee/databases/postgres/migrations/core/versions/0698355c7641_add_tables_for_testsets.py new file mode 100644 index 0000000000..c0b8756dec --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/0698355c7641_add_tables_for_testsets.py @@ -0,0 +1,388 @@ +"""add tables for testsets (artifacts, variants, & revisions) + +Revision ID: 0698355c7641 +Revises: 9698355c7649 +Create Date: 2025-04-24 07:27:45.801481 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "0698355c7641" +down_revision: Union[str, None] = "9698355c7649" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # - ARTIFACTS -------------------------------------------------------------- + + op.create_table( + "testset_artifacts", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metadata", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_testset_artifacts_project_id_slug", + "project_id", + "slug", + ), + ) + + # -------------------------------------------------------------------------- + + # - VARIANTS --------------------------------------------------------------- + + op.create_table( + "testset_variants", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "artifact_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metadata", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "artifact_id"], + ["testset_artifacts.project_id", "testset_artifacts.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_testset_variants_project_id_slug", + "project_id", + "slug", + ), + sa.Index( + "ix_testset_variants_project_id_artifact_id", + "project_id", + "artifact_id", + ), + ) + + # -------------------------------------------------------------------------- + + # - REVISIONS -------------------------------------------------------------- + + op.create_table( + "testset_revisions", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "artifact_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "variant_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "version", + sa.String(), + nullable=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metadata", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.Column( + "message", + sa.String(), + nullable=True, + ), + sa.Column( + "author", + sa.UUID(), + nullable=False, + ), + sa.Column( + "date", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "artifact_id"], + ["testset_artifacts.project_id", "testset_artifacts.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "variant_id"], + ["testset_variants.project_id", "testset_variants.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_testset_revisions_project_id_slug", + "project_id", + "slug", + ), + sa.Index( + "ix_testset_revisions_project_id_artifact_id", + "project_id", + "artifact_id", + ), + sa.Index( + "ix_testset_revisions_project_id_variant_id", + "project_id", + "variant_id", + ), + ) + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # - REVISIONS -------------------------------------------------------------- + + op.drop_table("testset_revisions") + + # -------------------------------------------------------------------------- + + # - VARIANTS --------------------------------------------------------------- + + op.drop_table("testset_variants") + + # -------------------------------------------------------------------------- + + # - ARTIFACTS -------------------------------------------------------------- + + op.drop_table("testset_artifacts") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/core/versions/0698355c7642_add_table_for_testcases.py b/api/ee/databases/postgres/migrations/core/versions/0698355c7642_add_table_for_testcases.py new file mode 100644 index 0000000000..c7a98fc712 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/0698355c7642_add_table_for_testcases.py @@ -0,0 +1,112 @@ +"""add tables for testcases (blobs) + +Revision ID: 0698355c7642 +Revises: 0698355c7641 +Create Date: 2025-04-24 07:27:45.801481 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "0698355c7642" +down_revision: Union[str, None] = "0698355c7641" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # - BLOBS ------------------------------------------------------------------ + + op.create_table( + "testcase_blobs", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "set_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.UniqueConstraint( + "project_id", + "set_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "set_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "set_id"], + ["testset_artifacts.project_id", "testset_artifacts.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_testcase_blobs_project_id_blob_slug", + "project_id", + "slug", + ), + sa.Index( + "ix_testcase_blobs_project_id_set_id", + "project_id", + "set_id", + ), + sa.Index( + "ix_testcase_blobs_project_id_set_id_id", + "project_id", + "set_id", + "id", + ), + sa.Index( + "ix_testcase_blobs_project_id_set_id_slug", + "project_id", + "set_id", + "slug", + ), + ) + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # - BLOBS ------------------------------------------------------------------ + + op.drop_table("testcase_blobs") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/core/versions/0f086ebc2f83_extend_app_type.py b/api/ee/databases/postgres/migrations/core/versions/0f086ebc2f83_extend_app_type.py new file mode 100644 index 0000000000..dd76961a2f --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/0f086ebc2f83_extend_app_type.py @@ -0,0 +1,58 @@ +"""Extend app_type + +Revision ID: 0f086ebc2f83 +Revises: 0f086ebc2f82 +Create Date: 2025-01-08 10:24:00 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "0f086ebc2f83" +down_revision: Union[str, None] = "425c68e8de6c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + # Define the new enum + temp_enum = sa.Enum( + "CHAT_TEMPLATE", + "COMPLETION_TEMPLATE", + "CHAT_SERVICE", + "COMPLETION_SERVICE", + "CUSTOM", + name="app_type_enum", + ) + temp_enum.create(op.get_bind(), checkfirst=True) + + # Update the column to use the new enum + op.execute( + "ALTER TABLE app_db ALTER COLUMN app_type TYPE app_type_enum USING app_type::text::app_type_enum" + ) + + # Drop the old enum + op.execute("DROP TYPE app_enumtype") + + +def downgrade(): + # Define the old enum + temp_enum = sa.Enum( + "CHAT_TEMPLATE", + "COMPLETION_TEMPLATE", + "CUSTOM", + name="app_enumtype", + ) + temp_enum.create(op.get_bind(), checkfirst=True) + + # Update the column to use the old enum + op.execute( + "ALTER TABLE app_db ALTER COLUMN app_type TYPE app_enumtype USING app_type::text::app_enumtype" + ) + + # Drop the new enum + op.execute("DROP TYPE app_type_enum") diff --git a/api/ee/databases/postgres/migrations/core/versions/12f477990f1e_add_meters.py b/api/ee/databases/postgres/migrations/core/versions/12f477990f1e_add_meters.py new file mode 100644 index 0000000000..2e5c4ef580 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/12f477990f1e_add_meters.py @@ -0,0 +1,54 @@ +"""add meters + +Revision ID: 12f477990f1e +Revises: 6965776e6940 +Create Date: 2025-01-25 16:51:06.233811 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql import func + +# revision identifiers, used by Alembic. +revision: str = "12f477990f1e" +down_revision: Union[str, None] = "6965776e6940" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "meters", + sa.Column( + "key", + sa.Enum( + "USERS", + "APPLICATIONS", + "EVALUATIONS", + "TRACES", + name="meters_type", + ), + nullable=False, + ), + sa.Column("value", sa.BigInteger(), nullable=False), + sa.Column("synced", sa.BigInteger(), nullable=False), + sa.Column("organization_id", sa.UUID(), nullable=False), + sa.Column("year", sa.SmallInteger(), nullable=True, server_default="0"), + sa.Column("month", sa.SmallInteger(), nullable=True, server_default="0"), + sa.PrimaryKeyConstraint("organization_id", "key", "year", "month"), + sa.ForeignKeyConstraint(["organization_id"], ["subscriptions.organization_id"]), + ) + op.create_index("idx_synced_value", "meters", ["synced", "value"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("idx_synced_value", table_name="meters") + op.drop_table("meters") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/154098b1e56c_set_user_id_column_in_db_entities_to_be_.py b/api/ee/databases/postgres/migrations/core/versions/154098b1e56c_set_user_id_column_in_db_entities_to_be_.py new file mode 100644 index 0000000000..411101fa4d --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/154098b1e56c_set_user_id_column_in_db_entities_to_be_.py @@ -0,0 +1,69 @@ +"""Set user_id column in db entities to be optional --- prep for project_id scoping + +Revision ID: 154098b1e56c +Revises: ad0987a77380 +Create Date: 2024-09-17 06:44:31.061378 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "154098b1e56c" +down_revision: Union[str, None] = "ad0987a77380" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("docker_images", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column("app_db", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column("deployments", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column("bases", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column("app_variants", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column("environments", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column("testsets", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column( + "evaluators_configs", "user_id", existing_type=sa.UUID, nullable=True + ) + op.alter_column( + "human_evaluations", "user_id", existing_type=sa.UUID, nullable=True + ) + op.alter_column( + "human_evaluations_scenarios", "user_id", existing_type=sa.UUID, nullable=True + ) + op.alter_column("evaluations", "user_id", existing_type=sa.UUID, nullable=True) + op.alter_column( + "evaluation_scenarios", "user_id", existing_type=sa.UUID, nullable=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column("docker_images", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column("app_db", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column("deployments", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column("bases", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column("app_variants", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column("environments", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column("testsets", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column( + "evaluators_configs", "user_id", existing_type=sa.UUID, nullable=False + ) + op.alter_column( + "human_evaluations", "user_id", existing_type=sa.UUID, nullable=False + ) + op.alter_column( + "human_evaluations_scenarios", "user_id", existing_type=sa.UUID, nullable=False + ) + op.alter_column("evaluations", "user_id", existing_type=sa.UUID, nullable=False) + op.alter_column( + "evaluation_scenarios", "user_id", existing_type=sa.UUID, nullable=False + ) + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/1c2d3e4f5a6b_workspaces_migration_to_add_default_project_and_membership.py b/api/ee/databases/postgres/migrations/core/versions/1c2d3e4f5a6b_workspaces_migration_to_add_default_project_and_membership.py new file mode 100644 index 0000000000..2e52bcfdc9 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/1c2d3e4f5a6b_workspaces_migration_to_add_default_project_and_membership.py @@ -0,0 +1,40 @@ +"""workspaces migration to add default project and memberships + +Revision ID: 1c2d3e4f5a6b +Revises: 6aafdfc2befb +Create Date: 2024-09-03 08:05:58.870573 + +""" + +from typing import Sequence, Union + +from alembic import context +import sqlalchemy as sa + +from ee.databases.postgres.migrations.core.data_migrations.workspaces import ( + create_default_project_for_workspaces, + create_default_project_memberships, + remove_default_projects_from_workspaces, +) + + +# revision identifiers, used by Alembic. +revision: str = "1c2d3e4f5a6b" +down_revision: Union[str, None] = "6aafdfc2befb" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### custom migration ### + connection = context.get_bind() # get database connect from alembic context + create_default_project_for_workspaces(session=connection) + create_default_project_memberships(session=connection) + # ### end custom migration ### + + +def downgrade() -> None: + # ### custom migration ### + connection = context.get_bind() # get database connect from alembic context + remove_default_projects_from_workspaces(session=connection) + # ### end custom migration ### diff --git a/api/ee/databases/postgres/migrations/core/versions/24f8bdb390ee_added_the_app_type_column_to_the_app_db_.py b/api/ee/databases/postgres/migrations/core/versions/24f8bdb390ee_added_the_app_type_column_to_the_app_db_.py new file mode 100644 index 0000000000..de300ce7fa --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/24f8bdb390ee_added_the_app_type_column_to_the_app_db_.py @@ -0,0 +1,59 @@ +"""Added the 'app_type' column to the 'app_db' table + +Revision ID: 24f8bdb390ee +Revises: e9fa2135f3fb +Create Date: 2024-09-09 07:32:45.053125 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "24f8bdb390ee" +down_revision: Union[str, None] = "847972cfa14a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # Create the enum type first + app_enumtype = sa.Enum( + "CHAT_TEMPLATE", + "COMPLETION_TEMPLATE", + "CUSTOM", + name="app_enumtype", + ) + app_enumtype.create(op.get_bind(), checkfirst=True) + + # Then add the column using the enum type + op.add_column( + "app_db", + sa.Column( + "app_type", + app_enumtype, + nullable=True, + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # Drop the column first + op.drop_column("app_db", "app_type") + + # Then drop the enum type + app_enumtype = sa.Enum( + "CHAT_TEMPLATE", + "COMPLETION_TEMPLATE", + "CUSTOM", + name="app_enumtype", + ) + app_enumtype.drop(op.get_bind(), checkfirst=True) + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/2a91436752f9_update_secrets_data_schema_type.py b/api/ee/databases/postgres/migrations/core/versions/2a91436752f9_update_secrets_data_schema_type.py new file mode 100644 index 0000000000..460986b788 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/2a91436752f9_update_secrets_data_schema_type.py @@ -0,0 +1,64 @@ +"""update secrets data schema type + +Revision ID: 2a91436752f9 +Revises: 0f086ebc2f83 +Create Date: 2025-02-10 10:38:31.555604 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import context, op + +from oss.databases.postgres.migrations.core.data_migrations.secrets import ( + rename_and_update_secrets_data_schema, + revert_rename_and_update_secrets_data_schema, +) + + +# revision identifiers, used by Alembic. +revision: str = "2a91436752f9" +down_revision: Union[str, None] = "0f086ebc2f83" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands to do data migration for secrets ### + connection = context.get_bind() + + # Define the new enum + secret_kinds = sa.Enum("PROVIDER_KEY", "CUSTOM_PROVIDER", name="secretkind_enum") + secret_kinds.create(bind=connection, checkfirst=True) + + # Update the column to make use of the new enum + op.execute( + "ALTER TABLE secrets ALTER COLUMN kind TYPE secretkind_enum USING kind::text::secretkind_enum" + ) + + # Drop the old enum + op.execute("DROP TYPE IF EXISTS secretkind") + + rename_and_update_secrets_data_schema(session=connection) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands to do data migration for secrets ### + connection = context.get_bind() + + # Define the new enum + secret_kinds = sa.Enum("PROVIDER_KEY", name="secretkind") + secret_kinds.create(bind=connection, checkfirst=True) + + # Update the column to make use of the new enum + op.execute( + "ALTER TABLE secrets ALTER COLUMN kind TYPE secretkind USING kind::text::secretkind" + ) + + # Drop the old enum + op.execute("DROP TYPE IF EXISTS secretkind_enum") + + revert_rename_and_update_secrets_data_schema(session=connection) + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/30dcf07de96a_add_tables_for_queries.py b/api/ee/databases/postgres/migrations/core/versions/30dcf07de96a_add_tables_for_queries.py new file mode 100644 index 0000000000..735a859ce0 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/30dcf07de96a_add_tables_for_queries.py @@ -0,0 +1,403 @@ +"""add tables for queries (artifacts, variants, & revisions) + +Revision ID: 30dcf07de96a +Revises: aa1b2c3d4e5f +Create Date: 2025-07-30 14:55:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "30dcf07de96a" +down_revision: Union[str, None] = "aa1b2c3d4e5f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # - ARTIFACTS -------------------------------------------------------------- + + op.create_table( + "query_artifacts", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_query_artifacts_project_id_slug", + "project_id", + "slug", + ), + ) + + # -------------------------------------------------------------------------- + + # - VARIANTS --------------------------------------------------------------- + + op.create_table( + "query_variants", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "artifact_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "artifact_id"], + ["query_artifacts.project_id", "query_artifacts.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_query_variants_project_id_slug", + "project_id", + "slug", + ), + sa.Index( + "ix_query_variants_project_id_artifact_id", + "project_id", + "artifact_id", + ), + ) + + # -------------------------------------------------------------------------- + + # - REVISIONS -------------------------------------------------------------- + + op.create_table( + "query_revisions", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "artifact_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "variant_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "version", + sa.String(), + nullable=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.Column( + "message", + sa.String(), + nullable=True, + ), + sa.Column( + "author", + sa.UUID(), + nullable=False, + ), + sa.Column( + "date", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "artifact_id"], + ["query_artifacts.project_id", "query_artifacts.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "variant_id"], + ["query_variants.project_id", "query_variants.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_query_revisions_project_id_slug", + "project_id", + "slug", + ), + sa.Index( + "ix_query_revisions_project_id_artifact_id", + "project_id", + "artifact_id", + ), + sa.Index( + "ix_query_revisions_project_id_variant_id", + "project_id", + "variant_id", + ), + ) + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # - REVISIONS -------------------------------------------------------------- + + op.drop_table("query_revisions") + + # -------------------------------------------------------------------------- + + # - VARIANTS --------------------------------------------------------------- + + op.drop_table("query_variants") + + # -------------------------------------------------------------------------- + + # - ARTIFACTS -------------------------------------------------------------- + + op.drop_table("query_artifacts") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/core/versions/320a4a7ee0c7_set_columns_in_api_key_table_to_be_.py b/api/ee/databases/postgres/migrations/core/versions/320a4a7ee0c7_set_columns_in_api_key_table_to_be_.py new file mode 100644 index 0000000000..463285cacb --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/320a4a7ee0c7_set_columns_in_api_key_table_to_be_.py @@ -0,0 +1,61 @@ +"""set columns in api_key table to be nullable -- prep for access control + +Revision ID: 320a4a7ee0c7 +Revises: b3f6bff547d4 +Create Date: 2024-10-22 10:57:36.983190 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op, context + +from ee.databases.postgres.migrations.core.data_migrations.api_keys import ( + update_api_key_to_make_use_of_project_id, + revert_api_key_to_make_use_of_workspace_id, +) + + +# revision identifiers, used by Alembic. +revision: str = "320a4a7ee0c7" +down_revision: Union[str, None] = "b3f6bff547d4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + connection = context.get_bind() + op.alter_column("api_keys", "user_id", nullable=True) + op.alter_column("api_keys", "workspace_id", nullable=True) + op.add_column("api_keys", sa.Column("project_id", sa.UUID(), nullable=True)) + op.add_column("api_keys", sa.Column("created_by_id", sa.UUID(), nullable=True)) + # ================== Custom data migration ====================== # + update_api_key_to_make_use_of_project_id(session=connection) + # ================== Custom data migration ====================== # + op.drop_column("api_keys", "user_id") + op.drop_column("api_keys", "workspace_id") + op.alter_column("api_keys", "created_by_id", nullable=False) + op.alter_column("api_keys", "project_id", nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + connection = context.get_bind() + inspector = sa.inspect(connection) + columns = [column["name"] for column in inspector.get_columns("api_keys")] + if "user_id" not in columns: + op.add_column("api_keys", sa.Column("user_id", sa.String(), nullable=True)) + + if "workspace_id" not in columns: + op.add_column("api_keys", sa.Column("workspace_id", sa.String(), nullable=True)) + # ================== Custom data migration ====================== # + revert_api_key_to_make_use_of_workspace_id(session=connection) + # ================== Custom data migration ====================== # + op.drop_column("api_keys", "created_by_id") + op.drop_column("api_keys", "project_id") + op.alter_column("api_keys", "user_id", nullable=False) + op.alter_column("api_keys", "workspace_id", nullable=False) + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/3b5f5652f611_populate_runs_references.py b/api/ee/databases/postgres/migrations/core/versions/3b5f5652f611_populate_runs_references.py new file mode 100644 index 0000000000..bb43067ccb --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/3b5f5652f611_populate_runs_references.py @@ -0,0 +1,77 @@ +"""Populate runs references + +Revision ID: 3b5f5652f611 +Revises: b3f15a7140ab +Create Date: 2025-10-07 12:00:00 +""" + +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa +import json + +# revision identifiers, used by Alembic. +revision: str = "3b5f5652f611" +down_revision: Union[str, None] = "b3f15a7140ab" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + + rows = conn.execute( + sa.text('SELECT id, data, "references" FROM evaluation_runs') + ).fetchall() + + for run_id, data, existing_refs in rows: + if existing_refs not in (None, [], {}): + continue + if not data or "steps" not in data: + continue + + refs_out = [] + seen = set() + + for step in data.get("steps", []): + refs = step.get("references", {}) + if not isinstance(refs, dict): + continue + + for key, ref in refs.items(): + if not isinstance(ref, dict): + continue + + entry = {"key": key} + + if ref.get("id") is not None: + entry["id"] = ref["id"] + if ref.get("slug") is not None: + entry["slug"] = ref["slug"] + if ref.get("version") is not None: + entry["version"] = ref["version"] + + dedup_key = ( + entry.get("id"), + entry["key"], + entry.get("slug"), + entry.get("version"), + ) + if dedup_key in seen: + continue + seen.add(dedup_key) + + refs_out.append(entry) + + if refs_out: + conn.execute( + sa.text( + 'UPDATE evaluation_runs SET "references" = :refs WHERE id = :id' + ), + {"refs": json.dumps(refs_out), "id": run_id}, + ) + + +def downgrade() -> None: + conn = op.get_bind() + conn.execute(sa.text('UPDATE evaluation_runs SET "references" = NULL')) diff --git a/api/ee/databases/postgres/migrations/core/versions/425c68e8de6c_add_secrets_dbe_model.py b/api/ee/databases/postgres/migrations/core/versions/425c68e8de6c_add_secrets_dbe_model.py new file mode 100644 index 0000000000..b58d9cc9ce --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/425c68e8de6c_add_secrets_dbe_model.py @@ -0,0 +1,53 @@ +"""add secrets dbe model + +Revision ID: 425c68e8de6c +Revises: 73a2d8cfaa3d +Create Date: 2024-12-05 10:30:54.986714 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from oss.src.dbs.postgres.secrets.custom_fields import PGPString + +# revision identifiers, used by Alembic. +revision: str = "425c68e8de6c" +down_revision: Union[str, None] = "73a2d8cfaa3d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") + op.create_table( + "secrets", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("kind", sa.Enum("PROVIDER_KEY", name="secretkind"), nullable=True), + sa.Column("data", PGPString(), nullable=True), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("updated_by_id", sa.UUID(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("secrets") + op.execute("DROP TYPE IF EXISTS secretkind;") + op.execute("DROP EXTENSION IF EXISTS pgcrypto;") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/4d9a58ff8f98_add_default_project_to_scoped_model_.py b/api/ee/databases/postgres/migrations/core/versions/4d9a58ff8f98_add_default_project_to_scoped_model_.py new file mode 100644 index 0000000000..22b9387d66 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/4d9a58ff8f98_add_default_project_to_scoped_model_.py @@ -0,0 +1,42 @@ +"""add default project to scoped model entities + +Revision ID: 4d9a58ff8f98 +Revises: d0b8e05ca190 +Create Date: 2024-09-17 07:16:57.740642 + +""" + +from typing import Sequence, Union + +from alembic import context + +from ee.databases.postgres.migrations.core.data_migrations.projects import ( + add_project_id_to_db_entities, + remove_project_id_from_db_entities, + repair_evaluation_scenario_to_have_project_id, + repair_evaluator_configs_to_have_project_id, +) + + +# revision identifiers, used by Alembic. +revision: str = "4d9a58ff8f98" +down_revision: Union[str, None] = "d0b8e05ca190" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### custom command ### + connection = context.get_bind() # get database connect from alembic context + add_project_id_to_db_entities(session=connection) + repair_evaluation_scenario_to_have_project_id(session=connection) + repair_evaluator_configs_to_have_project_id(session=connection) + repair_evaluation_scenario_to_have_project_id(session=connection) + # ### end custom command ### + + +def downgrade() -> None: + # ### custom command ### + connection = context.get_bind() # get database connect from alembic context + remove_project_id_from_db_entities(session=connection) + # ### end custom command ### diff --git a/api/ee/databases/postgres/migrations/core/versions/54e81e9eed88_add_tables_for_evaluations.py b/api/ee/databases/postgres/migrations/core/versions/54e81e9eed88_add_tables_for_evaluations.py new file mode 100644 index 0000000000..f8549687ce --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/54e81e9eed88_add_tables_for_evaluations.py @@ -0,0 +1,514 @@ +"""add tables for evaluations + +Revision ID: 54e81e9eed88 +Revises: 9698355c7650 +Create Date: 2025-04-24 07:27:45.801481 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "54e81e9eed88" +down_revision: Union[str, None] = "9698355c7650" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.rename_table( + "evaluation_aggregated_results", + "auto_evaluation_aggregated_results", + ) + op.rename_table( + "evaluation_evaluator_configs", + "auto_evaluation_evaluator_configs", + ) + op.rename_table( + "evaluation_scenario_results", + "auto_evaluation_scenario_results", + ) + op.rename_table( + "evaluation_scenarios", + "auto_evaluation_scenarios", + ) + op.rename_table( + "evaluations", + "auto_evaluations", + ) + op.rename_table( + "evaluators_configs", + "auto_evaluator_configs", + ) + + op.create_table( + "evaluation_runs", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "status", + sa.VARCHAR, + nullable=False, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_evaluation_runs_project_id", + "project_id", + ), + ) + + op.create_table( + "evaluation_scenarios", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "status", + sa.VARCHAR, + nullable=False, + ), + sa.Column( + "run_id", + sa.UUID(), + nullable=False, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "run_id"], + ["evaluation_runs.project_id", "evaluation_runs.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_evaluation_scenarios_project_id", + "project_id", + ), + sa.Index( + "ix_evaluation_scenarios_run_id", + "run_id", + ), + ) + + op.create_table( + "evaluation_steps", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "status", + sa.VARCHAR, + nullable=False, + ), + sa.Column( + "timestamp", + sa.TIMESTAMP(timezone=True), + nullable=False, + ), + sa.Column( + "key", + sa.String(), + nullable=False, + ), + sa.Column( + "repeat_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "retry_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "hash_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "trace_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "testcase_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "error", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "scenario_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "run_id", + sa.UUID(), + nullable=False, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "run_id"], + ["evaluation_runs.project_id", "evaluation_runs.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "scenario_id"], + ["evaluation_scenarios.project_id", "evaluation_scenarios.id"], + ondelete="CASCADE", + ), + sa.UniqueConstraint( + "project_id", + "run_id", + "scenario_id", + "key", + "retry_id", + "retry_id", + ), + sa.Index( + "ix_evaluation_steps_project_id", + "project_id", + ), + sa.Index( + "ix_evaluation_steps_scenario_id", + "scenario_id", + ), + sa.Index( + "ix_evaluation_steps_run_id", + "run_id", + ), + ) + + op.create_table( + "evaluation_metrics", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "status", + sa.VARCHAR, + nullable=False, + ), + sa.Column( + "scenario_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "run_id", + sa.UUID(), + nullable=False, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "run_id"], + ["evaluation_runs.project_id", "evaluation_runs.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "scenario_id"], + ["evaluation_scenarios.project_id", "evaluation_scenarios.id"], + ondelete="CASCADE", + ), + sa.UniqueConstraint( + "project_id", + "run_id", + "scenario_id", + ), + sa.Index( + "ix_evaluation_metrics_project_id", + "project_id", + ), + sa.Index( + "ix_evaluation_metrics_run_id", + "run_id", + ), + sa.Index( + "ix_evaluation_metrics_scenario_id", + "scenario_id", + ), + ) + + +def downgrade() -> None: + op.drop_table("evaluation_metrics") + op.drop_table("evaluation_steps") + op.drop_table("evaluation_scenarios") + op.drop_table("evaluation_runs") + + op.rename_table( + "auto_evaluator_configs", + "evaluators_configs", + ) + + op.rename_table( + "auto_evaluations", + "evaluations", + ) + op.rename_table( + "auto_evaluation_scenarios", + "evaluation_scenarios", + ) + op.rename_table( + "auto_evaluation_scenario_results", + "evaluation_scenario_results", + ) + op.rename_table( + "auto_evaluation_evaluator_configs", + "evaluation_evaluator_configs", + ) + op.rename_table( + "auto_evaluation_aggregated_results", + "evaluation_aggregated_results", + ) diff --git a/api/ee/databases/postgres/migrations/core/versions/5a71b3f140ab_fix_all_preview_schemas.py b/api/ee/databases/postgres/migrations/core/versions/5a71b3f140ab_fix_all_preview_schemas.py new file mode 100644 index 0000000000..62d244d1e1 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/5a71b3f140ab_fix_all_preview_schemas.py @@ -0,0 +1,426 @@ +"""fix all preview schemas + +Revision ID: 5a71b3f140ab +Revises: 8089ee7692d1 +Create Date: 2025-09-03 14:28:06.362553 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision: str = "5a71b3f140ab" +down_revision: Union[str, None] = "8089ee7692d1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # EVALUATION RUNS ---------------------------------------------------------- + + op.add_column( + "evaluation_runs", + sa.Column( + "references", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + ) + + op.create_index( + "ix_evaluation_runs_references", + "evaluation_runs", + ["references"], + unique=False, + postgresql_using="gin", + postgresql_ops={"references": "jsonb_path_ops"}, + ) + op.create_index( + "ix_evaluation_runs_flags", + "evaluation_runs", + ["flags"], + unique=False, + postgresql_using="gin", + ) + op.create_index( + "ix_evaluation_runs_tags", + "evaluation_runs", + ["tags"], + unique=False, + postgresql_using="gin", + ) + + # EVALUATION SCENARIOS ----------------------------------------------------- + + op.add_column( + "evaluation_scenarios", + sa.Column( + "interval", + postgresql.INTEGER(), + nullable=True, + ), + ) + + op.create_index( + "ix_evaluation_scenarios_timestamp_interval", + "evaluation_scenarios", + ["timestamp", "interval"], + unique=False, + ) + op.create_index( + "ix_evaluation_scenarios_flags", + "evaluation_scenarios", + ["flags"], + unique=False, + postgresql_using="gin", + ) + op.create_index( + "ix_evaluation_scenarios_tags", + "evaluation_scenarios", + ["tags"], + unique=False, + postgresql_using="gin", + ) + + # EVALUATION RESULTS ------------------------------------------------------- + + op.alter_column( + "evaluation_steps", + "timestamp", + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=True, + ) + op.add_column( + "evaluation_steps", + sa.Column( + "interval", + postgresql.INTEGER(), + nullable=True, + ), + ) + + op.create_unique_constraint( + "uq_evaluation_steps_project_run_scenario_step_repeat", + "evaluation_steps", + ["project_id", "run_id", "scenario_id", "step_key", "repeat_idx"], + ) + + op.create_index( + "ix_evaluation_steps_tags", + "evaluation_steps", + ["tags"], + unique=False, + postgresql_using="gin", + ) + op.create_index( + "ix_evaluation_steps_flags", + "evaluation_steps", + ["flags"], + unique=False, + postgresql_using="gin", + ) + op.create_index( + "ix_evaluation_steps_timestamp_interval", + "evaluation_steps", + ["timestamp", "interval"], + unique=False, + ) + op.create_index( + "ix_evaluation_steps_repeat_idx", + "evaluation_steps", + ["repeat_idx"], + unique=False, + ) + op.create_index( + "ix_evaluation_steps_step_key", + "evaluation_steps", + ["step_key"], + unique=False, + ) + + op.rename_table("evaluation_steps", "evaluation_results") + + op.execute( + "ALTER TABLE evaluation_results RENAME CONSTRAINT " + "uq_evaluation_steps_project_run_scenario_step_repeat TO " + "uq_evaluation_results_project_run_scenario_step_repeat" + ) + + op.execute( + "ALTER INDEX ix_evaluation_steps_project_id RENAME TO ix_evaluation_results_project_id" + ) + op.execute( + "ALTER INDEX ix_evaluation_steps_run_id RENAME TO ix_evaluation_results_run_id" + ) + op.execute( + "ALTER INDEX ix_evaluation_steps_scenario_id RENAME TO ix_evaluation_results_scenario_id" + ) + op.execute( + "ALTER INDEX ix_evaluation_steps_step_key RENAME TO ix_evaluation_results_step_key" + ) + op.execute( + "ALTER INDEX ix_evaluation_steps_repeat_idx RENAME TO ix_evaluation_results_repeat_idx" + ) + op.execute( + "ALTER INDEX ix_evaluation_steps_timestamp_interval RENAME TO ix_evaluation_results_timestamp_interval" + ) + op.execute( + "ALTER INDEX ix_evaluation_steps_flags RENAME TO ix_evaluation_results_flags" + ) + op.execute( + "ALTER INDEX ix_evaluation_steps_tags RENAME TO ix_evaluation_results_tags" + ) + + # EVALUATION METRICS ------------------------------------------------------- + + op.add_column( + "evaluation_metrics", + sa.Column( + "interval", + postgresql.INTEGER(), + nullable=True, + ), + ) + + op.drop_constraint( + op.f("evaluation_metrics_project_id_run_id_scenario_id_key"), + "evaluation_metrics", + type_="unique", + ) + + op.create_unique_constraint( + "uq_evaluation_metrics_project_run_scenario_timestamp_interval", + "evaluation_metrics", + ["project_id", "run_id", "scenario_id", "timestamp", "interval"], + ) + + op.create_index( + "ix_evaluation_metrics_timestamp_interval", + "evaluation_metrics", + ["timestamp", "interval"], + unique=False, + ) + op.create_index( + "ix_evaluation_metrics_flags", + "evaluation_metrics", + ["flags"], + unique=False, + postgresql_using="gin", + ) + op.create_index( + "ix_evaluation_metrics_tags", + "evaluation_metrics", + ["tags"], + unique=False, + postgresql_using="gin", + ) + + # EVALUATION QUEUES -------------------------------------------------------- + + op.add_column( + "evaluation_queues", + sa.Column( + "name", + sa.String(), + nullable=True, + ), + ) + op.add_column( + "evaluation_queues", + sa.Column( + "description", + sa.String(), + nullable=True, + ), + ) + op.add_column( + "evaluation_queues", + sa.Column( + "status", + sa.VARCHAR(), + nullable=False, + server_default=sa.text("'pending'::varchar"), + ), + ) + + op.create_index( + "ix_evaluation_queues_flags", + "evaluation_queues", + ["flags"], + unique=False, + postgresql_using="gin", + ) + op.create_index( + "ix_evaluation_queues_tags", + "evaluation_queues", + ["tags"], + unique=False, + postgresql_using="gin", + ) + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # EVALUATION QUEUES -------------------------------------------------------- + + op.drop_index( + "ix_evaluation_queues_tags", + table_name="evaluation_queues", + ) + op.drop_index( + "ix_evaluation_queues_flags", + table_name="evaluation_queues", + ) + + op.drop_column( + "evaluation_queues", + "status", + ) + op.drop_column( + "evaluation_queues", + "description", + ) + op.drop_column( + "evaluation_queues", + "name", + ) + + # EVALUATION METRICS ------------------------------------------------------- + + op.drop_index( + "ix_evaluation_metrics_tags", + table_name="evaluation_metrics", + ) + op.drop_index( + "ix_evaluation_metrics_flags", + table_name="evaluation_metrics", + ) + op.drop_index( + "ix_evaluation_metrics_timestamp_interval", + table_name="evaluation_metrics", + ) + + op.drop_constraint( + "uq_evaluation_metrics_project_run_scenario_timestamp_interval", + "evaluation_metrics", + type_="unique", + ) + + op.create_unique_constraint( + op.f("evaluation_metrics_project_id_run_id_scenario_id_key"), + "evaluation_metrics", + ["project_id", "run_id", "scenario_id"], + postgresql_nulls_not_distinct=False, + ) + + op.drop_column("evaluation_metrics", "interval") + + # EVALUATION RESULTS ------------------------------------------------------- + + op.execute( + "ALTER INDEX ix_evaluation_results_tags RENAME TO ix_evaluation_steps_tags" + ) + op.execute( + "ALTER INDEX ix_evaluation_results_flags RENAME TO ix_evaluation_steps_flags" + ) + op.execute( + "ALTER INDEX ix_evaluation_results_timestamp_interval RENAME TO ix_evaluation_steps_timestamp_interval" + ) + op.execute( + "ALTER INDEX ix_evaluation_results_repeat_idx RENAME TO ix_evaluation_steps_repeat_idx" + ) + op.execute( + "ALTER INDEX ix_evaluation_results_step_key RENAME TO ix_evaluation_steps_step_key" + ) + op.execute( + "ALTER INDEX ix_evaluation_results_scenario_id RENAME TO ix_evaluation_steps_scenario_id" + ) + op.execute( + "ALTER INDEX ix_evaluation_results_run_id RENAME TO ix_evaluation_steps_run_id" + ) + op.execute( + "ALTER INDEX ix_evaluation_results_project_id RENAME TO ix_evaluation_steps_project_id" + ) + + op.execute( + "ALTER TABLE evaluation_results RENAME CONSTRAINT uq_evaluation_results_project_run_scenario_step_repeat " + "TO uq_evaluation_steps_project_run_scenario_step_repeat" + ) + + op.rename_table("evaluation_results", "evaluation_steps") + + op.drop_index( + "ix_evaluation_steps_tags", + table_name="evaluation_steps", + ) + op.drop_index( + "ix_evaluation_steps_flags", + table_name="evaluation_steps", + ) + op.drop_index( + "ix_evaluation_steps_timestamp_interval", + table_name="evaluation_steps", + ) + op.drop_index( + "ix_evaluation_steps_repeat_idx", + table_name="evaluation_steps", + ) + op.drop_index( + "ix_evaluation_steps_step_key", + table_name="evaluation_steps", + ) + + op.drop_constraint( + "uq_evaluation_steps_project_run_scenario_step_repeat", + "evaluation_steps", + type_="unique", + ) + + op.alter_column( + "evaluation_steps", + "timestamp", + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=False, + ) + + op.drop_column("evaluation_steps", "interval") + + # EVALUATION SCENARIOS ----------------------------------------------------- + + op.drop_index( + "ix_evaluation_scenarios_tags", + table_name="evaluation_scenarios", + ) + op.drop_index( + "ix_evaluation_scenarios_flags", + table_name="evaluation_scenarios", + ) + op.drop_index( + "ix_evaluation_scenarios_timestamp_interval", + table_name="evaluation_scenarios", + ) + + op.drop_column("evaluation_scenarios", "interval") + + # EVALUATION RUNS ---------------------------------------------------------- + + op.drop_index( + "ix_evaluation_runs_tags", + table_name="evaluation_runs", + ) + op.drop_index( + "ix_evaluation_runs_flags", + table_name="evaluation_runs", + ) + op.drop_index( + "ix_evaluation_runs_references", + table_name="evaluation_runs", + ) + + op.drop_column("evaluation_runs", "references") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/core/versions/6161b674688d_add_commit_message_column_to_app_.py b/api/ee/databases/postgres/migrations/core/versions/6161b674688d_add_commit_message_column_to_app_.py new file mode 100644 index 0000000000..81d1ee6046 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/6161b674688d_add_commit_message_column_to_app_.py @@ -0,0 +1,39 @@ +"""add commit_message column to app_variants, app_variant_revisions and environments_revisions table + +Revision ID: 6161b674688d +Revises: 2a91436752f9 +Create Date: 2025-03-27 08:23:07.894643 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "6161b674688d" +down_revision: Union[str, None] = "2a91436752f9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "app_variant_revisions", + sa.Column("commit_message", sa.String(length=255), nullable=True), + ) + op.add_column( + "environments_revisions", + sa.Column("commit_message", sa.String(length=255), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("environments_revisions", "commit_message") + op.drop_column("app_variant_revisions", "commit_message") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/6965776e6940_add_subscriptions.py b/api/ee/databases/postgres/migrations/core/versions/6965776e6940_add_subscriptions.py new file mode 100644 index 0000000000..b6b76bc89f --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/6965776e6940_add_subscriptions.py @@ -0,0 +1,40 @@ +"""add subscriptions + +Revision ID: 6965776e6940 +Revises: 425c68e8de6c +Create Date: 2025-01-23 13:42:47.716771 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "6965776e6940" +down_revision: Union[str, None] = "7cc66fc40298" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "subscriptions", + sa.Column("plan", sa.String(), nullable=False), + sa.Column("active", sa.Boolean(), nullable=False), + sa.Column("organization_id", sa.UUID(), nullable=False), + sa.Column("customer_id", sa.String(), nullable=True), + sa.Column("subscription_id", sa.String(), nullable=True), + sa.Column("anchor", sa.SmallInteger(), nullable=True), + sa.PrimaryKeyConstraint("organization_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("subscriptions") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/6aafdfc2befb_rename_user_organizations_to_organization_members.py b/api/ee/databases/postgres/migrations/core/versions/6aafdfc2befb_rename_user_organizations_to_organization_members.py new file mode 100644 index 0000000000..02fb6c9eef --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/6aafdfc2befb_rename_user_organizations_to_organization_members.py @@ -0,0 +1,63 @@ +"""created project_members table and added organization&workspace id to projects table + +Revision ID: 6aafdfc2befb +Revises: 8accbbea1d21 +Create Date: 2024-09-02 15:50:58.870573 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "6aafdfc2befb" +down_revision: Union[str, None] = "e14e8689cd03" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "organization_members", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + connection = op.get_bind() + inspector = sa.inspect(connection) + if "user_organizations" not in inspector.get_table_names(): + op.create_table( + "user_organizations", + sa.Column("id", sa.UUID(), autoincrement=False, nullable=False), + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + sa.Column("organization_id", sa.UUID(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + name="user_organizations_organization_id_fkey", + ), + sa.ForeignKeyConstraint( + ["user_id"], ["users.id"], name="user_organizations_user_id_fkey" + ), + sa.PrimaryKeyConstraint("id", name="user_organizations_pkey"), + ) + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/73a2d8cfaa3c_add_is_demo_flag.py b/api/ee/databases/postgres/migrations/core/versions/73a2d8cfaa3c_add_is_demo_flag.py new file mode 100644 index 0000000000..94eed007df --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/73a2d8cfaa3c_add_is_demo_flag.py @@ -0,0 +1,30 @@ +"""add initial demo + +Revision ID: 73a2d8cfaa3c +Revises: 24f8bdb390ee +Create Date: 2024-12-02 9:00:00 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "73a2d8cfaa3c" +down_revision: Union[str, None] = "24f8bdb390ee" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### custom data migrations ### + op.add_column("project_members", sa.Column("is_demo", sa.BOOLEAN(), nullable=True)) + # ### end of custom data commands ### + + +def downgrade() -> None: + # ### custom data migrations ### + op.drop_column("project_members", "is_demo") + # ### end of custom data commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/73a2d8cfaa3d_add_initial_demo.py b/api/ee/databases/postgres/migrations/core/versions/73a2d8cfaa3d_add_initial_demo.py new file mode 100644 index 0000000000..f20dfb0e2d --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/73a2d8cfaa3d_add_initial_demo.py @@ -0,0 +1,36 @@ +"""add initial demo + +Revision ID: 73a2d8cfaa3d +Revises: 73a2d8cfaa3c +Create Date: 2024-12-02 9:00:00 + +""" + +from typing import Sequence, Union + +from alembic import context + +from ee.databases.postgres.migrations.core.data_migrations.demos import ( + add_users_to_demos, + remove_users_from_demos, +) + +# revision identifiers, used by Alembic. +revision: str = "73a2d8cfaa3d" +down_revision: Union[str, None] = "73a2d8cfaa3c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### custom data migrations ### + connection = context.get_bind() # get database connect from alembic context + add_users_to_demos(session=connection) + # ### end of custom data commands ### + + +def downgrade() -> None: + # ### custom data migrations ### + connection = context.get_bind() # get database connect from alembic context + remove_users_from_demos(session=connection) + # ### end of custom data commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/770d68410ab0_transfer_user_organization_to_.py b/api/ee/databases/postgres/migrations/core/versions/770d68410ab0_transfer_user_organization_to_.py new file mode 100644 index 0000000000..a69fbc2b6b --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/770d68410ab0_transfer_user_organization_to_.py @@ -0,0 +1,35 @@ +"""transfer user organization to organization members + +Revision ID: 770d68410ab0 +Revises: 79b9acb137a1 +Create Date: 2024-09-08 18:21:27.192472 + +""" + +from typing import Sequence, Union +from alembic import context +from alembic import op + + +from ee.databases.postgres.migrations.core.data_migrations.export_records import ( + transfer_records_from_user_organization_to_organization_members, + transfer_records_from_organization_members_to_user_organization, +) + + +# revision identifiers, used by Alembic. +revision: str = "770d68410ab0" +down_revision: Union[str, None] = "79b9acb137a1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + connection = context.get_bind() # get database connect from alembic context + transfer_records_from_user_organization_to_organization_members(session=connection) + + +def downgrade() -> None: + connection = context.get_bind() # get database connect from alembic context + transfer_records_from_organization_members_to_user_organization(session=connection) + op.drop_table("organization_members") diff --git a/api/ee/databases/postgres/migrations/core/versions/7990f1e12f47_create_free_plans.py b/api/ee/databases/postgres/migrations/core/versions/7990f1e12f47_create_free_plans.py new file mode 100644 index 0000000000..3061a4d230 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/7990f1e12f47_create_free_plans.py @@ -0,0 +1,360 @@ +"""create free plans + +Revision ID: 7990f1e12f47 +Revises: 12f477990f1e +Create Date: 2025-01-25 16:51:06.233811 + +""" + +from typing import Sequence, Union +from os import environ +from datetime import datetime, timezone +from time import time + +from alembic import context + +from sqlalchemy import Connection, func, insert, select, update + +import stripe + +from oss.src.utils.logging import get_module_logger +from oss.src.models.db_models import UserDB +from oss.src.models.db_models import AppDB +from ee.src.models.db_models import OrganizationDB +from ee.src.models.db_models import OrganizationMemberDB +from ee.src.models.db_models import ProjectDB +from ee.src.models.db_models import ProjectMemberDB +from ee.src.dbs.postgres.subscriptions.dbes import SubscriptionDBE +from ee.src.dbs.postgres.meters.dbes import MeterDBE +from ee.src.core.subscriptions.types import FREE_PLAN +from ee.src.core.entitlements.types import Gauge + +stripe.api_key = environ.get("STRIPE_API_KEY") + +log = get_module_logger(__name__) + +# revision identifiers, used by Alembic. +revision: str = "7990f1e12f47" +down_revision: Union[str, None] = "12f477990f1e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + try: + session: Connection = context.get_bind() + + now = datetime.now(timezone.utc) + + # --> GET ORGANIZATION COUNT + query = select(func.count()).select_from(OrganizationDB) + + nof_organizations = session.execute(query).scalar() + # <-- GET ORGANIZATION COUNT + + # --> ITERATE OVER ORGANIZATION BATCHES + organization_batch_size = 100 + organization_batch_index = 0 + + while True: + # --> GET ORGANIZATION BATCH + query = ( + select(OrganizationDB) + .limit(organization_batch_size) + .offset(organization_batch_index * organization_batch_size) + ) + + organizations = session.execute(query).all() + + organization_batch_index += 1 + + if not organizations: + break + # <-- GET ORGANIZATION BATCH + + # --> ITERATE OVER ORGANIZATIONS + for i, organization in enumerate(organizations): + log.info( + " %s / %s - %s", + (organization_batch_index - 1) * organization_batch_size + i + 1, + nof_organizations, + organization.id, + ) + + ti = time() + + # xti = time() + # --> GET ORGANIZATION INFO + owner = organization.owner + + if not owner: + continue + + query = select(UserDB).where( + UserDB.id == owner, + ) + + user = session.execute(query).first() + + if not user: + continue + + email = user.email + + if not email: + continue + # <-- GET ORGANIZATION INFO + # xtf = time() + # xdt = xtf - xti + # log.info(" - GET ORGANIZATION INFO: %s ms", int(xdt * 1000)) + + # xti = time() + # --> CHECK IF SUBSCRIPTION EXISTS + organization_id = organization.id + customer_id = None + subscription_id = None + plan = FREE_PLAN + active = True + anchor = now.day + + subscription_exists = ( + session.execute( + select(SubscriptionDBE).where( + SubscriptionDBE.organization_id == organization_id, + ) + ) + .scalars() + .first() + ) + # <-- CHECK IF SUBSCRIPTION EXISTS + # xtf = time() + # xdt = xtf - xti + # log.info(" - CHECK IF SUBSCRIPTION EXISTS: %s ms", int(xdt * 1000)) + + # xti = time() + # --> CREATE OR UPDATE SUBSCRIPTION + if not subscription_exists: + query = insert(SubscriptionDBE).values( + organization_id=organization_id, + subscription_id=subscription_id, + customer_id=customer_id, + plan=plan.value, + active=active, + anchor=anchor, + ) + + session.execute(query) + else: + query = ( + update(SubscriptionDBE) + .where( + SubscriptionDBE.organization_id == organization_id, + ) + .values( + subscription_id=subscription_id, + customer_id=customer_id, + plan=plan.value, + active=active, + anchor=anchor, + ) + ) + + session.execute(query) + # <-- CREATE OR UPDATE SUBSCRIPTION + # xtf = time() + # xdt = xtf - xti + # log.info(" - CREATE OR UPDATE SUBSCRIPTION: %s ms", int(xdt * 1000)) + + # xti = time() + # --> GET ORGANIZATION MEMBERS + query = ( + select(func.count()) + .select_from(OrganizationMemberDB) + .where( + OrganizationMemberDB.organization_id == organization.id, + ) + ) + + nof_members = session.execute(query).scalar() + # <-- GET ORGANIZATION MEMBERS + # xtf = time() + # xdt = xtf - xti + # log.info(" - GET ORGANIZATION MEMBERS: %s ms", int(xdt * 1000)) + + # xti = time() + # --> CHECK IF USERS METER EXISTS + key = Gauge.USERS + value = nof_members + synced = 0 + # organization_id = organization_id + year = 0 + month = 0 + + users_meter_exists = ( + session.execute( + select(MeterDBE).where( + MeterDBE.organization_id == organization_id, + MeterDBE.key == key, + MeterDBE.year == year, + MeterDBE.month == month, + ) + ) + .scalars() + .first() + ) + # <-- CHECK IF USERS METER EXISTS + # xtf = time() + # xdt = xtf - xti + # log.info(" - CHECK IF USERS METER EXISTS: %s ms", int(xdt * 1000)) + + # xti = time() + # --> CREATE OR UPDATE USERS METER + if not users_meter_exists: + query = insert(MeterDBE).values( + organization_id=organization_id, + key=key, + year=year, + month=month, + value=value, + synced=synced, + ) + + session.execute(query) + else: + query = ( + update(MeterDBE) + .where( + MeterDBE.organization_id == organization_id, + MeterDBE.key == key, + MeterDBE.year == year, + MeterDBE.month == month, + ) + .values( + value=value, + synced=synced, + ) + ) + + session.execute(query) + # <-- CREATE OR UPDATE USERS METER + # xtf = time() + # xdt = xtf - xti + # log.info(" - CREATE OR UPDATE USERS METER: %s ms", int(xdt * 1000)) + + # xti = time() + # --> GET ORGANIZATION PROJECTS + query = select(ProjectDB).where( + ProjectDB.organization_id == organization_id, + ) + + projects = session.execute(query).all() + # <-- GET ORGANIZATION PROJECTS + # xtf = time() + # xdt = xtf - xti + # log.info(" - GET ORGANIZATION PROJECTS: %s ms", int(xdt * 1000)) + + # xti = time() + # --> ITERATE OVER PROJECTS + value = 0 + + for project in projects: + # --> GET PROJECT APPLICATIONS + query = select(AppDB).where( + AppDB.project_id == project.id, + ) + + apps = session.execute(query).scalars().all() + # <-- GET PROJECT APPLICATIONS + + value += len(apps) + # <-- ITERATE OVER PROJECTS + # xtf = time() + # xdt = xtf - xti + # log.info(" - ITERATE OVER PROJECTS: %s ms", int(xdt * 1000)) + + # xti = time() + # --> CHECK IF APPLICATIONS METER EXISTS + key = Gauge.APPLICATIONS + # value = value + synced = 0 + # organization_id = organization_id + year = 0 + month = 0 + + applications_meter_exists = ( + session.execute( + select(MeterDBE).where( + MeterDBE.organization_id == organization_id, + MeterDBE.key == key, + MeterDBE.year == year, + MeterDBE.month == month, + ) + ) + .scalars() + .first() + ) + # <-- CHECK IF APPLICATIONS METER EXISTS + # xtf = time() + # xdt = xtf - xti + # log.info( + # " - CHECK IF APPLICATIONS METER EXISTS: %s ms", int(xdt * 1000) + # ) + + # xti = time() + # --> CREATE OR UPDATE APPLICATIONS METER + if not applications_meter_exists: + query = insert(MeterDBE).values( + organization_id=organization_id, + key=key, + year=year, + month=month, + value=value, + synced=synced, + ) + + session.execute(query) + else: + query = ( + update(MeterDBE) + .where( + MeterDBE.organization_id == organization_id, + MeterDBE.key == key, + MeterDBE.year == year, + MeterDBE.month == month, + ) + .values( + value=value, + synced=synced, + ) + ) + + session.execute(query) + # <-- CREATE OR UPDATE APPLICATIONS METER + # xtf = time() + # xdt = xtf - xti + # log.info( + # " - CREATE OR UPDATE APPLICATIONS METER: %s ms", int(xdt * 1000) + # ) + + tf = time() + dt = tf - ti + log.info( + " %s / %s - %s - %s ms", + (organization_batch_index - 1) * organization_batch_size + i + 1, + nof_organizations, + organization.id, + int(dt * 1000), + ) + # <-- ITERATE OVER ORGANIZATIONS + + # <-- ITERATE OVER ORGANIZATION BATCHES + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Error during free plans migration: %s", e) + session.rollback() + raise e + + log.info("Free plans migration completed successfully.") + + +def downgrade() -> None: + pass diff --git a/api/ee/databases/postgres/migrations/core/versions/79b9acb137a1_transfer_workspace_invitations_to_.py b/api/ee/databases/postgres/migrations/core/versions/79b9acb137a1_transfer_workspace_invitations_to_.py new file mode 100644 index 0000000000..bade4cb395 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/79b9acb137a1_transfer_workspace_invitations_to_.py @@ -0,0 +1,37 @@ +"""transfer workspace invitations to project invitations + +Revision ID: 79b9acb137a1 +Revises: 9b0e1a740b88 +Create Date: 2024-09-05 17:16:29.480645 + +""" + +from typing import Sequence, Union + +from alembic import context + +from ee.databases.postgres.migrations.core.data_migrations.invitations import ( + transfer_invitations_from_old_table_to_new_table, + revert_invitations_transfer_from_new_table_to_old_table, +) + + +# revision identifiers, used by Alembic. +revision: str = "79b9acb137a1" +down_revision: Union[str, None] = "9b0e1a740b88" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### custom migration ### + connection = context.get_bind() # get database connect from alembic context + transfer_invitations_from_old_table_to_new_table(session=connection) + # ### end of custom migration ### + + +def downgrade() -> None: + # ### custom migration ### + connection = context.get_bind() # get database connect from alembic context + revert_invitations_transfer_from_new_table_to_old_table(session=connection) + # ### end of custom migration ### diff --git a/api/ee/databases/postgres/migrations/core/versions/7cc66fc40298_add_hidden_column_to_app_variants_table.py b/api/ee/databases/postgres/migrations/core/versions/7cc66fc40298_add_hidden_column_to_app_variants_table.py new file mode 100644 index 0000000000..d45ba53b3c --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/7cc66fc40298_add_hidden_column_to_app_variants_table.py @@ -0,0 +1,35 @@ +"""add 'hidden' column to app_variants table + +Revision ID: 7cc66fc40298 +Revises: 6161b674688d +Create Date: 2025-03-27 14:40:47.770949 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "7cc66fc40298" +down_revision: Union[str, None] = "6161b674688d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("app_variants", sa.Column("hidden", sa.Boolean(), nullable=True)) + op.add_column( + "app_variant_revisions", sa.Column("hidden", sa.Boolean(), nullable=True) + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("app_variants", "hidden") + op.drop_column("app_variant_revisions", "hidden") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/8089ee7692d1_cleanup_preview_entities.py b/api/ee/databases/postgres/migrations/core/versions/8089ee7692d1_cleanup_preview_entities.py new file mode 100644 index 0000000000..36e9e4edd4 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/8089ee7692d1_cleanup_preview_entities.py @@ -0,0 +1,168 @@ +"""clean up preview entities + +Revision ID: 8089ee7692d1 +Revises: fa07e07350bf +Create Date: 2025-08-20 16:00:00.00000000 + +""" + +from typing import Sequence, Union +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "8089ee7692d1" +down_revision: Union[str, None] = "fa07e07350bf" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +TABLES_WITH_DATA_MIGRATION = [ + "evaluation_runs", + "evaluation_metrics", + "evaluation_queues", + "testcase_blobs", + "testset_revisions", + "query_revisions", + "workflow_revisions", +] + +TABLES_WITH_META_MIGRATION = [ + "evaluation_runs", + "evaluation_scenarios", + "evaluation_steps", + "evaluation_metrics", + "evaluation_queues", + "testcase_blobs", + "testset_artifacts", + "testset_variants", + "testset_revisions", + "query_artifacts", + "query_variants", + "query_revisions", + "workflow_artifacts", + "workflow_variants", + "workflow_revisions", +] + + +def upgrade() -> None: + # Convert jsonb -> json for data columns + for table in TABLES_WITH_DATA_MIGRATION: + op.alter_column( + table_name=table, + column_name="data", + type_=sa.JSON(), + postgresql_using="data::json", + ) + + # Convert jsonb -> json for meta columns + for table in TABLES_WITH_META_MIGRATION: + op.alter_column( + table_name=table, + column_name="meta", + type_=sa.JSON(), + postgresql_using="meta::json", + ) + + # Add new timestamp column + op.add_column( + "evaluation_scenarios", + sa.Column( + "timestamp", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + ) + + # Add repeat_idx and drop old repeat_id + retry_id + op.add_column( + "evaluation_steps", + sa.Column( + "repeat_idx", + sa.Integer(), + nullable=True, + ), + ) + op.drop_column( + "evaluation_steps", + "repeat_id", + ) + op.drop_column( + "evaluation_steps", + "retry_id", + ) + + # Rename key -> step_key + op.alter_column( + "evaluation_steps", + "key", + new_column_name="step_key", + existing_type=sa.String(), # adjust if needed + existing_nullable=False, + ) + + op.drop_column( + "evaluation_metrics", + "interval", + ) + + +def downgrade() -> None: + op.add_column( + "evaluation_metrics", + sa.Column( + "interval", + sa.Integer(), + nullable=True, + ), + ) + + # Rename step_key back to key + op.alter_column( + "evaluation_steps", + "step_key", + new_column_name="key", + existing_type=sa.String(), # adjust if needed + existing_nullable=False, + ) + + # Recreate repeat_id and retry_id columns + op.add_column( + "evaluation_steps", + sa.Column("repeat_id", sa.UUID(), nullable=False), + ) + op.add_column( + "evaluation_steps", + sa.Column("retry_id", sa.UUID(), nullable=False), + ) + + # Drop repeat_idx column + op.drop_column( + "evaluation_steps", + "repeat_idx", + ) + + # Drop timestamp column + op.drop_column( + "evaluation_scenarios", + "timestamp", + ) + + # Convert meta columns back to jsonb + for table in TABLES_WITH_META_MIGRATION: + op.alter_column( + table_name=table, + column_name="meta", + type_=sa.dialects.postgresql.JSONB(), + postgresql_using="meta::jsonb", + ) + + # Convert data columns back to jsonb + for table in TABLES_WITH_DATA_MIGRATION: + op.alter_column( + table_name=table, + column_name="data", + type_=sa.dialects.postgresql.JSONB(), + postgresql_using="data::jsonb", + ) diff --git a/api/ee/databases/postgres/migrations/core/versions/847972cfa14a_add_nodes_dbe.py b/api/ee/databases/postgres/migrations/core/versions/847972cfa14a_add_nodes_dbe.py new file mode 100644 index 0000000000..239b9fb280 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/847972cfa14a_add_nodes_dbe.py @@ -0,0 +1,121 @@ +"""add_nodes_dbe + +Revision ID: 847972cfa14a +Revises: 320a4a7ee0c7 +Create Date: 2024-11-07 12:21:19.080345 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "847972cfa14a" +down_revision: Union[str, None] = "320a4a7ee0c7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "nodes", + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("updated_by_id", sa.UUID(), nullable=True), + sa.Column("root_id", sa.UUID(), nullable=False), + sa.Column("tree_id", sa.UUID(), nullable=False), + sa.Column("tree_type", sa.Enum("INVOCATION", name="treetype"), nullable=True), + sa.Column("node_id", sa.UUID(), nullable=False), + sa.Column("node_name", sa.String(), nullable=False), + sa.Column( + "node_type", + sa.Enum( + "AGENT", + "WORKFLOW", + "CHAIN", + "TASK", + "TOOL", + "EMBEDDING", + "QUERY", + "COMPLETION", + "CHAT", + "RERANK", + name="nodetype", + ), + nullable=True, + ), + sa.Column("parent_id", sa.UUID(), nullable=True), + sa.Column("time_start", sa.TIMESTAMP(), nullable=False), + sa.Column("time_end", sa.TIMESTAMP(), nullable=False), + sa.Column( + "status", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metrics", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "refs", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "exception", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "links", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column("content", sa.String(), nullable=True), + sa.Column( + "otel", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.PrimaryKeyConstraint("project_id", "node_id"), + ) + op.create_index( + "index_project_id_node_id", "nodes", ["project_id", "created_at"], unique=False + ) + op.create_index( + "index_project_id_root_id", "nodes", ["project_id", "root_id"], unique=False + ) + op.create_index( + "index_project_id_tree_id", "nodes", ["project_id", "tree_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("index_project_id_tree_id", table_name="nodes") + op.drop_index("index_project_id_root_id", table_name="nodes") + op.drop_index("index_project_id_node_id", table_name="nodes") + op.drop_table("nodes") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/8accbbea1d21_initial_migration.py b/api/ee/databases/postgres/migrations/core/versions/8accbbea1d21_initial_migration.py new file mode 100644 index 0000000000..d5f43f9f08 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/8accbbea1d21_initial_migration.py @@ -0,0 +1,1000 @@ +"""initial migration + +Revision ID: 8accbbea1d21 +Revises: +Create Date: 2024-07-27 16:20:33.077302 + +""" + +import os +from typing import Sequence, Union + +from alembic import op +from alembic import context + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from oss.src.utils.env import env +from ee.databases.postgres.migrations.core.utils import is_initial_setup + + +# revision identifiers, used by Alembic. +revision: str = "8accbbea1d21" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def unique_constraint_exists(engine: sa.Engine, table_name: str, constraint_name: str): + with engine.connect() as conn: + result = conn.execute( + sa.text( + f""" + SELECT conname FROM pg_constraint + WHERE conname = '{constraint_name}' AND conrelid = '{table_name}'::regclass; + """ + ) + ) + return result.fetchone() is not None + + +def first_time_user_from_agenta_v019_upwards_upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "api_keys", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("prefix", sa.String(), nullable=True), + sa.Column("hashed_key", sa.String(), nullable=True), + sa.Column("user_id", sa.String(), nullable=True), + sa.Column("workspace_id", sa.String(), nullable=True), + sa.Column("rate_limit", sa.Integer(), nullable=True), + sa.Column("hidden", sa.Boolean(), nullable=True), + sa.Column("expiration_date", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "ids_mapping", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("table_name", sa.String(), nullable=False), + sa.Column("objectid", sa.String(), nullable=False), + sa.Column("uuid", sa.UUID(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "invitations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("token", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("used", sa.Boolean(), nullable=True), + sa.Column("workspace_id", sa.String(), nullable=False), + sa.Column( + "workspace_roles", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column("expiration_date", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("token"), + ) + op.create_table( + "organizations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("type", sa.String(), nullable=True), + sa.Column("owner", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("is_paying", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "templates", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("type", sa.Enum("IMAGE", "ZIP", name="templatetype"), nullable=False), + sa.Column("template_uri", sa.String(), nullable=True), + sa.Column("tag_id", sa.Integer(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("repo_name", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("size", sa.Integer(), nullable=True), + sa.Column("digest", sa.String(), nullable=True), + sa.Column("last_pushed", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("name"), + ) + op.create_table( + "users", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("uid", sa.String(), nullable=True), + sa.Column("username", sa.String(), nullable=True), + sa.Column("email", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email"), + sa.UniqueConstraint("id"), + ) + op.create_index(op.f("ix_users_uid"), "users", ["uid"], unique=True) + op.create_table( + "user_organizations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "workspaces", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("type", sa.String(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "app_db", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_name", sa.String(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "docker_images", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("type", sa.String(), nullable=True), + sa.Column("template_uri", sa.String(), nullable=True), + sa.Column("docker_id", sa.String(), nullable=True), + sa.Column("tags", sa.String(), nullable=True), + sa.Column("deletable", sa.Boolean(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_index( + op.f("ix_docker_images_docker_id"), "docker_images", ["docker_id"], unique=False + ) + op.create_table( + "workspace_members", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspaces.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "deployments", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("container_name", sa.String(), nullable=True), + sa.Column("container_id", sa.String(), nullable=True), + sa.Column("uri", sa.String(), nullable=True), + sa.Column("status", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("cloud_map_service_id", sa.String(), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "evaluators_configs", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("evaluator_key", sa.String(), nullable=True), + sa.Column( + "settings_values", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "testsets", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("csvdata", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "bases", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("base_name", sa.String(), nullable=True), + sa.Column("image_id", sa.UUID(), nullable=True), + sa.Column("deployment_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["deployment_id"], ["deployments.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["image_id"], ["docker_images.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "human_evaluations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("status", sa.String(), nullable=True), + sa.Column("evaluation_type", sa.String(), nullable=True), + sa.Column("testset_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["testset_id"], + ["testsets.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "app_variants", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("variant_name", sa.String(), nullable=True), + sa.Column("revision", sa.Integer(), nullable=True), + sa.Column("image_id", sa.UUID(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("modified_by_id", sa.UUID(), nullable=True), + sa.Column("base_name", sa.String(), nullable=True), + sa.Column("base_id", sa.UUID(), nullable=True), + sa.Column("config_name", sa.String(), nullable=False), + sa.Column( + "config_parameters", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["base_id"], + ["bases.id"], + ), + sa.ForeignKeyConstraint( + ["image_id"], ["docker_images.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["modified_by_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "human_evaluations_scenarios", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("evaluation_id", sa.UUID(), nullable=True), + sa.Column("inputs", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("outputs", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("vote", sa.String(), nullable=True), + sa.Column("score", sa.String(), nullable=True), + sa.Column("correct_answer", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("is_pinned", sa.Boolean(), nullable=True), + sa.Column("note", sa.String(), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["evaluation_id"], ["human_evaluations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "app_variant_revisions", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("variant_id", sa.UUID(), nullable=True), + sa.Column("revision", sa.Integer(), nullable=True), + sa.Column("modified_by_id", sa.UUID(), nullable=True), + sa.Column("base_id", sa.UUID(), nullable=True), + sa.Column("config_name", sa.String(), nullable=False), + sa.Column( + "config_parameters", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["base_id"], + ["bases.id"], + ), + sa.ForeignKeyConstraint( + ["modified_by_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["variant_id"], ["app_variants.id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "environments", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("name", sa.String(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("revision", sa.Integer(), nullable=True), + sa.Column("deployed_app_variant_id", sa.UUID(), nullable=True), + sa.Column("deployed_app_variant_revision_id", sa.UUID(), nullable=True), + sa.Column("deployment_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["deployed_app_variant_id"], ["app_variants.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["deployed_app_variant_revision_id"], + ["app_variant_revisions.id"], + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["deployment_id"], ["deployments.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "evaluations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("app_id", sa.UUID(), nullable=True), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("status", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("testset_id", sa.UUID(), nullable=True), + sa.Column("variant_id", sa.UUID(), nullable=True), + sa.Column("variant_revision_id", sa.UUID(), nullable=True), + sa.Column( + "average_cost", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column("total_cost", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column( + "average_latency", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint(["app_id"], ["app_db.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint(["testset_id"], ["testsets.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["variant_id"], ["app_variants.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["variant_revision_id"], ["app_variant_revisions.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "human_evaluation_variants", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("human_evaluation_id", sa.UUID(), nullable=True), + sa.Column("variant_id", sa.UUID(), nullable=True), + sa.Column("variant_revision_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["human_evaluation_id"], ["human_evaluations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["variant_id"], ["app_variants.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["variant_revision_id"], ["app_variant_revisions.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "environments_revisions", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("environment_id", sa.UUID(), nullable=True), + sa.Column("revision", sa.Integer(), nullable=True), + sa.Column("modified_by_id", sa.UUID(), nullable=True), + sa.Column("deployed_app_variant_revision_id", sa.UUID(), nullable=True), + sa.Column("deployment_id", sa.UUID(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["deployed_app_variant_revision_id"], + ["app_variant_revisions.id"], + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["deployment_id"], ["deployments.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["environment_id"], ["environments.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["modified_by_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "evaluation_aggregated_results", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("evaluation_id", sa.UUID(), nullable=True), + sa.Column("evaluator_config_id", sa.UUID(), nullable=True), + sa.Column("result", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.ForeignKeyConstraint( + ["evaluation_id"], ["evaluations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["evaluator_config_id"], ["evaluators_configs.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "evaluation_evaluator_configs", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("evaluation_id", sa.UUID(), nullable=False), + sa.Column("evaluator_config_id", sa.UUID(), nullable=False), + sa.ForeignKeyConstraint( + ["evaluation_id"], ["evaluations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["evaluator_config_id"], ["evaluators_configs.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id", "evaluation_id", "evaluator_config_id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "evaluation_scenarios", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("evaluation_id", sa.UUID(), nullable=True), + sa.Column("variant_id", sa.UUID(), nullable=True), + sa.Column("inputs", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("outputs", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column( + "correct_answers", postgresql.JSONB(astext_type=sa.Text()), nullable=True + ), + sa.Column("is_pinned", sa.Boolean(), nullable=True), + sa.Column("note", sa.String(), nullable=True), + sa.Column("latency", sa.Integer(), nullable=True), + sa.Column("cost", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["evaluation_id"], ["evaluations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["variant_id"], ["app_variants.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "evaluation_scenario_results", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("evaluation_scenario_id", sa.UUID(), nullable=True), + sa.Column("evaluator_config_id", sa.UUID(), nullable=True), + sa.Column("result", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.ForeignKeyConstraint( + ["evaluation_scenario_id"], ["evaluation_scenarios.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["evaluator_config_id"], ["evaluators_configs.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + # ### end Alembic commands ### + + +def first_time_user_from_agenta_v019_upwards_downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("evaluation_scenario_results") + op.drop_table("evaluation_scenarios") + op.drop_table("evaluation_evaluator_configs") + op.drop_table("evaluation_aggregated_results") + op.drop_table("environments_revisions") + op.drop_table("human_evaluation_variants") + op.drop_table("evaluations") + op.drop_table("environments") + op.drop_table("app_variant_revisions") + op.drop_table("human_evaluations_scenarios") + op.drop_table("app_variants") + op.drop_table("human_evaluations") + op.drop_table("bases") + op.drop_table("testsets") + op.drop_table("evaluators_configs") + op.drop_table("deployments") + op.drop_table("workspace_members") + op.drop_index(op.f("ix_docker_images_docker_id"), table_name="docker_images") + op.drop_table("docker_images") + op.drop_table("app_db") + op.drop_table("workspaces") + op.drop_table("user_organizations") + op.drop_index(op.f("ix_users_uid"), table_name="users") + op.drop_table("users") + op.drop_table("templates") + op.drop_table("organizations") + op.drop_table("invitations") + op.drop_table("ids_mapping") + op.drop_table("api_keys") + # ### end Alembic commands ### + + +def returning_user_from_agenta_v018_downwards_upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + engine = sa.create_engine(env.POSTGRES_URI_CORE) + if not unique_constraint_exists(engine, "app_db", "app_db_pkey"): + op.create_unique_constraint("app_db_pkey", "app_db", ["id"]) + + if not unique_constraint_exists( + engine, "app_variant_revisions", "app_variant_revisions_pkey" + ): + op.create_unique_constraint( + "app_variant_revisions_pkey", "app_variant_revisions", ["id"] + ) + + if not unique_constraint_exists(engine, "app_variants", "app_variants_pkey"): + op.create_unique_constraint("app_variants_pkey", "app_variants", ["id"]) + + if not unique_constraint_exists(engine, "bases", "bases_pkey"): + op.create_unique_constraint("bases_pkey", "bases", ["id"]) + + if not unique_constraint_exists(engine, "deployments", "deployments_pkey"): + op.create_unique_constraint("deployments_pkey", "deployments", ["id"]) + + if not unique_constraint_exists(engine, "docker_images", "docker_images_pkey"): + op.create_unique_constraint("docker_images_pkey", "docker_images", ["id"]) + + if not unique_constraint_exists(engine, "environments", "environments_pkey"): + op.create_unique_constraint("environments_pkey", "environments", ["id"]) + + if not unique_constraint_exists( + engine, "environments_revisions", "environments_revisions_pkey" + ): + op.create_unique_constraint( + "environments_revisions_pkey", "environments_revisions", ["id"] + ) + + if not unique_constraint_exists( + engine, "evaluation_aggregated_results", "evaluation_aggregated_results_pkey" + ): + op.create_unique_constraint( + "evaluation_aggregated_results_pkey", + "evaluation_aggregated_results", + ["id"], + ) + + if not unique_constraint_exists( + engine, "evaluation_scenario_results", "evaluation_scenario_results_pkey" + ): + op.create_unique_constraint( + "evaluation_scenario_results_pkey", "evaluation_scenario_results", ["id"] + ) + + if not unique_constraint_exists( + engine, "evaluation_scenarios", "evaluation_scenarios_pkey" + ): + op.create_unique_constraint( + "evaluation_scenarios_pkey", "evaluation_scenarios", ["id"] + ) + + if not unique_constraint_exists(engine, "evaluations", "evaluations_pkey"): + op.create_unique_constraint("evaluations_pkey", "evaluations", ["id"]) + + if not unique_constraint_exists( + engine, "evaluators_configs", "evaluators_configs_pkey" + ): + op.create_unique_constraint( + "evaluators_configs_pkey", "evaluators_configs", ["id"] + ) + + if not unique_constraint_exists( + engine, "human_evaluation_variants", "human_evaluation_variants_pkey" + ): + op.create_unique_constraint( + "human_evaluation_variants_pkey", "human_evaluation_variants", ["id"] + ) + + if not unique_constraint_exists( + engine, "human_evaluations", "human_evaluations_pkey" + ): + op.create_unique_constraint( + "human_evaluations_pkey", "human_evaluations", ["id"] + ) + + if not unique_constraint_exists( + engine, "human_evaluations_scenarios", "human_evaluations_scenarios_pkey" + ): + op.create_unique_constraint( + "human_evaluations_scenarios_pkey", "human_evaluations_scenarios", ["id"] + ) + + if not unique_constraint_exists(engine, "ids_mapping", "ids_mapping_pkey"): + op.create_unique_constraint("ids_mapping_pkey", "ids_mapping", ["id"]) + + if not unique_constraint_exists(engine, "templates", "templates_pkey"): + op.create_unique_constraint("templates_pkey", "templates", ["id"]) + + if not unique_constraint_exists(engine, "testsets", "testsets_pkey"): + op.create_unique_constraint("testsets_pkey", "testsets", ["id"]) + + if not unique_constraint_exists(engine, "users", "users_pkey"): + op.create_unique_constraint("users_pkey", "users", ["id"]) + + if not unique_constraint_exists(engine, "api_keys", "api_keys_pkey"): + op.create_unique_constraint("api_keys_pkey", "api_keys", ["id"]) + + if not unique_constraint_exists(engine, "invitations", "invitations_pkey"): + op.create_unique_constraint("invitations_pkey", "invitations", ["id"]) + + if not unique_constraint_exists(engine, "organizations", "organizations_pkey"): + op.create_unique_constraint("organizations_pkey", "organizations", ["id"]) + + if not unique_constraint_exists( + engine, "user_organizations", "user_organizations_pkey" + ): + op.create_unique_constraint( + "user_organizations_pkey", "user_organizations", ["id"] + ) + + if not unique_constraint_exists( + engine, "workspace_members", "workspace_members_pkey" + ): + op.create_unique_constraint( + "workspace_members_pkey", "workspace_members", ["id"] + ) + + if not unique_constraint_exists(engine, "workspaces", "workspaces_pkey"): + op.create_unique_constraint("workspaces_pkey", "workspaces", ["id"]) + + # ### end Alembic commands ### + + +def returning_user_from_agenta_v018_downwards_downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + engine = sa.create_engine(env.POSTGRES_URI_CORE) + if unique_constraint_exists(engine, "users", "users_pkey"): + op.drop_constraint("users_pkey", "users", type_="unique") + + if unique_constraint_exists(engine, "testsets", "testsets_pkey"): + op.drop_constraint("testsets_pkey", "testsets", type_="unique") + + if unique_constraint_exists(engine, "templates", "templates_pkey"): + op.drop_constraint("templates_pkey", "templates", type_="unique") + + if unique_constraint_exists(engine, "ids_mapping", "ids_mapping_pkey"): + op.drop_constraint("ids_mapping_pkey", "ids_mapping", type_="unique") + + if unique_constraint_exists( + engine, "human_evaluations_scenarios", "human_evaluations_scenarios_pkey" + ): + op.drop_constraint( + "human_evaluations_scenarios_pkey", + "human_evaluations_scenarios", + type_="unique", + ) + + if unique_constraint_exists(engine, "human_evaluations", "human_evaluations_pkey"): + op.drop_constraint( + "human_evaluations_pkey", "human_evaluations", type_="unique" + ) + + if unique_constraint_exists( + engine, "human_evaluation_variants", "human_evaluation_variants_pkey" + ): + op.drop_constraint( + "human_evaluation_variants_pkey", + "human_evaluation_variants", + type_="unique", + ) + + if unique_constraint_exists( + engine, "evaluators_configs", "evaluators_configs_pkey" + ): + op.drop_constraint( + "evaluators_configs_pkey", "evaluators_configs", type_="unique" + ) + + if unique_constraint_exists(engine, "evaluations", "evaluations_pkey"): + op.drop_constraint("evaluations_pkey", "evaluations", type_="unique") + + if unique_constraint_exists( + engine, "evaluation_scenarios", "evaluation_scenarios_pkey" + ): + op.drop_constraint( + "evaluation_scenarios_pkey", "evaluation_scenarios", type_="unique" + ) + + if unique_constraint_exists( + engine, "evaluation_scenario_results", "evaluation_scenario_results_pkey" + ): + op.drop_constraint( + "evaluation_scenario_results_pkey", + "evaluation_scenario_results", + type_="unique", + ) + + if unique_constraint_exists( + engine, "evaluation_aggregated_results", "evaluation_aggregated_results_pkey" + ): + op.drop_constraint( + "evaluation_aggregated_results_pkey", + "evaluation_aggregated_results", + type_="unique", + ) + + if unique_constraint_exists( + engine, "environments_revisions", "environments_revisions_pkey" + ): + op.drop_constraint( + "environments_revisions_pkey", "environments_revisions", type_="unique" + ) + + if unique_constraint_exists(engine, "environments", "environments_pkey"): + op.drop_constraint("environments_pkey", "environments", type_="unique") + + if unique_constraint_exists(engine, "docker_images", "docker_images_pkey"): + op.drop_constraint("docker_images_pkey", "docker_images", type_="unique") + + if unique_constraint_exists(engine, "deployments", "deployments_pkey"): + op.drop_constraint("deployments_pkey", "deployments", type_="unique") + + if unique_constraint_exists(engine, "bases", "bases_pkey"): + op.drop_constraint("bases_pkey", "bases", type_="unique") + + if unique_constraint_exists(engine, "app_variants", "app_variants_pkey"): + op.drop_constraint("app_variants_pkey", "app_variants", type_="unique") + + if unique_constraint_exists( + engine, "app_variant_revisions", "app_variant_revisions_pkey" + ): + op.drop_constraint( + "app_variant_revisions_pkey", "app_variant_revisions", type_="unique" + ) + + if unique_constraint_exists(engine, "app_db", "app_db_pkey"): + op.drop_constraint("app_db_pkey", "app_db", type_="unique") + + if unique_constraint_exists(engine, "workspaces", "workspaces_pkey"): + op.drop_constraint("workspaces_pkey", "workspaces", type_="unique") + + if unique_constraint_exists(engine, "workspace_members", "workspace_members_pkey"): + op.drop_constraint( + "workspace_members_pkey", "workspace_members", type_="unique" + ) + + if unique_constraint_exists( + engine, "user_organizations", "user_organizations_pkey" + ): + op.drop_constraint( + "user_organizations_pkey", "user_organizations", type_="unique" + ) + + if unique_constraint_exists(engine, "organizations", "organizations_pkey"): + op.drop_constraint("organizations_pkey", "organizations", type_="unique") + + if unique_constraint_exists(engine, "invitations", "invitations_pkey"): + op.drop_constraint("invitations_pkey", "invitations", type_="unique") + + if unique_constraint_exists(engine, "api_keys", "api_keys_pkey"): + op.drop_constraint("api_keys_pkey", "api_keys", type_="unique") + # ### end Alembic commands ### + + +def upgrade() -> None: + engine = sa.create_engine(context.config.get_main_option("sqlalchemy.url")) + if is_initial_setup(engine=engine): + first_time_user_from_agenta_v019_upwards_upgrade() + else: + returning_user_from_agenta_v018_downwards_upgrade() + + +def downgrade() -> None: + engine = sa.create_engine(context.config.get_main_option("sqlalchemy.url")) + if is_initial_setup(engine=engine): + first_time_user_from_agenta_v019_upwards_downgrade() + else: + returning_user_from_agenta_v018_downwards_downgrade() diff --git a/api/ee/databases/postgres/migrations/core/versions/91d3b4a8c27f_fix_ag_config.py b/api/ee/databases/postgres/migrations/core/versions/91d3b4a8c27f_fix_ag_config.py new file mode 100644 index 0000000000..1baa0b36fe --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/91d3b4a8c27f_fix_ag_config.py @@ -0,0 +1,61 @@ +"""Fix ag_config + +Revision ID: 91d3b4a8c27f +Revises: 7990f1e12f47 +Create Date: 2025-04-24 11:00:00 +""" + +from typing import Sequence, Union + +from alembic import op +from sqlalchemy import text + + +revision: str = "91d3b4a8c27f" +down_revision: Union[str, None] = "7990f1e12f47" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + batch_size = 100 + + conn = op.get_bind() + + while True: + # Update config_parameters in app_variant_revisions table + result = conn.execute( + text( + f""" + WITH updated AS ( + UPDATE app_variant_revisions + SET config_parameters = config_parameters->'ag_config' + WHERE id IN ( + SELECT id + FROM app_variant_revisions + WHERE config_parameters ? 'ag_config' + LIMIT {batch_size} + ) + RETURNING id + ) + SELECT COUNT(*) FROM updated; + """ + ) + ) + count = result.scalar() + if count == 0: + break + + # Clear the config_parameters column in app_variants table (execute once) + result = conn.execute( + text( + f""" + UPDATE app_variants + SET config_parameters = '{{}}'::jsonb + """ + ) + ) + + +def downgrade(): + pass diff --git a/api/ee/databases/postgres/migrations/core/versions/9698355c7649_add_tables_for_workflows.py b/api/ee/databases/postgres/migrations/core/versions/9698355c7649_add_tables_for_workflows.py new file mode 100644 index 0000000000..506fe0a1cb --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/9698355c7649_add_tables_for_workflows.py @@ -0,0 +1,388 @@ +"""add tables for workflows (artifacts, variants, & revisions) + +Revision ID: 9698355c7649 +Revises: 7990f1e12f47 +Create Date: 2025-04-24 07:27:45.801481 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "9698355c7649" +down_revision: Union[str, None] = "91d3b4a8c27f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # - ARTIFACTS -------------------------------------------------------------- + + op.create_table( + "workflow_artifacts", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metadata", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_workflow_artifacts_project_id_slug", + "project_id", + "slug", + ), + ) + + # -------------------------------------------------------------------------- + + # - VARIANTS --------------------------------------------------------------- + + op.create_table( + "workflow_variants", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "artifact_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metadata", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "artifact_id"], + ["workflow_artifacts.project_id", "workflow_artifacts.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_workflow_variants_project_id_slug", + "project_id", + "slug", + ), + sa.Index( + "ix_workflow_variants_project_id_artifact_id", + "project_id", + "artifact_id", + ), + ) + + # -------------------------------------------------------------------------- + + # - REVISIONS -------------------------------------------------------------- + + op.create_table( + "workflow_revisions", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "artifact_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "variant_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "slug", + sa.String(), + nullable=False, + ), + sa.Column( + "version", + sa.String(), + nullable=True, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metadata", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "name", + sa.String(), + nullable=True, + ), + sa.Column( + "description", + sa.String(), + nullable=True, + ), + sa.Column( + "message", + sa.String(), + nullable=True, + ), + sa.Column( + "author", + sa.UUID(), + nullable=False, + ), + sa.Column( + "date", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.UniqueConstraint( + "project_id", + "slug", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "artifact_id"], + ["workflow_artifacts.project_id", "workflow_artifacts.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "variant_id"], + ["workflow_variants.project_id", "workflow_variants.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_workflow_revisions_project_id_slug", + "project_id", + "slug", + ), + sa.Index( + "ix_workflow_revisions_project_id_artifact_id", + "project_id", + "artifact_id", + ), + sa.Index( + "ix_workflow_revisions_project_id_variant_id", + "project_id", + "variant_id", + ), + ) + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # - REVISIONS -------------------------------------------------------------- + + op.drop_table("workflow_revisions") + + # -------------------------------------------------------------------------- + + # - VARIANTS --------------------------------------------------------------- + + op.drop_table("workflow_variants") + + # -------------------------------------------------------------------------- + + # - ARTIFACTS -------------------------------------------------------------- + + op.drop_table("workflow_artifacts") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/core/versions/9698355c7650_rename_metadata_to_meta.py b/api/ee/databases/postgres/migrations/core/versions/9698355c7650_rename_metadata_to_meta.py new file mode 100644 index 0000000000..d0870f8288 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/9698355c7650_rename_metadata_to_meta.py @@ -0,0 +1,51 @@ +"""rename metadata to meta + +Revision ID: 9698355c7650 +Revises: 0698355c7642 +Create Date: 2025-05-21 07:27:45.801481 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "9698355c7650" +down_revision: Union[str, None] = "0698355c7642" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # - WORKFLOWS -------------------------------------------------------------- + + op.execute("ALTER TABLE workflow_artifacts RENAME COLUMN metadata TO meta") + op.execute("ALTER TABLE workflow_variants RENAME COLUMN metadata TO meta") + op.execute("ALTER TABLE workflow_revisions RENAME COLUMN metadata TO meta") + + # - TESTSETS --------------------------------------------------------------- + + op.execute("ALTER TABLE testset_artifacts RENAME COLUMN metadata TO meta") + op.execute("ALTER TABLE testset_variants RENAME COLUMN metadata TO meta") + op.execute("ALTER TABLE testset_revisions RENAME COLUMN metadata TO meta") + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # - WORKFLOWS -------------------------------------------------------------- + + op.execute("ALTER TABLE workflow_artifacts RENAME COLUMN meta TO metadata") + op.execute("ALTER TABLE workflow_variants RENAME COLUMN meta TO metadata") + op.execute("ALTER TABLE workflow_revisions RENAME COLUMN meta TO metadata") + + # - TESTSETS --------------------------------------------------------------- + + op.execute("ALTER TABLE testset_artifacts RENAME COLUMN meta TO metadata") + op.execute("ALTER TABLE testset_variants RENAME COLUMN meta TO metadata") + op.execute("ALTER TABLE testset_revisions RENAME COLUMN meta TO metadata") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/core/versions/9b0e1a740b88_create_project_invitations_table.py b/api/ee/databases/postgres/migrations/core/versions/9b0e1a740b88_create_project_invitations_table.py new file mode 100644 index 0000000000..d265d52a12 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/9b0e1a740b88_create_project_invitations_table.py @@ -0,0 +1,60 @@ +"""create project_invitations table + +Revision ID: 9b0e1a740b88 +Revises: 1c2d3e4f5a6b +Create Date: 2024-09-05 16:08:04.440845 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "9b0e1a740b88" +down_revision: Union[str, None] = "1c2d3e4f5a6b" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Get the current connection + connection = op.get_bind() + inspector = sa.inspect(connection) + if "project_invitations" not in inspector.get_table_names(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "project_invitations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("token", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("used", sa.Boolean(), nullable=True), + sa.Column("role", sa.String(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("project_id", sa.UUID(), nullable=True), + sa.Column("expiration_date", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + sa.UniqueConstraint("token"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + connection = op.get_bind() + inspector = sa.inspect(connection) + if "project_invitations" in inspector.get_table_names(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("project_invitations") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/aa1b2c3d4e5f_migrate_config_parameters_jsonb_to_json.py b/api/ee/databases/postgres/migrations/core/versions/aa1b2c3d4e5f_migrate_config_parameters_jsonb_to_json.py new file mode 100644 index 0000000000..e0da80ce6b --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/aa1b2c3d4e5f_migrate_config_parameters_jsonb_to_json.py @@ -0,0 +1,132 @@ +"""Migrate config_parameters from JSONB to JSON + +Revision ID: aa1b2c3d4e5f +Revises: d5d4d6bf738f +Create Date: 2025-07-11 12:00:00 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = "aa1b2c3d4e5f" +down_revision: Union[str, None] = "d5d4d6bf738f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + """ + Migrate config_parameters from JSONB to JSON type to preserve key ordering. + This involves: + 1. Creating new JSON columns + 2. Copying data from JSONB to JSON + 3. Dropping old JSONB columns + 4. Renaming new columns to original names + """ + + # Step 1: Add new JSON columns with temporary names + op.add_column( + "app_variants", + sa.Column("config_parameters_json_temp", sa.JSON(), nullable=True), + ) + + op.add_column( + "app_variant_revisions", + sa.Column("config_parameters_json_temp", sa.JSON(), nullable=True), + ) + + # Step 2: Copy data from JSONB to JSON columns + # For app_variants table + op.execute( + """ + UPDATE app_variants + SET config_parameters_json_temp = config_parameters::json + """ + ) + + # For app_variant_revisions table + op.execute( + """ + UPDATE app_variant_revisions + SET config_parameters_json_temp = config_parameters::json + """ + ) + + # Step 3: Drop the old JSONB columns + op.drop_column("app_variants", "config_parameters") + op.drop_column("app_variant_revisions", "config_parameters") + + # Step 4: Rename the new JSON columns to the original names + op.alter_column( + "app_variants", + "config_parameters_json_temp", + new_column_name="config_parameters", + nullable=False, + server_default="{}", + ) + + op.alter_column( + "app_variant_revisions", + "config_parameters_json_temp", + new_column_name="config_parameters", + nullable=False, + ) + + +def downgrade(): + """ + Migrate config_parameters from JSON back to JSONB type. + """ + + # Step 1: Add new JSONB columns with temporary names + op.add_column( + "app_variants", + sa.Column("config_parameters_jsonb_temp", postgresql.JSONB(), nullable=True), + ) + + op.add_column( + "app_variant_revisions", + sa.Column("config_parameters_jsonb_temp", postgresql.JSONB(), nullable=True), + ) + + # Step 2: Copy data from JSON to JSONB columns + # For app_variants table + op.execute( + """ + UPDATE app_variants + SET config_parameters_jsonb_temp = config_parameters::jsonb + """ + ) + + # For app_variant_revisions table + op.execute( + """ + UPDATE app_variant_revisions + SET config_parameters_jsonb_temp = config_parameters::jsonb + """ + ) + + # Step 3: Drop the old JSON columns + op.drop_column("app_variants", "config_parameters") + op.drop_column("app_variant_revisions", "config_parameters") + + # Step 4: Rename the new JSONB columns to the original names + op.alter_column( + "app_variants", + "config_parameters_jsonb_temp", + new_column_name="config_parameters", + nullable=False, + ) + + op.alter_column( + "app_variant_revisions", + "config_parameters_jsonb_temp", + new_column_name="config_parameters", + nullable=False, + ) diff --git a/api/ee/databases/postgres/migrations/core/versions/ad0987a77380_update_evaluators_names_with_app_name_.py b/api/ee/databases/postgres/migrations/core/versions/ad0987a77380_update_evaluators_names_with_app_name_.py new file mode 100644 index 0000000000..42f949782b --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/ad0987a77380_update_evaluators_names_with_app_name_.py @@ -0,0 +1,35 @@ +"""Update evaluators names with app name as prefix + +Revision ID: ad0987a77380 +Revises: 770d68410ab0 +Create Date: 2024-09-17 06:32:38.238473 + +""" + +from typing import Sequence, Union + +from alembic import context + +from ee.databases.postgres.migrations.core.data_migrations.applications import ( + update_evaluators_with_app_name, +) + + +# revision identifiers, used by Alembic. +revision: str = "ad0987a77380" +down_revision: Union[str, None] = "770d68410ab0" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### custom command ### + connection = context.get_bind() # get database connect from alembic context + update_evaluators_with_app_name(session=connection) + # ### end custom command ### + + +def downgrade() -> None: + # ### custom command ### + pass + # ### end custom command ### diff --git a/api/ee/databases/postgres/migrations/core/versions/b3f15a7140ab_add_version_to_eval_entities.py b/api/ee/databases/postgres/migrations/core/versions/b3f15a7140ab_add_version_to_eval_entities.py new file mode 100644 index 0000000000..f6a9d6a9af --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/b3f15a7140ab_add_version_to_eval_entities.py @@ -0,0 +1,107 @@ +"""Add version to evaluation entities + +Revision ID: b3f15a7140ab +Revises: 5a71b3f140ab +Create Date: 2025-10-03 14:30:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +revision: str = "b3f15a7140ab" +down_revision: Union[str, None] = "5a71b3f140ab" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # BASED ON + # version = Column( + # String, + # nullable=True, + # ) + + # EVALUATION RUNS ---------------------------------------------------------- + + op.add_column( + "evaluation_runs", + sa.Column( + "version", + sa.String(), + nullable=True, + ), + ) + + # EVALUATION SCENARIOS ----------------------------------------------------- + + op.add_column( + "evaluation_scenarios", + sa.Column( + "version", + sa.String(), + nullable=True, + ), + ) + + # EVALUATION RESULTS ------------------------------------------------------- + + op.add_column( + "evaluation_results", + sa.Column( + "version", + sa.String(), + nullable=True, + ), + ) + + # EVALUATION METRICS ------------------------------------------------------- + + op.add_column( + "evaluation_metrics", + sa.Column( + "version", + sa.String(), + nullable=True, + ), + ) + + # EVALUATION QUEUES -------------------------------------------------------- + + op.add_column( + "evaluation_queues", + sa.Column( + "version", + sa.String(), + nullable=True, + ), + ) + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # EVALUATION QUEUES -------------------------------------------------------- + + op.drop_column("evaluation_queues", "version") + + # EVALUATION METRICS ------------------------------------------------------- + + op.drop_column("evaluation_metrics", "version") + + # EVALUATION RESULTS ------------------------------------------------------- + + op.drop_column("evaluation_results", "version") + + # EVALUATION SCENARIOS ----------------------------------------------------- + + op.drop_column("evaluation_scenarios", "version") + + # EVALUATION RUNS ---------------------------------------------------------- + + op.drop_column("evaluation_runs", "version") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/core/versions/b3f6bff547d4_remove_app_id_from_evaluators_configs.py b/api/ee/databases/postgres/migrations/core/versions/b3f6bff547d4_remove_app_id_from_evaluators_configs.py new file mode 100644 index 0000000000..647857d32d --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/b3f6bff547d4_remove_app_id_from_evaluators_configs.py @@ -0,0 +1,38 @@ +"""repair remaining malformed evaluation/evaluator data + +Revision ID: b3f6bff547d4 +Revises: 4d9a58ff8f98 +Create Date: 2024-10-10 21:56:26.901827 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "b3f6bff547d4" +down_revision: Union[str, None] = "4d9a58ff8f98" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + connection = op.get_bind() + inspector = sa.inspect(connection) + if "evaluators_configs" not in inspector.get_table_names(): + # Check if app_id exists in the evaluators_configs table + columns = [ + column["name"] for column in inspector.get_columns("evaluators_configs") + ] + if "app_id" in columns: + op.drop_column("evaluators_configs", "app_id") + + +def downgrade() -> None: + op.add_column( + "evaluators_configs", + sa.Column("app_id", sa.UUID(), autoincrement=False, nullable=True), + ) diff --git a/api/ee/databases/postgres/migrations/core/versions/d0b8e05ca190_scope_project_id_to_db_models_entities.py b/api/ee/databases/postgres/migrations/core/versions/d0b8e05ca190_scope_project_id_to_db_models_entities.py new file mode 100644 index 0000000000..c204c1dd65 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/d0b8e05ca190_scope_project_id_to_db_models_entities.py @@ -0,0 +1,348 @@ +"""scope project_id to db models/entities + +Revision ID: d0b8e05ca190 +Revises: 154098b1e56c +Create Date: 2024-09-17 07:11:16.704972 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from oss.src.utils.env import env +from ee.databases.postgres.migrations.core import utils + + +# revision identifiers, used by Alembic. +revision: str = "d0b8e05ca190" +down_revision: Union[str, None] = "154098b1e56c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + engine = sa.create_engine(env.POSTGRES_URI_CORE) + op.add_column("app_db", sa.Column("project_id", sa.UUID(), nullable=True)) + op.drop_constraint("app_db_user_id_fkey", "app_db", type_="foreignkey") + op.create_foreign_key( + "app_db_projects_fkey", + "app_db", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("app_db", "user_id") + op.add_column( + "app_variant_revisions", sa.Column("project_id", sa.UUID(), nullable=True) + ) + op.create_foreign_key( + "app_variant_revisions_projects_fkey", + "app_variant_revisions", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.add_column("app_variants", sa.Column("project_id", sa.UUID(), nullable=True)) + op.drop_constraint("app_variants_user_id_fkey", "app_variants", type_="foreignkey") + op.create_foreign_key( + "app_variants_projects_fkey", + "app_variants", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("app_variants", "user_id") + op.add_column("bases", sa.Column("project_id", sa.UUID(), nullable=True)) + op.drop_constraint("bases_app_id_fkey", "bases", type_="foreignkey") + op.drop_constraint("bases_user_id_fkey", "bases", type_="foreignkey") + op.create_foreign_key( + "bases_projects_fkey", + "bases", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("bases", "user_id") + op.add_column("deployments", sa.Column("project_id", sa.UUID(), nullable=True)) + op.drop_constraint("deployments_user_id_fkey", "deployments", type_="foreignkey") + op.create_foreign_key( + "deployments_projects_fkey", + "deployments", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("deployments", "user_id") + op.add_column("docker_images", sa.Column("project_id", sa.UUID(), nullable=True)) + op.drop_constraint( + "docker_images_user_id_fkey", "docker_images", type_="foreignkey" + ) + op.create_foreign_key( + "docker_images_projects_fkey", + "docker_images", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("docker_images", "user_id") + op.add_column("environments", sa.Column("project_id", sa.UUID(), nullable=True)) + op.drop_constraint("environments_user_id_fkey", "environments", type_="foreignkey") + op.create_foreign_key( + "environments_projects_fkey", + "environments", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("environments", "user_id") + op.add_column( + "environments_revisions", sa.Column("project_id", sa.UUID(), nullable=True) + ) + op.create_foreign_key( + "environments_revisions_projects_fkey", + "environments_revisions", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.add_column( + "evaluation_scenarios", sa.Column("project_id", sa.UUID(), nullable=True) + ) + op.drop_constraint( + "evaluation_scenarios_user_id_fkey", "evaluation_scenarios", type_="foreignkey" + ) + op.create_foreign_key( + "evaluation_scenarios_projects_fkey", + "evaluation_scenarios", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("evaluation_scenarios", "user_id") + op.add_column("evaluations", sa.Column("project_id", sa.UUID(), nullable=True)) + op.drop_constraint("evaluations_user_id_fkey", "evaluations", type_="foreignkey") + op.create_foreign_key( + "evaluations_projects_fkey", + "evaluations", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("evaluations", "user_id") + op.add_column( + "evaluators_configs", sa.Column("project_id", sa.UUID(), nullable=True) + ) + op.drop_constraint( + "evaluators_configs_user_id_fkey", "evaluators_configs", type_="foreignkey" + ) + op.drop_constraint( + "evaluators_configs_app_id_fkey", "evaluators_configs", type_="foreignkey" + ) + op.create_foreign_key( + "evaluators_configs_projects_fkey", + "evaluators_configs", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("evaluators_configs", "user_id") + op.add_column( + "human_evaluations", sa.Column("project_id", sa.UUID(), nullable=True) + ) + op.drop_constraint( + "human_evaluations_user_id_fkey", "human_evaluations", type_="foreignkey" + ) + op.create_foreign_key( + "human_evaluations_projects_fkey", + "human_evaluations", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("human_evaluations", "user_id") + op.add_column( + "human_evaluations_scenarios", sa.Column("project_id", sa.UUID(), nullable=True) + ) + op.drop_constraint( + "human_evaluations_scenarios_user_id_fkey", + "human_evaluations_scenarios", + type_="foreignkey", + ) + op.create_foreign_key( + "human_evaluations_scenarios_projects_fkey", + "human_evaluations_scenarios", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("human_evaluations_scenarios", "user_id") + op.alter_column("projects", "is_default", existing_type=sa.BOOLEAN(), nullable=True) + op.add_column("testsets", sa.Column("project_id", sa.UUID(), nullable=True)) + if not utils.unique_constraint_exists(engine, "testsets", "testsets_user_id_fkey"): + op.drop_constraint("testsets_user_id_fkey", "testsets", type_="foreignkey") + op.drop_constraint("testsets_app_id_fkey", "testsets", type_="foreignkey") + + op.create_foreign_key( + "testsets_projects_fkey", + "testsets", + "projects", + ["project_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_column("testsets", "app_id") + op.drop_column("testsets", "user_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "testsets", sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True) + ) + op.add_column( + "testsets", sa.Column("app_id", sa.UUID(), autoincrement=False, nullable=True) + ) + op.create_foreign_key( + "testsets_app_id_fkey", + "testsets", + "app_db", + ["app_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + "testsets_user_id_fkey", "testsets", "users", ["user_id"], ["id"] + ) + op.drop_column("testsets", "project_id") + op.alter_column( + "projects", "is_default", existing_type=sa.BOOLEAN(), nullable=False + ) + op.add_column( + "human_evaluations_scenarios", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "human_evaluations_scenarios_user_id_fkey", + "human_evaluations_scenarios", + "users", + ["user_id"], + ["id"], + ) + op.drop_column("human_evaluations_scenarios", "project_id") + op.add_column( + "human_evaluations", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "human_evaluations_user_id_fkey", + "human_evaluations", + "users", + ["user_id"], + ["id"], + ) + op.drop_column("human_evaluations", "project_id") + op.add_column( + "evaluators_configs", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "evaluators_configs_app_id_fkey", + "evaluators_configs", + "app_db", + ["app_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_foreign_key( + "evaluators_configs_user_id_fkey", + "evaluators_configs", + "users", + ["user_id"], + ["id"], + ) + op.drop_column("evaluators_configs", "project_id") + op.add_column( + "evaluations", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "evaluations_user_id_fkey", "evaluations", "users", ["user_id"], ["id"] + ) + op.drop_column("evaluations", "project_id") + op.add_column( + "evaluation_scenarios", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "evaluation_scenarios_user_id_fkey", + "evaluation_scenarios", + "users", + ["user_id"], + ["id"], + ) + op.drop_column("evaluation_scenarios", "project_id") + op.drop_column("environments_revisions", "project_id") + op.add_column( + "environments", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "environments_user_id_fkey", "environments", "users", ["user_id"], ["id"] + ) + op.drop_column("environments", "project_id") + op.add_column( + "docker_images", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "docker_images_user_id_fkey", "docker_images", "users", ["user_id"], ["id"] + ) + op.drop_column("docker_images", "project_id") + op.add_column( + "deployments", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "deployments_user_id_fkey", "deployments", "users", ["user_id"], ["id"] + ) + op.drop_column("deployments", "project_id") + op.add_column( + "bases", sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True) + ) + op.create_foreign_key("bases_user_id_fkey", "bases", "users", ["user_id"], ["id"]) + op.create_foreign_key( + "bases_app_id_fkey", "bases", "app_db", ["app_id"], ["id"], ondelete="CASCADE" + ) + op.drop_column("bases", "project_id") + op.add_column( + "app_variants", + sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), + ) + op.create_foreign_key( + "app_variants_user_id_fkey", "app_variants", "users", ["user_id"], ["id"] + ) + op.drop_column("app_variants", "project_id") + op.drop_column("app_variant_revisions", "project_id") + op.add_column( + "app_db", sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True) + ) + op.create_foreign_key("app_db_user_id_fkey", "app_db", "users", ["user_id"], ["id"]) + op.drop_column("app_db", "project_id") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/d5d4d6bf738f_add_evaluation_queues.py b/api/ee/databases/postgres/migrations/core/versions/d5d4d6bf738f_add_evaluation_queues.py new file mode 100644 index 0000000000..6d39d973aa --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/d5d4d6bf738f_add_evaluation_queues.py @@ -0,0 +1,116 @@ +"""add evaluation queues + +Revision ID: d5d4d6bf738f +Revises: fd77265d65dc +Create Date: 2025-07-10 17:04:00.000000 +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "d5d4d6bf738f" +down_revision: Union[str, None] = "fd77265d65dc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "evaluation_queues", + sa.Column( + "project_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + sa.Column( + "run_id", + sa.UUID(), + nullable=False, + ), + sa.PrimaryKeyConstraint( + "project_id", + "id", + ), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["project_id", "run_id"], + ["evaluation_runs.project_id", "evaluation_runs.id"], + ondelete="CASCADE", + ), + sa.Index( + "ix_evaluation_queues_project_id", + "project_id", + ), + sa.Index( + "ix_evaluation_queues_run_id", + "run_id", + ), + ) + + +def downgrade() -> None: + op.drop_table("evaluation_queues") diff --git a/api/ee/databases/postgres/migrations/core/versions/e14e8689cd03_created_project_members_table_and_added_.py b/api/ee/databases/postgres/migrations/core/versions/e14e8689cd03_created_project_members_table_and_added_.py new file mode 100644 index 0000000000..a1eebc1154 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/e14e8689cd03_created_project_members_table_and_added_.py @@ -0,0 +1,68 @@ +"""created project_members table and added organization&workspace id to projects table + +Revision ID: e14e8689cd03 +Revises: e9fa2135f3fb +Create Date: 2024-09-02 15:50:58.870573 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "e14e8689cd03" +down_revision: Union[str, None] = "e9fa2135f3fb" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "projects", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("project_name", sa.String(), nullable=False), + sa.Column("is_default", sa.Boolean(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("workspace_id", sa.UUID(), nullable=True), + sa.Column("organization_id", sa.UUID(), nullable=True), + sa.ForeignKeyConstraint( + ["organization_id"], ["organizations.id"], ondelete="SET NULL" + ), + sa.ForeignKeyConstraint( + ["workspace_id"], ["workspaces.id"], ondelete="SET NULL" + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + op.create_table( + "project_members", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("project_id", sa.UUID(), nullable=True), + sa.Column("role", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("project_members") + op.drop_table("projects") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/e9fa2135f3fb_add_modified_by_id_column_to_apps_db_.py b/api/ee/databases/postgres/migrations/core/versions/e9fa2135f3fb_add_modified_by_id_column_to_apps_db_.py new file mode 100644 index 0000000000..cf9c02f606 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/e9fa2135f3fb_add_modified_by_id_column_to_apps_db_.py @@ -0,0 +1,31 @@ +"""add modified_by_id column to apps_db table + +Revision ID: e9fa2135f3fb +Revises: 8accbbea1d21 +Create Date: 2024-09-03 20:51:51.856509 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "e9fa2135f3fb" +down_revision: Union[str, None] = "8accbbea1d21" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("app_db", sa.Column("modified_by_id", sa.UUID(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("app_db", "modified_by_id") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/core/versions/fa07e07350bf_add_timestamp_to_metrics.py b/api/ee/databases/postgres/migrations/core/versions/fa07e07350bf_add_timestamp_to_metrics.py new file mode 100644 index 0000000000..c6d85c7467 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/fa07e07350bf_add_timestamp_to_metrics.py @@ -0,0 +1,34 @@ +"""add timestamp to metrics + +Revision ID: fa07e07350bf +Revises: 30dcf07de96a +Create Date: 2025-07-30 14:55:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "fa07e07350bf" +down_revision: Union[str, None] = "30dcf07de96a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "evaluation_metrics", + sa.Column("timestamp", sa.TIMESTAMP(timezone=True), nullable=True), + ) + op.add_column( + "evaluation_metrics", + sa.Column("interval", sa.INTEGER(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("evaluation_metrics", "interval") + op.drop_column("evaluation_metrics", "timestamp") diff --git a/api/ee/databases/postgres/migrations/core/versions/fd77265d65dc_fix_preview_entities.py b/api/ee/databases/postgres/migrations/core/versions/fd77265d65dc_fix_preview_entities.py new file mode 100644 index 0000000000..0e4666cc84 --- /dev/null +++ b/api/ee/databases/postgres/migrations/core/versions/fd77265d65dc_fix_preview_entities.py @@ -0,0 +1,232 @@ +"""fix previw entities + +Revision ID: fd77265d65dc +Revises: 54e81e9eed88 +Create Date: 2025-05-29 16:30:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "fd77265d65dc" +down_revision: Union[str, None] = "54e81e9eed88" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # - WORKFLOWS -------------------------------------------------------------- + + op.add_column( + "workflow_artifacts", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "workflow_variants", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "workflow_revisions", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + + # - TESTSETS --------------------------------------------------------------- + + op.add_column( + "testset_artifacts", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "testset_variants", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "testset_revisions", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + + # - TESTCASES -------------------------------------------------------------- + + op.add_column( + "testcase_blobs", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "testcase_blobs", + sa.Column( + "flags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "testcase_blobs", + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.drop_column("testcase_blobs", "slug") + op.add_column( + "testcase_blobs", + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + ) + op.add_column( + "testcase_blobs", + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + ) + op.add_column( + "testcase_blobs", + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + ) + op.add_column( + "testcase_blobs", + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + ) + op.add_column( + "testcase_blobs", + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + ) + op.add_column( + "testcase_blobs", + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + ) + + # - EVALUATIONS ------------------------------------------------------------ + + op.add_column( + "evaluation_runs", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "evaluation_scenarios", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "evaluation_steps", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + op.add_column( + "evaluation_metrics", + sa.Column( + "tags", + postgresql.JSONB(none_as_null=True), + nullable=True, + ), + ) + + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # - WORKFLOWS -------------------------------------------------------------- + + op.drop_column("workflow_artifacts", "tags") + op.drop_column("workflow_variants", "tags") + op.drop_column("workflow_revisions", "tags") + + # - TESTSETS --------------------------------------------------------------- + + op.drop_column("testset_artifacts", "tags") + op.drop_column("testset_variants", "tags") + op.drop_column("testset_revisions", "tags") + + # - TESTCASES -------------------------------------------------------------- + + op.drop_column("testcase_blobs", "flags") + op.drop_column("testcase_blobs", "tags") + op.drop_column("testcase_blobs", "meta") + op.add_column( + "testcase_blobs", + sa.Column( + "slug", + sa.String(), + nullable=True, + ), + ) + op.drop_column("testcase_blobs", "created_at") + op.drop_column("testcase_blobs", "updated_at") + op.drop_column("testcase_blobs", "deleted_at") + op.drop_column("testcase_blobs", "created_by_id") + op.drop_column("testcase_blobs", "updated_by_id") + op.drop_column("testcase_blobs", "deleted_by_id") + + # - EVALUATIONS ------------------------------------------------------------ + + op.drop_column("evaluation_runs", "tags") + op.drop_column("evaluation_scenarios", "tags") + op.drop_column("evaluation_steps", "tags") + op.drop_column("evaluation_metrics", "tags") + + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/find_head.py b/api/ee/databases/postgres/migrations/find_head.py new file mode 100644 index 0000000000..c435c485a9 --- /dev/null +++ b/api/ee/databases/postgres/migrations/find_head.py @@ -0,0 +1,48 @@ +import os +import re +from typing import Union, Dict, Set +import sys + +database = sys.argv[1] + +MIGRATIONS_DIR = f"./{database}/versions/" + +revision_pattern = re.compile(r'revision\s*:\s*str\s*=\s*"([a-f0-9]+)"') +down_revision_pattern = re.compile( + r'down_revision\s*:\s*Union\[str,\s*None\]\s*=\s*(?:"([^"]+)"|None)' +) + +revisions: Dict[str, Union[str, None]] = {} +all_down_revisions: Set[str] = set() + +for filename in os.listdir(MIGRATIONS_DIR): + if not filename.endswith(".py"): + continue + + print("---------") + print("file:", filename) + + with open(os.path.join(MIGRATIONS_DIR, filename), encoding="utf-8") as f: + content = f.read() + revision_match = revision_pattern.search(content) + down_revision_match = down_revision_pattern.search(content) + + print("revision:", revision_match) + print("down_revision:", down_revision_match) + if revision_match: + revision = revision_match.group(1) + down_revision = ( + down_revision_match.group(1) if down_revision_match else None + ) + if down_revision in ("None", ""): + down_revision = None + revisions[revision] = down_revision + if down_revision: + all_down_revisions.add(down_revision) + +# head(s) = revisions that are not anyone's down_revision +heads = [rev for rev in revisions if rev not in all_down_revisions] + +print("---------") +print() +print("Heads:", heads) diff --git a/api/ee/databases/postgres/migrations/runner.py b/api/ee/databases/postgres/migrations/runner.py new file mode 100644 index 0000000000..14baed1924 --- /dev/null +++ b/api/ee/databases/postgres/migrations/runner.py @@ -0,0 +1,21 @@ +import asyncio + +from ee.databases.postgres.migrations.utils import ( + split_core_and_tracing, + copy_nodes_from_core_to_tracing, +) +from ee.databases.postgres.migrations.core.utils import ( + run_alembic_migration as migrate_core, +) +from ee.databases.postgres.migrations.tracing.utils import ( + run_alembic_migration as migrate_tracing, +) + + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + + loop.run_until_complete(split_core_and_tracing()) + migrate_core() + migrate_tracing() + loop.run_until_complete(copy_nodes_from_core_to_tracing()) diff --git a/api/ee/databases/postgres/migrations/tracing/README copy.md b/api/ee/databases/postgres/migrations/tracing/README copy.md new file mode 100644 index 0000000000..8d8552e3c3 --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/README copy.md @@ -0,0 +1,35 @@ +# Migrations with Alembic + +Generic single-database configuration with an async dbapi. + +## Autogenerate Migrations + +One of Alembic's key features is its ability to auto-generate migration scripts. By analyzing the current database state and comparing it with the application's table metadata, Alembic can automatically generate the necessary migration scripts using the `--autogenerate` flag in the alembic revision command. + +Note that autogenerate sometimes does not detect all database changes and it is always necessary to manually review (and correct if needed) the candidate migrations that autogenerate produces. + +### Making migrations + +To make migrations after creating a new table schema or modifying a current column in a table, run the following commands: + +```bash +docker exec -e PYTHONPATH=/app -w /app/ee/databases/postgres/migrations/core agenta-ee-dev-api-1 alembic -c alembic.ini revision --autogenerate -m "migration message" +``` + +The above command will create a script that contains the changes that was made to the database schema. Kindly update "migration message" with a message that is clear to indicate what change was made. Here are some examples: + +- added username column in users table +- renamed template_uri to template_repository_uri +- etc + +### Applying Migrations + +```bash +docker exec -e PYTHONPATH=/app -w /app/ee/databases/postgres/migrations/core agenta-ee-dev-api-1 alembic -c alembic.ini upgrade head +``` + +The above command will be used to apply the changes in the script created to the database table(s). If you'd like to revert the migration, run the following command: + +```bash +docker exec -e PYTHONPATH=/app -w /app/ee/databases/postgres/migrations/core agenta-ee-dev-api-1 alembic -c alembic.ini downgrade head +``` diff --git a/api/ee/databases/postgres/migrations/tracing/__init__.py b/api/ee/databases/postgres/migrations/tracing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/databases/postgres/migrations/tracing/alembic.ini b/api/ee/databases/postgres/migrations/tracing/alembic.ini new file mode 100644 index 0000000000..046889088d --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/alembic.ini @@ -0,0 +1,114 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = /app/ee/databases/postgres/migrations/tracing + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to migrations/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S \ No newline at end of file diff --git a/api/ee/databases/postgres/migrations/tracing/env.py b/api/ee/databases/postgres/migrations/tracing/env.py new file mode 100644 index 0000000000..9376d4486d --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/env.py @@ -0,0 +1,100 @@ +import os +import asyncio +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.engine import Connection, create_engine +from sqlalchemy.ext.asyncio import async_engine_from_config, create_async_engine + +from alembic import context + +from oss.src.dbs.postgres.shared.engine import engine + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config +config.set_main_option("sqlalchemy.url", engine.postgres_uri_tracing) # type: ignore + + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +from oss.src.dbs.postgres.shared.base import Base + +import oss.src.dbs.postgres.tracing.dbes + +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + transaction_per_migration=True, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure( + transaction_per_migration=True, + connection=connection, + target_metadata=target_metadata, + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine + and associate a connection with the context. + """ + + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/api/ee/databases/postgres/migrations/tracing/script.py.mako b/api/ee/databases/postgres/migrations/tracing/script.py.mako new file mode 100644 index 0000000000..fbc4b07dce --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/api/ee/databases/postgres/migrations/tracing/utils.py b/api/ee/databases/postgres/migrations/tracing/utils.py new file mode 100644 index 0000000000..15f3e66b5f --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/utils.py @@ -0,0 +1,188 @@ +import os +import asyncio +import logging +import traceback + +import click +import asyncpg +from alembic import command +from sqlalchemy import Engine +from alembic.config import Config +from sqlalchemy import inspect, text +from alembic.script import ScriptDirectory +from sqlalchemy.exc import ProgrammingError +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine + +from oss.src.utils.env import env + + +# Initializer logger +logger = logging.getLogger("alembic.env") + +# Initialize alembic config +alembic_cfg = Config(env.ALEMBIC_CFG_PATH_TRACING) +script = ScriptDirectory.from_config(alembic_cfg) + +logger.info("license: ee") +logger.info("migrations: tracing") +logger.info("ALEMBIC_CFG_PATH_TRACING: %s", env.ALEMBIC_CFG_PATH_TRACING) +logger.info("alembic_cfg: %s", alembic_cfg) +logger.info("script: %s", script) + + +def is_initial_setup(engine) -> bool: + """ + Check if the database is in its initial state by verifying the existence of required tables. + + This function inspects the current state of the database and determines if it needs initial setup by checking for the presence of a predefined set of required tables. + + Args: + engine (sqlalchemy.engine.base.Engine): The SQLAlchemy engine used to connect to the database. + + Returns: + bool: True if the database is in its initial state (i.e., not all required tables exist), False otherwise. + """ + + inspector = inspect(engine) + required_tables = ["spans"] + existing_tables = inspector.get_table_names() + + # Check if all required tables exist in the database + all_tables_exist = all(table in existing_tables for table in required_tables) + + return not all_tables_exist + + +async def get_current_migration_head_from_db(engine: AsyncEngine): + """ + Checks the alembic_version table to get the current migration head that has been applied. + + Args: + engine (Engine): The engine that connects to an sqlalchemy pool + + Returns: + the current migration head (where 'head' is the revision stored in the migration script) + """ + + async with engine.connect() as connection: + try: + result = await connection.execute(text("SELECT version_num FROM alembic_version")) # type: ignore + except (asyncpg.exceptions.UndefinedTableError, ProgrammingError): + # Note: If the alembic_version table does not exist, it will result in raising an UndefinedTableError exception. + # We need to suppress the error and return a list with the alembic_version table name to inform the user that there is a pending migration \ + # to make Alembic start tracking the migration changes. + # -------------------------------------------------------------------------------------- + # This effect (the exception raising) happens for both users (first-time and returning) + return "alembic_version" + + migration_heads = [row[0] for row in result.fetchall()] + assert ( + len(migration_heads) == 1 + ), "There can only be one migration head stored in the database." + return migration_heads[0] + + +async def get_pending_migration_head(): + """ + Gets the migration head that have not been applied. + + Returns: + the pending migration head + """ + + engine = create_async_engine(url=env.POSTGRES_URI_TRACING) + try: + current_migration_script_head = script.get_current_head() + migration_head_from_db = await get_current_migration_head_from_db(engine=engine) + + pending_migration_head = [] + if current_migration_script_head != migration_head_from_db: + pending_migration_head.append(current_migration_script_head) + if "alembic_version" == migration_head_from_db: + pending_migration_head.append("alembic_version") + finally: + await engine.dispose() + + return pending_migration_head + + +def run_alembic_migration(): + """ + Applies migration for first-time users and also checks the environment variable "AGENTA_AUTO_MIGRATIONS" to determine whether to apply migrations for returning users. + """ + + try: + pending_migration_head = asyncio.run(get_pending_migration_head()) + FIRST_TIME_USER = True if "alembic_version" in pending_migration_head else False + + if FIRST_TIME_USER or env.AGENTA_AUTO_MIGRATIONS: + command.upgrade(alembic_cfg, "head") + click.echo( + click.style( + "\nMigration applied successfully. The container will now exit.", + fg="green", + ), + color=True, + ) + else: + click.echo( + click.style( + "\nAll migrations are up-to-date. The container will now exit.", + fg="yellow", + ), + color=True, + ) + except Exception as e: + click.echo( + click.style( + f"\nAn ERROR occurred while applying migration: {traceback.format_exc()}\nThe container will now exit.", + fg="red", + ), + color=True, + ) + raise e + + +async def check_for_new_migrations(): + """ + Checks for new migrations and notify the user. + """ + + pending_migration_head = await get_pending_migration_head() + if len(pending_migration_head) >= 1 and isinstance(pending_migration_head[0], str): + click.echo( + click.style( + f"\nWe have detected that there are pending database migrations {pending_migration_head} that need to be applied to keep the application up to date. To ensure the application functions correctly with the latest updates, please follow the guide here => https://docs.agenta.ai/self-host/migration/applying-schema-migration\n", + fg="yellow", + ), + color=True, + ) + return + + +def unique_constraint_exists( + engine: Engine, table_name: str, constraint_name: str +) -> bool: + """ + The function checks if a unique constraint with a specific name exists on a table in a PostgreSQL + database. + + Args: + - engine (Engine): instance of a database engine that represents a connection to a database. + - table_name (str): name of the table to check the existence of the unique constraint. + - constraint_name (str): name of the unique constraint to check for existence. + + Returns: + - returns a boolean value indicating whether a unique constraint with the specified `constraint_name` exists in the table. + """ + + with engine.connect() as conn: + result = conn.execute( + text( + f""" + SELECT conname FROM pg_constraint + WHERE conname = '{constraint_name}' AND conrelid = '{table_name}'::regclass; + """ + ) + ) + return result.fetchone() is not None diff --git a/api/ee/databases/postgres/migrations/tracing/versions/58b1b61e5d6c_add_spans.py b/api/ee/databases/postgres/migrations/tracing/versions/58b1b61e5d6c_add_spans.py new file mode 100644 index 0000000000..d0b32e0008 --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/versions/58b1b61e5d6c_add_spans.py @@ -0,0 +1,202 @@ +"""Add Spans v2 + +Revision ID: 58b1b61e5d6c +Revises: +Create Date: 2025-03-28 12:22:05.104488 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "58b1b61e5d6c" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "spans", + sa.Column( + "project_id", + sa.UUID(), + # sa.ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.TIMESTAMP(timezone=True), + server_onupdate=sa.text("CURRENT_TIMESTAMP"), + nullable=True, + ), + sa.Column( + "deleted_at", + sa.TIMESTAMP(timezone=True), + nullable=True, + ), + sa.Column( + "created_by_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "updated_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "deleted_by_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "trace_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "span_id", + sa.UUID(), + nullable=False, + ), + sa.Column( + "parent_id", + sa.UUID(), + nullable=True, + ), + sa.Column( + "span_kind", + sa.Enum( + "SPAN_KIND_UNSPECIFIED", + "SPAN_KIND_INTERNAL", + "SPAN_KIND_SERVER", + "SPAN_KIND_CLIENT", + "SPAN_KIND_PRODUCER", + "SPAN_KIND_CONSUMER", + name="otelspankind", + ), + nullable=False, + ), + sa.Column( + "span_name", + sa.VARCHAR(), + nullable=False, + ), + sa.Column( + "start_time", + sa.TIMESTAMP(timezone=True), + nullable=False, + ), + sa.Column( + "end_time", + sa.TIMESTAMP(timezone=True), + nullable=False, + ), + sa.Column( + "status_code", + sa.Enum( + "STATUS_CODE_UNSET", + "STATUS_CODE_OK", + "STATUS_CODE_ERROR", + name="otelstatuscode", + ), + nullable=False, + ), + sa.Column( + "status_message", + sa.VARCHAR(), + nullable=True, + ), + sa.Column( + "attributes", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "events", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "links", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "references", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + # sa.Column( + # "content", + # sa.VARCHAR(), + # nullable=True, + # ), + sa.PrimaryKeyConstraint( + "project_id", + "trace_id", + "span_id", + ), + sa.Index( + "ix_project_id_trace_id", + "project_id", + "trace_id", + ), + sa.Index( + "ix_project_id_span_id", + "project_id", + "span_id", + ), + sa.Index( + "ix_project_id_start_time", + "project_id", + "start_time", + ), + sa.Index( + "ix_project_id", + "project_id", + ), + sa.Index( + "ix_attributes_gin", + "attributes", + postgresql_using="gin", + ), + sa.Index( + "ix_events_gin", + "events", + postgresql_using="gin", + ), + sa.Index( + "ix_links_gin", + "links", + postgresql_using="gin", + ), + sa.Index( + "ix_references_gin", + "references", + postgresql_using="gin", + ), + ) + + +def downgrade() -> None: + op.drop_index("ix_references_gin", table_name="spans") + op.drop_index("ix_links_gin", table_name="spans") + op.drop_index("ix_events_gin", table_name="spans") + op.drop_index("ix_attributes_gin", table_name="spans") + op.drop_index("ix_project_id", table_name="spans") + op.drop_index("ix_project_id_start_time", table_name="spans") + op.drop_index("ix_project_id_span_id", table_name="spans") + op.drop_index("ix_project_id_trace_id", table_name="spans") + op.drop_table("spans") diff --git a/api/ee/databases/postgres/migrations/tracing/versions/847972cfa14a_add_nodes.py b/api/ee/databases/postgres/migrations/tracing/versions/847972cfa14a_add_nodes.py new file mode 100644 index 0000000000..4b6903973b --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/versions/847972cfa14a_add_nodes.py @@ -0,0 +1,121 @@ +"""add_nodes_dbe + +Revision ID: 847972cfa14a +Revises: 58b1b61e5d6c +Create Date: 2024-11-07 12:21:19.080345 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "847972cfa14a" +down_revision: Union[str, None] = "58b1b61e5d6c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "nodes", + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("CURRENT_TIMESTAMP"), + nullable=False, + ), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("updated_by_id", sa.UUID(), nullable=True), + sa.Column("root_id", sa.UUID(), nullable=False), + sa.Column("tree_id", sa.UUID(), nullable=False), + sa.Column("tree_type", sa.Enum("INVOCATION", name="treetype"), nullable=True), + sa.Column("node_id", sa.UUID(), nullable=False), + sa.Column("node_name", sa.String(), nullable=False), + sa.Column( + "node_type", + sa.Enum( + "AGENT", + "WORKFLOW", + "CHAIN", + "TASK", + "TOOL", + "EMBEDDING", + "QUERY", + "COMPLETION", + "CHAT", + "RERANK", + name="nodetype", + ), + nullable=True, + ), + sa.Column("parent_id", sa.UUID(), nullable=True), + sa.Column("time_start", sa.TIMESTAMP(), nullable=False), + sa.Column("time_end", sa.TIMESTAMP(), nullable=False), + sa.Column( + "status", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "data", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "metrics", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "meta", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "refs", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "exception", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column( + "links", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.Column("content", sa.String(), nullable=True), + sa.Column( + "otel", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + sa.PrimaryKeyConstraint("project_id", "node_id"), + ) + op.create_index( + "index_project_id_node_id", "nodes", ["project_id", "created_at"], unique=False + ) + op.create_index( + "index_project_id_root_id", "nodes", ["project_id", "root_id"], unique=False + ) + op.create_index( + "index_project_id_tree_id", "nodes", ["project_id", "tree_id"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("index_project_id_tree_id", table_name="nodes") + op.drop_index("index_project_id_root_id", table_name="nodes") + op.drop_index("index_project_id_node_id", table_name="nodes") + op.drop_table("nodes") + # ### end Alembic commands ### diff --git a/api/ee/databases/postgres/migrations/tracing/versions/fd77265d65dc_fix_spans.py b/api/ee/databases/postgres/migrations/tracing/versions/fd77265d65dc_fix_spans.py new file mode 100644 index 0000000000..6cb4e3f963 --- /dev/null +++ b/api/ee/databases/postgres/migrations/tracing/versions/fd77265d65dc_fix_spans.py @@ -0,0 +1,202 @@ +"""fix spans + +Revision ID: fd77265d65dc +Revises: 847972cfa14a +Create Date: 2025-05-29 16:30:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from oss.src.core.tracing.dtos import SpanType +from oss.src.core.tracing.dtos import TraceType + +# revision identifiers, used by Alembic. +revision: str = "fd77265d65dc" +down_revision: Union[str, None] = "847972cfa14a" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # - SPANS ------------------------------------------------------------------ + trace_type_enum = sa.Enum(TraceType, name="tracetype") + span_type_enum = sa.Enum(SpanType, name="spantype") + + trace_type_enum.create(op.get_bind(), checkfirst=True) + span_type_enum.create(op.get_bind(), checkfirst=True) + + op.add_column( + "spans", + sa.Column( + "trace_type", + trace_type_enum, + nullable=True, + ), + ) + op.add_column( + "spans", + sa.Column( + "span_type", + span_type_enum, + nullable=True, + ), + ) + op.add_column( + "spans", + sa.Column( + "hashes", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "spans", + sa.Column( + "exception", + postgresql.JSONB(none_as_null=True, astext_type=sa.Text()), + nullable=True, + ), + ) + op.create_index( + "ix_spans_project_id_trace_type", + "spans", + ["project_id", "trace_type"], + if_not_exists=True, + ) + op.create_index( + "ix_spans_project_id_span_type", + "spans", + ["project_id", "span_type"], + if_not_exists=True, + ) + op.create_index( + "ix_spans_project_id_trace_id_created_at", + "spans", + ["project_id", "trace_id", sa.text("created_at DESC")], + if_not_exists=True, + ) + op.create_index( + "ix_spans_project_id_trace_id_start_time", + "spans", + ["project_id", "trace_id", sa.text("start_time DESC")], + if_not_exists=True, + ) + op.create_index( + "ix_hashes_gin", + "spans", + ["hashes"], + postgresql_using="gin", + postgresql_ops={"hashes": "jsonb_path_ops"}, + if_not_exists=True, + ) + op.drop_index( + "ix_events_gin", + table_name="spans", + if_exists=True, + ) + op.create_index( + "ix_events_gin", + "spans", # replace with your table name + ["events"], + postgresql_using="gin", + postgresql_ops={"events": "jsonb_path_ops"}, + if_not_exists=True, + ) + op.create_index( + "ix_spans_fts_attributes_gin", + "spans", + [sa.text("to_tsvector('simple', attributes)")], + postgresql_using="gin", + if_not_exists=True, + ) + op.create_index( + "ix_spans_fts_events_gin", + "spans", + [sa.text("to_tsvector('simple', events)")], + postgresql_using="gin", + if_not_exists=True, + ) + # -------------------------------------------------------------------------- + + +def downgrade() -> None: + # - SPANS ------------------------------------------------------------------ + op.drop_index( + "ix_spans_fts_events_gin", + table_name="spans", + if_exists=True, + ) + op.drop_index( + "ix_spans_fts_attributes_gin", + table_name="spans", + if_exists=True, + ) + op.drop_index( + "ix_events_gin", + table_name="spans", + if_exists=True, + ) + op.create_index( + "ix_events_gin", + "spans", + ["events"], + postgresql_using="gin", + if_not_exists=True, + ) + op.drop_index( + "ix_hashes_gin", + table_name="spans", + if_exists=True, + ) + op.drop_index( + "ix_spans_project_id_trace_id_start_time", + table_name="spans", + if_exists=True, + ) + op.drop_index( + "ix_spans_project_id_trace_id_created_at", + table_name="spans", + if_exists=True, + ) + op.drop_index( + "ix_spans_project_id_span_type", + table_name="spans", + if_exists=True, + ) + op.drop_index( + "ix_spans_project_id_trace_type", + table_name="spans", + if_exists=True, + ) + op.drop_column( + "spans", + "exception", + if_exists=True, + ) + op.drop_column( + "spans", + "hashes", + if_exists=True, + ) + op.drop_column( + "spans", + "span_type", + if_exists=True, + ) + op.drop_column( + "spans", + "trace_type", + if_exists=True, + ) + + span_type_enum = sa.Enum(SpanType, name="spantype") + trace_type_enum = sa.Enum(TraceType, name="tracetype") + + span_type_enum.drop(op.get_bind(), checkfirst=True) + trace_type_enum.drop(op.get_bind(), checkfirst=True) + # -------------------------------------------------------------------------- diff --git a/api/ee/databases/postgres/migrations/utils.py b/api/ee/databases/postgres/migrations/utils.py new file mode 100644 index 0000000000..f3874da1c8 --- /dev/null +++ b/api/ee/databases/postgres/migrations/utils.py @@ -0,0 +1,313 @@ +import os +import subprocess +import tempfile + +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +from sqlalchemy.exc import ProgrammingError + +from oss.src.utils.env import env + + +# Config (can override via env) +POSTGRES_URI = ( + os.getenv("POSTGRES_URI") + or env.POSTGRES_URI_CORE + or env.POSTGRES_URI_TRACING + or "postgresql+asyncpg://username:password@localhost:5432/agenta_ee" +) +DB_PROTOCOL = POSTGRES_URI.split("://")[0] # .replace("+asyncpg", "") +DB_USER = POSTGRES_URI.split("://")[1].split(":")[0] +DB_PASS = POSTGRES_URI.split("://")[1].split(":")[1].split("@")[0] +DB_HOST = POSTGRES_URI.split("@")[1].split(":")[0] +DB_PORT = POSTGRES_URI.split(":")[-1].split("/")[0] +ADMIN_DB = "postgres" + +POSTGRES_URI_POSTGRES = ( + f"{DB_PROTOCOL}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{ADMIN_DB}" +) + +# Rename/create map: {'old_name': 'new_name'} +RENAME_MAP = { + "agenta_ee": "agenta_ee_core", + "supertokens_ee": "agenta_ee_supertokens", + "agenta_ee_tracing": "agenta_ee_tracing", +} + + +NODES_TF = { + "agenta_ee_core": "agenta_ee_tracing", +} + + +async def copy_nodes_from_core_to_tracing(): + engine = create_async_engine( + POSTGRES_URI_POSTGRES, + isolation_level="AUTOCOMMIT", + ) + + async with engine.begin() as conn: + for old_name, new_name in NODES_TF.items(): + old_exists = ( + await conn.execute( + text("SELECT 1 FROM pg_database WHERE datname = :name"), + {"name": old_name}, + ) + ).scalar() + + new_exists = ( + await conn.execute( + text("SELECT 1 FROM pg_database WHERE datname = :name"), + {"name": new_name}, + ) + ).scalar() + + if old_exists and new_exists: + # Check if the nodes table exists in old_name database + check_url = f"{DB_PROTOCOL}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{old_name}" + check_engine = create_async_engine(check_url) + async with check_engine.begin() as conn: + result = ( + await conn.execute( + text("SELECT to_regclass('public.nodes')"), + ) + ).scalar() + if result is None: + print( + f"⚠️ Table 'nodes' does not exist in '{old_name}'. Skipping copy." + ) + return + + count = ( + await conn.execute( + text("SELECT COUNT(*) FROM public.nodes"), + ) + ).scalar() + + if count == 0: + print( + f"⚠️ Table 'nodes' is empty in '{old_name}'. Skipping copy." + ) + return + + check_url = f"{DB_PROTOCOL}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{new_name}" + check_engine = create_async_engine(check_url) + + async with check_engine.begin() as conn: + count = ( + await conn.execute( + text( + "SELECT COUNT(*) FROM public.nodes", + ) + ) + ).scalar() + + if (count or 0) > 0: + print( + f"⚠️ Table 'nodes' already exists in '{new_name}' with {count} rows. Skipping copy." + ) + return + + with tempfile.NamedTemporaryFile(suffix=".sql", delete=False) as tmp: + dump_file = tmp.name + + try: + # Step 1: Dump the 'nodes' table to file + subprocess.run( + [ + "pg_dump", + "-h", + DB_HOST, + "-p", + str(DB_PORT), + "-U", + DB_USER, + "-d", + old_name, + "-t", + "nodes", + "--format=custom", # requires -f, not stdout redirection + "--no-owner", + "--no-privileges", + "-f", + dump_file, + ], + check=True, + env={**os.environ, "PGPASSWORD": DB_PASS}, + ) + + print(f"✔ Dumped 'nodes' table to '{dump_file}'") + + # Step 2: Restore the dump into the new database + subprocess.run( + [ + "pg_restore", + "--data-only", + "--no-owner", + "--no-privileges", + "-h", + DB_HOST, + "-p", + str(DB_PORT), + "-U", + DB_USER, + "-d", + new_name, + dump_file, + ], + check=True, + env={**os.environ, "PGPASSWORD": DB_PASS}, + ) + + print(f"✔ Restored 'nodes' table into '{new_name}'") + + # Step 3: Verify 'nodes' exists in both DBs, then drop from old + source_engine = create_async_engine( + f"{DB_PROTOCOL}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{old_name}" + ) + dest_engine = create_async_engine( + f"{DB_PROTOCOL}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{new_name}" + ) + + async with source_engine.begin() as src, dest_engine.begin() as dst: + src = await src.execution_options(isolation_level="AUTOCOMMIT") + dst = await dst.execution_options(isolation_level="AUTOCOMMIT") + + src_exists = ( + await src.execute( + text("SELECT to_regclass('public.nodes')") + ) + ).scalar() + dst_exists = ( + await dst.execute( + text("SELECT to_regclass('public.nodes')"), + ) + ).scalar() + + if src_exists and dst_exists: + subprocess.run( + [ + "psql", + "-h", + DB_HOST, + "-p", + str(DB_PORT), + "-U", + DB_USER, + "-d", + old_name, + "-c", + "TRUNCATE TABLE public.nodes CASCADE", + ], + check=True, + env={**os.environ, "PGPASSWORD": DB_PASS}, + ) + + count = ( + await src.execute( + text("SELECT COUNT(*) FROM public.nodes"), + ) + ).scalar() + + print(f"✅ Remaining rows: {count}") + + except subprocess.CalledProcessError as e: + print(f"❌ pg_dump/psql failed: {e}") + finally: + if os.path.exists(dump_file): + os.remove(dump_file) + + +async def split_core_and_tracing(): + engine = create_async_engine( + POSTGRES_URI_POSTGRES, + isolation_level="AUTOCOMMIT", + ) + + async with engine.begin() as conn: + for old_name, new_name in RENAME_MAP.items(): + old_exists = ( + await conn.execute( + text("SELECT 1 FROM pg_database WHERE datname = :name"), + {"name": old_name}, + ) + ).scalar() + + new_exists = ( + await conn.execute( + text("SELECT 1 FROM pg_database WHERE datname = :name"), + {"name": new_name}, + ) + ).scalar() + + if old_exists and not new_exists: + print(f"Renaming database '{old_name}' → '{new_name}'...") + try: + await conn.execute( + text(f"ALTER DATABASE {old_name} RENAME TO {new_name}") + ) + print(f"✔ Renamed '{old_name}' to '{new_name}'") + except ProgrammingError as e: + print(f"❌ Failed to rename '{old_name}': {e}") + + elif not old_exists and new_exists: + print( + f"'{old_name}' does not exist, but '{new_name}' already exists. No action taken." + ) + + elif not old_exists and not new_exists: + print( + f"Neither '{old_name}' nor '{new_name}' exists. Creating '{new_name}'..." + ) + try: + # Ensure the role exists + await conn.execute( + text( + f""" + DO $$ + BEGIN + IF NOT EXISTS (SELECT FROM pg_roles WHERE rolname = '{DB_USER}') THEN + EXECUTE format('CREATE ROLE %I WITH LOGIN PASSWORD %L', '{DB_USER}', '{DB_PASS}'); + END IF; + END + $$; + """ + ) + ) + print(f"✔ Ensured role '{DB_USER}' exists") + + # Create the new database + await conn.execute(text(f"CREATE DATABASE {new_name}")) + print(f"✔ Created database '{new_name}'") + + # Grant privileges on the database to the role + await conn.execute( + text( + f"GRANT ALL PRIVILEGES ON DATABASE {new_name} TO {DB_USER}" + ) + ) + print( + f"✔ Granted privileges on database '{new_name}' to '{DB_USER}'" + ) + + # Connect to the new database to grant schema permissions + new_db_url = f"{DB_PROTOCOL}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{new_name}" + + async with create_async_engine( + new_db_url, isolation_level="AUTOCOMMIT" + ).begin() as new_db_conn: + await new_db_conn.execute( + text(f"GRANT ALL ON SCHEMA public TO {DB_USER}") + ) + print( + f"✔ Granted privileges on schema 'public' in '{new_name}' to '{DB_USER}'" + ) + + except ProgrammingError as e: + print( + f"❌ Failed during creation or configuration of '{new_name}': {e}" + ) + + else: + print(f"Both '{old_name}' and '{new_name}' exist. No action taken.") diff --git a/api/ee/docker/Dockerfile.dev b/api/ee/docker/Dockerfile.dev new file mode 100644 index 0000000000..a650319e31 --- /dev/null +++ b/api/ee/docker/Dockerfile.dev @@ -0,0 +1,44 @@ +FROM python:3.11-slim-bullseye + +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y curl cron gnupg2 lsb-release && \ + echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \ + curl -fsSL https://www.postgresql.org/media/keys/ACCC4CF8.asc | \ + gpg --dearmor -o /etc/apt/trusted.gpg.d/postgresql.gpg && \ + apt-get update && \ + apt-get install -y postgresql-client-16 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN pip install --upgrade pip \ + && pip install poetry + +COPY ./ee /app/ee/ +COPY ./oss /app/oss/ +COPY ./entrypoint.py ./pyproject.toml /app/ + +RUN poetry config virtualenvs.create false \ + && poetry install --no-interaction --no-ansi + # && pip install -e /sdk/ + +# ENV PYTHONPATH=/sdk:$PYTHONPATH + +COPY ./ee/src/crons/meters.sh /meters.sh +COPY ./ee/src/crons/meters.txt /etc/cron.d/meters-cron +RUN sed -i -e '$a\' /etc/cron.d/meters-cron +RUN cat -A /etc/cron.d/meters-cron + +RUN chmod +x /meters.sh \ + && chmod 0644 /etc/cron.d/meters-cron + +COPY ./ee/src/crons/queries.sh /queries.sh +COPY ./ee/src/crons/queries.txt /etc/cron.d/queries-cron +RUN sed -i -e '$a\' /etc/cron.d/queries-cron +RUN cat -A /etc/cron.d/queries-cron + +RUN chmod +x /queries.sh \ + && chmod 0644 /etc/cron.d/queries-cron + +EXPOSE 8000 diff --git a/api/ee/docker/Dockerfile.gh b/api/ee/docker/Dockerfile.gh new file mode 100644 index 0000000000..8e8e6ec936 --- /dev/null +++ b/api/ee/docker/Dockerfile.gh @@ -0,0 +1,44 @@ +FROM python:3.11-slim-bullseye + +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y curl cron gnupg2 lsb-release && \ + echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \ + curl -fsSL https://www.postgresql.org/media/keys/ACCC4CF8.asc | \ + gpg --dearmor -o /etc/apt/trusted.gpg.d/postgresql.gpg && \ + apt-get update && \ + apt-get install -y postgresql-client-16 && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +RUN pip install --upgrade pip \ + && pip install poetry + +COPY ./ee /app/ee/ +COPY ./oss /app/oss/ +COPY ./entrypoint.py ./pyproject.toml /app/ + +RUN poetry config virtualenvs.create false \ + && poetry install --no-interaction --no-ansi +# + +# + +COPY ./ee/src/crons/meters.sh /meters.sh +COPY ./ee/src/crons/meters.txt /etc/cron.d/meters-cron +RUN sed -i -e '$a\' /etc/cron.d/meters-cron +RUN cat -A /etc/cron.d/meters-cron + +RUN chmod +x /meters.sh \ + && chmod 0644 /etc/cron.d/meters-cron + +COPY ./ee/src/crons/queries.sh /queries.sh +COPY ./ee/src/crons/queries.txt /etc/cron.d/queries-cron +RUN sed -i -e '$a\' /etc/cron.d/queries-cron +RUN cat -A /etc/cron.d/queries-cron + +RUN chmod +x /queries.sh \ + && chmod 0644 /etc/cron.d/queries-cron + +EXPOSE 8000 diff --git a/api/ee/src/__init__.py b/api/ee/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/apis/__init__.py b/api/ee/src/apis/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/apis/fastapi/__init__.py b/api/ee/src/apis/fastapi/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/apis/fastapi/billing/__init__.py b/api/ee/src/apis/fastapi/billing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/apis/fastapi/billing/models.py b/api/ee/src/apis/fastapi/billing/models.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/apis/fastapi/billing/router.py b/api/ee/src/apis/fastapi/billing/router.py new file mode 100644 index 0000000000..7ac23142c5 --- /dev/null +++ b/api/ee/src/apis/fastapi/billing/router.py @@ -0,0 +1,980 @@ +from typing import Any, Dict +from os import environ +from json import loads, decoder +from uuid import getnode +from datetime import datetime, timezone +from dateutil.relativedelta import relativedelta + +from fastapi import APIRouter, Request, status, HTTPException, Query +from fastapi.responses import JSONResponse + +import stripe + +from oss.src.utils.common import is_ee +from oss.src.utils.logging import get_module_logger +from oss.src.utils.exceptions import intercept_exceptions +from oss.src.utils.caching import get_cache, set_cache, invalidate_cache + +from oss.src.services.db_manager import ( + get_user_with_id, + get_organization_by_id, +) + +from ee.src.utils.permissions import check_action_access +from ee.src.models.shared_models import Permission +from ee.src.core.entitlements.types import ENTITLEMENTS, CATALOG, Tracker, Quota +from ee.src.core.subscriptions.types import Event, Plan +from ee.src.core.subscriptions.service import ( + SubscriptionsService, + SwitchException, + EventException, +) + + +log = get_module_logger(__name__) + +stripe.api_key = environ.get("STRIPE_API_KEY") + +MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xff:02x}" for ele in range(40, -1, -8)) +STRIPE_WEBHOOK_SECRET = environ.get("STRIPE_WEBHOOK_SECRET") +STRIPE_TARGET = environ.get("STRIPE_TARGET") or MAC_ADDRESS +AGENTA_PRICING = loads(environ.get("AGENTA_PRICING") or "{}") + +FORBIDDEN_RESPONSE = JSONResponse( + status_code=403, + content={ + "detail": "You do not have access to perform this action. Please contact your organization admin.", + }, +) + + +class SubscriptionsRouter: + def __init__( + self, + subscription_service: SubscriptionsService, + ): + self.subscription_service = subscription_service + + # ROUTER + self.router = APIRouter() + + # USES 'STRIPE_WEBHOOK_SECRET', SHOULD BE IN A DIFFERENT ROUTER + self.router.add_api_route( + "/stripe/events/", + self.handle_events, + methods=["POST"], + operation_id="handle_events", + ) + + self.router.add_api_route( + "/stripe/portals/", + self.create_portal_user_route, + methods=["POST"], + operation_id="create_portal", + ) + + self.router.add_api_route( + "/stripe/checkouts/", + self.create_checkout_user_route, + methods=["POST"], + operation_id="create_checkout", + ) + + self.router.add_api_route( + "/plans", + self.fetch_plan_user_route, + methods=["GET"], + operation_id="fetch_plans", + ) + + self.router.add_api_route( + "/plans/switch", + self.switch_plans_user_route, + methods=["POST"], + operation_id="switch_plans", + ) + + self.router.add_api_route( + "/subscription", + self.fetch_subscription_user_route, + methods=["GET"], + operation_id="fetch_subscription", + ) + + self.router.add_api_route( + "/subscription/cancel", + self.cancel_subscription_user_route, + methods=["POST"], + operation_id="cancel_plan", + ) + + self.router.add_api_route( + "/usage", + self.fetch_usage_user_route, + methods=["GET"], + operation_id="fetch_usage", + ) + + # ADMIN ROUTER + self.admin_router = APIRouter() + + self.admin_router.add_api_route( + "/stripe/portals/", + self.create_portal_admin_route, + methods=["POST"], + operation_id="admin_create_portal", + ) + + self.admin_router.add_api_route( + "/stripe/checkouts/", + self.create_checkout_admin_route, + methods=["POST"], + operation_id="admin_create_checkout", + ) + + self.admin_router.add_api_route( + "/plans/switch", + self.switch_plans_admin_route, + methods=["POST"], + operation_id="admin_switch_plans", + ) + + self.admin_router.add_api_route( + "/subscription/cancel", + self.cancel_subscription_admin_route, + methods=["POST"], + operation_id="admin_cancel_subscription", + ) + + # DOESN'T REQUIRE 'organization_id' + self.admin_router.add_api_route( + "/usage/report", + self.report_usage, + methods=["POST"], + operation_id="admin_report_usage", + ) + + # HANDLERS + + @intercept_exceptions() + async def handle_events( + self, + request: Request, + ): + if not stripe.api_key: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Missing Stripe API Key"}, + ) + + payload = await request.body() + stripe_event = None + + try: + stripe_event = loads(payload) + except decoder.JSONDecodeError: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"status": "error", "message": "Payload extraction failed"}, + ) + + try: + stripe_event = stripe.Event.construct_from( + stripe_event, + stripe.api_key, + ) + except ValueError as e: + log.error("Could not construct stripe event: %s", e) + raise HTTPException(status_code=400, detail="Invalid payload") from e + + try: + sig_header = request.headers.get("stripe-signature") + + if STRIPE_WEBHOOK_SECRET: + stripe_event = stripe.Webhook.construct_event( + payload, + sig_header, + STRIPE_WEBHOOK_SECRET, + ) + except stripe.error.SignatureVerificationError as e: + log.error("Webhook signature verification failed: %s", e) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"status": "error", "message": "Signature verification failed"}, + ) + + metadata = None + + if not stripe_event.type.startswith("invoice"): + if not hasattr(stripe_event.data.object, "metadata"): + log.warn("Skipping stripe event: %s (no metadata)", stripe_event.type) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Metadata not found"}, + ) + else: + metadata = stripe_event.data.object.metadata + + if stripe_event.type.startswith("invoice"): + if not hasattr( + stripe_event.data.object, "subscription_details" + ) and not hasattr( + stripe_event.data.object.subscription_details, "metadata" + ): + log.warn("Skipping stripe event: %s (no metadata)", stripe_event.type) + + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Metadata not found"}, + ) + else: + metadata = stripe_event.data.object.subscription_details.metadata + + if "target" not in metadata: + log.warn("Skipping stripe event: %s (no target)", stripe_event.type) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Target not found"}, + ) + + target = metadata.get("target") + + if target != STRIPE_TARGET: + log.warn( + "Skipping stripe event: %s (wrong target: %s)", + stripe_event.type, + target, + ) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Target mismatch"}, + ) + + if "organization_id" not in metadata: + log.warn("Skipping stripe event: %s (no organization)", stripe_event.type) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Organization ID not found"}, + ) + + organization_id = metadata.get("organization_id") + + log.info( + "Stripe event: %s | %s | %s", + organization_id, + stripe_event.type, + target, + ) + + try: + event = None + subscription_id = None + plan = None + anchor = None + + if stripe_event.type == "customer.subscription.created": + event = Event.SUBSCRIPTION_CREATED + + if "id" not in stripe_event.data.object: + log.warn( + "Skipping stripe event: %s (no subscription)", + stripe_event.type, + ) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={ + "status": "error", + "message": "Subscription ID not found", + }, + ) + + subscription_id = stripe_event.data.object.id + + if "plan" not in metadata: + log.warn("Skipping stripe event: %s (no plan)", stripe_event.type) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={ + "status": "error", + "message": "Plan not found", + }, + ) + + plan = Plan(metadata.get("plan")) + + if "billing_cycle_anchor" not in stripe_event.data.object: + log.warn("Skipping stripe event: %s (no anchor)", stripe_event.type) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={ + "status": "error", + "message": "Anchor not found", + }, + ) + + anchor = datetime.fromtimestamp( + stripe_event.data.object.billing_cycle_anchor + ).day + + elif stripe_event.type == "invoice.payment_failed": + event = Event.SUBSCRIPTION_PAUSED + + elif stripe_event.type == "invoice.payment_succeeded": + event = Event.SUBSCRIPTION_RESUMED + + elif stripe_event.type == "customer.subscription.deleted": + event = Event.SUBSCRIPTION_CANCELLED + + else: + log.warn("Skipping stripe event: %s (unsupported)", stripe_event.type) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"status": "error", "message": "Unsupported event"}, + ) + + subscription = await self.subscription_service.process_event( + organization_id=organization_id, + event=event, + subscription_id=subscription_id, + plan=plan, + anchor=anchor, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail="unexpected error") from e + + if not subscription: + raise HTTPException(status_code=500, detail="unexpected error") + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"status": "success"}, + ) + + async def create_portal( + self, + organization_id: str, + ): + if not stripe.api_key: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Missing Stripe API Key"}, + ) + + subscription = await self.subscription_service.read( + organization_id=organization_id, + ) + + if not subscription: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"status": "error", "message": "Subscription not found"}, + ) + + if not subscription.customer_id: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={ + "status": "error", + "message": "Access denied: please subscribe to a plan to access the portal", + }, + ) + + portal = stripe.billing_portal.Session.create( + customer=subscription.customer_id, + ) + + return {"portal_url": portal.url} + + async def create_checkout( + self, + organization_id: str, + plan: Plan, + success_url: str, + ): + if not stripe.api_key: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"status": "error", "message": "Missing Stripe API Key"}, + ) + + if plan.name not in Plan.__members__.keys(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid plan", + ) + + subscription = await self.subscription_service.read( + organization_id=organization_id, + ) + + if not subscription: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={ + "status": "error", + "message": "Subscription (Agenta) not found", + }, + ) + + if subscription.subscription_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={ + "status": "error", + "message": "Subscription (Stripe) already exists", + }, + ) + + if not subscription.customer_id: + organization = await get_organization_by_id( + organization_id=organization_id, + ) + + if not organization: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={ + "status": "error", + "message": "Organization not found", + }, + ) + + user = await get_user_with_id( + user_id=organization.owner, + ) + + if not user: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"status": "error", "message": "Owner not found"}, + ) + + customer = stripe.Customer.create( + name=organization.name, + email=user.email, + metadata={ + "organization_id": organization_id, + "target": STRIPE_TARGET, + }, + ) + + subscription.customer_id = customer.id + + await self.subscription_service.update( + subscription=subscription, + ) + + checkout = stripe.checkout.Session.create( + mode="subscription", + payment_method_types=["card"], + allow_promotion_codes=True, + customer_update={"address": "auto", "name": "auto"}, + billing_address_collection="required", + automatic_tax={"enabled": True}, + tax_id_collection={"enabled": True}, + # + customer=subscription.customer_id, + line_items=list(AGENTA_PRICING[plan].values()), + # + subscription_data={ + # "billing_cycle_anchor": anchor, + "metadata": { + "organization_id": organization_id, + "plan": plan.value, + "target": STRIPE_TARGET, + }, + }, + # + ui_mode="hosted", + success_url=success_url, + ) + + return {"checkout_url": checkout.url} + + async def fetch_plans( + self, + organization_id: str, + ): + plans = [] + + subscription = await self.subscription_service.read( + organization_id=organization_id, + ) + + if not subscription: + key = None + else: + key = subscription.plan.value + + for plan in CATALOG: + if plan["type"] == "standard": + plans.append(plan) + elif plan["type"] == "custom" and plan["plan"] == key: + plans.append(plan) + + return plans + + async def switch_plans( + self, + organization_id: str, + plan: Plan, + # force: bool, + ): + if plan.name not in Plan.__members__.keys(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid plan", + ) + + try: + subscription = await self.subscription_service.process_event( + organization_id=organization_id, + event=Event.SUBSCRIPTION_SWITCHED, + plan=plan.value, + # force=force, + ) + + if not subscription: + raise HTTPException(status_code=500, detail="unexpected error") + + except EventException as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + except SwitchException as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="unexpected error", + ) from e + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"status": "success"}, + ) + + async def fetch_subscription( + self, + organization_id: str, + ): + now = datetime.now(timezone.utc) + + subscription = await self.subscription_service.read( + organization_id=organization_id, + ) + + if not subscription or not subscription.plan: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={ + "status": "error", + "message": "Subscription (Agenta) not found", + }, + ) + + plan = subscription.plan + anchor = subscription.anchor + + _status: Dict[str, Any] = dict( + plan=plan.value, + type="standard", + ) + + if plan == Plan.CLOUD_V0_HOBBY: + return _status + + if not subscription.subscription_id: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={ + "status": "error", + "message": "Subscription (Agenta) not found", + }, + ) + + if not stripe.api_key: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={ + "status": "error", + "message": "Missing Stripe API Key", + }, + ) + + try: + _subscription = stripe.Subscription.retrieve( + id=subscription.subscription_id, + ) + except Exception: + _subscription = None + + if _subscription: + _status["period_start"] = int(_subscription.current_period_start) + _status["period_end"] = int(_subscription.current_period_end) + _status["free_trial"] = _subscription.status == "trialing" + + return _status + + if not anchor or anchor < 1 or anchor > 31: + anchor = now.day + + last_day_this_month = ( + datetime( + now.year, + now.month, + 1, + tzinfo=timezone.utc, + ) + + relativedelta( + months=+1, + days=-1, + ) + ).day + + day_this_month = min(anchor, last_day_this_month) + + if now.day < anchor: + prev_month = now + relativedelta( + months=-1, + ) + + last_day_prev_month = ( + datetime( + prev_month.year, + prev_month.month, + 1, + tzinfo=timezone.utc, + ) + + relativedelta( + months=+1, + days=-1, + ) + ).day + + day_prev_month = min(anchor, last_day_prev_month) + + period_start = datetime( + year=prev_month.year, + month=prev_month.month, + day=day_prev_month, + tzinfo=timezone.utc, + ) + period_end = datetime( + year=now.year, + month=now.month, + day=day_this_month, + tzinfo=timezone.utc, + ) + else: + period_start = datetime( + year=now.year, + month=now.month, + day=day_this_month, + tzinfo=timezone.utc, + ) + + next_month = now + relativedelta( + months=+1, + ) + + last_day_next_month = ( + datetime( + next_month.year, + next_month.month, + 1, + tzinfo=timezone.utc, + ) + + relativedelta( + months=+1, + days=-1, + ) + ).day + + day_next_month = min(anchor, last_day_next_month) + + period_end = datetime( + year=next_month.year, + month=next_month.month, + day=day_next_month, + tzinfo=timezone.utc, + ) + + _status["period_start"] = int(period_start.timestamp()) + _status["period_end"] = int(period_end.timestamp()) + _status["free_trial"] = False + _status["type"] = "custom" + + return _status + + async def cancel_subscription( + self, + organization_id: str, + ): + subscription = await self.subscription_service.read( + organization_id=organization_id, + ) + + if not subscription: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Subscription (Agenta) not found", + ) + + if not subscription.subscription_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Subscription (Stripe) not found", + ) + + try: + stripe.Subscription.cancel(subscription.subscription_id) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Could not cancel subscription. Please try again or contact support.", + ) from e + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"status": "success"}, + ) + + async def fetch_usage( + self, + organization_id: str, + ): + now = datetime.now(timezone.utc) + + subscription = await self.subscription_service.read( + organization_id=organization_id, + ) + + if not subscription: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"status": "error", "message": "Subscription not found"}, + ) + + plan = subscription.plan + anchor_day = subscription.anchor + anchor_month = (now.month + (1 if now.day >= anchor_day else 0)) % 12 + + entitlements = ENTITLEMENTS.get(plan) + + if not entitlements: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"status": "error", "message": "Plan not found"}, + ) + + meters = await self.subscription_service.meters_service.fetch( + organization_id=organization_id, + ) + + usage = {} + + for tracker in [Tracker.COUNTERS, Tracker.GAUGES]: + for key in list(entitlements[tracker].keys()): + quota: Quota = entitlements[tracker][key] + value = 0 + + for meter in meters: + if meter.key == key: + if meter.month != 0 and meter.month != anchor_month: + continue + + value = meter.value + + usage[key] = { + "value": value, + "limit": quota.limit, + "free": quota.free, + "monthly": quota.monthly is True, + "strict": quota.strict is True, + } + + return usage + + @intercept_exceptions() + async def report_usage( + self, + ): + try: + await self.subscription_service.meters_service.report() + except Exception as e: + raise HTTPException(status_code=500, detail="unexpected error") from e + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"status": "success"}, + ) + + # ROUTES + + @intercept_exceptions() + async def create_portal_user_route( + self, + request: Request, + ): + if not await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.EDIT_BILLING, + ): + return FORBIDDEN_RESPONSE + + return await self.create_portal( + organization_id=request.state.organization_id, + ) + + @intercept_exceptions() + async def create_portal_admin_route( + self, + organization_id: str = Query(...), + ): + return await self.create_portal( + organization_id=organization_id, + ) + + @intercept_exceptions() + async def create_checkout_user_route( + self, + request: Request, + plan: Plan = Query(...), + success_url: str = Query(...), # find a way to make this optional or moot + ): + if not await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.EDIT_BILLING, + ): + return FORBIDDEN_RESPONSE + + return await self.create_checkout( + organization_id=request.state.organization_id, + plan=plan, + success_url=success_url, + ) + + @intercept_exceptions() + async def create_checkout_admin_route( + self, + organization_id: str = Query(...), + plan: Plan = Query(...), + success_url: str = Query(...), # find a way to make this optional or moot + ): + return await self.create_checkout( + organization_id=organization_id, + plan=plan, + success_url=success_url, + ) + + @intercept_exceptions() + async def fetch_plan_user_route( + self, + request: Request, + ): + if not await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.VIEW_BILLING, + ): + return FORBIDDEN_RESPONSE + + return await self.fetch_plans( + organization_id=request.state.organization_id, + ) + + @intercept_exceptions() + async def switch_plans_user_route( + self, + request: Request, + plan: Plan = Query(...), + ): + if not await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.EDIT_BILLING, + ): + return FORBIDDEN_RESPONSE + + return await self.switch_plans( + organization_id=request.state.organization_id, + plan=plan, + ) + + @intercept_exceptions() + async def switch_plans_admin_route( + self, + organization_id: str = Query(...), + plan: Plan = Query(...), + ): + return await self.switch_plans( + organization_id=organization_id, + plan=plan, + ) + + @intercept_exceptions() + async def fetch_subscription_user_route( + self, + request: Request, + ): + if not await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.VIEW_BILLING, + ): + return FORBIDDEN_RESPONSE + + return await self.fetch_subscription( + organization_id=request.state.organization_id, + ) + + @intercept_exceptions() + async def cancel_subscription_user_route( + self, + request: Request, + ): + if not await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.EDIT_BILLING, + ): + return FORBIDDEN_RESPONSE + + return await self.cancel_subscription( + organization_id=request.state.organization_id, + ) + + @intercept_exceptions() + async def cancel_subscription_admin_route( + self, + organization_id: str = Query(...), + ): + return await self.cancel_subscription( + organization_id=organization_id, + ) + + @intercept_exceptions() + async def fetch_usage_user_route( + self, + request: Request, + ): + if not await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.VIEW_BILLING, + ): + return FORBIDDEN_RESPONSE + + return await self.fetch_usage( + organization_id=request.state.organization_id, + ) diff --git a/api/ee/src/core/__init__.py b/api/ee/src/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/core/entitlements/__init__.py b/api/ee/src/core/entitlements/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/core/entitlements/service.py b/api/ee/src/core/entitlements/service.py new file mode 100644 index 0000000000..f62b11fc74 --- /dev/null +++ b/api/ee/src/core/entitlements/service.py @@ -0,0 +1,97 @@ +from typing import Optional, Dict, List + +from ee.src.core.entitlements.types import ( + Tracker, + Constraint, + ENTITLEMENTS, + CONSTRAINTS, +) +from ee.src.core.entitlements.types import Quota, Gauge +from ee.src.core.subscriptions.types import Plan +from ee.src.core.meters.service import MetersService +from ee.src.core.meters.types import MeterDTO + + +class ConstaintsException(Exception): + issues: Dict[Gauge, int] = {} + + +class EntitlementsService: + def __init__( + self, + meters_service: MetersService, + ): + self.meters_service = meters_service + + async def enforce( + self, + *, + organization_id: str, + plan: str, + force: Optional[bool] = False, + ) -> None: + issues = await self.check( + organization_id=organization_id, + plan=plan, + ) + + if issues: + if not force: + raise ConstaintsException( + issues=issues, + ) + + await self.fix( + organization_id=organization_id, + issues=issues, + ) + + async def check( + self, + *, + organization_id: str, + plan: Plan, + ) -> Dict[Gauge, int]: + issues = {} + + for key in CONSTRAINTS[Constraint.BLOCKED][Tracker.GAUGES]: + quotas: List[Quota] = ENTITLEMENTS[plan][Tracker.GAUGES] + + if key in quotas: + meter = MeterDTO( + organization_id=organization_id, + key=key, + ) + quota: Quota = quotas[key] + + check, meter = await self.meters_service.check( + meter=meter, + quota=quota, + ) + + if not check: + issues[key] = quota.limit + + return issues + + async def fix( + self, + *, + organization_id: str, + issues: Dict[Gauge, int], + ) -> None: + # TODO: Implement fix + pass + + +# TODO: +# -- P0 / MUST +# - Add active : Optional[bool] = None to all scopes and users +# -- P1 / SHOULD +# - Add parent scopes to all child scope +# - Add parent scopes membership on child scope membership creation +# - Remove children scopes membership on parent scope membership removal +# -- P2 / COULD +# - Add created_at / updated_at to all scopes +# - Set updated_at on all updates + on creation +# - Move organization roles to memberships diff --git a/api/ee/src/core/entitlements/types.py b/api/ee/src/core/entitlements/types.py new file mode 100644 index 0000000000..791ddfd024 --- /dev/null +++ b/api/ee/src/core/entitlements/types.py @@ -0,0 +1,277 @@ +from typing import Optional +from enum import Enum +from pydantic import BaseModel + +from ee.src.core.subscriptions.types import Plan + + +class Tracker(str, Enum): + FLAGS = "flags" + COUNTERS = "counters" + GAUGES = "gauges" + + +class Flag(str, Enum): + # HISTORY = "history" + HOOKS = "hooks" + RBAC = "rbac" + + +class Counter(str, Enum): + TRACES = "traces" + EVALUATIONS = "evaluations" + EVALUATORS = "evaluators" + ANNOTATIONS = "annotations" + + +class Gauge(str, Enum): + USERS = "users" + APPLICATIONS = "applications" + + +class Constraint(str, Enum): + BLOCKED = "blocked" + READ_ONLY = "read_only" + + +class Quota(BaseModel): + free: Optional[int] = None + limit: Optional[int] = None + monthly: Optional[bool] = None + strict: Optional[bool] = False + + +class Probe(BaseModel): + monthly: Optional[bool] = False + delta: Optional[bool] = False + + +CATALOG = [ + { + "title": "Hobby", + "description": "Great for hobby projects and POCs.", + "type": "standard", + "plan": Plan.CLOUD_V0_HOBBY.value, + "price": { + "base": { + "type": "flat", + "currency": "USD", + "amount": 0.00, + }, + }, + "features": [ + "2 prompts", + "5k traces/month", + "20 evaluations/month", + "2 seats", + ], + }, + { + "title": "Pro", + "description": "For production projects.", + "type": "standard", + "plan": Plan.CLOUD_V0_PRO.value, + "price": { + "base": { + "type": "flat", + "currency": "USD", + "amount": 49.00, + }, + "users": { + "type": "tiered", + "currency": "USD", + "tiers": [ + { + "limit": 3, + "amount": 0.00, + }, + { + "limit": 10, + "amount": 20.00, + "rate": 1, + }, + ], + }, + "traces": { + "type": "tiered", + "currency": "USD", + "tiers": [ + { + "limit": 10_000, + "amount": 0.00, + }, + { + "amount": 5.00, + "rate": 10_000, + }, + ], + }, + }, + "features": [ + "Unlimited prompts", + "10k traces/month", + "Unlimited evaluations", + "3 seats included", + "Up to 10 seats", + ], + }, + # { + # "title": "Business", + # "description": "For scale, security, and support.", + # "type": "standard", + # "price": { + # "base": { + # "type": "flat", + # "currency": "USD", + # "amount": 399.00, + # "starting_at": True, + # }, + # }, + # "features": [ + # "Unlimited prompts", + # "Unlimited traces", + # "Unlimited evaluations", + # "Unlimited seats", + # ], + # }, + { + "title": "Enterprise", + "description": "For large organizations or custom needs.", + "type": "standard", + "features": [ + "Everything in Pro", + "Unlimited seats", + "SOC 2 reports", + "Security reviews", + "Dedicated support", + "Custom SLAs", + "Custom terms", + "Self-hosted deployment options", + ], + }, + { + "title": "Humanity Labs", + "description": "For Humanity Labs.", + "plan": Plan.CLOUD_V0_HUMANITY_LABS.value, + "type": "custom", + "features": [ + "Everything in Enterprise", + ], + }, + { + "title": "X Labs", + "description": "For X Labs.", + "plan": Plan.CLOUD_V0_X_LABS.value, + "type": "custom", + "features": [ + "Everything in Enterprise", + ], + }, + { + "title": "Agenta", + "description": "For Agenta.", + "plan": Plan.CLOUD_V0_AGENTA_AI.value, + "type": "custom", + "features": [ + "Everything in Enterprise", + ], + }, +] + +ENTITLEMENTS = { + Plan.CLOUD_V0_HOBBY: { + Tracker.FLAGS: { + Flag.HOOKS: False, + Flag.RBAC: False, + }, + Tracker.COUNTERS: { + Counter.TRACES: Quota(limit=5_000, monthly=True, free=5_000), + Counter.EVALUATIONS: Quota(limit=20, monthly=True, free=20, strict=True), + }, + Tracker.GAUGES: { + Gauge.USERS: Quota(limit=2, strict=True, free=2), + Gauge.APPLICATIONS: Quota(limit=2, strict=True, free=2), + }, + }, + Plan.CLOUD_V0_PRO: { + Tracker.FLAGS: { + Flag.HOOKS: True, + Flag.RBAC: False, + }, + Tracker.COUNTERS: { + Counter.TRACES: Quota(monthly=True, free=10_000), + Counter.EVALUATIONS: Quota(monthly=True, strict=True), + }, + Tracker.GAUGES: { + Gauge.USERS: Quota(limit=10, strict=True, free=3), + Gauge.APPLICATIONS: Quota(strict=True), + }, + }, + Plan.CLOUD_V0_HUMANITY_LABS: { + Tracker.FLAGS: { + Flag.HOOKS: True, + Flag.RBAC: True, + }, + Tracker.COUNTERS: { + Counter.TRACES: Quota(monthly=True), + Counter.EVALUATIONS: Quota(monthly=True, strict=True), + }, + Tracker.GAUGES: { + Gauge.USERS: Quota(strict=True), + Gauge.APPLICATIONS: Quota(strict=True), + }, + }, + Plan.CLOUD_V0_X_LABS: { + Tracker.FLAGS: { + Flag.HOOKS: False, + Flag.RBAC: False, + }, + Tracker.COUNTERS: { + Counter.TRACES: Quota(monthly=True), + Counter.EVALUATIONS: Quota(monthly=True, strict=True), + }, + Tracker.GAUGES: { + Gauge.USERS: Quota(strict=True), + Gauge.APPLICATIONS: Quota(strict=True), + }, + }, + Plan.CLOUD_V0_AGENTA_AI: { + Tracker.FLAGS: { + Flag.HOOKS: True, + Flag.RBAC: True, + }, + Tracker.COUNTERS: { + Counter.TRACES: Quota(monthly=True), + Counter.EVALUATIONS: Quota(monthly=True, strict=True), + }, + Tracker.GAUGES: { + Gauge.USERS: Quota(strict=True), + Gauge.APPLICATIONS: Quota(strict=True), + }, + }, +} + + +REPORTS = [ + Counter.TRACES.value, + Gauge.USERS.value, +] + +CONSTRAINTS = { + Constraint.BLOCKED: { + Tracker.FLAGS: [ + Flag.HOOKS, + Flag.RBAC, + ], + Tracker.GAUGES: [ + Gauge.USERS, + Gauge.APPLICATIONS, + ], + }, + Constraint.READ_ONLY: { + Tracker.COUNTERS: [ + Counter.TRACES, + Counter.EVALUATIONS, + ], + }, +} diff --git a/api/ee/src/core/meters/__init__.py b/api/ee/src/core/meters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/core/meters/interfaces.py b/api/ee/src/core/meters/interfaces.py new file mode 100644 index 0000000000..9f4e66605d --- /dev/null +++ b/api/ee/src/core/meters/interfaces.py @@ -0,0 +1,88 @@ +from typing import Tuple, Callable, Optional +from datetime import datetime + +from ee.src.core.entitlements.types import Quota +from ee.src.core.meters.types import MeterDTO + + +class MetersDAOInterface: + def __init__(self): + raise NotImplementedError + + async def dump( + self, + ) -> list[MeterDTO]: + """ + Dump all meters where 'synced' != 'value'. + + :return: A list of MeterDTO objects for meters where 'synced' != 'value'. + """ + raise NotImplementedError + + async def bump( + self, + meters: list[MeterDTO], + ) -> None: + """ + Update the 'synced' field for the given list of meters. + + :param meters: A list of MeterDTO objects containing the details of meters to update. + """ + raise NotImplementedError + + async def fetch( + self, + *, + organization_id: str, + ) -> list[MeterDTO]: + """ + Fetch all meters for a given organization. + + Parameters: + - organization_id: The ID of the organization to fetch meters for. + + Returns: + - List[MeterDTO]: A list of MeterDTO objects containing the meter details. + """ + raise NotImplementedError + + async def check( + self, + *, + meter: MeterDTO, + quota: Quota, + anchor: Optional[int] = None, + ) -> Tuple[bool, MeterDTO]: + """ + Check if the meter adjustment or absolute value is allowed. + + Parameters: + - meter: MeterDTO containing the current meter information and either `value` or `delta`. + - quota: QuotaDTO defining the allowed quota limits. + + Returns: + - allowed (bool): Whether the operation is within the allowed limits. + - meter (MeterDTO): The current meter value if found or 0 if not. + """ + raise NotImplementedError + + async def adjust( + self, + *, + meter: MeterDTO, + quota: Quota, + anchor: Optional[int] = None, + ) -> Tuple[bool, MeterDTO, Callable]: + """ + Adjust the meter value based on the quota. + + Parameters: + - meter: MeterDTO containing either `value` or `delta` for the adjustment. + - quota: QuotaDTO defining the allowed quota limits. + + Returns: + - allowed (bool): Whether the adjustment was within quota limits. + - meter (MeterDTO): The updated meter value after the adjustment. + - rollback (callable): A function to rollback the adjustment (optional, if applicable). + """ + raise NotImplementedError diff --git a/api/ee/src/core/meters/service.py b/api/ee/src/core/meters/service.py new file mode 100644 index 0000000000..ed2ef0fa33 --- /dev/null +++ b/api/ee/src/core/meters/service.py @@ -0,0 +1,173 @@ +from typing import Tuple, Callable, List, Optional +from datetime import datetime +from os import environ +from json import loads + +import stripe + +from oss.src.utils.logging import get_module_logger + +from ee.src.core.entitlements.types import Quota +from ee.src.core.entitlements.types import Counter, Gauge, REPORTS +from ee.src.core.meters.types import MeterDTO +from ee.src.core.meters.interfaces import MetersDAOInterface + +AGENTA_PRICING = loads(environ.get("AGENTA_PRICING") or "{}") + +log = get_module_logger(__name__) + +stripe.api_key = environ.get("STRIPE_API_KEY") + + +class MetersService: + def __init__( + self, + meters_dao: MetersDAOInterface, + ): + self.meters_dao = meters_dao + + async def dump( + self, + ) -> List[MeterDTO]: + return await self.meters_dao.dump() + + async def bump( + self, + *, + meters: List[MeterDTO], + ) -> None: + await self.meters_dao.bump(meters=meters) + + async def fetch( + self, + *, + organization_id: str, + ) -> List[MeterDTO]: + return await self.meters_dao.fetch(organization_id=organization_id) + + async def check( + self, + *, + meter: MeterDTO, + quota: Quota, + anchor: Optional[int] = None, + ) -> Tuple[bool, MeterDTO]: + return await self.meters_dao.check(meter=meter, quota=quota, anchor=anchor) + + async def adjust( + self, + *, + meter: MeterDTO, + quota: Quota, + anchor: Optional[int] = None, + ) -> Tuple[bool, MeterDTO, Callable]: + return await self.meters_dao.adjust(meter=meter, quota=quota, anchor=anchor) + + async def report(self): + if not stripe.api_key: + log.warn("Missing Stripe API Key.") + return + + try: + meters = await self.dump() + + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Error dumping meters: %s", e) + return + + try: + for meter in meters: + if meter.subscription is None: + continue + + try: + if meter.key.value in REPORTS: + subscription_id = meter.subscription.subscription_id + customer_id = meter.subscription.customer_id + + if not subscription_id: + continue + + if not customer_id: + continue + + if meter.key.name in Gauge.__members__.keys(): + try: + price_id = ( + AGENTA_PRICING.get(meter.subscription.plan, {}) + .get("users", {}) + .get("price") + ) + + if not price_id: + continue + + _id = None + for item in stripe.SubscriptionItem.list( + subscription=subscription_id, + ).auto_paging_iter(): + if item.price.id == price_id: + _id = item.id + break + + if not _id: + continue + + quantity = meter.value + + items = [{"id": _id, "quantity": quantity}] + + stripe.Subscription.modify( + subscription_id, + items=items, + ) + + except ( + Exception # pylint: disable=broad-exception-caught + ) as e: + log.error("Error modifying subscription: %s", e) + continue + + log.info( + f"[stripe] updating: {meter.organization_id} | | {'sync ' if meter.key.value in REPORTS else ' '} | {meter.key}: {meter.value}" + ) + + if meter.key.name in Counter.__members__.keys(): + try: + event_name = meter.key.value + delta = meter.value - meter.synced + payload = {"delta": delta, "customer_id": customer_id} + + stripe.billing.MeterEvent.create( + event_name=event_name, + payload=payload, + ) + except ( + Exception # pylint: disable=broad-exception-caught + ) as e: + log.error("Error creating meter event: %s", e) + continue + + log.info( + f"[stripe] reporting: {meter.organization_id} | {(('0' if (meter.month != 0 and meter.month < 10) else '') + str(meter.month)) if meter.month != 0 else ' '}.{meter.year if meter.year else ' '} | {'sync ' if meter.key.value in REPORTS else ' '} | {meter.key}: {meter.value - meter.synced}" + ) + + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Error reporting meter: %s", e) + + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Error reporting meters: %s", e) + + try: + for meter in meters: + meter.synced = meter.value + + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Error syncing meters: %s", e) + + try: + await self.bump(meters=meters) + + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Error bumping meters: %s", e) + return diff --git a/api/ee/src/core/meters/types.py b/api/ee/src/core/meters/types.py new file mode 100644 index 0000000000..a0ada9da16 --- /dev/null +++ b/api/ee/src/core/meters/types.py @@ -0,0 +1,32 @@ +from typing import Optional + +from uuid import UUID +from enum import Enum + +from pydantic import BaseModel + +from ee.src.core.entitlements.types import Counter, Gauge +from ee.src.core.subscriptions.types import SubscriptionDTO + + +class Meters(str, Enum): + # COUNTERS + TRACES = Counter.TRACES.value + EVALUATIONS = Counter.EVALUATIONS.value + # GAUGES + USERS = Gauge.USERS.value + APPLICATIONS = Gauge.APPLICATIONS.value + + +class MeterDTO(BaseModel): + organization_id: UUID + + year: Optional[int] = 0 + month: Optional[int] = 0 + + key: Meters + value: Optional[int] = None + synced: Optional[int] = None + delta: Optional[int] = None + + subscription: Optional[SubscriptionDTO] = None diff --git a/api/ee/src/core/subscriptions/__init__.py b/api/ee/src/core/subscriptions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/core/subscriptions/interfaces.py b/api/ee/src/core/subscriptions/interfaces.py new file mode 100644 index 0000000000..2c47a9a302 --- /dev/null +++ b/api/ee/src/core/subscriptions/interfaces.py @@ -0,0 +1,56 @@ +from typing import Optional + +from ee.src.core.subscriptions.types import SubscriptionDTO + + +class SubscriptionsDAOInterface: + def __init__(self): + raise NotImplementedError + + async def create( + self, + *, + subscription: SubscriptionDTO, + ) -> SubscriptionDTO: + """ + Create a new subscription. + + Parameters: + - subscription: SubscriptionDTO containing subscription details. + + Returns: + - SubscriptionDTO: The created subscription. + """ + raise NotImplementedError + + async def read( + self, + *, + organization_id: str, + ) -> Optional[SubscriptionDTO]: + """ + Read a subscription by organization ID. + + Parameters: + - organization_id: The ID of the organization to fetch. + + Returns: + - Optional[SubscriptionDTO]: The subscription if found, else None. + """ + raise NotImplementedError + + async def update( + self, + *, + subscription: SubscriptionDTO, + ) -> Optional[SubscriptionDTO]: + """ + Update an existing subscription. + + Parameters: + - subscription: SubscriptionDTO containing updated details. + + Returns: + - Optional[SubscriptionDTO]: The updated subscription if found, else None. + """ + raise NotImplementedError diff --git a/api/ee/src/core/subscriptions/service.py b/api/ee/src/core/subscriptions/service.py new file mode 100644 index 0000000000..f69adcbd74 --- /dev/null +++ b/api/ee/src/core/subscriptions/service.py @@ -0,0 +1,271 @@ +from typing import Optional +from json import loads +from uuid import getnode +from datetime import datetime, timezone, timedelta + +from os import environ + +import stripe + +from oss.src.utils.logging import get_module_logger + +from ee.src.core.subscriptions.types import ( + SubscriptionDTO, + Event, + Plan, + FREE_PLAN, + REVERSE_TRIAL_PLAN, + REVERSE_TRIAL_DAYS, +) +from ee.src.core.subscriptions.interfaces import SubscriptionsDAOInterface +from ee.src.core.entitlements.service import EntitlementsService +from ee.src.core.meters.service import MetersService + +log = get_module_logger(__name__) + +stripe.api_key = environ.get("STRIPE_SECRET_KEY") + +MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xff:02x}" for ele in range(40, -1, -8)) +STRIPE_TARGET = environ.get("STRIPE_TARGET") or MAC_ADDRESS +AGENTA_PRICING = loads(environ.get("AGENTA_PRICING") or "{}") + + +class SwitchException(Exception): + pass + + +class EventException(Exception): + pass + + +class SubscriptionsService: + def __init__( + self, + subscriptions_dao: SubscriptionsDAOInterface, + meters_service: MetersService, + ): + self.subscriptions_dao = subscriptions_dao + self.meters_service = meters_service + self.entitlements_service = EntitlementsService(meters_service=meters_service) + + async def create( + self, + *, + subscription: SubscriptionDTO, + ) -> Optional[SubscriptionDTO]: + return await self.subscriptions_dao.create(subscription=subscription) + + async def read( + self, + *, + organization_id: str, + ) -> Optional[SubscriptionDTO]: + return await self.subscriptions_dao.read(organization_id=organization_id) + + async def update( + self, + *, + subscription: SubscriptionDTO, + ) -> Optional[SubscriptionDTO]: + return await self.subscriptions_dao.update(subscription=subscription) + + async def start_reverse_trial( + self, + *, + organization_id: str, + organization_name: str, + organization_email: str, + ) -> Optional[SubscriptionDTO]: + now = datetime.now(tz=timezone.utc) + anchor = now + timedelta(days=REVERSE_TRIAL_DAYS) + + subscription = await self.read(organization_id=organization_id) + + if subscription: + return None + + subscription = await self.create( + subscription=SubscriptionDTO( + organization_id=organization_id, + plan=FREE_PLAN, + active=True, + anchor=anchor.day, + ) + ) + + if not subscription: + return None + + if not stripe.api_key: + log.warn("Missing Stripe API Key.") + return None + + customer = stripe.Customer.create( + name=organization_name, + email=organization_email, + metadata={ + "organization_id": organization_id, + "target": STRIPE_TARGET, + }, + ) + + customer_id = customer.id + + if not customer_id: + log.error( + "Failed to create Stripe customer for organization ID: %s", + organization_id, + ) + + return None + + stripe_subscription = stripe.Subscription.create( + customer=customer_id, + items=list(AGENTA_PRICING[REVERSE_TRIAL_PLAN].values()), + # + # automatic_tax={"enabled": True}, + metadata={ + "organization_id": organization_id, + "plan": REVERSE_TRIAL_PLAN.value, + "target": STRIPE_TARGET, + }, + # + trial_period_days=REVERSE_TRIAL_DAYS, + trial_settings={"end_behavior": {"missing_payment_method": "cancel"}}, + ) + + subscription = await self.update( + subscription=SubscriptionDTO( + organization_id=organization_id, + customer_id=customer_id, + subscription_id=stripe_subscription.id, + plan=REVERSE_TRIAL_PLAN, + active=True, + anchor=anchor.day, + ) + ) + + return subscription + + async def process_event( + self, + *, + organization_id: str, + event: Event, + subscription_id: Optional[str] = None, + plan: Optional[Plan] = None, + anchor: Optional[Plan] = None, + # force: Optional[bool] = True, + **kwargs, + ) -> SubscriptionDTO: + log.info( + "Billing event: %s | %s | %s", + organization_id, + event, + plan, + ) + + now = datetime.now(tz=timezone.utc) + + if not anchor: + anchor = now.day + + subscription = await self.read(organization_id=organization_id) + + if not subscription: + raise EventException( + "Subscription not found for organization ID: {organization_id}" + ) + + if event == Event.SUBSCRIPTION_CREATED: + subscription.active = True + subscription.plan = plan + subscription.subscription_id = subscription_id + subscription.anchor = anchor + + subscription = await self.update(subscription=subscription) + + elif subscription.plan != FREE_PLAN and event == Event.SUBSCRIPTION_PAUSED: + subscription.active = False + + subscription = await self.update(subscription=subscription) + + elif subscription.plan != FREE_PLAN and event == Event.SUBSCRIPTION_RESUMED: + subscription.active = True + + subscription = await self.update(subscription=subscription) + + elif subscription.plan != FREE_PLAN and event == Event.SUBSCRIPTION_SWITCHED: + if not stripe.api_key: + log.warn("Missing Stripe API Key.") + return None + + if subscription.plan == plan: + log.warn("Subscription already on the plan: %s", plan) + + raise EventException( + f"Same plan [{plan}] already exists for organization ID: {organization_id}" + ) + + if not subscription.subscription_id: + raise SwitchException( + f"Cannot switch plans without an existing subscription for organization ID: {organization_id}" + ) + + try: + _subscription = stripe.Subscription.retrieve( + id=subscription.subscription_id, + ) + except Exception as e: # pylint: disable=too-broad-exception + log.warn( + "Failed to retrieve subscription from Stripe: %s", subscription + ) + + raise EventException( + "Could not switch plans. Please try again or contact support.", + ) from e + + subscription.active = True + subscription.plan = plan + + # await self.entitlements_service.enforce( + # organization_id=organization_id, + # plan=plan, + # force=force, + # ) + + stripe.Subscription.modify( + subscription.subscription_id, + items=[ + {"id": item.id, "deleted": True} + for item in stripe.SubscriptionItem.list( + subscription=subscription.subscription_id, + ).data + ] + + list(AGENTA_PRICING[plan].values()), + ) + + subscription = await self.update(subscription=subscription) + + elif subscription.plan != FREE_PLAN and event == Event.SUBSCRIPTION_CANCELLED: + subscription.active = True + subscription.plan = FREE_PLAN + subscription.subscription_id = None + subscription.anchor = anchor + + # await self.entitlements_service.enforce( + # organization_id=organization_id, + # plan=FREE_PLAN, + # force=True, + # ) + + subscription = await self.update(subscription=subscription) + + else: + log.warn("Invalid subscription event: %s ", subscription) + + raise EventException( + f"Invalid subscription event {event} for organization ID: {organization_id}" + ) + + return subscription diff --git a/api/ee/src/core/subscriptions/types.py b/api/ee/src/core/subscriptions/types.py new file mode 100644 index 0000000000..1f55dbe386 --- /dev/null +++ b/api/ee/src/core/subscriptions/types.py @@ -0,0 +1,40 @@ +from typing import Optional + +from os import environ + +from uuid import UUID +from enum import Enum + +from pydantic import BaseModel + + +class Plan(str, Enum): + CLOUD_V0_HOBBY = "cloud_v0_hobby" + CLOUD_V0_PRO = "cloud_v0_pro" + # + CLOUD_V0_HUMANITY_LABS = "cloud_v0_humanity_labs" + CLOUD_V0_X_LABS = "cloud_v0_x_labs" + # + CLOUD_V0_AGENTA_AI = "cloud_v0_agenta_ai" + + +class Event(str, Enum): + SUBSCRIPTION_CREATED = "subscription_created" + SUBSCRIPTION_PAUSED = "subscription_paused" + SUBSCRIPTION_RESUMED = "subscription_resumed" + SUBSCRIPTION_SWITCHED = "subscription_switched" + SUBSCRIPTION_CANCELLED = "subscription_cancelled" + + +class SubscriptionDTO(BaseModel): + organization_id: UUID + customer_id: Optional[str] = None + subscription_id: Optional[str] = None + plan: Optional[Plan] = None + active: Optional[bool] = None + anchor: Optional[int] = None + + +FREE_PLAN = Plan.CLOUD_V0_HOBBY # Move to ENV FILE +REVERSE_TRIAL_PLAN = Plan.CLOUD_V0_PRO # move to ENV FILE +REVERSE_TRIAL_DAYS = 14 # move to ENV FILE diff --git a/api/ee/src/crons/meters.sh b/api/ee/src/crons/meters.sh new file mode 100644 index 0000000000..c0f7d8c5ae --- /dev/null +++ b/api/ee/src/crons/meters.sh @@ -0,0 +1,17 @@ +#!/bin/sh +set -eu + +AGENTA_AUTH_KEY=$(tr '\0' '\n' < /proc/1/environ | grep ^AGENTA_AUTH_KEY= | cut -d= -f2-) + +echo "--------------------------------------------------------" +echo "[$(date)] meters.sh running from cron" >> /proc/1/fd/1 + +# Make POST request, show status and response +curl \ + -s \ + -w "\nHTTP_STATUS:%{http_code}\n" \ + -X POST \ + -H "Authorization: Access ${AGENTA_AUTH_KEY}" \ + "http://api:8000/admin/billing/usage/report" || echo "❌ CURL failed" + +echo "[$(date)] meters.sh done" >> /proc/1/fd/1 \ No newline at end of file diff --git a/api/ee/src/crons/meters.txt b/api/ee/src/crons/meters.txt new file mode 100644 index 0000000000..f3acd78570 --- /dev/null +++ b/api/ee/src/crons/meters.txt @@ -0,0 +1,2 @@ +* * * * * root echo "cron test $(date)" >> /proc/1/fd/1 2>&1 +0 * * * * root sh /meters.sh >> /proc/1/fd/1 2>&1 diff --git a/api/ee/src/crons/queries.sh b/api/ee/src/crons/queries.sh new file mode 100644 index 0000000000..b9e8c7a6e1 --- /dev/null +++ b/api/ee/src/crons/queries.sh @@ -0,0 +1,24 @@ +#!/bin/sh +set -eu + +AGENTA_AUTH_KEY=$(tr '\0' '\n' < /proc/1/environ | grep ^AGENTA_AUTH_KEY= | cut -d= -f2-) +TRIGGER_INTERVAL=$(awk 'NR==2 {split($1, a, "/"); print (a[2] ? a[2] : 1)}' /etc/cron.d/queries-cron) +NOW_UTC=$(date -u "+%Y-%m-%dT%H:%M:00Z") +MINUTE=$(date -u "+%M" | sed 's/^0*//') +ROUNDED_MINUTE=$(( (MINUTE / TRIGGER_INTERVAL) * TRIGGER_INTERVAL )) +TRIGGER_DATETIME=$(date -u "+%Y-%m-%dT%H") +TRIGGER_DATETIME="${TRIGGER_DATETIME}:$(printf "%02d" $ROUNDED_MINUTE):00Z" + + +echo "--------------------------------------------------------" +echo "[$(date)] queries.sh running from cron" >> /proc/1/fd/1 + +# Make POST request, show status and response +curl \ + -s \ + -w "\nHTTP_STATUS:%{http_code}\n" \ + -X POST \ + -H "Authorization: Access ${AGENTA_AUTH_KEY}" \ + "http://api:8000/admin/evaluations/runs/refresh?trigger_interval=${TRIGGER_INTERVAL}&trigger_datetime=${TRIGGER_DATETIME}" || echo "❌ CURL failed" + +echo "[$(date)] queries.sh done" >> /proc/1/fd/1 \ No newline at end of file diff --git a/api/ee/src/crons/queries.txt b/api/ee/src/crons/queries.txt new file mode 100644 index 0000000000..586a61af8e --- /dev/null +++ b/api/ee/src/crons/queries.txt @@ -0,0 +1,2 @@ +* * * * * root echo "cron test $(date)" >> /proc/1/fd/1 2>&1 +*/1 * * * * root sh /queries.sh >> /proc/1/fd/1 2>&1 diff --git a/api/ee/src/dbs/__init__.py b/api/ee/src/dbs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/dbs/postgres/__init__.py b/api/ee/src/dbs/postgres/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/dbs/postgres/meters/__init__.py b/api/ee/src/dbs/postgres/meters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/dbs/postgres/meters/dao.py b/api/ee/src/dbs/postgres/meters/dao.py new file mode 100644 index 0000000000..5302329dc3 --- /dev/null +++ b/api/ee/src/dbs/postgres/meters/dao.py @@ -0,0 +1,290 @@ +from typing import Callable, Tuple, Optional +from collections import defaultdict +from datetime import datetime, timezone + +from sqlalchemy import update +from sqlalchemy.future import select +from sqlalchemy.orm import joinedload +from sqlalchemy import case, tuple_ +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy import func, literal + + +from oss.src.utils.logging import get_module_logger +from oss.src.dbs.postgres.shared.engine import engine + +from ee.src.core.entitlements.types import Quota +from ee.src.core.meters.types import MeterDTO +from ee.src.core.subscriptions.types import SubscriptionDTO +from ee.src.core.meters.interfaces import MetersDAOInterface +from ee.src.dbs.postgres.meters.dbes import MeterDBE + + +log = get_module_logger(__name__) + + +class MetersDAO(MetersDAOInterface): + def __init__(self): + pass + + async def dump(self) -> list[MeterDTO]: + async with engine.core_session() as session: + stmt = ( + select(MeterDBE) + .filter(MeterDBE.synced != MeterDBE.value) + .options(joinedload(MeterDBE.subscription)) + ) # NO RISK OF DEADLOCK + + result = await session.execute(stmt) + meters = result.scalars().all() + + return [ + MeterDTO( + organization_id=meter.organization_id, + year=meter.year, + month=meter.month, + value=meter.value, + key=meter.key, + synced=meter.synced, + subscription=( + SubscriptionDTO( + organization_id=meter.subscription.organization_id, + customer_id=meter.subscription.customer_id, + subscription_id=meter.subscription.subscription_id, + plan=meter.subscription.plan, + active=meter.subscription.active, + anchor=meter.subscription.anchor, + ) + if meter.subscription + else None + ), + ) + for meter in meters + ] + + async def bump( + self, + meters: list[MeterDTO], + ) -> None: + if not meters: + return + + # Sort for consistent lock acquisition + sorted_meters = sorted( + meters, + key=lambda m: ( + m.organization_id, + m.key, + m.year, + m.month, + ), + ) + + async with engine.core_session() as session: + for meter in sorted_meters: + stmt = ( + update(MeterDBE) + .where( + MeterDBE.organization_id == meter.organization_id, + MeterDBE.key == meter.key, + MeterDBE.year == meter.year, + MeterDBE.month == meter.month, + ) + .values(synced=meter.synced) + ) + + await session.execute(stmt) + + await session.commit() + + async def fetch( + self, + *, + organization_id: str, + ) -> list[MeterDTO]: + async with engine.core_session() as session: + stmt = select(MeterDBE).filter_by( + organization_id=organization_id, + ) # NO RISK OF DEADLOCK + + result = await session.execute(stmt) + meters = result.scalars().all() + + return [ + MeterDTO( + organization_id=meter.organization_id, + key=meter.key, + year=meter.year, + month=meter.month, + value=meter.value, + synced=meter.synced, + ) + for meter in meters + ] + + async def check( + self, + *, + meter: MeterDTO, + quota: Quota, + anchor: Optional[int] = None, + ) -> Tuple[bool, MeterDTO]: + if quota.monthly: + now = datetime.now(timezone.utc) + + if not anchor: + meter.year = now.year + meter.month = now.month + + if anchor: + if now.day < anchor: + meter.year = now.year + meter.month = now.month + else: + meter.year = now.year + now.month // 12 + meter.month = (now.month + 1) % 12 + + async with engine.core_session() as session: + stmt = select(MeterDBE).filter_by( + organization_id=meter.organization_id, + key=meter.key, + year=meter.year, + month=meter.month, + ) # NO RISK OF DEADLOCK + + result = await session.execute(stmt) + meter_record = result.scalar_one_or_none() + + current_value = meter_record.value if meter_record else 0 + + adjusted_value = current_value + (meter.delta or 0) + adjusted_value = adjusted_value if adjusted_value >= 0 else 0 + + if quota.limit is None: + allowed = True + else: + allowed = adjusted_value <= quota.limit + + return ( + allowed, + MeterDTO( + **meter.model_dump(exclude={"value", "synced"}), + value=current_value, + synced=meter_record.synced if meter_record else 0, + ), + ) + + async def adjust( + self, + *, + meter: MeterDTO, + quota: Quota, + anchor: Optional[int] = None, + ) -> Tuple[bool, MeterDTO, Callable]: + # 1. Normalize meter.year/month if monthly quota + if quota.monthly: + now = datetime.now(timezone.utc) + + if not anchor: + meter.year = now.year + meter.month = now.month + elif now.day < anchor: + meter.year = now.year + meter.month = now.month + else: + meter.year = now.year + now.month // 12 + meter.month = (now.month + 1) % 12 + + # 2. Calculate proposed value (starting from 0) + desired_value = meter.value if meter.value is not None else (meter.delta or 0) + desired_value = max(desired_value, 0) + + # 3. Block insert if quota exceeded + if quota.limit is not None and desired_value > quota.limit: + return ( + False, + MeterDTO( + **meter.model_dump(exclude={"value", "synced"}), + value=0, + synced=0, + ), + lambda: None, + ) + + where_clauses = [] + + # Handle unlimited quota case + if quota.limit is None: + where_clauses.append(literal(True)) + + # Strict mode: use the adjusted value check + elif quota.strict: + if meter.delta is not None: + adjusted_expr = func.greatest(MeterDBE.value + meter.delta, 0) + elif meter.value is not None: + adjusted_expr = func.greatest(meter.value, 0) + else: + raise ValueError("Either delta or value must be set") + + where_clauses.append(adjusted_expr <= quota.limit) + + # Soft mode: just compare current value + else: + where_clauses.append(MeterDBE.value <= quota.limit) + + # Now safely combine the conditions + where = None + for where_clause in where_clauses: + if where is None: + where = where_clause + else: + where = where | where_clause + + # 4. Build SQL statement (atomic upsert) + async with engine.core_session() as session: + stmt = ( + insert(MeterDBE) + .values( + organization_id=meter.organization_id, + key=meter.key, + year=meter.year, + month=meter.month, + value=desired_value, + synced=0, + ) + .on_conflict_do_update( + index_elements=[ + MeterDBE.organization_id, + MeterDBE.key, + MeterDBE.year, + MeterDBE.month, + ], + set_={ + "value": func.greatest( + ( + (MeterDBE.value + meter.delta) + if meter.delta is not None + else meter.value + ), + 0, + ) + }, + where=where, + ) + ) + + result = await session.execute(stmt) + await session.commit() + + # 5. Check if update was applied (strict mode) + allowed = result.rowcount > 0 + + return ( + allowed, + MeterDTO( + **meter.model_dump(exclude={"value", "synced"}), + value=desired_value, # not technically accurate in soft mode, but good enough + synced=0, + ), + lambda: None, # rollback not needed; no state was touched otherwise + ) diff --git a/api/ee/src/dbs/postgres/meters/dbas.py b/api/ee/src/dbs/postgres/meters/dbas.py new file mode 100644 index 0000000000..450e517d28 --- /dev/null +++ b/api/ee/src/dbs/postgres/meters/dbas.py @@ -0,0 +1,29 @@ +from sqlalchemy import Column, Enum as SQLEnum, SmallInteger, BigInteger + +from ee.src.core.meters.types import Meters + +from oss.src.dbs.postgres.shared.dbas import OrganizationScopeDBA + + +class PeriodDBA: + __abstract__ = True + + year = Column(SmallInteger, nullable=False) + month = Column(SmallInteger, nullable=False) + + +class MeterDBA( + OrganizationScopeDBA, + PeriodDBA, +): + __abstract__ = True + + key = Column( + SQLEnum( + Meters, + name="meters_type", + ), + nullable=False, + ) + value = Column(BigInteger, nullable=False) + synced = Column(BigInteger, nullable=False) diff --git a/api/ee/src/dbs/postgres/meters/dbes.py b/api/ee/src/dbs/postgres/meters/dbes.py new file mode 100644 index 0000000000..f1353ba022 --- /dev/null +++ b/api/ee/src/dbs/postgres/meters/dbes.py @@ -0,0 +1,29 @@ +from sqlalchemy import PrimaryKeyConstraint, ForeignKeyConstraint, Index, func +from sqlalchemy.orm import relationship + +from oss.src.dbs.postgres.shared.base import Base +from ee.src.dbs.postgres.meters.dbas import MeterDBA + + +class MeterDBE(Base, MeterDBA): + __tablename__ = "meters" + + __table_args__ = ( + PrimaryKeyConstraint( + "organization_id", + "key", + "year", + "month", + ), + ForeignKeyConstraint( + ["organization_id"], + ["subscriptions.organization_id"], + ), + Index( + "idx_synced_value", + "synced", + "value", + ), + ) + + subscription = relationship("SubscriptionDBE", back_populates="meters") diff --git a/api/ee/src/dbs/postgres/shared/__init__.py b/api/ee/src/dbs/postgres/shared/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/dbs/postgres/subscriptions/__init__.py b/api/ee/src/dbs/postgres/subscriptions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/dbs/postgres/subscriptions/dao.py b/api/ee/src/dbs/postgres/subscriptions/dao.py new file mode 100644 index 0000000000..485af2dde0 --- /dev/null +++ b/api/ee/src/dbs/postgres/subscriptions/dao.py @@ -0,0 +1,84 @@ +from typing import Optional, List + +from sqlalchemy.future import select + +from ee.src.core.subscriptions.types import SubscriptionDTO +from ee.src.core.subscriptions.interfaces import SubscriptionsDAOInterface + +from oss.src.dbs.postgres.shared.engine import engine +from ee.src.dbs.postgres.subscriptions.dbes import SubscriptionDBE +from ee.src.dbs.postgres.subscriptions.mappings import ( + map_dbe_to_dto, + map_dto_to_dbe, +) + + +class SubscriptionsDAO(SubscriptionsDAOInterface): + def __init__(self): + pass + + async def create( + self, + *, + subscription: SubscriptionDTO, + ) -> SubscriptionDTO: + async with engine.core_session() as session: + subscription_dbe = map_dto_to_dbe(subscription) + + session.add(subscription_dbe) + + await session.commit() + + subscription_dto = map_dbe_to_dto(subscription_dbe) + + return subscription_dto + + async def read( + self, + *, + organization_id: str, + ) -> Optional[SubscriptionDTO]: + async with engine.core_session() as session: + result = await session.execute( + select(SubscriptionDBE).where( + SubscriptionDBE.organization_id == organization_id, + ) + ) + + subscription_dbe = result.scalars().one_or_none() + + if not subscription_dbe: + return None + + subscription_dto = map_dbe_to_dto(subscription_dbe) + + return subscription_dto + + async def update( + self, + *, + subscription: SubscriptionDTO, + ) -> Optional[SubscriptionDTO]: + async with engine.core_session() as session: + result = await session.execute( + select(SubscriptionDBE).where( + SubscriptionDBE.organization_id == subscription.organization_id, + ) + ) + + subscription_dbe = result.scalars().one_or_none() + + if not subscription_dbe: + return None + + subscription_dbe.customer_id = subscription.customer_id + subscription_dbe.subscription_id = subscription.subscription_id + subscription_dbe.plan = subscription.plan + subscription_dbe.active = subscription.active + subscription_dbe.anchor = subscription.anchor + + await session.commit() + + subscription_dto = map_dbe_to_dto(subscription_dbe) + + return subscription_dto diff --git a/api/ee/src/dbs/postgres/subscriptions/dbas.py b/api/ee/src/dbs/postgres/subscriptions/dbas.py new file mode 100644 index 0000000000..7810907030 --- /dev/null +++ b/api/ee/src/dbs/postgres/subscriptions/dbas.py @@ -0,0 +1,19 @@ +from sqlalchemy import Column, String, Boolean, SmallInteger + +from oss.src.dbs.postgres.shared.dbas import OrganizationScopeDBA + + +class StripeDBA: + customer_id = Column(String, nullable=True) + subscription_id = Column(String, nullable=True) + + +class SubscriptionDBA( + OrganizationScopeDBA, + StripeDBA, +): + __abstract__ = True + + plan = Column(String, nullable=False) + active = Column(Boolean, nullable=False) + anchor = Column(SmallInteger, nullable=True) diff --git a/api/ee/src/dbs/postgres/subscriptions/dbes.py b/api/ee/src/dbs/postgres/subscriptions/dbes.py new file mode 100644 index 0000000000..b548dd1a56 --- /dev/null +++ b/api/ee/src/dbs/postgres/subscriptions/dbes.py @@ -0,0 +1,24 @@ +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import relationship + + +from oss.src.dbs.postgres.shared.base import Base +from ee.src.dbs.postgres.subscriptions.dbas import SubscriptionDBA + + +from sqlalchemy import PrimaryKeyConstraint, Index, func + + +from ee.src.dbs.postgres.meters.dbas import MeterDBA + + +class SubscriptionDBE(Base, SubscriptionDBA): + __tablename__ = "subscriptions" + + __table_args__ = ( + PrimaryKeyConstraint( + "organization_id", + ), + ) + + meters = relationship("MeterDBE", back_populates="subscription") diff --git a/api/ee/src/dbs/postgres/subscriptions/mappings.py b/api/ee/src/dbs/postgres/subscriptions/mappings.py new file mode 100644 index 0000000000..b8d0b4e8b5 --- /dev/null +++ b/api/ee/src/dbs/postgres/subscriptions/mappings.py @@ -0,0 +1,26 @@ +from ee.src.core.subscriptions.types import SubscriptionDTO +from ee.src.dbs.postgres.subscriptions.dbes import SubscriptionDBE + +from ee.src.core.subscriptions.types import Plan + + +def map_dbe_to_dto(subscription_dbe: SubscriptionDBE) -> SubscriptionDTO: + return SubscriptionDTO( + organization_id=subscription_dbe.organization_id, + customer_id=subscription_dbe.customer_id, + subscription_id=subscription_dbe.subscription_id, + plan=Plan(subscription_dbe.plan), + active=subscription_dbe.active, + anchor=subscription_dbe.anchor, + ) + + +def map_dto_to_dbe(subscription_dto: SubscriptionDTO) -> SubscriptionDBE: + return SubscriptionDBE( + organization_id=subscription_dto.organization_id, + customer_id=subscription_dto.customer_id, + subscription_id=subscription_dto.subscription_id, + plan=subscription_dto.plan.value, + active=subscription_dto.active or False, + anchor=subscription_dto.anchor, + ) diff --git a/api/ee/src/main.py b/api/ee/src/main.py new file mode 100644 index 0000000000..86d8ecf618 --- /dev/null +++ b/api/ee/src/main.py @@ -0,0 +1,123 @@ +from fastapi import FastAPI + +from oss.src.utils.logging import get_module_logger + +from ee.src.routers import ( + workspace_router, + organization_router, + evaluation_router, + human_evaluation_router, +) + +from ee.src.dbs.postgres.meters.dao import MetersDAO +from ee.src.dbs.postgres.subscriptions.dao import SubscriptionsDAO + +from ee.src.core.meters.service import MetersService +from ee.src.core.subscriptions.service import SubscriptionsService + +from ee.src.apis.fastapi.billing.router import SubscriptionsRouter + +# DBS -------------------------------------------------------------------------- + +meters_dao = MetersDAO() + +subscriptions_dao = SubscriptionsDAO() + +# CORE ------------------------------------------------------------------------- + +meters_service = MetersService( + meters_dao=meters_dao, +) + +subscription_service = SubscriptionsService( + subscriptions_dao=subscriptions_dao, + meters_service=meters_service, +) + +# APIS ------------------------------------------------------------------------- + +subscriptions_router = SubscriptionsRouter( + subscription_service=subscription_service, +) + + +log = get_module_logger(__name__) + + +def extend_main(app: FastAPI): + # ROUTES ------------------------------------------------------------------- + + app.include_router( + router=subscriptions_router.router, + prefix="/billing", + tags=["Billing"], + ) + + app.include_router( + router=subscriptions_router.admin_router, + prefix="/admin/billing", + tags=["Admin", "Billing"], + ) + + # ROUTES (more) ------------------------------------------------------------ + + app.include_router( + organization_router.router, + prefix="/organizations", + ) + + app.include_router( + workspace_router.router, + prefix="/workspaces", + ) + + app.include_router( + evaluation_router.router, + prefix="/evaluations", + tags=["Evaluations"], + ) + + app.include_router( + human_evaluation_router.router, + prefix="/human-evaluations", + tags=["Human-Evaluations"], + ) + + # -------------------------------------------------------------------------- + + return app + + +def load_tasks(): + import ee.src.tasks.evaluations.live + import ee.src.tasks.evaluations.legacy + import ee.src.tasks.evaluations.batch + + +def extend_app_schema(app: FastAPI): + app.openapi()["info"]["title"] = "Agenta API" + app.openapi()["info"]["description"] = "Agenta API" + app.openapi()["info"]["contact"] = { + "name": "Agenta", + "url": "https://agenta.ai", + "email": "team@agenta.ai", + } + app.openapi()["components"]["securitySchemes"] = { + "APIKeyHeader": { + "type": "apiKey", + "name": "Authorization", + "in": "header", + } + } + app.openapi()["security"] = [ + { + "APIKeyHeader": [], + }, + ] + app.openapi()["servers"] = [ + { + "url": "https://cloud.agenta.ai/api", + }, + ] + + return app diff --git a/api/ee/src/models/api/api_models.py b/api/ee/src/models/api/api_models.py new file mode 100644 index 0000000000..f15c8ffacc --- /dev/null +++ b/api/ee/src/models/api/api_models.py @@ -0,0 +1,72 @@ +from typing import Optional, List +from pydantic import BaseModel, Field +from datetime import datetime, timezone + +from oss.src.models.api.api_models import ( + CreateApp, + AppVariant, + Environment, + AppVariantResponse, + AppVariantOutputExtended, + EnvironmentOutput, + EnvironmentRevision, + EnvironmentOutputExtended, +) + + +class TimestampModel(BaseModel): + created_at: str = Field(str(datetime.now(timezone.utc))) + updated_at: str = Field(str(datetime.now(timezone.utc))) + + +class InviteRequest(BaseModel): + email: str + roles: List[str] + + +class ReseendInviteRequest(BaseModel): + email: str + + +class InviteToken(BaseModel): + token: str + + +class CreateApp_(CreateApp): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None + + +class AppVariant_(AppVariant): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None + + +class Environment_(Environment): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None + + +class AppVariantResponse_(AppVariantResponse): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None + + +class AppVariantOutputExtended_(AppVariantOutputExtended): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None + + +class EnvironmentOutput_(EnvironmentOutput): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None + + +class EnvironmentRevision_(EnvironmentRevision): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None + + +class EnvironmentOutputExtended_(EnvironmentOutputExtended): + organization_id: Optional[str] = None + workspace_id: Optional[str] = None diff --git a/api/ee/src/models/api/organization_models.py b/api/ee/src/models/api/organization_models.py new file mode 100644 index 0000000000..1ce05a65fc --- /dev/null +++ b/api/ee/src/models/api/organization_models.py @@ -0,0 +1,33 @@ +from typing import Optional, List + +from pydantic import BaseModel, Field + + +class Organization(BaseModel): + id: str + name: str + description: str + type: Optional[str] = None + owner: str + workspaces: List[str] = Field(default_factory=list) + members: List[str] = Field(default_factory=list) + invitations: List = Field(default_factory=list) + is_paying: Optional[bool] = None + + +class CreateOrganization(BaseModel): + name: str + owner: str + description: Optional[str] = None + type: Optional[str] = None + + +class OrganizationUpdate(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + updated_at: Optional[str] = None + + +class OrganizationOutput(BaseModel): + id: str + name: str diff --git a/api/ee/src/models/api/user_models.py b/api/ee/src/models/api/user_models.py new file mode 100644 index 0000000000..8a0d702ad8 --- /dev/null +++ b/api/ee/src/models/api/user_models.py @@ -0,0 +1,9 @@ +from typing import List + +from pydantic import Field + +from oss.src.models.api.user_models import User + + +class User_(User): + organizations: List[str] = Field(default_factory=list) diff --git a/api/ee/src/models/api/workspace_models.py b/api/ee/src/models/api/workspace_models.py new file mode 100644 index 0000000000..56218eb38a --- /dev/null +++ b/api/ee/src/models/api/workspace_models.py @@ -0,0 +1,58 @@ +from datetime import datetime +from typing import Optional, List, Dict + +from pydantic import BaseModel + +from ee.src.models.api.api_models import TimestampModel +from ee.src.models.shared_models import WorkspaceRole, Permission + + +class WorkspacePermission(BaseModel): + role_name: WorkspaceRole + role_description: Optional[str] = None + permissions: Optional[List[Permission]] = None + + +class WorkspaceMember(BaseModel): + user_id: str + roles: List[WorkspacePermission] + + +class WorkspaceMemberResponse(BaseModel): + user: Dict + roles: List[WorkspacePermission] + + +class Workspace(BaseModel): + id: Optional[str] = None + name: str + description: Optional[str] = None + type: Optional[str] + members: Optional[List[WorkspaceMember]] = None + + +class WorkspaceResponse(TimestampModel): + id: str + name: str + description: Optional[str] = None + type: Optional[str] + organization: str + members: Optional[List[WorkspaceMemberResponse]] = None + + +class CreateWorkspace(BaseModel): + name: str + description: Optional[str] = None + type: Optional[str] = None + + +class UserRole(BaseModel): + email: str + organization_id: str + role: Optional[str] = None + + +class UpdateWorkspace(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + updated_at: Optional[datetime] = None diff --git a/api/ee/src/models/db_models.py b/api/ee/src/models/db_models.py new file mode 100644 index 0000000000..f09b9e0324 --- /dev/null +++ b/api/ee/src/models/db_models.py @@ -0,0 +1,518 @@ +from typing import Optional, List, Sequence +from datetime import datetime, timezone + +import uuid_utils.compat as uuid +from sqlalchemy.orm import relationship, backref +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey, Integer + +from ee.src.models.shared_models import ( + WorkspaceRole, + Permission, +) +from oss.src.models.db_models import ( + ProjectDB as OssProjectDB, + WorkspaceDB as OssWorkspaceDB, + OrganizationDB as OssOrganizationDB, + DeploymentDB as OssDeploymentDB, + # dependency + CASCADE_ALL_DELETE, + mutable_json_type, +) +from oss.src.dbs.postgres.shared.base import Base +from oss.src.dbs.postgres.observability.dbes import NodesDBE + + +class OrganizationDB(OssOrganizationDB): + is_paying = Column(Boolean, nullable=True, default=False) + + organization_members = relationship( + "OrganizationMemberDB", back_populates="organization" + ) + project = relationship( + "ee.src.models.db_models.ProjectDB", + back_populates="organization", + overlaps="organization", + ) + + +class WorkspaceDB(OssWorkspaceDB): + pass + + members = relationship("WorkspaceMemberDB", back_populates="workspace") + projects = relationship( + "ee.src.models.db_models.ProjectDB", + cascade="all, delete-orphan", + back_populates="workspace", + overlaps="workspace", + ) + organization = relationship( + "ee.src.models.db_models.OrganizationDB", back_populates="workspaces_relation" + ) + + def get_member_role(self, user_id: str) -> Optional[str]: + member: Optional[WorkspaceMemberDB] = next( + (member for member in self.members if str(member.user_id) == user_id), + None, + ) + return member.role if member else None # type: ignore + + def get_member_role_name(self, user_id: str) -> Optional[str]: + role = self.get_member_role(user_id) + return role + + def get_all_members(self) -> List[str]: + return [str(member.user_id) for member in self.members] + + def get_member_with_roles(self, user_id: str) -> Optional["WorkspaceMemberDB"]: + return next( + (member for member in self.members if str(member.user_id) == user_id), + None, + ) + + def get_member_permissions(self, user_id: str) -> List[Permission]: + user_role = self.get_member_role(user_id) + if user_role: + return Permission.default_permissions(user_role) + return [] + + def has_permission(self, user_id: str, permission: Permission) -> bool: + user_role = self.get_member_role(user_id) + if user_role and permission in Permission.default_permissions(user_role): + return True + return False + + def has_role(self, user_id: str, role_to_check: WorkspaceRole) -> bool: + user_role = self.get_member_role(user_id) + if user_role: + return user_role == role_to_check + return False + + def is_owner(self, user_id: str) -> bool: + return any( + str(member.user_id) == user_id + and WorkspaceRole.OWNER == self.get_member_role_name(user_id) + for member in self.members + ) + + +class ProjectDB(OssProjectDB): + workspace = relationship( + "ee.src.models.db_models.WorkspaceDB", + back_populates="projects", + overlaps="projects", + ) + organization = relationship( + "ee.src.models.db_models.OrganizationDB", + back_populates="project", + ) + project_members = relationship( + "ProjectMemberDB", cascade="all, delete-orphan", back_populates="project" + ) + invitations = relationship( + "InvitationDB", cascade="all, delete-orphan", back_populates="project" + ) + + def get_member_role( + self, user_id: str, members: Sequence["ProjectMemberDB"] + ) -> Optional[str]: + member: Optional["ProjectMemberDB"] = next( + (member for member in members if str(member.user_id) == user_id), + None, + ) + return member.role if member else None # type: ignore + + def get_member_role_name( + self, user_id: str, members: Sequence["ProjectMemberDB"] + ) -> Optional[str]: + role = self.get_member_role(user_id=user_id, members=members) + return role + + def get_all_members(self) -> List[str]: + return [str(member.user_id) for member in self.project_members] + + def get_member_with_roles(self, user_id: str) -> Optional["ProjectMemberDB"]: + return next( + ( + member + for member in self.project_members + if str(member.user_id) == user_id + ), + None, + ) + + def get_member_permissions( + self, user_id: str, members: Sequence["ProjectMemberDB"] + ) -> List[Permission]: + user_role = self.get_member_role(user_id, members) + if user_role: + return Permission.default_permissions(user_role) + return [] + + def has_permission( + self, user_id: str, permission: Permission, members: Sequence["ProjectMemberDB"] + ) -> bool: + user_role = self.get_member_role(user_id, members) + if user_role and permission in Permission.default_permissions(user_role): + return True + return False + + def has_role( + self, + user_id: str, + role_to_check: WorkspaceRole, + members: Sequence["ProjectMemberDB"], + ) -> bool: + user_role = self.get_member_role(user_id, members) + if user_role: + return user_role == role_to_check + return False + + def is_owner(self, user_id: str, members: Sequence["ProjectMemberDB"]) -> bool: + return any( + str(member.user_id) == user_id + and WorkspaceRole.OWNER == self.get_member_role_name(user_id, members) + for member in members + ) + + +class WorkspaceMemberDB(Base): + __tablename__ = "workspace_members" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id")) + role = Column(String, default="viewer") + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + user = relationship( + "UserDB", backref=backref("workspace_memberships", lazy="dynamic") + ) + workspace = relationship( + "ee.src.models.db_models.WorkspaceDB", back_populates="members" + ) + + +class OrganizationMemberDB(Base): + __tablename__ = "organization_members" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + organization_id = Column(UUID(as_uuid=True), ForeignKey("organizations.id")) + + user = relationship( + "UserDB", backref=backref("organization_members", lazy="dynamic") + ) + organization = relationship( + "ee.src.models.db_models.OrganizationDB", back_populates="organization_members" + ) + + +class ProjectMemberDB(Base): + __tablename__ = "project_members" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + project_id = Column(UUID(as_uuid=True), ForeignKey("projects.id")) + role = Column(String, default="viewer") + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + is_demo = Column(Boolean, nullable=True) + + user = relationship("UserDB") + project = relationship("ee.src.models.db_models.ProjectDB") + + +class DeploymentDB(OssDeploymentDB): + pass + + +class HumanEvaluationVariantDB(Base): + __tablename__ = "human_evaluation_variants" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + human_evaluation_id = Column( + UUID(as_uuid=True), ForeignKey("human_evaluations.id", ondelete="CASCADE") + ) + variant_id = Column( + UUID(as_uuid=True), ForeignKey("app_variants.id", ondelete="SET NULL") + ) + variant_revision_id = Column( + UUID(as_uuid=True), ForeignKey("app_variant_revisions.id", ondelete="SET NULL") + ) + + variant = relationship("AppVariantDB", backref="evaluation_variant") + variant_revision = relationship( + "AppVariantRevisionsDB", backref="evaluation_variant_revision" + ) + + +class HumanEvaluationDB(Base): + __tablename__ = "human_evaluations" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + app_id = Column(UUID(as_uuid=True), ForeignKey("app_db.id", ondelete="CASCADE")) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + status = Column(String) + evaluation_type = Column(String) + testset_id = Column(UUID(as_uuid=True), ForeignKey("testsets.id")) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + testset = relationship("TestSetDB") + evaluation_variant = relationship( + "HumanEvaluationVariantDB", + cascade=CASCADE_ALL_DELETE, + backref="human_evaluation", + ) + evaluation_scenario = relationship( + "HumanEvaluationScenarioDB", + cascade=CASCADE_ALL_DELETE, + backref="evaluation_scenario", + ) + + +class HumanEvaluationScenarioDB(Base): + __tablename__ = "human_evaluations_scenarios" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + evaluation_id = Column( + UUID(as_uuid=True), ForeignKey("human_evaluations.id", ondelete="CASCADE") + ) + inputs = Column( + mutable_json_type(dbtype=JSONB, nested=True) + ) # List of HumanEvaluationScenarioInput + outputs = Column( + mutable_json_type(dbtype=JSONB, nested=True) + ) # List of HumanEvaluationScenarioOutput + vote = Column(String) + score = Column(String) + correct_answer = Column(String) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + is_pinned = Column(Boolean) + note = Column(String) + + +class EvaluationAggregatedResultDB(Base): + __tablename__ = "auto_evaluation_aggregated_results" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + evaluation_id = Column( + UUID(as_uuid=True), ForeignKey("auto_evaluations.id", ondelete="CASCADE") + ) + evaluator_config_id = Column( + UUID(as_uuid=True), + ForeignKey("auto_evaluator_configs.id", ondelete="SET NULL"), + ) + result = Column(mutable_json_type(dbtype=JSONB, nested=True)) # Result + + evaluator_config = relationship("EvaluatorConfigDB", backref="evaluator_config") + + +class EvaluationScenarioResultDB(Base): + __tablename__ = "auto_evaluation_scenario_results" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + evaluation_scenario_id = Column( + UUID(as_uuid=True), + ForeignKey("auto_evaluation_scenarios.id", ondelete="CASCADE"), + ) + evaluator_config_id = Column( + UUID(as_uuid=True), + ForeignKey("auto_evaluator_configs.id", ondelete="SET NULL"), + ) + result = Column(mutable_json_type(dbtype=JSONB, nested=True)) # Result + + +class EvaluationDB(Base): + __tablename__ = "auto_evaluations" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + app_id = Column(UUID(as_uuid=True), ForeignKey("app_db.id", ondelete="CASCADE")) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + status = Column(mutable_json_type(dbtype=JSONB, nested=True)) # Result + testset_id = Column( + UUID(as_uuid=True), ForeignKey("testsets.id", ondelete="SET NULL") + ) + variant_id = Column( + UUID(as_uuid=True), ForeignKey("app_variants.id", ondelete="SET NULL") + ) + variant_revision_id = Column( + UUID(as_uuid=True), ForeignKey("app_variant_revisions.id", ondelete="SET NULL") + ) + average_cost = Column(mutable_json_type(dbtype=JSONB, nested=True)) # Result + total_cost = Column(mutable_json_type(dbtype=JSONB, nested=True)) # Result + average_latency = Column(mutable_json_type(dbtype=JSONB, nested=True)) # Result + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + project = relationship("ee.src.models.db_models.ProjectDB") + testset = relationship("TestSetDB") + variant = relationship("AppVariantDB") + variant_revision = relationship("AppVariantRevisionsDB") + aggregated_results = relationship( + "EvaluationAggregatedResultDB", + cascade=CASCADE_ALL_DELETE, + backref="evaluation", + ) + evaluation_scenarios = relationship( + "EvaluationScenarioDB", cascade=CASCADE_ALL_DELETE, backref="evaluation" + ) + evaluator_configs = relationship( + "EvaluationEvaluatorConfigDB", + cascade=CASCADE_ALL_DELETE, + backref="evaluation", + ) + + +class EvaluationEvaluatorConfigDB(Base): + __tablename__ = "auto_evaluation_evaluator_configs" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + evaluation_id = Column( + UUID(as_uuid=True), + ForeignKey("auto_evaluations.id", ondelete="CASCADE"), + primary_key=True, + ) + evaluator_config_id = Column( + UUID(as_uuid=True), + ForeignKey("auto_evaluator_configs.id", ondelete="SET NULL"), + primary_key=True, + ) + + +class EvaluationScenarioDB(Base): + __tablename__ = "auto_evaluation_scenarios" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + evaluation_id = Column( + UUID(as_uuid=True), ForeignKey("auto_evaluations.id", ondelete="CASCADE") + ) + variant_id = Column( + UUID(as_uuid=True), ForeignKey("app_variants.id", ondelete="SET NULL") + ) + inputs = Column( + mutable_json_type(dbtype=JSONB, nested=True) + ) # List of EvaluationScenarioInput + outputs = Column( + mutable_json_type(dbtype=JSONB, nested=True) + ) # List of EvaluationScenarioOutput + correct_answers = Column( + mutable_json_type(dbtype=JSONB, nested=True) + ) # List of CorrectAnswer + is_pinned = Column(Boolean) + note = Column(String) + latency = Column(Integer) + cost = Column(Integer) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + project = relationship("ee.src.models.db_models.ProjectDB") + variant = relationship("AppVariantDB") + results = relationship( + "EvaluationScenarioResultDB", + cascade=CASCADE_ALL_DELETE, + backref="evaluation_scenario", + ) diff --git a/api/ee/src/models/extended/deprecated_models.py b/api/ee/src/models/extended/deprecated_models.py new file mode 100644 index 0000000000..c68a07e851 --- /dev/null +++ b/api/ee/src/models/extended/deprecated_models.py @@ -0,0 +1,101 @@ +from datetime import datetime, timezone + +import uuid_utils.compat as uuid + +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean, Integer + + +DeprecatedBase = declarative_base() + + +class DeprecatedAppDB(DeprecatedBase): + __tablename__ = "app_db" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + app_name = Column(String) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + modified_by_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=True) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + +class DeprecatedAPIKeyDB(DeprecatedBase): + __tablename__ = "api_keys" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + prefix = Column(String) + hashed_key = Column(String) + user_id = Column(String, nullable=True) + workspace_id = Column(String, nullable=True) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), nullable=True + ) + created_by_id = Column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + rate_limit = Column(Integer, default=0) + hidden = Column(Boolean, default=False) + expiration_date = Column(DateTime(timezone=True), nullable=True) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + updated_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) + + +class UserOrganizationDB(DeprecatedBase): + __tablename__ = "user_organizations" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id")) + organization_id = Column(UUID(as_uuid=True), ForeignKey("organizations.id")) + + +class OldInvitationDB(DeprecatedBase): + __tablename__ = "invitations" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + token = Column(String, unique=True, nullable=False) + email = Column(String, nullable=False) + organization_id = Column(String, nullable=False) + used = Column(Boolean, default=False) + workspace_id = Column(String, nullable=False) + workspace_roles = Column(JSONB, nullable=True) + expiration_date = Column(DateTime(timezone=True), nullable=True) + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + ) diff --git a/api/ee/src/models/extended/deprecated_transfer_models.py b/api/ee/src/models/extended/deprecated_transfer_models.py new file mode 100644 index 0000000000..3657dddacd --- /dev/null +++ b/api/ee/src/models/extended/deprecated_transfer_models.py @@ -0,0 +1,347 @@ +from datetime import datetime, timezone + +import uuid_utils.compat as uuid + +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, String, DateTime, Boolean, ForeignKey + + +DeprecatedBase = declarative_base() + + +class WorkspaceDB(DeprecatedBase): + __tablename__ = "workspaces" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + + +class OrganizationDB(DeprecatedBase): + __tablename__ = "organizations" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + + +class ProjectDB(DeprecatedBase): + __tablename__ = "projects" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + is_default = Column(Boolean, default=False) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class AppDB(DeprecatedBase): + __tablename__ = "app_db" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class AppVariantDB(DeprecatedBase): + __tablename__ = "app_variants" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class AppVariantRevisionsDB(DeprecatedBase): + __tablename__ = "app_variant_revisions" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + variant_id = Column( + UUID(as_uuid=True), ForeignKey("app_variants.id", ondelete="CASCADE") + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + + +class VariantBaseDB(DeprecatedBase): + __tablename__ = "bases" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class DeploymentDB(DeprecatedBase): + __tablename__ = "deployments" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class AppEnvironmentDB(DeprecatedBase): + __tablename__ = "environments" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class AppEnvironmentRevisionDB(DeprecatedBase): + __tablename__ = "environments_revisions" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class EvaluationScenarioDB(DeprecatedBase): + __tablename__ = "evaluation_scenarios" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + evaluation_id = Column( + UUID(as_uuid=True), ForeignKey("evaluations.id", ondelete="CASCADE") + ) + variant_id = Column( + UUID(as_uuid=True), ForeignKey("app_variants.id", ondelete="SET NULL") + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class EvaluationDB(DeprecatedBase): + __tablename__ = "evaluations" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class EvaluatorConfigDB(DeprecatedBase): + __tablename__ = "evaluators_configs" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + + app_id = Column(UUID(as_uuid=True), ForeignKey("app_db.id", ondelete="CASCADE")) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class HumanEvaluationDB(DeprecatedBase): + __tablename__ = "human_evaluations" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class HumanEvaluationScenarioDB(DeprecatedBase): + __tablename__ = "human_evaluations_scenarios" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) + + +class TestSetDB(DeprecatedBase): + __tablename__ = "testsets" + __table_args__ = {"extend_existing": True} + + id = Column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid7, + unique=True, + nullable=False, + ) + project_id = Column( + UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE") + ) + workspace_id = Column( + UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="SET NULL") + ) + organization_id = Column( + UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL") + ) diff --git a/api/ee/src/models/shared_models.py b/api/ee/src/models/shared_models.py new file mode 100644 index 0000000000..4f7ed234da --- /dev/null +++ b/api/ee/src/models/shared_models.py @@ -0,0 +1,200 @@ +from enum import Enum +from typing import List + +from pydantic import BaseModel, Field + + +class WorkspaceRole(str, Enum): + OWNER = "owner" + VIEWER = "viewer" + EDITOR = "editor" + EVALUATOR = "evaluator" + WORKSPACE_ADMIN = "workspace_admin" + DEPLOYMENT_MANAGER = "deployment_manager" + + @classmethod + def is_valid_role(cls, role: str) -> bool: + return role.upper() in list(WorkspaceRole.__members__.keys()) + + @classmethod + def get_description(cls, role): + descriptions = { + cls.OWNER: "Can fully manage the workspace, including adding and removing members.", + cls.VIEWER: "Can view the workspace content but cannot make changes.", + cls.EDITOR: "Can edit workspace content, but cannot manage members or roles.", + cls.EVALUATOR: "Can evaluate models and provide feedback within the workspace.", + cls.WORKSPACE_ADMIN: "Can manage workspace settings and members but cannot delete the workspace.", + cls.DEPLOYMENT_MANAGER: "Can manage model deployments within the workspace.", + } + return descriptions.get(role, "Description not available, Role not found") + + +class Permission(str, Enum): + # general + READ_SYSTEM = "read_system" + + # App and variants + VIEW_APPLICATIONS = "view_applications" + EDIT_APPLICATIONS = "edit_application" + + CREATE_APP_VARIANT = "create_app_variant" + DELETE_APP_VARIANT = "delete_app_variant" + + MODIFY_VARIANT_CONFIGURATIONS = "modify_variant_configurations" + EDIT_APPLICATIONS_VARIANT = "delete_application_variant" + + # Service + RUN_SERVICE = "run_service" + + # Vault Secret + CREATE_SECRET = "create_secret" + VIEW_SECRET = "view_secret" + UPDATE_SECRET = "update_secret" + DELETE_SECRET = "delete_secret" + + # App environment deployment + VIEW_APP_ENVIRONMENT_DEPLOYMENT = "view_app_environment_deployment" + EDIT_APP_ENVIRONMENT_DEPLOYMENT = "edit_app_environment_deployment" + CREATE_APP_ENVIRONMENT_DEPLOYMENT = "create_app_environment_deployment" + + # Testset + VIEW_TESTSET = "view_testset" + EDIT_TESTSET = "edit_testset" + CREATE_TESTSET = "create_testset" + DELETE_TESTSET = "delete_testset" + + # Evaluation + VIEW_EVALUATION = "view_evaluation" + RUN_EVALUATIONS = "run_evaluations" + EDIT_EVALUATION = "edit_evaluation" + CREATE_EVALUATION = "create_evaluation" + DELETE_EVALUATION = "delete_evaluation" + + # Deployment + DEPLOY_APPLICATION = "deploy_application" + + # Workspace + VIEW_WORKSPACE = "view_workspace" + EDIT_WORKSPACE = "edit_workspace" + CREATE_WORKSPACE = "create_workspace" + DELETE_WORKSPACE = "delete_workspace" + MODIFY_USER_ROLES = "modify_user_roles" + ADD_USER_TO_WORKSPACE = "add_new_user_to_workspace" + + # Organization + EDIT_ORGANIZATION = "edit_organization" + DELETE_ORGANIZATION = "delete_organization" + ADD_USER_TO_ORGANIZATION = "add_new_user_to_organization" + + # User Profile + RESET_PASSWORD = "reset_password" + + # Billing (Plans, Subscriptions, Usage, etc) + VIEW_BILLING = "view_billing" + EDIT_BILLING = "edit_billing" + + # Workflows + VIEW_WORKFLOWS = "view_workflows" + EDIT_WORKFLOWS = "edit_workflows" + RUN_WORKFLOWS = "run_workflows" + + # Evaluators + VIEW_EVALUATORS = "view_evaluators" + EDIT_EVALUATORS = "edit_evaluators" + + # Queries + VIEW_QUERIES = "view_queries" + EDIT_QUERIES = "edit_queries" + + # Testsets + VIEW_TESTSETS = "view_testsets" + EDIT_TESTSETS = "edit_testsets" + + # Annotations + VIEW_ANNOTATIONS = "view_annotations" + EDIT_ANNOTATIONS = "edit_annotations" + + # Invocations + VIEW_INVOCATIONS = "view_invocations" + EDIT_INVOCATIONS = "edit_invocations" + + # Evaluations + VIEW_EVALUATION_RUNS = "view_evaluation_runs" + EDIT_EVALUATION_RUNS = "edit_evaluation_runs" + + VIEW_EVALUATION_SCENARIOS = "view_evaluation_scenarios" + EDIT_EVALUATION_SCENARIOS = "edit_evaluation_scenarios" + + VIEW_EVALUATION_RESULTS = "view_evaluation_results" + EDIT_EVALUATION_RESULTS = "edit_evaluation_results" + + VIEW_EVALUATION_METRICS = "view_evaluation_metrics" + EDIT_EVALUATION_METRICS = "edit_evaluation_metrics" + + VIEW_EVALUATION_QUEUES = "view_evaluation_queues" + EDIT_EVALUATION_QUEUES = "edit_evaluation_queues" + + @classmethod + def default_permissions(cls, role): + VIEWER_PERMISSIONS = [ + cls.READ_SYSTEM, + cls.VIEW_APPLICATIONS, + cls.VIEW_SECRET, + cls.VIEW_APP_ENVIRONMENT_DEPLOYMENT, + cls.VIEW_TESTSET, + cls.VIEW_EVALUATION, + cls.RUN_SERVICE, + cls.VIEW_BILLING, + # + cls.VIEW_WORKFLOWS, + cls.VIEW_EVALUATORS, + cls.VIEW_TESTSETS, + cls.VIEW_ANNOTATIONS, + ] + defaults = { + WorkspaceRole.OWNER: [p for p in cls], + WorkspaceRole.VIEWER: VIEWER_PERMISSIONS, + WorkspaceRole.EDITOR: [ + p + for p in cls + if p + not in [ + cls.DELETE_SECRET, + cls.RESET_PASSWORD, + cls.DELETE_TESTSET, + cls.DELETE_WORKSPACE, + cls.CREATE_WORKSPACE, + cls.EDIT_ORGANIZATION, + cls.DELETE_EVALUATION, + cls.MODIFY_USER_ROLES, + cls.EDIT_APPLICATIONS, + cls.DELETE_ORGANIZATION, + cls.ADD_USER_TO_WORKSPACE, + cls.ADD_USER_TO_ORGANIZATION, + cls.EDIT_BILLING, + ] + ], + WorkspaceRole.DEPLOYMENT_MANAGER: VIEWER_PERMISSIONS + + [cls.DEPLOY_APPLICATION], + WorkspaceRole.WORKSPACE_ADMIN: [ + p + for p in cls + if p + not in [ + cls.DELETE_WORKSPACE, + cls.DELETE_ORGANIZATION, + cls.EDIT_ORGANIZATION, + cls.ADD_USER_TO_ORGANIZATION, + cls.EDIT_BILLING, + ] + ], + WorkspaceRole.EVALUATOR: VIEWER_PERMISSIONS + + [cls.CREATE_EVALUATION, cls.RUN_EVALUATIONS], + } + + return defaults.get(role, []) + + +class WorkspaceMember(BaseModel): + role_name: WorkspaceRole + permissions: List[Permission] = Field(default_factory=list) diff --git a/api/ee/src/routers/evaluation_router.py b/api/ee/src/routers/evaluation_router.py new file mode 100644 index 0000000000..2cf6dc1da0 --- /dev/null +++ b/api/ee/src/routers/evaluation_router.py @@ -0,0 +1,519 @@ +from typing import Any, List +import random + +from fastapi.responses import JSONResponse +from fastapi import HTTPException, Request, status, Response, Query + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.caching import get_cache, set_cache + +from ee.src.services import converters +from ee.src.services import evaluation_service + +from ee.src.tasks.evaluations.legacy import ( + setup_evaluation, + annotate, +) +from oss.src.utils.common import APIRouter, is_ee +from oss.src.models.api.evaluation_model import ( + Evaluation, + EvaluationScenario, + NewEvaluation, + DeleteEvaluation, +) +from ee.src.services import db_manager_ee +from oss.src.services import app_manager, db_manager + +if is_ee(): + from ee.src.models.shared_models import Permission + from ee.src.utils.permissions import check_action_access + from ee.src.utils.entitlements import ( + check_entitlements, + Tracker, + Counter, + NOT_ENTITLED_RESPONSE, + ) + +from oss.src.routers.testset_router import _validate_testset_limits + + +from oss.src.apis.fastapi.evaluations.models import EvaluationRunsResponse + + +router = APIRouter() + + +log = get_module_logger(__name__) + + +@router.get( + "/by_resource/", + response_model=List[str], +) +async def fetch_evaluation_ids( + resource_type: str, + request: Request, + resource_ids: List[str] = Query(None), +): + """Fetches evaluation ids for a given resource type and id. + + Arguments: + app_id (str): The ID of the app for which to fetch evaluations. + resource_type (str): The type of resource for which to fetch evaluations. + resource_ids List[ObjectId]: The IDs of resource for which to fetch evaluations. + + Raises: + HTTPException: If the resource_type is invalid or access is denied. + + Returns: + List[str]: A list of evaluation ids. + """ + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + evaluations = await db_manager_ee.fetch_evaluations_by_resource( + resource_type, + request.state.project_id, + resource_ids, + ) + return list(map(lambda x: str(x.id), evaluations)) + + +@router.get( + "/{evaluation_id}/status/", + operation_id="fetch_evaluation_status", +) +async def fetch_evaluation_status( + evaluation_id: str, + request: Request, +): + """Fetches the status of the evaluation. + + Args: + evaluation_id (str): the evaluation id + request (Request): the request object + + Returns: + (str): the evaluation status + """ + + cache_key = { + "evaluation_id": evaluation_id, + } + + evaluation_status = await get_cache( + project_id=request.state.project_id, + namespace="fetch_evaluation_status", + key=cache_key, + retry=False, + ) + + if evaluation_status is not None: + return {"status": evaluation_status} + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + + evaluation_status = await db_manager_ee.fetch_evaluation_status_by_id( + project_id=request.state.project_id, + evaluation_id=evaluation_id, + ) + + await set_cache( + project_id=request.state.project_id, + namespace="fetch_evaluation_status", + key=cache_key, + value=evaluation_status, + ttl=15, # 15 seconds + ) + + return {"status": evaluation_status} + + +@router.get( + "/{evaluation_id}/results/", + operation_id="fetch_legacy_evaluation_results", +) +async def fetch_evaluation_results( + evaluation_id: str, + request: Request, +): + """Fetches the results of the evaluation + + Args: + evaluation_id (str): the evaluation id + request (Request): the request object + + Returns: + _type_: _description_ + """ + + evaluation = await db_manager_ee.fetch_evaluation_by_id( + project_id=request.state.project_id, + evaluation_id=evaluation_id, + ) + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + + results = converters.aggregated_result_of_evaluation_to_pydantic( + evaluation.aggregated_results # type: ignore + ) + return {"results": results, "evaluation_id": evaluation_id} + + +@router.get( + "/{evaluation_id}/evaluation_scenarios/", + response_model=List[EvaluationScenario], + operation_id="fetch_legacy_evaluation_scenarios", +) +async def fetch_evaluation_scenarios( + evaluation_id: str, + request: Request, +): + """Fetches evaluation scenarios for a given evaluation ID. + + Arguments: + evaluation_id (str): The ID of the evaluation for which to fetch scenarios. + + Raises: + HTTPException: If the evaluation is not found or access is denied. + + Returns: + List[EvaluationScenario]: A list of evaluation scenarios. + """ + + evaluation = await db_manager_ee.fetch_evaluation_by_id( + project_id=request.state.project_id, + evaluation_id=evaluation_id, + ) + if not evaluation: + raise HTTPException( + status_code=404, detail=f"Evaluation with id {evaluation_id} not found" + ) + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + + eval_scenarios = await evaluation_service.fetch_evaluation_scenarios_for_evaluation( + evaluation_id=str(evaluation.id), project_id=str(evaluation.project_id) + ) + return eval_scenarios + + +@router.get( + "/", + response_model=List[Evaluation], + operation_id="fetch_legacy_evaluations", +) +async def fetch_list_evaluations( + app_id: str, + request: Request, +): + """Fetches a list of evaluations, optionally filtered by an app ID. + + Args: + app_id (Optional[str]): An optional app ID to filter the evaluations. + + Returns: + List[Evaluation]: A list of evaluations. + """ + + app = await db_manager.fetch_app_by_id(app_id) + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(app.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + + return await evaluation_service.fetch_list_evaluations(app, str(app.project_id)) + + +@router.get( + "/{evaluation_id}/", + response_model=Evaluation, + operation_id="fetch_legacy_evaluation", +) +async def fetch_evaluation( + evaluation_id: str, + request: Request, +): + """Fetches a single evaluation based on its ID. + + Args: + evaluation_id (str): The ID of the evaluation to fetch. + + Returns: + Evaluation: The fetched evaluation. + """ + + evaluation = await db_manager_ee.fetch_evaluation_by_id( + project_id=request.state.project_id, + evaluation_id=evaluation_id, + ) + if not evaluation: + raise HTTPException( + status_code=404, detail=f"Evaluation with id {evaluation_id} not found" + ) + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + + return await converters.evaluation_db_to_pydantic(evaluation) + + +@router.delete( + "/", + response_model=List[str], + operation_id="delete_legacy_evaluations", +) +async def delete_evaluations( + payload: DeleteEvaluation, + request: Request, +): + """ + Delete specific comparison tables based on their unique IDs. + + Args: + delete_evaluations (List[str]): The unique identifiers of the comparison tables to delete. + + Returns: + A list of the deleted comparison tables' IDs. + """ + + evaluation = await db_manager_ee.fetch_evaluation_by_id( + project_id=request.state.project_id, + evaluation_id=payload.evaluations_ids[0], + ) + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation.project_id), + permission=Permission.DELETE_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + + # Update last_modified_by app information + await app_manager.update_last_modified_by( + user_uid=request.state.user_id, + object_id=random.choice(payload.evaluations_ids), + object_type="evaluation", + project_id=str(evaluation.project_id), + ) + + await evaluation_service.delete_evaluations(payload.evaluations_ids) + return Response(status_code=status.HTTP_204_NO_CONTENT) + + +@router.get( + "/evaluation_scenarios/comparison-results/", + response_model=Any, + operation_id="fetch_legacy_evaluation_scenarios_comparison_results", +) +async def fetch_evaluation_scenarios_comparison_results( + evaluations_ids: str, + request: Request, +): + """Fetches evaluation scenarios for a given evaluation ID. + + Arguments: + evaluation_id (str): The ID of the evaluation for which to fetch scenarios. + + Raises: + HTTPException: If the evaluation is not found or access is denied. + + Returns: + List[EvaluationScenario]: A list of evaluation scenarios. + """ + + evaluations_ids_list = evaluations_ids.split(",") + evaluation = await db_manager_ee.fetch_evaluation_by_id( + project_id=request.state.project_id, + evaluation_id=evaluations_ids_list[0], + ) + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your organization admin." + log.error(error_msg) + return JSONResponse( + {"detail": error_msg}, + status_code=403, + ) + + eval_scenarios = await evaluation_service.compare_evaluations_scenarios( + evaluations_ids_list, str(evaluation.project_id) + ) + + return eval_scenarios + + +@router.post( + "/preview/start", + response_model=EvaluationRunsResponse, + operation_id="start_evaluation", +) +async def start_evaluation( + request: Request, + payload: NewEvaluation, +) -> EvaluationRunsResponse: + try: + if is_ee(): + # Permissions Check ------------------------------------------------ + check = await check_action_access( + project_id=request.state.project_id, + user_uid=request.state.user_id, + permission=Permission.CREATE_EVALUATION, + ) + if not check: + raise HTTPException( + status_code=403, + detail="You do not have permission to perform this action. Please contact your organization admin.", + ) + # ------------------------------------------------------------------ + + # Entitlements Check ----------------------------------------------- + check, _, _ = await check_entitlements( + organization_id=request.state.organization_id, + key=Counter.EVALUATIONS, + delta=1, + ) + + if not check: + return NOT_ENTITLED_RESPONSE(Tracker.COUNTERS) + # ------------------------------------------------------------------ + + # Input Validation ----------------------------------------------------- + nof_runs = len(payload.revisions_ids) + + if nof_runs == 0: + raise HTTPException( + status_code=400, + detail="No revisions provided for evaluation. Please provide at least one revision.", + ) + # ---------------------------------------------------------------------- + + # Evaluation Run Execution --------------------------------------------- + runs = [] + + for i in range(nof_runs): + run = await setup_evaluation( + project_id=request.state.project_id, + user_id=request.state.user_id, + # + name=payload.name, + # + testset_id=payload.testset_id, + # + revision_id=payload.revisions_ids[i], + # + autoeval_ids=payload.evaluators_configs, + ) + + if not run: + continue + + runs.append(run) + + annotate.delay( + project_id=request.state.project_id, + user_id=request.state.user_id, + # + run_id=run.id, + # + testset_id=payload.testset_id, + # + revision_id=payload.revisions_ids[i], + # + autoeval_ids=payload.evaluators_configs, + # + run_config=payload.rate_limit.model_dump(mode="json"), + ) + # ---------------------------------------------------------------------- + + runs_response = EvaluationRunsResponse( + count=len(runs), + runs=runs, + ) + + return runs_response + + except KeyError as e: + log.error(e, exc_info=True) + + raise HTTPException( + status_code=400, + detail="Columns in the test set should match the names of the inputs in the variant", + ) from e diff --git a/api/ee/src/routers/human_evaluation_router.py b/api/ee/src/routers/human_evaluation_router.py new file mode 100644 index 0000000000..3b2904062c --- /dev/null +++ b/api/ee/src/routers/human_evaluation_router.py @@ -0,0 +1,460 @@ +from typing import List, Dict +from fastapi import HTTPException, Body, Request, status, Response + +from oss.src.utils.logging import get_module_logger +from ee.src.services import converters +from oss.src.services import db_manager +from ee.src.services import db_manager_ee +from ee.src.services import results_service +from ee.src.services import evaluation_service +from oss.src.utils.common import APIRouter, is_ee +from oss.src.models.api.evaluation_model import ( + DeleteEvaluation, + EvaluationScenarioScoreUpdate, + HumanEvaluation, + HumanEvaluationScenario, + HumanEvaluationScenarioUpdate, + EvaluationType, + HumanEvaluationUpdate, + NewHumanEvaluation, + SimpleEvaluationOutput, +) +from ee.src.services.evaluation_service import ( + update_human_evaluation_scenario, + update_human_evaluation_service, +) + +if is_ee(): + from ee.src.models.shared_models import ( + Permission, + ) # noqa pylint: disable-all + from ee.src.utils.permissions import ( + check_action_access, + ) # noqa pylint: disable-all + + +router = APIRouter() + +log = get_module_logger(__name__) + + +@router.post( + "/", response_model=SimpleEvaluationOutput, operation_id="create_human_evaluation" +) +async def create_human_evaluation( + payload: NewHumanEvaluation, + request: Request, +): + """Creates a new comparison table document + Raises: + HTTPException: _description_ + Returns: + _description_ + """ + + try: + app = await db_manager.fetch_app_by_id(app_id=payload.app_id) + if app is None: + raise HTTPException(status_code=404, detail="App not found") + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(app.project_id), + permission=Permission.CREATE_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + new_human_evaluation_db = await evaluation_service.create_new_human_evaluation( + payload + ) + return await converters.human_evaluation_db_to_simple_evaluation_output( + new_human_evaluation_db + ) + except KeyError: + raise HTTPException( + status_code=400, + detail="columns in the test set should match the names of the inputs in the variant", + ) + + +@router.get("/", response_model=List[HumanEvaluation]) +async def fetch_list_human_evaluations( + app_id: str, + request: Request, +): + """Fetches a list of evaluations, optionally filtered by an app ID. + + Args: + app_id (Optional[str]): An optional app ID to filter the evaluations. + + Returns: + List[HumanEvaluation]: A list of evaluations. + """ + + app = await db_manager.fetch_app_by_id(app_id=app_id) + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(app.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + return await evaluation_service.fetch_list_human_evaluations( + app_id, str(app.project_id) + ) + + +@router.get("/{evaluation_id}/", response_model=HumanEvaluation) +async def fetch_human_evaluation( + evaluation_id: str, + request: Request, +): + """Fetches a single evaluation based on its ID. + + Args: + evaluation_id (str): The ID of the evaluation to fetch. + + Returns: + HumanEvaluation: The fetched evaluation. + """ + + human_evaluation = await db_manager_ee.fetch_human_evaluation_by_id(evaluation_id) + if not human_evaluation: + raise HTTPException(status_code=404, detail="Evaluation not found") + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(human_evaluation.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + return await evaluation_service.fetch_human_evaluation(human_evaluation) + + +@router.get( + "/{evaluation_id}/evaluation_scenarios/", + response_model=List[HumanEvaluationScenario], + operation_id="fetch_human_evaluation_scenarios", +) +async def fetch_human_evaluation_scenarios( + evaluation_id: str, + request: Request, +): + """Fetches evaluation scenarios for a given evaluation ID. + + Arguments: + evaluation_id (str): The ID of the evaluation for which to fetch scenarios. + + Raises: + HTTPException: If the evaluation is not found or access is denied. + + Returns: + List[EvaluationScenario]: A list of evaluation scenarios. + """ + + human_evaluation = await db_manager_ee.fetch_human_evaluation_by_id(evaluation_id) + if human_evaluation is None: + raise HTTPException( + status_code=404, + detail=f"Evaluation with id {evaluation_id} not found", + ) + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(human_evaluation.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + eval_scenarios = ( + await evaluation_service.fetch_human_evaluation_scenarios_for_evaluation( + human_evaluation + ) + ) + + return eval_scenarios + + +@router.put("/{evaluation_id}/", operation_id="update_human_evaluation") +async def update_human_evaluation( + request: Request, + evaluation_id: str, + update_data: HumanEvaluationUpdate = Body(...), +): + """Updates an evaluation's status. + + Raises: + HTTPException: If the columns in the test set do not match with the inputs in the variant. + + Returns: + None: A 204 No Content status code, indicating that the update was successful. + """ + + try: + human_evaluation = await db_manager_ee.fetch_human_evaluation_by_id( + evaluation_id + ) + if not human_evaluation: + raise HTTPException(status_code=404, detail="Evaluation not found") + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(human_evaluation.project_id), + permission=Permission.EDIT_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + await update_human_evaluation_service(human_evaluation, update_data) + return Response(status_code=status.HTTP_204_NO_CONTENT) + + except KeyError: + raise HTTPException( + status_code=400, + detail="columns in the test set should match the names of the inputs in the variant", + ) + + +@router.put( + "/{evaluation_id}/evaluation_scenario/{evaluation_scenario_id}/{evaluation_type}/" +) +async def update_evaluation_scenario_router( + evaluation_id: str, + evaluation_scenario_id: str, + evaluation_type: EvaluationType, + payload: HumanEvaluationScenarioUpdate, + request: Request, +): + """Updates an evaluation scenario's vote or score based on its type. + + Raises: + HTTPException: If update fails or unauthorized. + + Returns: + None: 204 No Content status code upon successful update. + """ + + evaluation_scenario_db = await db_manager_ee.fetch_human_evaluation_scenario_by_id( + evaluation_scenario_id + ) + if evaluation_scenario_db is None: + raise HTTPException( + status_code=404, + detail=f"Evaluation scenario with id {evaluation_scenario_id} not found", + ) + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation_scenario_db.project_id), + permission=Permission.EDIT_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + await update_human_evaluation_scenario( + evaluation_scenario_db, + payload, + evaluation_type, + ) + return Response(status_code=status.HTTP_204_NO_CONTENT) + + +@router.get("/evaluation_scenario/{evaluation_scenario_id}/score/") +async def get_evaluation_scenario_score_router( + evaluation_scenario_id: str, + request: Request, +) -> Dict[str, str]: + """ + Fetch the score of a specific evaluation scenario. + + Args: + evaluation_scenario_id: The ID of the evaluation scenario to fetch. + + Returns: + Dictionary containing the scenario ID and its score. + """ + + evaluation_scenario = db_manager_ee.fetch_evaluation_scenario_by_id( + evaluation_scenario_id + ) + if evaluation_scenario is None: + raise HTTPException( + status_code=404, + detail=f"Evaluation scenario with id {evaluation_scenario_id} not found", + ) + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation_scenario.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + return { + "scenario_id": str(evaluation_scenario.id), + "score": evaluation_scenario.score, + } + + +@router.put("/evaluation_scenario/{evaluation_scenario_id}/score/") +async def update_evaluation_scenario_score_router( + evaluation_scenario_id: str, + payload: EvaluationScenarioScoreUpdate, + request: Request, +): + """Updates the score of an evaluation scenario. + + Raises: + HTTPException: Server error if the evaluation update fails. + + Returns: + None: 204 No Content status code upon successful update. + """ + + evaluation_scenario = await db_manager_ee.fetch_evaluation_scenario_by_id( + evaluation_scenario_id + ) + if evaluation_scenario is None: + raise HTTPException( + status_code=404, + detail=f"Evaluation scenario with id {evaluation_scenario_id} not found", + ) + + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation_scenario.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + await db_manager.update_human_evaluation_scenario( + evaluation_scenario_id=str(evaluation_scenario.id), # type: ignore + values_to_update=payload.model_dump(), + ) + return Response(status_code=status.HTTP_204_NO_CONTENT) + + +@router.get("/{evaluation_id}/results/", operation_id="fetch_results") +async def fetch_results( + evaluation_id: str, + request: Request, +): + """Fetch all the results for one the comparison table + + Arguments: + evaluation_id -- _description_ + + Returns: + _description_ + """ + + evaluation = await db_manager_ee.fetch_human_evaluation_by_id(evaluation_id) + if evaluation is None: + raise HTTPException( + status_code=404, + detail=f"Evaluation with id {evaluation_id} not found", + ) + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation.project_id), + permission=Permission.VIEW_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + if evaluation.evaluation_type == EvaluationType.human_a_b_testing: + results = await results_service.fetch_results_for_evaluation(evaluation) + return {"votes_data": results} + + elif evaluation.evaluation_type == EvaluationType.single_model_test: + results = await results_service.fetch_results_for_single_model_test( + evaluation_id + ) + return {"results_data": results} + + +@router.delete("/", response_model=List[str]) +async def delete_evaluations( + payload: DeleteEvaluation, + request: Request, +): + """ + Delete specific comparison tables based on their unique IDs. + + Args: + payload (List[str]): The unique identifiers of the comparison tables to delete. + + Returns: + A list of the deleted comparison tables' IDs. + """ + + evaluation = await db_manager_ee.fetch_human_evaluation_by_id( + payload.evaluations_ids[0] + ) + if is_ee(): + has_permission = await check_action_access( + user_uid=request.state.user_id, + project_id=str(evaluation.project_id), + permission=Permission.DELETE_EVALUATION, + ) + if not has_permission: + error_msg = f"You do not have permission to perform this action. Please contact your Organization Admin." + raise HTTPException( + detail=error_msg, + status_code=403, + ) + + await evaluation_service.delete_human_evaluations(payload.evaluations_ids) + return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/api/ee/src/routers/organization_router.py b/api/ee/src/routers/organization_router.py new file mode 100644 index 0000000000..7b265a692a --- /dev/null +++ b/api/ee/src/routers/organization_router.py @@ -0,0 +1,239 @@ +from fastapi.responses import JSONResponse +from fastapi import HTTPException, Request + +from oss.src.utils.logging import get_module_logger +from oss.src.services import db_manager +from ee.src.services import db_manager_ee +from oss.src.utils.common import APIRouter +from ee.src.services import workspace_manager +from ee.src.models.db_models import Permission +from ee.src.services.selectors import ( + get_user_own_org, + get_user_org_and_workspace_id, +) +from ee.src.models.api.workspace_models import ( + CreateWorkspace, + UpdateWorkspace, + WorkspaceResponse, +) +from ee.src.utils.permissions import ( + check_user_org_access, + check_rbac_permission, +) +from ee.src.models.api.organization_models import ( + CreateOrganization, + OrganizationUpdate, + OrganizationOutput, +) +from ee.src.services.organization_service import ( + update_an_organization, + get_organization_details, +) + + +router = APIRouter() + +log = get_module_logger(__name__) + + +@router.get("/own/", response_model=OrganizationOutput, operation_id="get_own_org") +async def get_user_organization( + request: Request, +): + try: + user_org_workspace_data: dict = await get_user_org_and_workspace_id( + request.state.user_id + ) + org_db = await get_user_own_org(user_uid=user_org_workspace_data["uid"]) + if org_db is None: + raise HTTPException(404, detail="User does not have an organization") + + return OrganizationOutput(id=str(org_db.id), name=org_db.name) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/{org_id}/", operation_id="fetch_ee_organization_details") +async def fetch_organization_details( + org_id: str, + request: Request, +): + """Get an organization's details. + + Raises: + HTTPException: _description_ + Permission Denied + + Returns: + OrganizationDB Instance + """ + + try: + workspace_id = await db_manager_ee.get_default_workspace_id_from_organization( + organization_id=org_id + ) + + project_id = await db_manager.get_default_project_id_from_workspace( + workspace_id=workspace_id + ) + + project_memberships = await db_manager_ee.fetch_project_memberships_by_user_id( + user_id=str(request.state.user_id) + ) + + membership = None + for project_membership in project_memberships: + if str(project_membership.project_id) == project_id: + membership = project_membership + break + + if not membership: + return JSONResponse( + status_code=403, + content={"detail": "You do not have access to this organization"}, + ) + + user_org_workspace_data = await get_user_org_and_workspace_id( + request.state.user_id + ) + has_permission = await check_user_org_access(user_org_workspace_data, org_id) + if not has_permission: + return JSONResponse( + status_code=403, + content={"detail": "You do not have access to this organization"}, + ) + + organization = await get_organization_details(org_id) + + if membership.role == "viewer" or membership.is_demo: + if "default_workspace" in organization: + organization["default_workspace"].members = [] + + return organization + + except Exception as e: + import traceback + + traceback.print_exc() + raise HTTPException( + status_code=500, + detail=str(e), + ) + + +@router.put("/{org_id}/", operation_id="update_organization") +async def update_organization( + org_id: str, + payload: OrganizationUpdate, + request: Request, +): + if not payload.name and not payload.description: + return JSONResponse( + {"detail": "Please provide a name or description to update"}, + status_code=400, + ) + + try: + user_org_workspace_data: dict = await get_user_org_and_workspace_id( + request.state.user_id + ) + has_permission = await check_user_org_access( + user_org_workspace_data, org_id, check_owner=True + ) + if not has_permission: + return JSONResponse( + {"detail": "You do not have permission to perform this action"}, + status_code=403, + ) + + organization = await update_an_organization(org_id, payload) + + return organization + + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) + + +@router.post( + "/{org_id}/workspaces/", + operation_id="create_workspace", + response_model=WorkspaceResponse, +) +async def create_workspace( + org_id: str, + payload: CreateWorkspace, + request: Request, +) -> WorkspaceResponse: + try: + user_org_workspace_data: dict = await get_user_org_and_workspace_id( + request.state.user_id + ) + has_permission = await check_user_org_access( + user_org_workspace_data, org_id, check_owner=True + ) + if not has_permission: + return JSONResponse( + {"detail": "You do not have permission to perform this action"}, + status_code=403, + ) + + if not payload.name: + return JSONResponse( + {"detail": "Please provide a name to create a workspace"}, + status_code=400, + ) + workspace = await workspace_manager.create_new_workspace( + payload, org_id, user_org_workspace_data["uid"] + ) + return workspace + + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) + + +@router.put( + "/{org_id}/workspaces/{workspace_id}/", + operation_id="update_workspace", + response_model=WorkspaceResponse, +) +async def update_workspace( + org_id: str, + workspace_id: str, + payload: UpdateWorkspace, + request: Request, +) -> WorkspaceResponse: + try: + user_org_workspace_data: dict = await get_user_org_and_workspace_id( + request.state.user_id + ) + project = await db_manager_ee.get_project_by_workspace(workspace_id) + has_permission = await check_rbac_permission( + user_org_workspace_data=user_org_workspace_data, + project_id=str(project.id), + permission=Permission.EDIT_WORKSPACE, + ) + if not has_permission: + return JSONResponse( + {"detail": "You do not have permission to update this workspace"}, + status_code=403, + ) + + if not payload.name and not payload.description: + return JSONResponse( + {"detail": "Please provide a name or description to update"}, + status_code=400, + ) + workspace = await workspace_manager.update_workspace(payload, workspace_id) + return workspace + + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) diff --git a/api/ee/src/routers/workspace_router.py b/api/ee/src/routers/workspace_router.py new file mode 100644 index 0000000000..40e0e17885 --- /dev/null +++ b/api/ee/src/routers/workspace_router.py @@ -0,0 +1,173 @@ +from typing import List + +from fastapi import HTTPException, Request +from fastapi.responses import JSONResponse + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.common import APIRouter +from ee.src.utils.permissions import check_rbac_permission +from ee.src.services import workspace_manager, db_manager_ee +from ee.src.services.selectors import get_user_org_and_workspace_id + +from ee.src.models.api.workspace_models import ( + UserRole, + Permission, + WorkspaceRole, +) + +router = APIRouter() + +log = get_module_logger(__name__) + + +@router.get( + "/permissions/", + operation_id="get_all_workspace_permissions", + response_model=List[Permission], +) +async def get_all_workspace_permissions() -> List[Permission]: + """ + Get all workspace permissions. + + Returns a list of all available workspace permissions. + + Returns: + List[Permission]: A list of Permission objects representing the available workspace permissions. + + Raises: + HTTPException: If there is an error retrieving the workspace permissions. + + """ + try: + workspace_permissions = await workspace_manager.get_all_workspace_permissions() + return sorted(workspace_permissions) + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) + + +@router.post("/{workspace_id}/roles/", operation_id="assign_role_to_user") +async def assign_role_to_user( + payload: UserRole, + workspace_id: str, + request: Request, +): + """ + Assigns a role to a user in a workspace. + + Args: + payload (UserRole): The payload containing the organization id, user email, and role to assign. + workspace_id (str): The ID of the workspace. + request (Request): The FastAPI request object. + + Returns: + bool: True if the role was successfully assigned, False otherwise. + + Raises: + HTTPException: If the user does not have permission to perform this action. + HTTPException: If there is an error assigning the role to the user. + """ + + try: + user_org_workspace_data = await get_user_org_and_workspace_id( + request.state.user_id + ) + project = await db_manager_ee.get_project_by_workspace(workspace_id) + has_permission = await check_rbac_permission( + user_org_workspace_data=user_org_workspace_data, + project_id=str(project.id), + role=WorkspaceRole.WORKSPACE_ADMIN, + ) + if not has_permission: + return JSONResponse( + status_code=403, + content={ + "detail": "You do not have permission to perform this action. Please contact your Organization Owner" + }, + ) + + if not WorkspaceRole.is_valid_role(payload.role): # type: ignore + return JSONResponse( + status_code=400, content={"detail": "Workspace role is invalid."} + ) + + create_user_role = await db_manager_ee.update_user_roles( + workspace_id, + payload, + ) + return create_user_role + except HTTPException as ex: + raise ex + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) + + +@router.delete("/{workspace_id}/roles/", operation_id="unassign_role_from_user") +async def unassign_role_from_user( + email: str, + org_id: str, + role: str, + workspace_id: str, + request: Request, +): + """ + Delete a role assignment from a user in a workspace. + + Args: + workspace_id (str): The ID of the workspace. + email (str): The email of the user to remove the role from. + org_id (str): The ID of the organization. + role (str): The role to remove from the user. + request (Request): The FastAPI request object. + + Returns: + bool: True if the role assignment was successfully deleted. + + Raises: + HTTPException: If there is an error in the request or the user does not have permission to perform the action. + HTTPException: If there is an error in updating the user's roles. + + """ + try: + user_org_workspace_data = await get_user_org_and_workspace_id( + request.state.user_id + ) + project = await db_manager_ee.get_project_by_workspace(workspace_id) + has_permission = await check_rbac_permission( + user_org_workspace_data=user_org_workspace_data, + project_id=str(project.id), + role=WorkspaceRole.WORKSPACE_ADMIN, + ) + if not has_permission: + return JSONResponse( + status_code=403, + content={ + "detail": "You do not have permission to perform this action. Please contact your Organization Owner" + }, + ) + + payload = UserRole( + email=email, + organization_id=org_id, + role=role, + ) + + delete_user_role = await db_manager_ee.update_user_roles( + workspace_id, + payload, + delete=True, + ) + + return delete_user_role + except HTTPException as ex: + raise ex + except Exception as e: + raise HTTPException( + status_code=500, + detail=str(e), + ) diff --git a/api/ee/src/services/admin_manager.py b/api/ee/src/services/admin_manager.py new file mode 100644 index 0000000000..57af9d8ef6 --- /dev/null +++ b/api/ee/src/services/admin_manager.py @@ -0,0 +1,404 @@ +from typing import Optional, Literal, Any +from uuid import UUID + +from pydantic import BaseModel +import uuid_utils.compat as uuid +from sqlalchemy.future import select + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.common import is_ee + +from oss.src.dbs.postgres.shared.engine import engine + +from oss.src.models.db_models import UserDB +from oss.src.services.api_key_service import create_api_key + +from ee.src.models.db_models import ( + WorkspaceDB, + ProjectDB, + OrganizationDB, + ProjectMemberDB as ProjectMembershipDB, + WorkspaceMemberDB as WorkspaceMembershipDB, + OrganizationMemberDB as OrganizationMembershipDB, +) + +log = get_module_logger(__name__) + + +class Reference(BaseModel): + id: Optional[UUID] = None + slug: Optional[str] = None + + class Config: + json_encoders = {UUID: str} + + def encode(self, data: Any) -> Any: + if isinstance(data, dict): + return {k: self.encode(v) for k, v in data.items()} + elif isinstance(data, list): + return [self.encode(item) for item in data] + for type_, encoder in self.Config.json_encoders.items(): + if isinstance(data, type_): + return encoder(data) + return data + + def model_dump(self, *args, **kwargs) -> dict: + kwargs.setdefault("exclude_none", True) + + return self.encode(super().model_dump(*args, **kwargs)) + + +class UserRequest(BaseModel): + name: str + email: str + + +Tier = str + + +class OrganizationRequest(BaseModel): + name: str + description: str + is_paying: bool + + +class WorkspaceRequest(BaseModel): + name: str + description: str + is_default: bool + # + organization_ref: Reference + + +class ProjectRequest(BaseModel): + name: str + description: str + is_default: bool + # + workspace_ref: Reference + organization_ref: Reference + + +OrganizationRole = Literal[ + "owner", + "viewer", + "editor", + "evaluator", + "workspace_admin", + "deployment_manager", +] # update list + + +class OrganizationMembershipRequest(BaseModel): + role: OrganizationRole + is_demo: bool + # + user_ref: Reference + organization_ref: Reference + + +WorkspaceRole = Literal[ # update list + "owner", + "viewer", + "editor", + "evaluator", + "workspace_admin", + "deployment_manager", +] + + +class WorkspaceMembershipRequest(BaseModel): + role: WorkspaceRole + is_demo: bool + # + user_ref: Reference + workspace_ref: Reference + + +ProjectRole = Literal[ # update list + "owner", + "viewer", + "editor", + "evaluator", + "workspace_admin", + "deployment_manager", +] + + +class ProjectMembershipRequest(BaseModel): + role: ProjectRole + is_demo: bool + # + user_ref: Reference + project_ref: Reference + + +Credentials = str + + +async def check_user( + request: UserRequest, +) -> Optional[UserRequest]: + async with engine.core_session() as session: + result = await session.execute( + select(UserDB).filter_by( + email=request.email, + ) + ) + + user_db = result.scalars().first() + + reference = Reference(id=user_db.id) if user_db else None + + return reference + + +async def create_user( + request: UserRequest, +) -> Reference: + async with engine.core_session() as session: + user_db = UserDB( + # id=uuid7() # use default + # + uid=str(uuid.uuid7()), + username=request.name, # rename to 'name' + email=request.email, + ) + + session.add(user_db) + + log.info( + "[scopes] user created", + user_id=user_db.id, + ) + + await session.commit() + + response = Reference(id=user_db.id) + + return response + + +async def create_organization( + request: OrganizationRequest, +) -> Reference: + async with engine.core_session() as session: + organization_db = OrganizationDB( + # id=uuid7() # use default + # + name=request.name, + description=request.description, + # + owner="", # move 'owner' from here to membership 'role' + # type=... # remove 'type' + ) + + if is_ee(): + organization_db.is_paying = request.is_paying + + session.add(organization_db) + + log.info( + "[scopes] organization created", + organization_id=organization_db.id, + ) + + await session.commit() + + response = Reference(id=organization_db.id) + + return response + + +async def create_workspace( + request: WorkspaceRequest, +) -> Reference: + async with engine.core_session() as session: + workspace_db = WorkspaceDB( + # id=uuid7() # use default + # + name=request.name, + description=request.description, + type=("default" if request.is_default else None), # rename to 'is_default' + # + organization_id=request.organization_ref.id, + ) + + session.add(workspace_db) + + log.info( + "[scopes] workspace created", + organization_id=workspace_db.organization_id, + workspace_id=workspace_db.id, + ) + + await session.commit() + + response = Reference(id=workspace_db.id) + + return response + + +async def create_project( + request: ProjectRequest, +) -> Reference: + async with engine.core_session() as session: + project_db = ProjectDB( + # id=uuid7() # use default + # + project_name=request.name, # rename to 'name' + # description=... # missing 'description' + is_default=request.is_default, + # + workspace_id=request.workspace_ref.id, + organization_id=request.organization_ref.id, + ) + + session.add(project_db) + + log.info( + "[scopes] project created", + organization_id=project_db.organization_id, + workspace_id=project_db.workspace_id, + project_id=project_db.id, + ) + + await session.commit() + + response = Reference(id=project_db.id) + + return response + + +async def create_organization_membership( + request: OrganizationMembershipRequest, +) -> Reference: + async with engine.core_session() as session: + membership_db = OrganizationMembershipDB( + # id=uuid7() # use default + # + # role=request.role, # move 'owner' from organization to here as 'role' + # is_demo=request.is_demo, # add 'is_demo' + # + user_id=request.user_ref.id, + organization_id=request.organization_ref.id, + ) + + session.add(membership_db) + + log.info( + "[scopes] organization membership created", + organization_id=request.organization_ref.id, + user_id=request.user_ref.id, + membership_id=membership_db.id, + ) + + await session.commit() + + if request.role == "owner": + result = await session.execute( + select(OrganizationDB).filter_by( + id=request.organization_ref.id, + ) + ) + + organization_db = result.scalars().first() + + organization_db.owner = str(request.user_ref.id) + + await session.commit() + + response = Reference(id=membership_db.id) + + return response + + +async def create_workspace_membership( + request: WorkspaceMembershipRequest, +) -> Reference: + async with engine.core_session() as session: + workspace = await session.execute( + select(WorkspaceDB).filter_by( + id=request.workspace_ref.id, + ) + ) + workspace_db = workspace.scalars().first() + + membership_db = WorkspaceMembershipDB( + # id=uuid7() # use default + # + role=request.role, + # is_demo=request.is_demo, # add 'is_demo' + # + user_id=request.user_ref.id, + workspace_id=request.workspace_ref.id, + ) + + session.add(membership_db) + + log.info( + "[scopes] workspace membership created", + organization_id=workspace_db.organization_id, + workspace_id=request.workspace_ref.id, + user_id=request.user_ref.id, + membership_id=membership_db.id, + ) + + await session.commit() + + response = Reference(id=membership_db.id) + + return response + + +async def create_project_membership( + request: ProjectMembershipRequest, +) -> Reference: + async with engine.core_session() as session: + project = await session.execute( + select(ProjectDB).filter_by( + id=request.project_ref.id, + ) + ) + project_db = project.scalars().first() + + membership_db = ProjectMembershipDB( + # id=uuid7() # use default + # + role=request.role, + is_demo=request.is_demo, + # + user_id=request.user_ref.id, + project_id=request.project_ref.id, + ) + + session.add(membership_db) + + log.info( + "[scopes] project membership created", + organization_id=project_db.organization_id, + workspace_id=project_db.workspace_id, + project_id=request.project_ref.id, + user_id=request.user_ref.id, + membership_id=membership_db.id, + ) + + await session.commit() + + response = Reference(id=membership_db.id) + + return response + + +async def create_credentials( + user_id: UUID, + project_id: UUID, +) -> Credentials: + apikey_token = await create_api_key( + user_id=str(user_id), + project_id=str(project_id), + ) + + credentials = f"ApiKey {apikey_token}" + + return credentials diff --git a/api/ee/src/services/aggregation_service.py b/api/ee/src/services/aggregation_service.py new file mode 100644 index 0000000000..55a14e5f8f --- /dev/null +++ b/api/ee/src/services/aggregation_service.py @@ -0,0 +1,135 @@ +import re +import traceback +from typing import List, Optional + +from oss.src.models.shared_models import InvokationResult, Result, Error + + +def aggregate_ai_critique(results: List[Result]) -> Result: + """Aggregates the results for the ai critique evaluation. + + Args: + results (List[Result]): list of result objects + + Returns: + Result: aggregated result + """ + + try: + numeric_scores = [] + for result in results: + # Extract the first number found in the result value + match = re.search(r"\d+", result.value) + if match: + try: + score = int(match.group()) + numeric_scores.append(score) + except ValueError: + # Ignore if the extracted value is not an integer + continue + + # Calculate the average of numeric scores if any are present + average_value = ( + sum(numeric_scores) / len(numeric_scores) if numeric_scores else None + ) + return Result( + type="number", + value=average_value, + ) + except Exception as exc: + return Result( + type="error", + value=None, + error=Error(message=str(exc), stacktrace=str(traceback.format_exc())), + ) + + +def aggregate_binary(results: List[Result]) -> Result: + """Aggregates the results for the binary (auto regex) evaluation. + + Args: + results (List[Result]): list of result objects + + Returns: + Result: aggregated result + """ + + if all(isinstance(result.value, bool) for result in results): + average_value = sum(int(result.value) for result in results) / len(results) + else: + average_value = None + return Result(type="number", value=average_value) + + +def aggregate_float(results: List[Result]) -> Result: + """Aggregates the results for evaluations aside from auto regex and ai critique. + + Args: + results (List[Result]): list of result objects + + Returns: + Result: aggregated result + """ + + try: + average_value = sum(result.value for result in results) / len(results) + return Result(type="number", value=average_value) + except Exception as exc: + return Result( + type="error", + value=None, + error=Error(message=str(exc), stacktrace=str(traceback.format_exc())), + ) + + +def aggregate_float_from_llm_app_response( + invocation_results: List[InvokationResult], key: Optional[str] +) -> Result: + try: + if not key: + raise ValueError("Key is required to aggregate InvokationResult objects.") + + values = [ + getattr(inv_result, key) + for inv_result in invocation_results + if hasattr(inv_result, key) and getattr(inv_result, key) is not None + ] + + if not values: + return Result(type=key, value=None) + + average_value = sum(values) / len(values) + return Result(type=key, value=average_value) + except Exception as exc: + return Result( + type="error", + value=None, + error=Error(message=str(exc), stacktrace=str(traceback.format_exc())), + ) + + +def sum_float_from_llm_app_response( + invocation_results: List[InvokationResult], key: Optional[str] +) -> Result: + try: + if not key: + raise ValueError("Key is required to aggregate InvokationResult objects.") + + values = [ + getattr(inv_result, key) + for inv_result in invocation_results + if hasattr(inv_result, key) and getattr(inv_result, key) is not None + ] + + if not values: + return Result(type=key, value=None) + + total_value = sum(values) + + return Result(type=key, value=total_value) + except Exception as exc: + return Result( + type="error", + value=None, + error=Error(message=str(exc), stacktrace=str(traceback.format_exc())), + ) diff --git a/api/ee/src/services/commoners.py b/api/ee/src/services/commoners.py new file mode 100644 index 0000000000..45e5643d78 --- /dev/null +++ b/api/ee/src/services/commoners.py @@ -0,0 +1,179 @@ +from os import getenv +from json import loads +from typing import List +from traceback import format_exc + +from pydantic import BaseModel + +from oss.src.utils.logging import get_module_logger +from oss.src.services import db_manager +from oss.src.utils.common import is_ee +from ee.src.services import workspace_manager +from ee.src.services.db_manager_ee import ( + create_organization, + add_user_to_organization, + add_user_to_workspace, + add_user_to_project, +) +from ee.src.services.selectors import ( + user_exists, +) +from ee.src.models.api.organization_models import CreateOrganization +from oss.src.services.user_service import create_new_user +from ee.src.services.email_helper import ( + add_contact_to_loops, +) + +log = get_module_logger(__name__) + +from ee.src.dbs.postgres.subscriptions.dao import SubscriptionsDAO +from ee.src.core.subscriptions.service import SubscriptionsService +from ee.src.dbs.postgres.meters.dao import MetersDAO +from ee.src.core.meters.service import MetersService + +subscription_service = SubscriptionsService( + subscriptions_dao=SubscriptionsDAO(), + meters_service=MetersService( + meters_dao=MetersDAO(), + ), +) + +from ee.src.utils.entitlements import check_entitlements, Gauge + +DEMOS = "AGENTA_DEMOS" +DEMO_ROLE = "viewer" + + +class Demo(BaseModel): + organization_id: str + workspace_id: str + project_id: str + + +async def list_all_demos() -> List[Demo]: + demos = [] + + try: + demo_project_ids = loads(getenv(DEMOS) or "[]") + + for project_id in demo_project_ids: + project = await db_manager.get_project_by_id(project_id) + + try: + demos.append( + Demo( + organization_id=str(project.organization_id), + workspace_id=str(project.workspace_id), + project_id=str(project.id), + ) + ) + + except: # pylint: disable=bare-except + log.error(format_exc()) + + except: # pylint: disable=bare-except + log.error(format_exc()) + + return demos + + +async def add_user_to_demos(user_id: str) -> None: + try: + demos = await list_all_demos() + + for organization_id in {demo.organization_id for demo in demos}: + await add_user_to_organization( + organization_id, + user_id, + # is_demo=True, + ) + + for workspace_id in {demo.workspace_id for demo in demos}: + await add_user_to_workspace( + workspace_id, + user_id, + DEMO_ROLE, + # is_demo=True, + ) + + for project_id in {demo.project_id for demo in demos}: + await add_user_to_project( + project_id, + user_id, + DEMO_ROLE, + is_demo=True, + ) + + except Exception as exc: + raise exc # TODO: handle exceptions + + +async def create_accounts(payload: dict): + """Creates a user account and an associated organization based on the + provided payload. + + Arguments: + payload (dict): The required payload. It consists of; user_id and user_email + """ + + user_dict = { + **payload, + "username": payload["email"].split("@")[0], + } + + user = await db_manager.get_user_with_email(email=user_dict["email"]) + if user is None: + log.info("[scopes] Yey! A new user is signing up!") + + # Create user first + user = await create_new_user(user_dict) + + log.info("[scopes] User [%s] created", user.id) + + # Prepare payload to create organization + create_org_payload = CreateOrganization( + name=user_dict["username"], + description="My Default Organization", + owner=str(user.id), + type="default", + ) + + # Create the user's default organization and workspace + organization = await create_organization( + payload=create_org_payload, + user=user, + ) + + log.info("[scopes] Organization [%s] created", organization.id) + + # Add the user to demos + await add_user_to_demos(str(user.id)) + + # Start reverse trial + try: + await subscription_service.start_reverse_trial( + organization_id=str(organization.id), + organization_name=organization.name, + organization_email=user_dict["email"], + ) + + except Exception as exc: + raise exc # TODO: handle exceptions + # await subscription_service.start_free_plan( + # organization_id=str(organization.id), + # ) + + await check_entitlements( + organization_id=str(organization.id), + key=Gauge.USERS, + delta=1, + ) + + log.info("[scopes] User [%s] authenticated", user.id) + + if is_ee(): + try: + # Adds contact to loops for marketing emails. TODO: Add opt-in checkbox to supertokens + add_contact_to_loops(user_dict["email"]) # type: ignore + except ConnectionError as ex: + log.warn("Error adding contact to loops %s", ex) diff --git a/api/ee/src/services/converters.py b/api/ee/src/services/converters.py new file mode 100644 index 0000000000..5b120899fc --- /dev/null +++ b/api/ee/src/services/converters.py @@ -0,0 +1,321 @@ +import uuid +from typing import List, Dict, Any +from datetime import datetime, timezone + +from oss.src.services import db_manager +from oss.src.models.api.evaluation_model import ( + CorrectAnswer, + Evaluation, + HumanEvaluation, + EvaluationScenario, + SimpleEvaluationOutput, + EvaluationScenarioInput, + HumanEvaluationScenario, + EvaluationScenarioOutput, +) +from ee.src.services import db_manager_ee +from ee.src.models.api.workspace_models import ( + WorkspaceRole, + WorkspaceResponse, +) +from ee.src.models.shared_models import Permission +from ee.src.models.db_models import ( + EvaluationDB, + HumanEvaluationDB, + EvaluationScenarioDB, + HumanEvaluationScenarioDB, +) +from oss.src.models.db_models import WorkspaceDB + + +async def get_workspace_in_format( + workspace: WorkspaceDB, +) -> WorkspaceResponse: + """Converts the workspace object to the WorkspaceResponse model. + + Arguments: + workspace (WorkspaceDB): The workspace object + project_id (str): The project ID + + Returns: + WorkspaceResponse: The workspace object in the WorkspaceResponse model + """ + + members = [] + + project = await db_manager_ee.get_project_by_workspace( + workspace_id=str(workspace.id) + ) + project_members = await db_manager_ee.get_project_members( + project_id=str(project.id) + ) + invitations = await db_manager_ee.get_project_invitations( + project_id=str(project.id), invitation_used=False + ) + + if len(invitations) > 0: + for invitation in invitations: + if not invitation.used and str(invitation.project_id) == str(project.id): + user = await db_manager.get_user_with_email(invitation.email) + member_dict = { + "user": { + "id": str(user.id) if user else invitation.email, + "email": user.email if user else invitation.email, + "username": ( + user.username if user else invitation.email.split("@")[0] + ), + "status": ( + "pending" + if invitation.expiration_date > datetime.now(timezone.utc) + else "expired" + ), + "created_at": ( + str(user.created_at) + if user + else ( + str(invitation.created_at) + if str(invitation.created_at) + else None + ) + ), + }, + "roles": [ + { + "role_name": invitation.role, + "role_description": WorkspaceRole.get_description( + invitation.role + ), + } + ], + } + members.append(member_dict) + + for project_member in project_members: + member_role = project_member.role + member_dict = { + "user": { + "id": str(project_member.user.id), + "email": project_member.user.email, + "username": project_member.user.username, + "status": "member", + "created_at": str(project_member.user.created_at), + }, + "roles": ( + [ + { + "role_name": member_role, + "role_description": WorkspaceRole.get_description(member_role), + "permissions": Permission.default_permissions(member_role), + } + ] + if member_role + else [] + ), + } + members.append(member_dict) + + workspace_response = WorkspaceResponse( + id=str(workspace.id), + name=workspace.name, + description=workspace.description, + type=workspace.type, + members=members, + organization=str(workspace.organization_id), + created_at=str(workspace.created_at), + updated_at=str(workspace.updated_at), + ) + return workspace_response + + +async def get_all_workspace_permissions() -> List[Permission]: + """ + Retrieve all workspace permissions. + + Returns: + List[Permission]: A list of all workspace permissions in the DB. + """ + workspace_permissions = list(Permission) + return workspace_permissions + + +def get_all_workspace_permissions_by_role(role_name: str) -> Dict[str, List[Any]]: + """ + Retrieve all workspace permissions. + + Returns: + List[Permission]: A list of all workspace permissions in the DB. + """ + workspace_permissions = Permission.default_permissions( + getattr(WorkspaceRole, role_name.upper()) + ) + return workspace_permissions + + +async def human_evaluation_db_to_simple_evaluation_output( + human_evaluation_db: HumanEvaluationDB, +) -> SimpleEvaluationOutput: + evaluation_variants = await db_manager_ee.fetch_human_evaluation_variants( + human_evaluation_id=str(human_evaluation_db.id) + ) + return SimpleEvaluationOutput( + id=str(human_evaluation_db.id), + app_id=str(human_evaluation_db.app_id), + project_id=str(human_evaluation_db.project_id), + status=human_evaluation_db.status, # type: ignore + evaluation_type=human_evaluation_db.evaluation_type, # type: ignore + variant_ids=[ + str(evaluation_variant.variant_id) + for evaluation_variant in evaluation_variants + ], + ) + + +async def evaluation_db_to_pydantic( + evaluation_db: EvaluationDB, +) -> Evaluation: + variant_name = ( + evaluation_db.variant.variant_name + if evaluation_db.variant.variant_name + else str(evaluation_db.variant_id) + ) + aggregated_results = aggregated_result_of_evaluation_to_pydantic( + evaluation_db.aggregated_results + ) + + return Evaluation( + id=str(evaluation_db.id), + app_id=str(evaluation_db.app_id), + project_id=str(evaluation_db.project_id), + status=evaluation_db.status, + variant_ids=[str(evaluation_db.variant_id)], + variant_revision_ids=[str(evaluation_db.variant_revision_id)], + revisions=[str(evaluation_db.variant_revision.revision)], + variant_names=[variant_name], + testset_id=str(evaluation_db.testset_id), + testset_name=evaluation_db.testset.name, + aggregated_results=aggregated_results, + created_at=str(evaluation_db.created_at), + updated_at=str(evaluation_db.updated_at), + average_cost=evaluation_db.average_cost, + total_cost=evaluation_db.total_cost, + average_latency=evaluation_db.average_latency, + ) + + +async def human_evaluation_db_to_pydantic( + evaluation_db: HumanEvaluationDB, +) -> HumanEvaluation: + evaluation_variants = await db_manager_ee.fetch_human_evaluation_variants( + human_evaluation_id=str(evaluation_db.id) # type: ignore + ) + + revisions = [] + variants_ids = [] + variants_names = [] + variants_revision_ids = [] + for evaluation_variant in evaluation_variants: + variant_name = ( + evaluation_variant.variant.variant_name + if isinstance(evaluation_variant.variant_id, uuid.UUID) + else str(evaluation_variant.variant_id) + ) + variants_names.append(str(variant_name)) + variants_ids.append(str(evaluation_variant.variant_id)) + variant_revision = ( + str(evaluation_variant.variant_revision.revision) + if isinstance(evaluation_variant.variant_revision_id, uuid.UUID) + else " None" + ) + revisions.append(variant_revision) + variants_revision_ids.append(str(evaluation_variant.variant_revision_id)) + + return HumanEvaluation( + id=str(evaluation_db.id), + app_id=str(evaluation_db.app_id), + project_id=str(evaluation_db.project_id), + status=evaluation_db.status, # type: ignore + evaluation_type=evaluation_db.evaluation_type, # type: ignore + variant_ids=variants_ids, + variant_names=variants_names, + testset_id=str(evaluation_db.testset_id), + testset_name=evaluation_db.testset.name, + variants_revision_ids=variants_revision_ids, + revisions=revisions, + created_at=str(evaluation_db.created_at), # type: ignore + updated_at=str(evaluation_db.updated_at), # type: ignore + ) + + +def human_evaluation_scenario_db_to_pydantic( + evaluation_scenario_db: HumanEvaluationScenarioDB, evaluation_id: str +) -> HumanEvaluationScenario: + return HumanEvaluationScenario( + id=str(evaluation_scenario_db.id), + evaluation_id=evaluation_id, + inputs=evaluation_scenario_db.inputs, # type: ignore + outputs=evaluation_scenario_db.outputs, # type: ignore + vote=evaluation_scenario_db.vote, # type: ignore + score=evaluation_scenario_db.score, # type: ignore + correct_answer=evaluation_scenario_db.correct_answer, # type: ignore + is_pinned=evaluation_scenario_db.is_pinned or False, # type: ignore + note=evaluation_scenario_db.note or "", # type: ignore + ) + + +def aggregated_result_of_evaluation_to_pydantic( + evaluation_aggregated_results: List, +) -> List[dict]: + transformed_results = [] + for aggregated_result in evaluation_aggregated_results: + evaluator_config_dict = ( + { + "id": str(aggregated_result.evaluator_config.id), + "name": aggregated_result.evaluator_config.name, + "evaluator_key": aggregated_result.evaluator_config.evaluator_key, + "settings_values": aggregated_result.evaluator_config.settings_values, + "created_at": str(aggregated_result.evaluator_config.created_at), + "updated_at": str(aggregated_result.evaluator_config.updated_at), + } + if isinstance(aggregated_result.evaluator_config_id, uuid.UUID) + else None + ) + transformed_results.append( + { + "evaluator_config": ( + {} if evaluator_config_dict is None else evaluator_config_dict + ), + "result": aggregated_result.result, + } + ) + return transformed_results + + +async def evaluation_scenario_db_to_pydantic( + evaluation_scenario_db: EvaluationScenarioDB, evaluation_id: str +) -> EvaluationScenario: + scenario_results = [ + { + "evaluator_config": str(scenario_result.evaluator_config_id), + "result": scenario_result.result, + } + for scenario_result in evaluation_scenario_db.results + ] + return EvaluationScenario( + id=str(evaluation_scenario_db.id), + evaluation_id=evaluation_id, + inputs=[ + EvaluationScenarioInput(**scenario_input) # type: ignore + for scenario_input in evaluation_scenario_db.inputs + ], + outputs=[ + EvaluationScenarioOutput(**scenario_output) # type: ignore + for scenario_output in evaluation_scenario_db.outputs + ], + correct_answers=[ + CorrectAnswer(**correct_answer) # type: ignore + for correct_answer in evaluation_scenario_db.correct_answers + ], + is_pinned=evaluation_scenario_db.is_pinned or False, # type: ignore + note=evaluation_scenario_db.note or "", # type: ignore + results=scenario_results, # type: ignore + ) diff --git a/api/ee/src/services/db_manager.py b/api/ee/src/services/db_manager.py new file mode 100644 index 0000000000..1091c4f736 --- /dev/null +++ b/api/ee/src/services/db_manager.py @@ -0,0 +1,35 @@ +import uuid + +from oss.src.dbs.postgres.shared.engine import engine +from ee.src.models.db_models import DeploymentDB_ as DeploymentDB + + +async def create_deployment( + app_id: str, + project_id: str, + uri: str, +) -> DeploymentDB: + """Create a new deployment. + Args: + app_id (str): The app variant to create the deployment for. + project_id (str): The project variant to create the deployment for. + uri (str): The URI of the service. + Returns: + DeploymentDB: The created deployment. + """ + + async with engine.core_session() as session: + try: + deployment = DeploymentDB( + app_id=uuid.UUID(app_id), + project_id=uuid.UUID(project_id), + uri=uri, + ) + + session.add(deployment) + await session.commit() + await session.refresh(deployment) + + return deployment + except Exception as e: + raise Exception(f"Error while creating deployment: {e}") diff --git a/api/ee/src/services/db_manager_ee.py b/api/ee/src/services/db_manager_ee.py new file mode 100644 index 0000000000..c0076afac3 --- /dev/null +++ b/api/ee/src/services/db_manager_ee.py @@ -0,0 +1,2129 @@ +import uuid +from typing import List, Dict, Union, Any, NoReturn, Optional, Tuple + +import sendgrid +from fastapi import HTTPException +from sendgrid.helpers.mail import Mail + +from sqlalchemy import func, asc +from sqlalchemy.future import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload, load_only, aliased +from sqlalchemy.exc import NoResultFound, MultipleResultsFound + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.common import is_ee + +from oss.src.dbs.postgres.shared.engine import engine +from oss.src.services import db_manager, evaluator_manager +from ee.src.models.api.workspace_models import ( + UserRole, + UpdateWorkspace, + CreateWorkspace, + WorkspaceResponse, +) +from ee.src.models.api.organization_models import ( + Organization, + CreateOrganization, + OrganizationUpdate, +) +from ee.src.models.shared_models import WorkspaceRole +from ee.src.models.db_models import ( + ProjectDB, + WorkspaceDB, + EvaluationDB, + OrganizationDB, + ProjectMemberDB, + WorkspaceMemberDB, + HumanEvaluationDB, + OrganizationMemberDB, + EvaluationScenarioDB, + HumanEvaluationScenarioDB, + HumanEvaluationVariantDB, + EvaluationScenarioResultDB, + EvaluationEvaluatorConfigDB, + EvaluationAggregatedResultDB, +) +from oss.src.models.db_models import ( + AppVariantDB, + UserDB, + AppDB, + TestSetDB, + InvitationDB, + EvaluatorConfigDB, + AppVariantRevisionsDB, +) +from oss.src.models.shared_models import ( + Result, + CorrectAnswer, + AggregatedResult, + EvaluationScenarioResult, + EvaluationScenarioInput, + EvaluationScenarioOutput, + HumanEvaluationScenarioInput, +) +from ee.src.services.converters import get_workspace_in_format +from ee.src.services.selectors import get_org_default_workspace + +from oss.src.utils.env import env + + +# Initialize sendgrid api client +sg = sendgrid.SendGridAPIClient(api_key=env.SENDGRID_API_KEY) + +log = get_module_logger(__name__) + + +async def get_organization(organization_id: str) -> OrganizationDB: + """ + Fetches an organization by its ID. + + Args: + organization_id (str): The ID of the organization to fetch. + + Returns: + OrganizationDB: The fetched organization. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) + ) + organization = result.scalars().first() + return organization + + +async def get_organizations_by_list_ids(organization_ids: List) -> List[OrganizationDB]: + """ + Retrieve organizations from the database by their IDs. + + Args: + organization_ids (List): A list of organization IDs to retrieve. + + Returns: + List: A list of dictionaries representing the retrieved organizations. + """ + + async with engine.core_session() as session: + organization_uuids = [uuid.UUID(org_id) for org_id in organization_ids] + query = select(OrganizationDB).where(OrganizationDB.id.in_(organization_uuids)) + result = await session.execute(query) + organizations = result.scalars().all() + return organizations + + +async def get_default_workspace_id(user_id: str) -> str: + """ + Retrieve the default workspace ID for a user. + + Args: + user_id (str): The user id. + + Returns: + str: The default workspace ID. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(WorkspaceMemberDB) + .filter_by(user_id=uuid.UUID(user_id), role=WorkspaceRole.OWNER) + .options(load_only(WorkspaceMemberDB.workspace_id)) # type: ignore + ) + member_in_workspace = result.scalars().first() + return str(member_in_workspace.workspace_id) + + +async def get_organization_workspaces(organization_id: str): + """ + Retries workspaces belonging to an organization. + + Args: + organization_id (str): The ID of the organization + """ + + async with engine.core_session() as session: + result = await session.execute( + select(WorkspaceDB) + .filter_by(organization_id=uuid.UUID(organization_id)) + .options(load_only(WorkspaceDB.organization_id)) # type: ignore + ) + workspaces = result.scalars().all() + return workspaces + + +async def get_workspace_administrators(workspace: WorkspaceDB) -> List[UserDB]: + """ + Retrieve the administrators of a workspace. + + Args: + workspace (WorkspaceDB): The workspace to retrieve the administrators for. + + Returns: + List[UserDB]: A list of UserDB objects representing the administrators of the workspace. + """ + + administrators = [] + for member in workspace.members: + if workspace.has_role( + member.user_id, WorkspaceRole.WORKSPACE_ADMIN + ) or workspace.has_role(member.user_id, WorkspaceRole.OWNER): + user = await db_manager.get_user_with_id(member.user_id) + administrators.append(user) + return administrators + + +async def create_project( + project_name: str, workspace_id: str, organization_id: str, session: AsyncSession +) -> WorkspaceDB: + """ + Create a new project. + + Args: + project_name (str): The name of the project. + workspace_id (str): The ID of the workspace. + organization_id (str): The ID of the organization. + session (AsyncSession): The database session. + + Returns: + WorkspaceDB: The created project. + """ + + project_db = ProjectDB( + project_name=project_name, + is_default=True, + organization_id=uuid.UUID(organization_id), + workspace_id=uuid.UUID(workspace_id), + ) + + session.add(project_db) + + log.info( + "[scopes] project created", + organization_id=organization_id, + workspace_id=workspace_id, + project_id=project_db.id, + ) + + await session.commit() + + return project_db + + +async def create_default_project( + organization_id: str, workspace_id: str, session: AsyncSession +) -> WorkspaceDB: + """ + Create a default project for an organization. + + Args: + organization_id (str): The ID of the organization. + workspace_id (str): The ID of the workspace. + session (AsyncSession): The database session. + + Returns: + WorkspaceDB: The created default project. + """ + + project_db = await create_project( + "Default", + workspace_id=workspace_id, + organization_id=organization_id, + session=session, + ) + return project_db + + +async def get_default_workspace_id_from_organization( + organization_id: str, +) -> Union[str, NoReturn]: + """ + Get the default (first) workspace ID belonging to a user from a organization. + + Args: + organization_id (str): The ID of the organization. + + Returns: + str: The default (first) workspace ID. + """ + + async with engine.core_session() as session: + workspace_query = await session.execute( + select(WorkspaceDB) + .where( + WorkspaceDB.organization_id == uuid.UUID(organization_id), + ) + .options(load_only(WorkspaceDB.id)) + ) + workspace = workspace_query.scalars().first() + if workspace is None: + raise NoResultFound( + f"No default workspace for the provided organization_id {organization_id} found" + ) + return str(workspace.id) + + +async def get_project_by_workspace(workspace_id: str) -> ProjectDB: + """Get the project from database using the organization id and workspace id. + + Args: + workspace_id (str): The ID of the workspace + + Returns: + ProjectDB: The retrieved project + """ + + assert workspace_id is not None, "Workspace ID is required to retrieve project" + async with engine.core_session() as session: + project_query = await session.execute( + select(ProjectDB).where( + ProjectDB.workspace_id == uuid.UUID(workspace_id), + ) + ) + project = project_query.scalars().first() + if project is None: + raise NoResultFound(f"No project with workspace IDs ({workspace_id}) found") + return project + + +async def create_project_member( + user_id: str, project_id: str, role: str, session: AsyncSession +) -> None: + """ + Create a new project member. + + Args: + user_id (str): The ID of the user. + project_id (str): The ID of the project. + role (str): The role of the user in the workspace. + session (AsyncSession): The database session. + """ + + project = await db_manager.fetch_project_by_id( + project_id=project_id, + ) + + if not project: + raise Exception(f"No project found with ID {project_id}") + + project_member = ProjectMemberDB( + user_id=uuid.UUID(user_id), + project_id=uuid.UUID(project_id), + role=role, + ) + + session.add(project_member) + + log.info( + "[scopes] project membership created", + organization_id=project.organization_id, + workspace_id=project.workspace_id, + project_id=project_id, + user_id=user_id, + membership_id=project_member.id, + ) + + await session.commit() + + +async def fetch_project_memberships_by_user_id( + user_id: str, +) -> List[ProjectMemberDB]: + async with engine.core_session() as session: + result = await session.execute( + select(ProjectMemberDB) + .filter_by(user_id=uuid.UUID(user_id)) + .options( + joinedload(ProjectMemberDB.project).joinedload(ProjectDB.workspace), + joinedload(ProjectMemberDB.project).joinedload(ProjectDB.organization), + ) + ) + project_memberships = result.scalars().all() + + return project_memberships + + +async def create_workspace_db_object( + session: AsyncSession, + payload: CreateWorkspace, + organization: OrganizationDB, + user: UserDB, + return_wrk_prj: bool = False, +) -> WorkspaceDB: + """Create a new workspace. + + Args: + payload (Workspace): The workspace payload. + organization (OrganizationDB): The organization that the workspace belongs to. + user (UserDB): The user that the workspace belongs to. + + Returns: + Workspace: The created workspace. + """ + + workspace = WorkspaceDB( + name=payload.name, + type=payload.type if payload.type else "", + description=payload.description if payload.description else "", + organization_id=organization.id, + ) + + session.add(workspace) + + log.info( + "[scopes] workspace created", + organization_id=organization.id, + workspace_id=workspace.id, + ) + + await session.commit() + + # add user as a member to the workspace with the owner role + workspace_member = WorkspaceMemberDB( + user_id=user.id, + workspace_id=workspace.id, + role="owner", + ) + session.add(workspace_member) + + log.info( + "[scopes] workspace membership created", + organization_id=workspace.organization_id, + workspace_id=workspace.id, + user_id=user.id, + membership_id=workspace_member.id, + ) + + await session.commit() + + await session.refresh(workspace, attribute_names=["organization"]) + + project_db = await create_default_project( + organization_id=str(organization.id), + workspace_id=str(workspace.id), + session=session, + ) + + # add user as a member to the project member with the owner role + await create_project_member( + user_id=str(user.id), + project_id=str(project_db.id), + role=workspace_member.role, + session=session, + ) + + # add default testset and evaluators + await db_manager.add_testset_to_app_variant( + template_name="completion", # type: ignore + app_name="completion", # type: ignore + project_id=str(project_db.id), + ) + await evaluator_manager.create_ready_to_use_evaluators( + project_id=str(project_db.id) + ) + + if return_wrk_prj: + return workspace, project_db + + return workspace + + +async def create_workspace( + payload: CreateWorkspace, organization_id: str, user_uid: str +) -> WorkspaceResponse: + """ + Create a new workspace. + + Args: + payload (CreateWorkspace): The workspace payload. + organization_id (str): The organization id. + user_uid (str): The user uid. + + Returns: + Workspace: The created workspace. + + """ + try: + user = await db_manager.get_user(user_uid) + organization = await get_organization(organization_id) + + async with engine.core_session() as session: + user_result = await session.execute(select(UserDB).filter_by(uid=user_uid)) + user = user_result.scalars().first() + + organization_result = await session.execute( + select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) + ) + organization = organization_result.scalars().first() + + # create workspace + workspace_db = await create_workspace_db_object( + session, payload, organization, user + ) + + return await get_workspace_in_format(workspace_db) + except Exception as e: + raise e + + +async def update_workspace( + payload: UpdateWorkspace, workspace: WorkspaceDB +) -> WorkspaceResponse: + """ + Update a workspace's details. + + Args: + workspace (WorkspaceDB): The workspace to update. + payload (UpdateWorkspace): The data to update the workspace with. + """ + + async with engine.core_session() as session: + result = await session.execute(select(WorkspaceDB).filter_by(id=workspace.id)) + workspace = result.scalars().first() + + if not workspace: + raise NoResultFound(f"Workspace with id {str(workspace.id)} not found") + + for key, value in payload.dict(exclude_unset=True).items(): + if hasattr(workspace, key): + setattr(workspace, key, value) + + await session.commit() + await session.refresh(workspace) + + return await get_workspace_in_format(workspace) + + +async def check_user_in_workspace_with_email(email: str, workspace_id: str) -> bool: + """ + Check if a user belongs to a workspace. + + Args: + email (str): The email of the user to check. + workspace_id (str): The workspace to check. + + Raises: + Exception: If there is an error checking if the user belongs to the workspace. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(WorkspaceMemberDB) + .join(UserDB, UserDB.id == WorkspaceMemberDB.user_id) + .where( + UserDB.email == email, + WorkspaceMemberDB.workspace_id == uuid.UUID(workspace_id), + ) + ) + workspace_member = result.scalars().first() + return False if workspace_member is None else True + + +async def update_user_roles( + workspace_id: str, + payload: UserRole, + delete: bool = False, +) -> bool: + """ + Update a user's roles in a workspace. + + Args: + workspace_id (str): The ID of the workspace. + payload (UserRole): The payload containing the user email and role to update. + delete (bool): Whether to delete the user's role or not. + + Returns: + bool: True if the user's roles were successfully updated, False otherwise. + + Raises: + Exception: If there is an error updating the user's roles. + """ + + user = await db_manager.get_user_with_email(payload.email) + project_id = await db_manager.get_default_project_id_from_workspace( + workspace_id=workspace_id + ) + + async with engine.core_session() as session: + # Ensure that an admin can not remove the owner of the workspace/project + project_owner_result = await session.execute( + select(ProjectMemberDB) + .filter_by(project_id=uuid.UUID(project_id), role="owner") + .options( + load_only( + ProjectMemberDB.user_id, # type: ignore + ProjectMemberDB.role, # type: ignore + ) + ) + ) + project_owner = project_owner_result.scalars().first() + if user.id == project_owner.user_id and project_owner.role == "owner": + raise HTTPException( + 403, + { + "message": "You do not have permission to perform this action. Please contact your Organization Owner" + }, + ) + + project_member_result = await session.execute( + select(ProjectMemberDB).filter_by( + project_id=uuid.UUID(project_id), user_id=user.id + ) + ) + project_member = project_member_result.scalars().first() + if not project_member: + raise NoResultFound( + f"User with id {str(user.id)} is not part of the workspace member." + ) + + workspace_member_result = await session.execute( + select(WorkspaceMemberDB).filter_by( + workspace_id=uuid.UUID(workspace_id), user_id=user.id + ) + ) + workspace_member = workspace_member_result.scalars().first() + if not workspace_member: + raise NoResultFound( + f"User with id {str(user.id)} is not part of the workspace member." + ) + + if not delete: + # Update the member's role + project_member.role = payload.role + workspace_member.role = payload.role + + await session.commit() + await session.refresh(project_member) + return True + + +async def add_user_to_workspace_and_org( + organization: OrganizationDB, + workspace: WorkspaceDB, + user: UserDB, + project_id: str, + role: str, +) -> bool: + async with engine.core_session() as session: + # create joined organization for user + user_organization = OrganizationMemberDB( + user_id=user.id, organization_id=organization.id + ) + session.add(user_organization) + + log.info( + "[scopes] organization membership created", + organization_id=organization.id, + user_id=user.id, + membership_id=user_organization.id, + ) + + # add user to workspace + workspace_member = WorkspaceMemberDB( + user_id=user.id, + workspace_id=workspace.id, + role=role, + ) + + session.add(workspace_member) + + log.info( + "[scopes] workspace membership created", + organization_id=organization.id, + workspace_id=workspace.id, + user_id=user.id, + membership_id=workspace_member.id, + ) + + # add user to project + await create_project_member( + user_id=str(user.id), project_id=project_id, role=role, session=session + ) + + return True + + +async def remove_user_from_workspace( + workspace_id: str, + email: str, +) -> WorkspaceResponse: + """ + Remove a user from a workspace. + + Args: + workspace_id (str): The ID of the workspace. + payload (UserRole): The payload containing the user email and role to remove. + + Returns: + workspace (WorkspaceResponse): The updated workspace. + + Raises: + HTTPException -- 403, from fastapi import Request + """ + + user = await db_manager.get_user_with_email(email) + project_id = await db_manager.get_default_project_id_from_workspace( + workspace_id=workspace_id + ) + project = await db_manager.get_project_by_id(project_id=project_id) + + async with engine.core_session() as session: + if ( + not user + ): # User is an invited user who has not yet created an account and therefore does not have a user object + pass + else: + # Ensure that a user can not remove the owner of the workspace + workspace_owner_result = await session.execute( + select(WorkspaceMemberDB) + .filter_by( + workspace_id=project.workspace_id, user_id=user.id, role="owner" + ) + .options( + load_only( + WorkspaceMemberDB.user_id, # type: ignore + WorkspaceMemberDB.role, # type: ignore + ) + ) + ) + workspace_owner = workspace_owner_result.scalars().first() + if (workspace_owner is not None and user is not None) and ( + user.id == workspace_owner.user_id and workspace_owner.role == "owner" + ): + raise HTTPException( + status_code=403, + detail={ + "message": "You do not have permission to perform this action. Please contact your Organization Owner" + }, + ) + + # remove user from workspace + workspace_member_result = await session.execute( + select(WorkspaceMemberDB).filter( + WorkspaceMemberDB.workspace_id == project.workspace_id, + WorkspaceMemberDB.user_id == user.id, + WorkspaceMemberDB.role != "owner", + ) + ) + workspace_member = workspace_member_result.scalars().first() + if workspace_member: + await session.delete(workspace_member) + + log.info( + "[scopes] workspace membership deleted", + organization_id=project.organization_id, + workspace_id=workspace_id, + user_id=user.id, + membership_id=workspace_member.id, + ) + + # remove user from project + project_member_result = await session.execute( + select(ProjectMemberDB).filter( + ProjectMemberDB.project_id == project.id, + ProjectMemberDB.user_id == user.id, + ProjectMemberDB.role != "owner", + ) + ) + project_member = project_member_result.scalars().first() + if project_member: + await session.delete(project_member) + + log.info( + "[scopes] project membership deleted", + organization_id=project.organization_id, + workspace_id=project.workspace_id, + project_id=project.id, + user_id=user.id, + membership_id=project_member.id, + ) + + # remove user from organization + joined_org_result = await session.execute( + select(OrganizationMemberDB).filter_by( + user_id=user.id, organization_id=project.organization_id + ) + ) + member_joined_org = joined_org_result.scalars().first() + if member_joined_org: + await session.delete(member_joined_org) + + log.info( + "[scopes] organization membership deleted", + organization_id=project.organization_id, + user_id=user.id, + membership_id=member_joined_org.id, + ) + + await session.commit() + + # If there's an invitation for the provided email address, delete it + user_workspace_invitations_query = await session.execute( + select(InvitationDB) + .filter_by(project_id=project.id, email=email) + .options( + load_only( + InvitationDB.id, # type: ignore + InvitationDB.project_id, # type: ignore + ) + ) + ) + user_invitations = user_workspace_invitations_query.scalars().all() + for invitation in user_invitations: + await delete_invitation(str(invitation.id)) + + workspace_db = await db_manager.get_workspace(workspace_id=workspace_id) + return await get_workspace_in_format(workspace_db) + + +async def create_organization( + payload: CreateOrganization, + user: UserDB, + return_org_wrk: Optional[bool] = False, + return_org_wrk_prj: Optional[bool] = False, +) -> Union[ + OrganizationDB, + Tuple[OrganizationDB, WorkspaceDB], + Tuple[OrganizationDB, WorkspaceDB, ProjectDB], +]: + """Create a new organization. + + Args: + payload (Organization): The organization payload. + + Returns: + Organization: The created organization. + Optional[Workspace]: The created workspace if return_org_wrk is True. + + Raises: + Exception: If there is an error creating the organization. + """ + + async with engine.core_session() as session: + create_org_data = payload.model_dump(exclude_unset=True) + if "owner" not in create_org_data: + create_org_data["owner"] = str(user.id) + + # create organization + organization_db = OrganizationDB(**create_org_data) + session.add(organization_db) + + log.info( + "[scopes] organization created", + organization_id=organization_db.id, + ) + + await session.commit() + + # create joined organization for user + user_organization = OrganizationMemberDB( + user_id=user.id, organization_id=organization_db.id + ) + session.add(user_organization) + + log.info( + "[scopes] organization membership created", + organization_id=organization_db.id, + user_id=user.id, + membership_id=user_organization.id, + ) + + await session.commit() + + # construct workspace payload + workspace_payload = CreateWorkspace( + name=payload.name, + type=payload.type if payload.type else "", + description=( + "My Default Workspace" + if payload.type == "default" + else payload.description + if payload.description + else "" + ), + ) + + # create workspace + workspace, project = await create_workspace_db_object( + session, + workspace_payload, + organization_db, + user, + return_wrk_prj=True, + ) + + if return_org_wrk_prj: + return organization_db, workspace, project + + if return_org_wrk: + return organization_db, workspace + + return organization_db + + +async def update_organization( + organization_id: str, payload: OrganizationUpdate +) -> OrganizationDB: + """ + Update an organization's details. + + Args: + organization_id (str): The organization to update. + payload (OrganizationUpdate): The data to update the organization with. + + Returns: + Organization: The updated organization. + + Raises: + Exception: If there is an error updating the organization. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(OrganizationDB).filter_by(id=uuid.UUID(organization_id)) + ) + organization = result.scalars().first() + + if not organization: + raise NoResultFound(f"Organization with id {organization_id} not found") + + for key, value in payload.model_dump(exclude_unset=True).items(): + if hasattr(organization, key): + setattr(organization, key, value) + + await session.commit() + await session.refresh(organization) + return organization + + +async def delete_invitation(invitation_id: str) -> bool: + """ + Delete an invitation from an organization. + + Args: + invitation (str): The invitation to delete. + + Returns: + bool: True if the invitation was successfully deleted, False otherwise. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(InvitationDB).filter_by(id=uuid.UUID(invitation_id)) + ) + + try: + invitation = result.scalars().one_or_none() + except MultipleResultsFound as e: + log.error( + f"Critical error: Database returned two rows when retrieving invitation with ID {invitation_id} to delete from Invitations table. Error details: {str(e)}" + ) + raise HTTPException( + 500, + { + "message": f"Error occured while trying to delete invitation with ID {invitation_id} from Invitations table. Error details: {str(e)}" + }, + ) + + project = await session.execute( + select(ProjectDB).filter_by(id=invitation.project_id) + ) + project = project.scalars().one_or_none() + + if not project: + log.error(f"Project with ID {invitation.project_id} not found.") + raise Exception(f"No project found with ID {invitation.project_id}") + + await session.delete(invitation) + + log.info( + "[scopes] invitation deleted", + organization_id=project.organization_id, + workspace_id=project.workspace_id, + project_id=invitation.project_id, + user_id=invitation.user_id, + membership_id=invitation.id, + ) + + await session.commit() + + return True + + +async def mark_invitation_as_used( + project_id: str, user_id: str, invitation: InvitationDB +) -> bool: + """ + Mark an invitation as used. + + Args: + project_id (str): The ID of the project. + user_id (str): the ID of the user. + invitation (InvitationDB): The invitation to mark as used. + + Returns: + bool: True if the invitation was successfully marked as used, False otherwise. + + Raises: + HTTPException: If there is an error marking the invitation as used. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(InvitationDB).filter_by( + project_id=uuid.UUID(project_id), token=invitation.token + ) + ) + organization_invitation = result.scalars().first() + if not organization_invitation: + return False + + organization_invitation.used = True + organization_invitation.user_id = uuid.UUID(user_id) + + await session.commit() + return True + + +async def get_org_details(organization: Organization) -> dict: + """ + Retrieve details of an organization. + + Args: + organization (Organization): The organization to retrieve details for. + project_id (str): The project_id to retrieve details for. + + Returns: + dict: A dictionary containing the organization's details. + """ + + default_workspace_db = await get_org_default_workspace(organization) + default_workspace = await get_workspace_details(default_workspace_db) + workspaces = await get_organization_workspaces(organization_id=str(organization.id)) + + sample_organization = { + "id": str(organization.id), + "name": organization.name, + "description": organization.description, + "type": organization.type, + "owner": organization.owner, + "workspaces": [str(workspace.id) for workspace in workspaces], + "default_workspace": default_workspace, + "is_paying": organization.is_paying if is_ee() else None, + } + return sample_organization + + +async def get_workspace_details(workspace: WorkspaceDB) -> WorkspaceResponse: + """ + Retrieve details of a workspace. + + Args: + workspace (Workspace): The workspace to retrieve details for. + project_id (str): The project_id to retrieve details for. + + Returns: + dict: A dictionary containing the workspace's details. + + Raises: + Exception: If there is an error retrieving the workspace details. + """ + + try: + workspace_response = await get_workspace_in_format(workspace) + return workspace_response + except Exception as e: + import traceback + + traceback.print_exc() + raise e + + +async def get_organization_invitations(organization_id: str): + """ + Gets the organization invitations. + + Args: + organization_id (str): The ID of the organization + """ + + async with engine.core_session() as session: + result = await session.execute( + select(InvitationDB).filter_by(organization_id=organization_id) + ) + invitations = result.scalars().all() + return invitations + + +async def get_project_invitations(project_id: str, **kwargs): + """ + Gets the project invitations. + + Args: + project_id (str): The ID of the project + """ + + async with engine.core_session() as session: + stmt = select(InvitationDB).filter( + InvitationDB.project_id == uuid.UUID(project_id) + ) + if kwargs.get("has_pending", False): + stmt = stmt.filter(InvitationDB.used == kwargs["invitation_used"]) + + result = await session.execute(stmt) + invitations = result.scalars().all() + return invitations + + +async def get_all_pending_invitations(email: str): + """ + Gets all pending invitations for a given email. + + Args: + email (str): The email address of the user. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(InvitationDB).filter( + InvitationDB.email == email, + InvitationDB.used == False, + ) + ) + invitations = result.scalars().all() + return invitations + + +async def get_project_invitation( + project_id: str, token: str, email: str +) -> InvitationDB: + """Get project invitation by project ID, token and email. + + Args: + project_id (str): The ID of the project. + token (str): The invitation token. + email (str): The email address of the invited user. + + Returns: + InvitationDB: invitation object + """ + + async with engine.core_session() as session: + result = await session.execute( + select(InvitationDB).filter_by( + project_id=uuid.UUID(project_id), token=token, email=email + ) + ) + invitation = result.scalars().first() + return invitation + + +async def get_project_members(project_id: str): + """Gets the members of a project. + + Args: + project_id (str): The ID of the project + """ + + async with engine.core_session() as session: + members_query = await session.execute( + select(ProjectMemberDB) + .filter(ProjectMemberDB.project_id == uuid.UUID(project_id)) + .options(joinedload(ProjectMemberDB.user)) + ) + project_members = members_query.scalars().all() + return project_members + + +async def create_org_workspace_invitation( + workspace_role: str, + token: str, + email: str, + project_id: str, + expiration_date, +) -> InvitationDB: + """ + Create an organization invitation. + + Args: + - workspace_role (str): The role to assign the invited user in the project/workspace. + - token (str): The token for the invitation. + - email (str): The email address of the invited user. + - expiration_date: The expiration date of the invitation. + + Returns: + InvitationDB: The created invitation. + + """ + + user = await db_manager.get_user_with_email(email=email) + + user_id = None + if user: + user_id = user.id + + project = await db_manager.fetch_project_by_id( + project_id=project_id, + ) + + if not project: + raise Exception(f"No project found with ID {project_id}") + + async with engine.core_session() as session: + invitation = InvitationDB( + token=token, + email=email, + project_id=uuid.UUID(project_id), + expiration_date=expiration_date, + role=workspace_role, + used=False, + ) + + session.add(invitation) + + log.info( + "[scopes] invitation created", + organization_id=project.organization_id, + workspace_id=project.workspace_id, + project_id=project_id, + user_id=user_id, + invitation_id=invitation.id, + ) + + await session.commit() + + return invitation + + +async def get_all_workspace_roles() -> List[WorkspaceRole]: + """ + Retrieve all workspace roles. + + Returns: + List[WorkspaceRole]: A list of all workspace roles in the DB. + """ + workspace_roles = list(WorkspaceRole) + return workspace_roles + + +# async def get_project_id_from_db_entity( +# object_id: str, type: str, project_id: str +# ) -> dict: +# """ +# Get the project id of the object. + +# Args: +# object_id (str): The ID of the object. +# type (str): The type of the object. + +# Returns: +# dict: The project_id of the object. + +# Raises: +# ValueError: If the object type is unknown. +# Exception: If there is an error retrieving the project_id. +# """ +# try: +# if type == "app": +# app = await db_manager.fetch_app_by_id(object_id) +# project_id = app.project_id + +# elif type == "app_variant": +# app_variant = await db_manager.fetch_app_variant_by_id(object_id) +# project_id = app_variant.project_id + +# elif type == "base": +# base = await db_manager.fetch_base_by_id(object_id) +# project_id = base.project_id + +# elif type == "deployment": +# deployment = await db_manager.get_deployment_by_id(object_id) +# project_id = deployment.project_id + +# elif type == "testset": +# testset = await db_manager.fetch_testset_by_id(object_id) +# project_id = testset.project_id + +# elif type == "evaluation": +# evaluation = await db_manager.fetch_evaluation_by_id(object_id) +# project_id = evaluation.project_id + +# elif type == "evaluation_scenario": +# evaluation_scenario = await db_manager.fetch_evaluation_scenario_by_id( +# object_id +# ) +# project_id = evaluation_scenario.project_id + +# elif type == "evaluator_config": +# evaluator_config = await db_manager.fetch_evaluator_config(object_id) +# project_id = evaluator_config.project_id + +# elif type == "human_evaluation": +# human_evaluation = await db_manager.fetch_human_evaluation_by_id(object_id) +# project_id = human_evaluation.project_id + +# elif type == "human_evaluation_scenario": +# human_evaluation_scenario = ( +# await db_manager.fetch_human_evaluation_scenario_by_id(object_id) +# ) +# project_id = human_evaluation_scenario.project_id + +# elif type == "human_evaluation_scenario_by_evaluation_id": +# human_evaluation_scenario_by_evaluation = ( +# await db_manager.fetch_human_evaluation_scenario_by_evaluation_id( +# object_id +# ) +# ) +# project_id = human_evaluation_scenario_by_evaluation.project_id + +# else: +# raise ValueError(f"Unknown object type: {type}") + +# return str(project_id) + +# except Exception as e: +# raise e + + +async def add_user_to_organization( + organization_id: str, + user_id: str, + # is_demo: bool = False, +) -> None: + async with engine.core_session() as session: + organization_member = OrganizationMemberDB( + user_id=user_id, + organization_id=organization_id, + ) + + session.add(organization_member) + + log.info( + "[scopes] organization membership created", + organization_id=organization_id, + user_id=user_id, + membership_id=organization_member.id, + ) + + await session.commit() + + +async def add_user_to_workspace( + workspace_id: str, + user_id: str, + role: str, + # is_demo: bool = False, +) -> None: + async with engine.core_session() as session: + # fetch workspace by workspace_id (SQL) + stmt = select(WorkspaceDB).filter_by(id=workspace_id) + workspace = await session.execute(stmt) + workspace = workspace.scalars().first() + + if not workspace: + raise Exception(f"No workspace found with ID {workspace_id}") + + workspace_member = WorkspaceMemberDB( + user_id=user_id, + workspace_id=workspace_id, + role=role, + ) + + session.add(workspace_member) + + # TODO: add organization_id + log.info( + "[scopes] workspace membership created", + organization_id=workspace.organization_id, + workspace_id=workspace_id, + user_id=user_id, + membership_id=workspace_member.id, + ) + + await session.commit() + + +async def add_user_to_project( + project_id: str, + user_id: str, + role: str, + is_demo: bool = False, +) -> None: + project = await db_manager.fetch_project_by_id( + project_id=project_id, + ) + + if not project: + raise Exception(f"No project found with ID {project_id}") + + async with engine.core_session() as session: + project_member = ProjectMemberDB( + user_id=user_id, + project_id=project_id, + role=role, + is_demo=is_demo, + ) + + session.add(project_member) + + log.info( + "[scopes] project membership created", + organization_id=project.organization_id, + workspace_id=project.workspace_id, + project_id=project_id, + user_id=user_id, + membership_id=project_member.id, + ) + + await session.commit() + + +async def fetch_evaluation_status_by_id( + project_id: str, + evaluation_id: str, +) -> Optional[str]: + """Fetch only the status of an evaluation by its ID.""" + assert evaluation_id is not None, "evaluation_id cannot be None" + + async with engine.core_session() as session: + query = ( + select(EvaluationDB) + .filter_by(project_id=project_id, id=uuid.UUID(evaluation_id)) + .options(load_only(EvaluationDB.status)) + ) + + result = await session.execute(query) + evaluation = result.scalars().first() + return evaluation.status if evaluation else None + + +async def fetch_evaluation_by_id( + project_id: str, + evaluation_id: str, +) -> Optional[EvaluationDB]: + """Fetches a evaluation by its ID. + + Args: + evaluation_id (str): The ID of the evaluation to fetch. + + Returns: + EvaluationDB: The fetched evaluation, or None if no evaluation was found. + """ + + assert evaluation_id is not None, "evaluation_id cannot be None" + async with engine.core_session() as session: + base_query = select(EvaluationDB).filter_by( + project_id=project_id, + id=uuid.UUID(evaluation_id), + ) + query = base_query.options( + joinedload(EvaluationDB.testset.of_type(TestSetDB)).load_only(TestSetDB.id, TestSetDB.name), # type: ignore + ) + + result = await session.execute( + query.options( + joinedload(EvaluationDB.variant.of_type(AppVariantDB)).load_only(AppVariantDB.id, AppVariantDB.variant_name), # type: ignore + joinedload(EvaluationDB.variant_revision.of_type(AppVariantRevisionsDB)).load_only(AppVariantRevisionsDB.revision), # type: ignore + joinedload( + EvaluationDB.aggregated_results.of_type( + EvaluationAggregatedResultDB + ) + ).joinedload(EvaluationAggregatedResultDB.evaluator_config), + ) + ) + evaluation = result.unique().scalars().first() + return evaluation + + +async def list_human_evaluations(app_id: str, project_id: str): + """ + Fetches human evaluations belonging to an App. + + Args: + app_id (str): The application identifier + """ + + async with engine.core_session() as session: + base_query = ( + select(HumanEvaluationDB) + .filter_by(app_id=uuid.UUID(app_id), project_id=uuid.UUID(project_id)) + .filter(HumanEvaluationDB.testset_id.isnot(None)) + ) + query = base_query.options( + joinedload(HumanEvaluationDB.testset.of_type(TestSetDB)).load_only(TestSetDB.id, TestSetDB.name), # type: ignore + ) + + result = await session.execute(query) + human_evaluations = result.scalars().all() + return human_evaluations + + +async def create_human_evaluation( + app: AppDB, + status: str, + evaluation_type: str, + testset_id: str, + variants_ids: List[str], +): + """ + Creates a human evaluation. + + Args: + app (AppDB: The app object + status (str): The status of the evaluation + evaluation_type (str): The evaluation type + testset_id (str): The ID of the evaluation testset + variants_ids (List[str]): The IDs of the variants for the evaluation + """ + + async with engine.core_session() as session: + human_evaluation = HumanEvaluationDB( + app_id=app.id, + project_id=app.project_id, + status=status, + evaluation_type=evaluation_type, + testset_id=testset_id, + ) + + session.add(human_evaluation) + await session.commit() + await session.refresh(human_evaluation, attribute_names=["testset"]) + + # create variants for human evaluation + await create_human_evaluation_variants( + human_evaluation_id=str(human_evaluation.id), + variants_ids=variants_ids, + ) + return human_evaluation + + +async def fetch_human_evaluation_variants(human_evaluation_id: str): + """ + Fetches human evaluation variants. + + Args: + human_evaluation_id (str): The human evaluation ID + + Returns: + The human evaluation variants. + """ + + async with engine.core_session() as session: + base_query = select(HumanEvaluationVariantDB).filter_by( + human_evaluation_id=uuid.UUID(human_evaluation_id) + ) + query = base_query.options( + joinedload(HumanEvaluationVariantDB.variant.of_type(AppVariantDB)).load_only(AppVariantDB.id, AppVariantDB.variant_name), # type: ignore + joinedload(HumanEvaluationVariantDB.variant_revision.of_type(AppVariantRevisionsDB)).load_only(AppVariantRevisionsDB.id, AppVariantRevisionsDB.revision), # type: ignore + ) + + result = await session.execute(query) + evaluation_variants = result.scalars().all() + return evaluation_variants + + +async def create_human_evaluation_variants( + human_evaluation_id: str, variants_ids: List[str] +): + """ + Creates human evaluation variants. + + Args: + human_evaluation_id (str): The human evaluation identifier + variants_ids (List[str]): The variants identifiers + project_id (str): The project ID + """ + + variants_dict = {} + for variant_id in variants_ids: + variant = await db_manager.fetch_app_variant_by_id(app_variant_id=variant_id) + if variant: + variants_dict[variant_id] = variant + + variants_revisions_dict = {} + for variant_id, variant in variants_dict.items(): + variant_revision = await db_manager.fetch_app_variant_revision_by_variant( + app_variant_id=str(variant.id), project_id=str(variant.project_id), revision=variant.revision # type: ignore + ) + if variant_revision: + variants_revisions_dict[variant_id] = variant_revision + + if set(variants_dict.keys()) != set(variants_revisions_dict.keys()): + raise ValueError("Mismatch between variants and their revisions") + + async with engine.core_session() as session: + for variant_id in variants_ids: + variant = variants_dict[variant_id] + variant_revision = variants_revisions_dict[variant_id] + human_evaluation_variant = HumanEvaluationVariantDB( + human_evaluation_id=uuid.UUID(human_evaluation_id), + variant_id=variant.id, # type: ignore + variant_revision_id=variant_revision.id, # type: ignore + ) + session.add(human_evaluation_variant) + + await session.commit() + + +async def fetch_human_evaluation_by_id( + evaluation_id: str, +) -> Optional[HumanEvaluationDB]: + """ + Fetches a evaluation by its ID. + + Args: + evaluation_id (str): The ID of the evaluation to fetch. + + Returns: + EvaluationDB: The fetched evaluation, or None if no evaluation was found. + """ + + assert evaluation_id is not None, "evaluation_id cannot be None" + async with engine.core_session() as session: + base_query = select(HumanEvaluationDB).filter_by(id=uuid.UUID(evaluation_id)) + query = base_query.options( + joinedload(HumanEvaluationDB.testset.of_type(TestSetDB)).load_only(TestSetDB.id, TestSetDB.name), # type: ignore + ) + result = await session.execute(query) + evaluation = result.scalars().first() + return evaluation + + +async def update_human_evaluation(evaluation_id: str, values_to_update: dict): + """Updates human evaluation with the specified values. + + Args: + evaluation_id (str): The evaluation ID + values_to_update (dict): The values to update + + Exceptions: + NoResultFound: if human evaluation is not found + """ + + async with engine.core_session() as session: + result = await session.execute( + select(HumanEvaluationDB).filter_by(id=uuid.UUID(evaluation_id)) + ) + human_evaluation = result.scalars().first() + if not human_evaluation: + raise NoResultFound(f"Human evaluation with id {evaluation_id} not found") + + for key, value in values_to_update.items(): + if hasattr(human_evaluation, key): + setattr(human_evaluation, key, value) + + await session.commit() + await session.refresh(human_evaluation) + + +async def delete_human_evaluation(evaluation_id: str): + """Delete the evaluation by its ID. + + Args: + evaluation_id (str): The ID of the evaluation to delete. + """ + + assert evaluation_id is not None, "evaluation_id cannot be None" + async with engine.core_session() as session: + result = await session.execute( + select(HumanEvaluationDB).filter_by(id=uuid.UUID(evaluation_id)) + ) + evaluation = result.scalars().first() + if not evaluation: + raise NoResultFound(f"Human evaluation with id {evaluation_id} not found") + + await session.delete(evaluation) + await session.commit() + + +async def create_human_evaluation_scenario( + inputs: List[HumanEvaluationScenarioInput], + project_id: str, + evaluation_id: str, + evaluation_extend: Dict[str, Any], +): + """ + Creates a human evaluation scenario. + + Args: + inputs (List[HumanEvaluationScenarioInput]): The inputs. + evaluation_id (str): The evaluation identifier. + evaluation_extend (Dict[str, any]): An extended required payload for the evaluation scenario. Contains score, vote, and correct_answer. + """ + + async with engine.core_session() as session: + evaluation_scenario = HumanEvaluationScenarioDB( + **evaluation_extend, + project_id=uuid.UUID(project_id), + evaluation_id=uuid.UUID(evaluation_id), + inputs=[input.model_dump() for input in inputs], + outputs=[], + ) + + session.add(evaluation_scenario) + await session.commit() + + +async def update_human_evaluation_scenario( + evaluation_scenario_id: str, values_to_update: dict +): + """Updates human evaluation scenario with the specified values. + + Args: + evaluation_scenario_id (str): The evaluation scenario ID + values_to_update (dict): The values to update + + Exceptions: + NoResultFound: if human evaluation scenario is not found + """ + + async with engine.core_session() as session: + result = await session.execute( + select(HumanEvaluationScenarioDB).filter_by( + id=uuid.UUID(evaluation_scenario_id) + ) + ) + human_evaluation_scenario = result.scalars().first() + if not human_evaluation_scenario: + raise NoResultFound( + f"Human evaluation scenario with id {evaluation_scenario_id} not found" + ) + + for key, value in values_to_update.items(): + if hasattr(human_evaluation_scenario, key): + setattr(human_evaluation_scenario, key, value) + + await session.commit() + await session.refresh(human_evaluation_scenario) + + +async def fetch_human_evaluation_scenarios(evaluation_id: str): + """ + Fetches human evaluation scenarios. + + Args: + evaluation_id (str): The evaluation identifier + + Returns: + The evaluation scenarios. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(HumanEvaluationScenarioDB) + .filter_by(evaluation_id=uuid.UUID(evaluation_id)) + .order_by(asc(HumanEvaluationScenarioDB.created_at)) + ) + evaluation_scenarios = result.scalars().all() + return evaluation_scenarios + + +async def fetch_evaluation_scenarios(evaluation_id: str, project_id: str): + """ + Fetches evaluation scenarios. + + Args: + evaluation_id (str): The evaluation identifier + project_id (str): The ID of the project + + Returns: + The evaluation scenarios. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(EvaluationScenarioDB) + .filter_by( + evaluation_id=uuid.UUID(evaluation_id), project_id=uuid.UUID(project_id) + ) + .options(joinedload(EvaluationScenarioDB.results)) + ) + evaluation_scenarios = result.unique().scalars().all() + return evaluation_scenarios + + +async def fetch_evaluation_scenario_by_id( + evaluation_scenario_id: str, +) -> Optional[EvaluationScenarioDB]: + """Fetches and evaluation scenario by its ID. + + Args: + evaluation_scenario_id (str): The ID of the evaluation scenario to fetch. + + Returns: + EvaluationScenarioDB: The fetched evaluation scenario, or None if no evaluation scenario was found. + """ + + assert evaluation_scenario_id is not None, "evaluation_scenario_id cannot be None" + async with engine.core_session() as session: + result = await session.execute( + select(EvaluationScenarioDB).filter_by(id=uuid.UUID(evaluation_scenario_id)) + ) + evaluation_scenario = result.scalars().first() + return evaluation_scenario + + +async def fetch_human_evaluation_scenario_by_id( + evaluation_scenario_id: str, +) -> Optional[HumanEvaluationScenarioDB]: + """Fetches and evaluation scenario by its ID. + + Args: + evaluation_scenario_id (str): The ID of the evaluation scenario to fetch. + + Returns: + EvaluationScenarioDB: The fetched evaluation scenario, or None if no evaluation scenario was found. + """ + + assert evaluation_scenario_id is not None, "evaluation_scenario_id cannot be None" + async with engine.core_session() as session: + result = await session.execute( + select(HumanEvaluationScenarioDB).filter_by( + id=uuid.UUID(evaluation_scenario_id) + ) + ) + evaluation_scenario = result.scalars().first() + return evaluation_scenario + + +async def fetch_human_evaluation_scenario_by_evaluation_id( + evaluation_id: str, +) -> Optional[HumanEvaluationScenarioDB]: + """Fetches and evaluation scenario by its ID. + Args: + evaluation_id (str): The ID of the evaluation object to use in fetching the human evaluation. + Returns: + EvaluationScenarioDB: The fetched evaluation scenario, or None if no evaluation scenario was found. + """ + + evaluation = await fetch_human_evaluation_by_id(evaluation_id) + async with engine.core_session() as session: + result = await session.execute( + select(HumanEvaluationScenarioDB).filter_by( + evaluation_id=evaluation.id # type: ignore + ) + ) + human_eval_scenario = result.scalars().first() + return human_eval_scenario + + +async def create_new_evaluation( + app: AppDB, + project_id: str, + testset: TestSetDB, + status: Result, + variant: str, + variant_revision: str, +) -> EvaluationDB: + """Create a new evaluation scenario. + Returns: + EvaluationScenarioDB: The created evaluation scenario. + """ + + async with engine.core_session() as session: + evaluation = EvaluationDB( + app_id=app.id, + project_id=uuid.UUID(project_id), + testset_id=testset.id, + status=status.model_dump(), + variant_id=uuid.UUID(variant), + variant_revision_id=uuid.UUID(variant_revision), + ) + + session.add(evaluation) + await session.commit() + await session.refresh( + evaluation, + attribute_names=[ + "testset", + "variant", + "variant_revision", + "aggregated_results", + ], + ) + + return evaluation + + +async def list_evaluations(app_id: str, project_id: str): + """Retrieves evaluations of the specified app from the db. + + Args: + app_id (str): The ID of the app + project_id (str): The ID of the project + """ + + async with engine.core_session() as session: + base_query = select(EvaluationDB).filter_by( + app_id=uuid.UUID(app_id), project_id=uuid.UUID(project_id) + ) + query = base_query.options( + joinedload(EvaluationDB.testset.of_type(TestSetDB)).load_only(TestSetDB.id, TestSetDB.name), # type: ignore + ) + + result = await session.execute( + query.options( + joinedload(EvaluationDB.variant.of_type(AppVariantDB)).load_only(AppVariantDB.id, AppVariantDB.variant_name), # type: ignore + joinedload(EvaluationDB.variant_revision.of_type(AppVariantRevisionsDB)).load_only(AppVariantRevisionsDB.revision), # type: ignore + joinedload( + EvaluationDB.aggregated_results.of_type( + EvaluationAggregatedResultDB + ) + ).joinedload(EvaluationAggregatedResultDB.evaluator_config), + ) + ) + evaluations = result.unique().scalars().all() + return evaluations + + +async def fetch_evaluations_by_resource( + resource_type: str, project_id: str, resource_ids: List[str] +): + """ + Fetches an evaluations by resource. + + Args: + resource_type (str): The resource type + project_id (str): The ID of the project + resource_ids (List[str]): The resource identifiers + + Returns: + The evaluations by resource. + + Raises: + HTTPException:400 resource_type {type} is not supported + """ + + ids = list(map(uuid.UUID, resource_ids)) + + async with engine.core_session() as session: + if resource_type == "variant": + result_evaluations = await session.execute( + select(EvaluationDB) + .filter( + EvaluationDB.variant_id.in_(ids), + EvaluationDB.project_id == uuid.UUID(project_id), + ) + .options(load_only(EvaluationDB.id)) # type: ignore + ) + result_human_evaluations = await session.execute( + select(HumanEvaluationDB) + .join(HumanEvaluationVariantDB) + .filter( + HumanEvaluationVariantDB.variant_id.in_(ids), + HumanEvaluationDB.project_id == uuid.UUID(project_id), + ) + .options(load_only(HumanEvaluationDB.id)) # type: ignore + ) + res_evaluations = result_evaluations.scalars().all() + res_human_evaluations = result_human_evaluations.scalars().all() + return res_evaluations + res_human_evaluations + + elif resource_type == "testset": + result_evaluations = await session.execute( + select(EvaluationDB) + .filter( + EvaluationDB.testset_id.in_(ids), + EvaluationDB.project_id == uuid.UUID(project_id), + ) + .options(load_only(EvaluationDB.id)) # type: ignore + ) + result_human_evaluations = await session.execute( + select(HumanEvaluationDB) + .filter( + HumanEvaluationDB.testset_id.in_(ids), + HumanEvaluationDB.project_id + == uuid.UUID(project_id), # Fixed to match HumanEvaluationDB + ) + .options(load_only(HumanEvaluationDB.id)) # type: ignore + ) + res_evaluations = result_evaluations.scalars().all() + res_human_evaluations = result_human_evaluations.scalars().all() + return res_evaluations + res_human_evaluations + + elif resource_type == "evaluator_config": + query = ( + select(EvaluationDB) + .join(EvaluationDB.evaluator_configs) + .filter( + EvaluationEvaluatorConfigDB.evaluator_config_id.in_(ids), + EvaluationDB.project_id == uuid.UUID(project_id), + ) + ) + result = await session.execute(query) + res = result.scalars().all() + return res + + raise HTTPException( + status_code=400, + detail=f"resource_type {resource_type} is not supported", + ) + + +async def delete_evaluations(evaluation_ids: List[str]) -> None: + """Delete evaluations based on the ids provided from the db. + + Args: + evaluations_ids (list[str]): The IDs of the evaluation + """ + + async with engine.core_session() as session: + query = select(EvaluationDB).where(EvaluationDB.id.in_(evaluation_ids)) + result = await session.execute(query) + evaluations = result.scalars().all() + for evaluation in evaluations: + await session.delete(evaluation) + await session.commit() + + +async def create_new_evaluation_scenario( + project_id: str, + evaluation_id: str, + variant_id: str, + inputs: List[EvaluationScenarioInput], + outputs: List[EvaluationScenarioOutput], + correct_answers: Optional[List[CorrectAnswer]], + is_pinned: Optional[bool], + note: Optional[str], + results: List[EvaluationScenarioResult], +) -> EvaluationScenarioDB: + """Create a new evaluation scenario. + + Returns: + EvaluationScenarioDB: The created evaluation scenario. + """ + + async with engine.core_session() as session: + evaluation_scenario = EvaluationScenarioDB( + project_id=uuid.UUID(project_id), + evaluation_id=uuid.UUID(evaluation_id), + variant_id=uuid.UUID(variant_id), + inputs=[input.model_dump() for input in inputs], + outputs=[output.model_dump() for output in outputs], + correct_answers=( + [correct_answer.model_dump() for correct_answer in correct_answers] + if correct_answers is not None + else [] + ), + is_pinned=is_pinned, + note=note, + ) + + session.add(evaluation_scenario) + await session.commit() + await session.refresh(evaluation_scenario) + + # create evaluation scenario result + for result in results: + evaluation_scenario_result = EvaluationScenarioResultDB( + evaluation_scenario_id=evaluation_scenario.id, + evaluator_config_id=uuid.UUID(result.evaluator_config), + result=result.result.model_dump(), + ) + + session.add(evaluation_scenario_result) + + await session.commit() # ensures that scenario results insertion is committed + await session.refresh(evaluation_scenario) + + return evaluation_scenario + + +async def update_evaluation_with_aggregated_results( + evaluation_id: str, aggregated_results: List[AggregatedResult] +): + async with engine.core_session() as session: + for result in aggregated_results: + aggregated_result = EvaluationAggregatedResultDB( + evaluation_id=uuid.UUID(evaluation_id), + evaluator_config_id=uuid.UUID(result.evaluator_config), + result=result.result.model_dump(), + ) + session.add(aggregated_result) + + await session.commit() + + +async def fetch_eval_aggregated_results(evaluation_id: str): + """ + Fetches an evaluation aggregated results by evaluation identifier. + + Args: + evaluation_id (str): The evaluation identifier + + Returns: + The evaluation aggregated results by evaluation identifier. + """ + + async with engine.core_session() as session: + base_query = select(EvaluationAggregatedResultDB).filter_by( + evaluation_id=uuid.UUID(evaluation_id) + ) + query = base_query.options( + joinedload( + EvaluationAggregatedResultDB.evaluator_config.of_type(EvaluatorConfigDB) + ).load_only( + EvaluatorConfigDB.id, # type: ignore + EvaluatorConfigDB.name, # type: ignore + EvaluatorConfigDB.evaluator_key, # type: ignore + EvaluatorConfigDB.settings_values, # type: ignore + EvaluatorConfigDB.created_at, # type: ignore + EvaluatorConfigDB.updated_at, # type: ignore + ) + ) + + result = await session.execute(query) + aggregated_results = result.scalars().all() + return aggregated_results + + +async def update_evaluation( + evaluation_id: str, project_id: str, updates: Dict[str, Any] +) -> EvaluationDB: + """ + Update an evaluator configuration in the database with the provided id. + + Arguments: + evaluation_id (str): The ID of the evaluator configuration to be updated. + project_id (str): The ID of the project. + updates (Dict[str, Any]): The updates to apply to the evaluator configuration. + + Returns: + EvaluatorConfigDB: The updated evaluator configuration object. + """ + + async with engine.core_session() as session: + result = await session.execute( + select(EvaluationDB).filter_by( + id=uuid.UUID(evaluation_id), project_id=uuid.UUID(project_id) + ) + ) + evaluation = result.scalars().first() + for key, value in updates.items(): + if hasattr(evaluation, key): + setattr(evaluation, key, value) + + await session.commit() + await session.refresh(evaluation) + + return evaluation + + +async def check_if_evaluation_contains_failed_evaluation_scenarios( + evaluation_id: str, +) -> bool: + async with engine.core_session() as session: + EvaluationResultAlias = aliased(EvaluationScenarioResultDB) + query = ( + select(func.count(EvaluationScenarioDB.id)) + .join(EvaluationResultAlias, EvaluationScenarioDB.results) + .where( + EvaluationScenarioDB.evaluation_id == uuid.UUID(evaluation_id), + EvaluationResultAlias.result["type"].astext == "error", + ) + ) + + result = await session.execute(query) + count = result.scalar() + if not count: + return False + return count > 0 diff --git a/api/ee/src/services/email_helper.py b/api/ee/src/services/email_helper.py new file mode 100644 index 0000000000..4316160ddf --- /dev/null +++ b/api/ee/src/services/email_helper.py @@ -0,0 +1,51 @@ +import time + +import requests + +from oss.src.utils.env import env + + +def add_contact_to_loops(email, max_retries=5, initial_delay=1): + """ + Add a contact to Loops audience with retry and exponential backoff. + + Args: + email (str): Email address of the contact to be added. + max_retries (int): Maximum number of retries in case of rate limiting. + initial_delay (int): Initial delay in seconds before retrying. + + Raises: + ConnectionError: If max retries reached and unable to connect to Loops API. + + Returns: + requests.Response: Response object from the Loops API. + """ + + # Endpoint URL + url = "https://app.loops.so/api/v1/contacts/create" + + # Request headers + headers = {"Authorization": f"Bearer {env.LOOPS_API_KEY}"} + + # Request payload/body + data = {"email": email} + + retries = 0 + delay = initial_delay + + while retries < max_retries: + # Making the POST request + response = requests.post(url, json=data, headers=headers, timeout=20) + + # If response code is 429, it indicates rate limiting + if response.status_code == 429: + print(f"Rate limit hit. Retrying in {delay} seconds...") + time.sleep(delay) + retries += 1 + delay *= 2 # Double the delay for exponential backoff + else: + # If response is not 429, return it + return response + + # If max retries reached, raise an exception or handle as needed + raise ConnectionError("Max retries reached. Unable to connect to Loops API.") diff --git a/api/ee/src/services/evaluation_service.py b/api/ee/src/services/evaluation_service.py new file mode 100644 index 0000000000..e2cd9a2d8f --- /dev/null +++ b/api/ee/src/services/evaluation_service.py @@ -0,0 +1,502 @@ +from typing import Dict, List, Any + +from fastapi import HTTPException + +from oss.src.utils.logging import get_module_logger +from ee.src.services import converters +from oss.src.services import db_manager +from ee.src.services import db_manager_ee + +from oss.src.models.api.evaluation_model import ( + Evaluation, + EvaluationType, + HumanEvaluation, + HumanEvaluationScenario, + HumanEvaluationUpdate, + EvaluationScenarioUpdate, + EvaluationStatusEnum, + NewHumanEvaluation, +) +from oss.src.models.db_models import AppDB +from ee.src.models.db_models import ( + EvaluationDB, + HumanEvaluationDB, + HumanEvaluationScenarioDB, +) + +from oss.src.models.shared_models import ( + HumanEvaluationScenarioInput, + HumanEvaluationScenarioOutput, + Result, +) + +log = get_module_logger(__name__) + + +class UpdateEvaluationScenarioError(Exception): + """Custom exception for update evaluation scenario errors.""" + + pass + + +async def prepare_csvdata_and_create_evaluation_scenario( + csvdata: List[Dict[str, str]], + payload_inputs: List[str], + project_id: str, + evaluation_type: EvaluationType, + new_evaluation: HumanEvaluationDB, +): + """ + Prepares CSV data and creates evaluation scenarios based on the inputs, evaluation + type, and other parameters provided. + + Args: + csvdata: A list of dictionaries representing the CSV data. + payload_inputs: A list of strings representing the names of the inputs in the variant. + project_id (str): The ID of the project + evaluation_type: The type of evaluation + new_evaluation: The instance of EvaluationDB + """ + + for datum in csvdata: + # Check whether the inputs in the test set match the inputs in the variant + try: + inputs = [ + {"input_name": name, "input_value": datum[name]} + for name in payload_inputs + ] + except KeyError: + await db_manager_ee.delete_human_evaluation( + evaluation_id=str(new_evaluation.id) + ) + msg = f""" + Columns in the test set should match the names of the inputs in the variant. + Inputs names in variant are: {[variant_input for variant_input in payload_inputs]} while + columns in test set are: {[col for col in datum.keys() if col != 'correct_answer']} + """ + raise HTTPException( + status_code=400, + detail=msg, + ) + + # Prepare scenario inputs + list_of_scenario_input = [] + for scenario_input in inputs: + eval_scenario_input_instance = HumanEvaluationScenarioInput( + input_name=scenario_input["input_name"], + input_value=scenario_input["input_value"], + ) + list_of_scenario_input.append(eval_scenario_input_instance) + + evaluation_scenario_extend_payload = { + **_extend_with_evaluation(evaluation_type), + **_extend_with_correct_answer(evaluation_type, datum), + } + await db_manager_ee.create_human_evaluation_scenario( + inputs=list_of_scenario_input, + project_id=project_id, + evaluation_id=str(new_evaluation.id), + evaluation_extend=evaluation_scenario_extend_payload, + ) + + +async def update_human_evaluation_service( + evaluation: EvaluationDB, update_payload: HumanEvaluationUpdate +) -> None: + """ + Update an existing evaluation based on the provided payload. + + Args: + evaluation (EvaluationDB): The evaluation instance. + update_payload (EvaluationUpdate): The payload for the update. + """ + + # Update the evaluation + await db_manager_ee.update_human_evaluation( + evaluation_id=str(evaluation.id), values_to_update=update_payload.model_dump() + ) + + +async def fetch_evaluation_scenarios_for_evaluation( + evaluation_id: str, project_id: str +): + """ + Fetch evaluation scenarios for a given evaluation ID. + + Args: + evaluation_id (str): The ID of the evaluation. + project_id (str): The ID of the project. + + Returns: + List[EvaluationScenario]: A list of evaluation scenarios. + """ + + evaluation_scenarios = await db_manager_ee.fetch_evaluation_scenarios( + evaluation_id=evaluation_id, project_id=project_id + ) + return [ + await converters.evaluation_scenario_db_to_pydantic( + evaluation_scenario_db=evaluation_scenario, evaluation_id=evaluation_id + ) + for evaluation_scenario in evaluation_scenarios + ] + + +async def fetch_human_evaluation_scenarios_for_evaluation( + human_evaluation: HumanEvaluationDB, +) -> List[HumanEvaluationScenario]: + """ + Fetch evaluation scenarios for a given evaluation ID. + + Args: + evaluation_id (str): The ID of the evaluation. + + Raises: + HTTPException: If the evaluation is not found or access is denied. + + Returns: + List[EvaluationScenario]: A list of evaluation scenarios. + """ + human_evaluation_scenarios = await db_manager_ee.fetch_human_evaluation_scenarios( + evaluation_id=str(human_evaluation.id) + ) + eval_scenarios = [ + converters.human_evaluation_scenario_db_to_pydantic( + evaluation_scenario_db=human_evaluation_scenario, + evaluation_id=str(human_evaluation.id), + ) + for human_evaluation_scenario in human_evaluation_scenarios + ] + return eval_scenarios + + +async def update_human_evaluation_scenario( + evaluation_scenario_db: HumanEvaluationScenarioDB, + evaluation_scenario_data: EvaluationScenarioUpdate, + evaluation_type: EvaluationType, +) -> None: + """ + Updates an evaluation scenario. + + Args: + evaluation_scenario_db (EvaluationScenarioDB): The evaluation scenario instance. + evaluation_scenario_data (EvaluationScenarioUpdate): New data for the scenario. + evaluation_type (EvaluationType): Type of the evaluation. + + Raises: + HTTPException: If evaluation scenario not found or access denied. + """ + + values_to_update = {} + payload = evaluation_scenario_data.model_dump(exclude_unset=True) + + if "score" in payload and evaluation_type == EvaluationType.single_model_test: + values_to_update["score"] = str(payload["score"]) + + if "vote" in payload and evaluation_type == EvaluationType.human_a_b_testing: + values_to_update["vote"] = payload["vote"] + + if "outputs" in payload: + new_outputs: List[Dict[str, Any]] = [ + HumanEvaluationScenarioOutput( + variant_id=output["variant_id"], + variant_output=output["variant_output"], + ).model_dump() + for output in payload["outputs"] + ] + values_to_update["outputs"] = new_outputs # type: ignore + + if "inputs" in payload: + new_inputs: List[Dict[str, Any]] = [ + HumanEvaluationScenarioInput( + input_name=input_item["input_name"], + input_value=input_item["input_value"], + ).model_dump() + for input_item in payload["inputs"] + ] + values_to_update["inputs"] = new_inputs # type: ignore + + if "is_pinned" in payload: + values_to_update["is_pinned"] = payload["is_pinned"] + + if "note" in payload: + values_to_update["note"] = payload["note"] + + if "correct_answer" in payload: + values_to_update["correct_answer"] = payload["correct_answer"] + + await db_manager_ee.update_human_evaluation_scenario( + evaluation_scenario_id=str(evaluation_scenario_db.id), + values_to_update=values_to_update, + ) + + +def _extend_with_evaluation(evaluation_type: EvaluationType): + evaluation = {} + if evaluation_type == EvaluationType.single_model_test: + evaluation["score"] = "" + + if evaluation_type == EvaluationType.human_a_b_testing: + evaluation["vote"] = "" + return evaluation + + +def _extend_with_correct_answer(evaluation_type: EvaluationType, row: dict): + correct_answer = {"correct_answer": ""} + if row.get("correct_answer") is not None: + correct_answer["correct_answer"] = row["correct_answer"] + return correct_answer + + +async def fetch_list_evaluations(app: AppDB, project_id: str) -> List[Evaluation]: + """ + Fetches a list of evaluations based on the provided filtering criteria. + + Args: + app (AppDB): An app to filter the evaluations. + project_id (str): The ID of the project + + Returns: + List[Evaluation]: A list of evaluations. + """ + + evaluations_db = await db_manager_ee.list_evaluations( + app_id=str(app.id), project_id=project_id + ) + return [ + await converters.evaluation_db_to_pydantic(evaluation) + for evaluation in evaluations_db + ] + + +async def fetch_list_human_evaluations( + app_id: str, project_id: str +) -> List[HumanEvaluation]: + """ + Fetches a list of evaluations based on the provided filtering criteria. + + Args: + app_id (Optional[str]): An optional app ID to filter the evaluations. + project_id (str): The ID of the project. + + Returns: + List[Evaluation]: A list of evaluations. + """ + + evaluations_db = await db_manager_ee.list_human_evaluations( + app_id=app_id, project_id=project_id + ) + return [ + await converters.human_evaluation_db_to_pydantic(evaluation) + for evaluation in evaluations_db + ] + + +async def fetch_human_evaluation(human_evaluation_db) -> HumanEvaluation: + """ + Fetches a single evaluation based on its ID. + + Args: + human_evaluation_db (HumanEvaluationDB): The evaluation instance. + + Returns: + Evaluation: The fetched evaluation. + """ + + return await converters.human_evaluation_db_to_pydantic(human_evaluation_db) + + +async def delete_human_evaluations(evaluation_ids: List[str]) -> None: + """ + Delete evaluations by their IDs. + + Args: + evaluation_ids (List[str]): A list of evaluation IDs. + project_id (str): The ID of the project. + + Raises: + NoResultFound: If evaluation not found or access denied. + """ + + for evaluation_id in evaluation_ids: + await db_manager_ee.delete_human_evaluation(evaluation_id=evaluation_id) + + +async def delete_evaluations(evaluation_ids: List[str]) -> None: + """ + Delete evaluations by their IDs. + + Args: + evaluation_ids (List[str]): A list of evaluation IDs. + + Raises: + HTTPException: If evaluation not found or access denied. + """ + + await db_manager_ee.delete_evaluations(evaluation_ids=evaluation_ids) + + +async def create_new_human_evaluation(payload: NewHumanEvaluation) -> HumanEvaluationDB: + """ + Create a new evaluation based on the provided payload and additional arguments. + + Args: + payload (NewHumanEvaluation): The evaluation payload. + + Returns: + HumanEvaluationDB + """ + + app = await db_manager.fetch_app_by_id(app_id=payload.app_id) + if app is None: + raise HTTPException( + status_code=404, + detail=f"App with id {payload.app_id} does not exist", + ) + + human_evaluation = await db_manager_ee.create_human_evaluation( + app=app, + status=payload.status, + evaluation_type=payload.evaluation_type, + testset_id=payload.testset_id, + variants_ids=payload.variant_ids, + ) + if human_evaluation is None: + raise HTTPException( + status_code=500, detail="Failed to create evaluation_scenario" + ) + + await prepare_csvdata_and_create_evaluation_scenario( + human_evaluation.testset.csvdata, + payload.inputs, + str(app.project_id), + payload.evaluation_type, + human_evaluation, + ) + return human_evaluation + + +async def create_new_evaluation( + app_id: str, + project_id: str, + revision_id: str, + testset_id: str, +) -> Evaluation: + """ + Create a new evaluation in the db + + Args: + app_id (str): The ID of the app. + project_id (str): The ID of the project. + revision_id (str): The ID of the variant revision. + testset_id (str): The ID of the testset. + + Returns: + Evaluation: The newly created evaluation. + """ + + app = await db_manager.fetch_app_by_id(app_id=app_id) + testset = await db_manager.fetch_testset_by_id(testset_id=testset_id) + variant_revision = await db_manager.fetch_app_variant_revision_by_id( + variant_revision_id=revision_id + ) + + assert ( + variant_revision and variant_revision.revision is not None + ), f"Variant revision with {revision_id} cannot be None" + + evaluation_db = await db_manager_ee.create_new_evaluation( + app=app, + project_id=project_id, + testset=testset, + status=Result( + value=EvaluationStatusEnum.EVALUATION_INITIALIZED, type="status", error=None + ), + variant=str(variant_revision.variant_id), + variant_revision=str(variant_revision.id), + ) + return await converters.evaluation_db_to_pydantic(evaluation_db) + + +async def compare_evaluations_scenarios(evaluations_ids: List[str], project_id: str): + evaluation = await db_manager_ee.fetch_evaluation_by_id( + project_id=project_id, + evaluation_id=evaluations_ids[0], + ) + testset = evaluation.testset + unique_testset_datapoints = remove_duplicates(testset.csvdata) + formatted_inputs = extract_inputs_values_from_testset(unique_testset_datapoints) + # # formatted_inputs: [{'input_name': 'country', 'input_value': 'Nauru'}] + + all_scenarios = [] + + for evaluation_id in evaluations_ids: + eval_scenarios = await fetch_evaluation_scenarios_for_evaluation( + evaluation_id=evaluation_id, project_id=project_id + ) + all_scenarios.append(eval_scenarios) + + grouped_scenarios_by_inputs = find_scenarios_by_input( + formatted_inputs, all_scenarios + ) + + return grouped_scenarios_by_inputs + + +def extract_inputs_values_from_testset(testset): + extracted_values = [] + + input_keys = testset[0].keys() + + for entry in testset: + for key in input_keys: + if key != "correct_answer": + extracted_values.append({"input_name": key, "input_value": entry[key]}) + + return extracted_values + + +def find_scenarios_by_input(formatted_inputs, all_scenarios): + results = [] + flattened_scenarios = [ + scenario for sublist in all_scenarios for scenario in sublist + ] + + for formatted_input in formatted_inputs: + input_name = formatted_input["input_name"] + input_value = formatted_input["input_value"] + + matching_scenarios = [ + scenario + for scenario in flattened_scenarios + if any( + input_item.name == input_name and input_item.value == input_value + for input_item in scenario.inputs + ) + ] + + results.append( + { + "input_name": input_name, + "input_value": input_value, + "scenarios": matching_scenarios, + } + ) + + return { + "inputs": formatted_inputs, + "data": results, + } + + +def remove_duplicates(csvdata): + unique_data = set() + unique_entries = [] + + for entry in csvdata: + entry_tuple = tuple(entry.items()) + if entry_tuple not in unique_data: + unique_data.add(entry_tuple) + unique_entries.append(entry) + + return unique_entries diff --git a/api/ee/src/services/llm_apps_service.py b/api/ee/src/services/llm_apps_service.py new file mode 100644 index 0000000000..15267ec378 --- /dev/null +++ b/api/ee/src/services/llm_apps_service.py @@ -0,0 +1,578 @@ +import json +import asyncio +import traceback +import aiohttp +from datetime import datetime +from typing import Any, Dict, List, Optional + +from oss.src.utils.logging import get_module_logger +from oss.src.utils import common +from oss.src.services import helpers +from oss.src.services.auth_helper import sign_secret_token +from oss.src.services.db_manager import get_project_by_id +from oss.src.apis.fastapi.tracing.utils import make_hash_id +from oss.src.models.shared_models import InvokationResult, Result, Error + +log = get_module_logger(__name__) + + +def get_nested_value(d: dict, keys: list, default=None): + """ + Helper function to safely retrieve nested values. + """ + try: + for key in keys: + if isinstance(d, dict): + d = d.get(key, default) + else: + return default + return d + except Exception as e: + log.error(f"Error accessing nested value: {e}") + return default + + +def extract_result_from_response(response: dict): + # Initialize default values + value = None + latency = None + cost = None + tokens = None + + try: + # Validate input + if not isinstance(response, dict): + raise ValueError("The response must be a dictionary.") + + # Handle version 3.0 response + if response.get("version") == "3.0": + value = response + # Ensure 'data' is a dictionary or convert it to a string + if not isinstance(value.get("data"), dict): + value["data"] = str(value.get("data")) + + if "tree" in response: + trace_tree = response.get("tree", {}).get("nodes", [])[0] + + duration_ms = get_nested_value( + trace_tree, ["metrics", "acc", "duration", "total"] + ) + if duration_ms: + duration_seconds = duration_ms / 1000 + else: + start_time = get_nested_value(trace_tree, ["time", "start"]) + end_time = get_nested_value(trace_tree, ["time", "end"]) + + if start_time and end_time: + duration_seconds = ( + datetime.fromisoformat(end_time) + - datetime.fromisoformat(start_time) + ).total_seconds() + else: + duration_seconds = None + + latency = duration_seconds + cost = get_nested_value( + trace_tree, ["metrics", "acc", "costs", "total"] + ) + tokens = get_nested_value( + trace_tree, ["metrics", "acc", "tokens", "total"] + ) + + # Handle version 2.0 response + elif response.get("version") == "2.0": + value = response + if not isinstance(value.get("data"), dict): + value["data"] = str(value.get("data")) + + if "trace" in response: + latency = response["trace"].get("latency", None) + cost = response["trace"].get("cost", None) + tokens = response["trace"].get("tokens", None) + + # Handle generic response (neither 2.0 nor 3.0) + else: + value = {"data": str(response.get("message", ""))} + latency = response.get("latency", None) + cost = response.get("cost", None) + tokens = response.get("tokens", None) + + # Determine the type of 'value' (either 'text' or 'object') + kind = "text" if isinstance(value, str) else "object" + + except ValueError as ve: + log.error(f"Input validation error: {ve}") + value = {"error": str(ve)} + kind = "error" + + except KeyError as ke: + log.error(f"Missing key: {ke}") + value = {"error": f"Missing key: {ke}"} + kind = "error" + + except TypeError as te: + log.error(f"Type error: {te}") + value = {"error": f"Type error: {te}"} + kind = "error" + + except Exception as e: + log.error(f"Unexpected error: {e}") + value = {"error": f"Unexpected error: {e}"} + kind = "error" + + return value, kind, cost, tokens, latency + + +async def make_payload( + datapoint: Any, parameters: Dict, openapi_parameters: List[Dict] +) -> Dict: + """ + Constructs the payload for invoking an app based on OpenAPI parameters. + + Args: + datapoint (Any): The data to be sent to the app. + parameters (Dict): The parameters required by the app taken from the db. + openapi_parameters (List[Dict]): The OpenAPI parameters of the app. + + Returns: + Dict: The constructed payload for the app. + """ + payload = {} + inputs = {} + messages = [] + + for param in openapi_parameters: + if param["name"] == "ag_config": + payload["ag_config"] = parameters + elif param["type"] == "input": + item = datapoint.get(param["name"], parameters.get(param["name"], "")) + assert ( + param["name"] != "ag_config" + ), "ag_config should be handled separately" + payload[param["name"]] = item + + # in case of dynamic inputs (as in our templates) + elif param["type"] == "dict": + # let's get the list of the dynamic inputs + if ( + param["name"] in parameters + ): # in case we have modified in the playground the default list of inputs (e.g. country_name) + input_names = [_["name"] for _ in parameters[param["name"]]] + else: # otherwise we use the default from the openapi + input_names = param["default"] + + for input_name in input_names: + item = datapoint.get(input_name, "") + inputs[input_name] = item + elif param["type"] == "messages": + # TODO: Right now the FE is saving chats always under the column name chats. The whole logic for handling chats and dynamic inputs is convoluted and needs rework in time. + chat_data = datapoint.get("chat", "") + item = json.loads(chat_data) + payload[param["name"]] = item + elif param["type"] == "file_url": + item = datapoint.get(param["name"], "") + payload[param["name"]] = item + else: + if param["name"] in parameters: # hotfix + log.warn( + f"Processing other param type '{param['type']}': {param['name']}" + ) + item = parameters[param["name"]] + payload[param["name"]] = item + + try: + input_keys = helpers.find_key_occurrences(parameters, "input_keys") or [] + inputs = {key: datapoint.get(key, None) for key in input_keys} + + messages_data = datapoint.get("messages", "[]") + messages = json.loads(messages_data) + payload["messages"] = messages + except Exception as e: # pylint: disable=broad-exception-caught + log.warn(f"Error making payload: {e}") + + payload["inputs"] = inputs + + return payload + + +async def invoke_app( + uri: str, + datapoint: Any, + parameters: Dict, + openapi_parameters: List[Dict], + user_id: str, + project_id: str, + **kwargs, +) -> InvokationResult: + """ + Invokes an app for one datapoint using the openapi_parameters to determine + how to invoke the app. + + Args: + uri (str): The URI of the app to invoke. + datapoint (Any): The data to be sent to the app. + parameters (Dict): The parameters required by the app taken from the db. + openapi_parameters (List[Dict]): The OpenAPI parameters of the app. + + Returns: + InvokationResult: The output of the app. + + Raises: + aiohttp.ClientError: If the POST request fails. + """ + + url = f"{uri}/test" + if "application_id" in kwargs: + url = url + f"?application_id={kwargs.get('application_id')}" + + payload = await make_payload(datapoint, parameters, openapi_parameters) + + project = await get_project_by_id( + project_id=project_id, + ) + + secret_token = await sign_secret_token( + user_id=str(user_id), + project_id=str(project_id), + workspace_id=str(project.workspace_id), + organization_id=str(project.organization_id), + ) + + headers = {} + if secret_token: + headers = {"Authorization": f"Secret {secret_token}"} + headers["ngrok-skip-browser-warning"] = "1" + + async with aiohttp.ClientSession() as client: + app_response = {} + + try: + log.info("Invoking workflow...", url=url) + response = await client.post( + url, + json=payload, + headers=headers, + timeout=900, + ) + app_response = await response.json() + response.raise_for_status() + + ( + value, + kind, + cost, + tokens, + latency, + ) = extract_result_from_response(app_response) + + trace_id = app_response.get("trace_id", None) + span_id = app_response.get("span_id", None) + + return InvokationResult( + result=Result( + type=kind, + value=value, + error=None, + ), + latency=latency, + cost=cost, + tokens=tokens, + trace_id=trace_id, + span_id=span_id, + ) + + except aiohttp.ClientResponseError as e: + error_message = app_response.get("detail", {}).get( + "error", f"HTTP error {e.status}: {e.message}" + ) + stacktrace = app_response.get("detail", {}).get( + "message" + ) or app_response.get("detail", {}).get( + "traceback", "".join(traceback.format_exception_only(type(e), e)) + ) + log.error(f"HTTP error occurred during request: {error_message}") + except aiohttp.ServerTimeoutError as e: + error_message = "Request timed out" + stacktrace = "".join(traceback.format_exception_only(type(e), e)) + log.error(error_message) + except aiohttp.ClientConnectionError as e: + error_message = f"Connection error: {str(e)}" + stacktrace = "".join(traceback.format_exception_only(type(e), e)) + log.error(error_message) + except json.JSONDecodeError as e: + error_message = "Failed to decode JSON from response" + stacktrace = "".join(traceback.format_exception_only(type(e), e)) + log.error(error_message) + except Exception as e: + error_message = f"Unexpected error: {str(e)}" + stacktrace = "".join(traceback.format_exception_only(type(e), e)) + log.error(error_message) + + return InvokationResult( + result=Result( + type="error", + error=Error( + message=error_message, + stacktrace=stacktrace, + ), + ) + ) + + +async def run_with_retry( + uri: str, + input_data: Any, + parameters: Dict, + max_retry_count: int, + retry_delay: int, + openapi_parameters: List[Dict], + user_id: str, + project_id: str, + **kwargs, +) -> InvokationResult: + """ + Runs the specified app with retry mechanism. + + Args: + uri (str): The URI of the app. + input_data (Any): The input data for the app. + parameters (Dict): The parameters for the app. + max_retry_count (int): The maximum number of retries. + retry_delay (int): The delay between retries in seconds. + openapi_parameters (List[Dict]): The OpenAPI parameters for the app. + + Returns: + InvokationResult: The invokation result. + + """ + + if "references" in kwargs and "testcase_id" in input_data: + kwargs["references"]["testcase"] = {"id": input_data["testcase_id"]} + + references = kwargs.get("references", None) + links = kwargs.get("links", None) + # hash_id = make_hash_id(references=references, links=links) + + retries = 0 + last_exception = None + while retries < max_retry_count: + try: + result = await invoke_app( + uri, + input_data, + parameters, + openapi_parameters, + user_id, + project_id, + **kwargs, + ) + return result + except aiohttp.ClientError as e: + last_exception = e + log.error(f"Error in evaluation. Retrying in {retry_delay} seconds:", e) + await asyncio.sleep(retry_delay) + retries += 1 + except Exception as e: + last_exception = e + log.warn(f"Error processing datapoint: {input_data}. {str(e)}") + log.warn("".join(traceback.format_exception_only(type(e), e))) + retries += 1 + + # If max retries is reached or an exception that isn't in the second block, + # update & return the last exception + log.warn("Max retries reached") + exception_message = ( + "Max retries reached" + if retries == max_retry_count + else f"Error processing {input_data} datapoint" + ) + + return InvokationResult( + result=Result( + type="error", + value=None, + error=Error(message=exception_message, stacktrace=str(last_exception)), + ) + ) + + +async def batch_invoke( + uri: str, + testset_data: List[Dict], + parameters: Dict, + rate_limit_config: Dict, + user_id: str, + project_id: str, + **kwargs, +) -> List[InvokationResult]: + """ + Invokes the LLm apps in batches, processing the testset data. + + Args: + uri (str): The URI of the LLm app. + testset_data (List[Dict]): The testset data to be processed. + parameters (Dict): The parameters for the LLm app. + rate_limit_config (Dict): The rate limit configuration. + + Returns: + List[InvokationResult]: The list of app outputs after running all batches. + """ + batch_size = rate_limit_config[ + "batch_size" + ] # Number of testset to make in each batch + max_retries = rate_limit_config[ + "max_retries" + ] # Maximum number of times to retry the failed llm call + retry_delay = rate_limit_config[ + "retry_delay" + ] # Delay before retrying the failed llm call (in seconds) + delay_between_batches = rate_limit_config[ + "delay_between_batches" + ] # Delay between batches (in seconds) + + list_of_app_outputs: List[ + InvokationResult + ] = [] # Outputs after running all batches + + project = await get_project_by_id( + project_id=project_id, + ) + + secret_token = await sign_secret_token( + user_id=str(user_id), + project_id=str(project_id), + workspace_id=str(project.workspace_id), + organization_id=str(project.organization_id), + ) + + headers = {} + if secret_token: + headers = {"Authorization": f"Secret {secret_token}"} + headers["ngrok-skip-browser-warning"] = "1" + + openapi_parameters = None + max_recursive_depth = 5 + runtime_prefix = uri + route_path = "" + + while max_recursive_depth > 0 and not openapi_parameters: + try: + openapi_parameters = await get_parameters_from_openapi( + runtime_prefix + "/openapi.json", + route_path, + headers, + ) + except Exception: # pylint: disable=broad-exception-caught + openapi_parameters = None + + if not openapi_parameters: + max_recursive_depth -= 1 + if not runtime_prefix.endswith("/"): + route_path = "/" + runtime_prefix.split("/")[-1] + route_path + runtime_prefix = "/".join(runtime_prefix.split("/")[:-1]) + else: + route_path = "" + runtime_prefix = runtime_prefix[:-1] + + # Final attempt to fetch OpenAPI parameters + openapi_parameters = await get_parameters_from_openapi( + runtime_prefix + "/openapi.json", + route_path, + headers, + ) + + # 🆕 Rewritten loop instead of recursion + for start_idx in range(0, len(testset_data), batch_size): + tasks = [] + + end_idx = min(start_idx + batch_size, len(testset_data)) + for index in range(start_idx, end_idx): + task = asyncio.ensure_future( + run_with_retry( + uri, + testset_data[index], + parameters, + max_retries, + retry_delay, + openapi_parameters, + user_id, + project_id, + **kwargs, + ) + ) + tasks.append(task) + + results = await asyncio.gather(*tasks) + + for result in results: + list_of_app_outputs.append(result) + + # Delay between batches if more to come + if end_idx < len(testset_data): + await asyncio.sleep(delay_between_batches) + + return list_of_app_outputs + + +async def get_parameters_from_openapi( + runtime_prefix: str, + route_path: str, + headers: Optional[Dict[str, str]], +) -> List[Dict]: + """ + Parse the OpenAI schema of an LLM app to return list of parameters that it takes with their type as determined by the x-parameter + Args: + uri (str): The URI of the OpenAPI schema. + + Returns: + list: A list of parameters. Each a dict with name and type. + Type can be one of: input, text, choice, float, dict, bool, int, file_url, messages. + + Raises: + KeyError: If the required keys are not found in the schema. + + """ + + schema = await _get_openai_json_from_uri(runtime_prefix, headers) + + try: + body_schema_name = ( + schema["paths"][route_path + "/test"]["post"]["requestBody"]["content"][ + "application/json" + ]["schema"]["$ref"] + .split("/") + .pop() + ) + except KeyError: + body_schema_name = "" + + try: + properties = schema["components"]["schemas"][body_schema_name]["properties"] + except KeyError: + properties = {} + + parameters = [] + for name, param in properties.items(): + parameters.append( + { + "name": name, + "type": param.get("x-parameter", "input"), + "default": param.get("default", []), + } + ) + return parameters + + +async def _get_openai_json_from_uri( + uri: str, + headers: Optional[Dict[str, str]], +): + if headers is None: + headers = {} + headers["ngrok-skip-browser-warning"] = "1" + + async with aiohttp.ClientSession() as client: + resp = await client.get(uri, headers=headers, timeout=5) + resp_text = await resp.text() + json_data = json.loads(resp_text) + return json_data diff --git a/api/ee/src/services/organization_service.py b/api/ee/src/services/organization_service.py new file mode 100644 index 0000000000..7ee4fdb150 --- /dev/null +++ b/api/ee/src/services/organization_service.py @@ -0,0 +1,121 @@ +from urllib.parse import quote + +from ee.src.services import db_manager_ee +from oss.src.services import email_service +from oss.src.models.db_models import UserDB +from ee.src.models.db_models import ( + WorkspaceDB, + OrganizationDB, +) +from ee.src.models.api.organization_models import ( + OrganizationUpdate, +) + +from oss.src.utils.env import env + + +async def update_an_organization( + org_id: str, payload: OrganizationUpdate +) -> OrganizationDB: + org = await db_manager_ee.get_organization(org_id) + if org is not None: + await db_manager_ee.update_organization(str(org.id), payload) + return org + raise NotFound("Organization not found") + + +class NotFound(Exception): + """Custom exception for credentials not found""" + + pass + + +async def send_invitation_email( + email: str, + token: str, + project_id: str, + workspace: WorkspaceDB, + organization: OrganizationDB, + user: UserDB, +): + """ + Sends an invitation email to the specified email address, containing a link to accept the invitation. + + Args: + email (str): The email address to send the invitation to. + token (str): The token to include in the invitation link. + project_id (str): The ID of the project that the user is being invited to join. + workspace (WorkspaceDB): The workspace that the user is being invited to join. + user (UserDB): The user who is sending the invitation. + + Returns: + bool: True if the email was sent successfully, False otherwise. + """ + + html_template = email_service.read_email_template("./templates/send_email.html") + + token_param = quote(token, safe="") + email_param = quote(email, safe="") + org_param = quote(str(organization.id), safe="") + workspace_param = quote(str(workspace.id), safe="") + project_param = quote(project_id, safe="") + + invite_link = ( + f"{env.AGENTA_WEB_URL}/auth?token={token_param}&email={email_param}" + f"&org_id={org_param}&workspace_id={workspace_param}&project_id={project_param}" + ) + + html_content = html_template.format( + username_placeholder=user.username, + action_placeholder="invited you to join", + workspace_placeholder=workspace.name, + call_to_action=( + "Click the link below to accept the invitation:


" + f'Accept Invitation' + ), + ) + + await email_service.send_email( + from_email="account@hello.agenta.ai", + to_email=email, + subject=f"{user.username} invited you to join {workspace.name}", + html_content=html_content, + ) + return True + + +async def notify_org_admin_invitation(workspace: WorkspaceDB, user: UserDB) -> bool: + """ + Sends an email notification to the owner of an organization when a new member joins. + + Args: + workspace (WorkspaceDB): The workspace that the user has joined. + user (UserDB): The user who has joined the organization. + + Returns: + bool: True if the email was sent successfully, False otherwise. + """ + + html_template = email_service.read_email_template("./templates/send_email.html") + html_content = html_template.format( + username_placeholder=user.username, + action_placeholder="joined your Workspace", + workspace_placeholder=f'"{workspace.name}"', + call_to_action=f'Click the link below to view your Workspace:


View Workspace', + ) + + workspace_admins = await db_manager_ee.get_workspace_administrators(workspace) + for workspace_admin in workspace_admins: + await email_service.send_email( + from_email="account@hello.agenta.ai", + to_email=workspace_admin.email, + subject=f"New Member Joined {workspace.name}", + html_content=html_content, + ) + + return True + + +async def get_organization_details(org_id: str) -> dict: + organization = await db_manager_ee.get_organization(org_id) + return await db_manager_ee.get_org_details(organization) diff --git a/api/ee/src/services/results_service.py b/api/ee/src/services/results_service.py new file mode 100644 index 0000000000..ca52151315 --- /dev/null +++ b/api/ee/src/services/results_service.py @@ -0,0 +1,116 @@ +import uuid +from typing import Sequence, Dict, Any + +from ee.src.services import db_manager_ee +from oss.src.models.api.evaluation_model import EvaluationType +from ee.src.models.db_models import ( + HumanEvaluationDB, + EvaluationScenarioDB, +) + + +async def fetch_results_for_evaluation(evaluation: HumanEvaluationDB): + evaluation_scenarios = await db_manager_ee.fetch_human_evaluation_scenarios( + evaluation_id=str(evaluation.id) + ) + + results: Dict[str, Any] = {} + if len(evaluation_scenarios) == 0: + return results + + evaluation_variants = await db_manager_ee.fetch_human_evaluation_variants( + human_evaluation_id=str(evaluation.id) + ) + results["variants"] = [ + str(evaluation_variant.variant_id) for evaluation_variant in evaluation_variants + ] + + variant_names: list[str] = [] + for evaluation_variant in evaluation_variants: + variant_name = ( + evaluation_variant.variant.variant_name + if isinstance(evaluation_variant.variant_id, uuid.UUID) + else str(evaluation_variant.variant_id) + ) + variant_names.append(str(variant_name)) + + results["variant_names"] = variant_names + results["nb_of_rows"] = len(evaluation_scenarios) + + if evaluation.evaluation_type == EvaluationType.human_a_b_testing: # type: ignore + results.update( + await _compute_stats_for_human_a_b_testing_evaluation(evaluation_scenarios) + ) + + return results + + +async def _compute_stats_for_evaluation(evaluation_scenarios: list, classes: list): + results = {} + for cl in classes: + results[cl] = [ + scenario for scenario in evaluation_scenarios if scenario.score == cl + ] + return results + + +async def _compute_stats_for_human_a_b_testing_evaluation( + evaluation_scenarios: Sequence[EvaluationScenarioDB], +): + results: Dict[str, Any] = {} + results["variants_votes_data"] = {} + results["flag_votes"] = {} + results["positive_votes"] = {} + + flag_votes_nb = [ + scenario for scenario in evaluation_scenarios if scenario.vote == "0" + ] + + positive_votes_nb = [ + scenario for scenario in evaluation_scenarios if scenario.vote == "1" + ] + + results["positive_votes"]["number_of_votes"] = len(positive_votes_nb) + results["positive_votes"]["percentage"] = ( + round(len(positive_votes_nb) / len(evaluation_scenarios) * 100, 2) + if len(evaluation_scenarios) + else 0 + ) + + results["flag_votes"]["number_of_votes"] = len(flag_votes_nb) + results["flag_votes"]["percentage"] = ( + round(len(flag_votes_nb) / len(evaluation_scenarios) * 100, 2) + if len(evaluation_scenarios) + else 0 + ) + + for scenario in evaluation_scenarios: + if scenario.vote not in results["variants_votes_data"]: + results["variants_votes_data"][scenario.vote] = {} + results["variants_votes_data"][scenario.vote]["number_of_votes"] = 1 + else: + results["variants_votes_data"][scenario.vote]["number_of_votes"] += 1 + + for key, value in results["variants_votes_data"].items(): + value["percentage"] = round( + value["number_of_votes"] / len(evaluation_scenarios) * 100, 2 + ) + return results + + +async def fetch_results_for_single_model_test(evaluation_id: str): + evaluation_scenarios = await db_manager_ee.fetch_human_evaluation_scenarios( + evaluation_id=str(evaluation_id) + ) + scores_and_counts: Dict[str, Any] = {} + for evaluation_scenario in evaluation_scenarios: + score = evaluation_scenario.score + if isinstance(score, str): + if score.isdigit(): # Check if the string is a valid integer + score = int(score) + else: + continue # Skip if the string is not a valid integer + + scores_and_counts[score] = scores_and_counts.get(score, 0) + 1 + + return scores_and_counts diff --git a/api/ee/src/services/selectors.py b/api/ee/src/services/selectors.py new file mode 100644 index 0000000000..f8a10ceecb --- /dev/null +++ b/api/ee/src/services/selectors.py @@ -0,0 +1,125 @@ +from typing import Dict, List, Union + +from sqlalchemy.future import select +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import load_only, joinedload + +from oss.src.services import db_manager +from oss.src.utils.logging import get_module_logger + +from oss.src.dbs.postgres.shared.engine import engine +from ee.src.models.api.organization_models import Organization +from ee.src.models.db_models import ( + WorkspaceDB, + OrganizationDB, + WorkspaceMemberDB, + OrganizationMemberDB, +) + +log = get_module_logger(__name__) + + +async def get_user_org_and_workspace_id(user_uid) -> Dict[str, Union[str, List[str]]]: + """ + Retrieves the user ID and organization IDs associated with a given user UID. + + Args: + user_uid (str): The UID of the user. + + Returns: + dict: A dictionary containing the user UID, ID, list of workspace IDS and list of organization IDS associated with a user. + If the user is not found, returns None + + Example Usage: + result = await get_user_org_and_workspace_id("user123") + print(result) + + Output: + { "id": "123", "uid": "user123", "organization_ids": [], "workspace_ids": []} + """ + + async with engine.core_session() as session: + user = await db_manager.get_user_with_id(user_id=user_uid) + if not user: + raise NoResultFound(f"User with uid {user_uid} not found") + + user_org_result = await session.execute( + select(OrganizationMemberDB) + .filter_by(user_id=user.id) + .options(load_only(OrganizationMemberDB.organization_id)) # type: ignore + ) + orgs = user_org_result.scalars().all() + organization_ids = [str(user_org.organization_id) for user_org in orgs] + + member_in_workspaces_result = await session.execute( + select(WorkspaceMemberDB) + .filter_by(user_id=user.id) + .options(load_only(WorkspaceMemberDB.workspace_id)) # type: ignore + ) + workspaces_ids = [ + str(user_workspace.workspace_id) + for user_workspace in member_in_workspaces_result.scalars().all() + ] + + return { + "id": str(user.id), + "uid": str(user.uid), + "workspace_ids": workspaces_ids, + "organization_ids": organization_ids, + } + + +async def user_exists(user_email: str) -> bool: + """Check if user exists in the database. + + Arguments: + user_email (str): The email address of the logged-in user + + Returns: + bool: confirming if the user exists or not. + """ + + user = await db_manager.get_user_with_email(email=user_email) + return False if not user else True + + +async def get_user_own_org(user_uid: str) -> OrganizationDB: + """Get's the default users' organization from the database. + + Arguments: + user_uid (str): The uid of the user + + Returns: + Organization: Instance of OrganizationDB + """ + + user = await db_manager.get_user_with_id(user_id=user_uid) + async with engine.core_session() as session: + result = await session.execute( + select(OrganizationDB).filter_by( + owner=str(user.id), + type="default", + ) + ) + org = result.scalars().first() + return org + + +async def get_org_default_workspace(organization: Organization) -> WorkspaceDB: + """Get's the default workspace for an organization from the database. + + Arguments: + organization (Organization): The organization + + Returns: + WorkspaceDB: Instance of WorkspaceDB + """ + + async with engine.core_session() as session: + result = await session.execute( + select(WorkspaceDB) + .filter_by(organization_id=organization.id, type="default") + .options(joinedload(WorkspaceDB.members)) + ) + workspace = result.scalars().first() + return workspace diff --git a/api/ee/src/services/templates/send_email.html b/api/ee/src/services/templates/send_email.html new file mode 100644 index 0000000000..7d124ffd8a --- /dev/null +++ b/api/ee/src/services/templates/send_email.html @@ -0,0 +1,7 @@ +

Hello,

+

+ {username_placeholder} has {action_placeholder} {workspace_placeholder} on + Agenta. +

+

{call_to_action}

+

Thank you for using Agenta!

diff --git a/api/ee/src/services/utils.py b/api/ee/src/services/utils.py new file mode 100644 index 0000000000..0eaedde4ff --- /dev/null +++ b/api/ee/src/services/utils.py @@ -0,0 +1,21 @@ +# Stdlib Imports +import asyncio +from functools import partial +from typing import Callable, Coroutine + + +async def run_in_separate_thread(func: Callable, *args, **kwargs) -> Coroutine: + """ + Run a synchronous function in a separate thread. + + Args: + func (callable): The synchronous function to be executed. + args (tuple): Positional arguments to be passed to `func`. + kwargs (dict): Keyword arguments to be passed to `func`. + + Returns: + The result of the synchronous function. + """ + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, partial(func, *args, **kwargs)) diff --git a/api/ee/src/services/workspace_manager.py b/api/ee/src/services/workspace_manager.py new file mode 100644 index 0000000000..d446729804 --- /dev/null +++ b/api/ee/src/services/workspace_manager.py @@ -0,0 +1,355 @@ +import asyncio + +from typing import List +from fastapi import HTTPException +from fastapi.responses import JSONResponse + +from oss.src.utils.logging import get_module_logger +from oss.src.services import db_manager +from ee.src.services import db_manager_ee, converters +from ee.src.models.db_models import ( + WorkspaceDB, + OrganizationDB, +) +from oss.src.models.db_models import UserDB +from ee.src.models.api.api_models import ( + InviteRequest, + ReseendInviteRequest, +) +from ee.src.models.api.workspace_models import ( + Permission, + WorkspaceRole, + WorkspaceResponse, + CreateWorkspace, + UpdateWorkspace, +) +from oss.src.models.db_models import InvitationDB +from oss.src.services.organization_service import ( + create_invitation, + check_existing_invitation, + check_valid_invitation, +) +from ee.src.services.organization_service import send_invitation_email + +log = get_module_logger(__name__) + + +async def get_workspace(workspace_id: str) -> WorkspaceDB: + """ + Get the workspace object based on the provided workspace ID. + + Parameters: + - workspace_id (str): The ID of the workspace. + + Returns: + - WorkspaceDB: The workspace object corresponding to the provided ID. + + Raises: + - HTTPException: If the workspace with the provided ID is not found. + + """ + + workspace = await db_manager.get_workspace(workspace_id) + if workspace is not None: + return workspace + raise HTTPException( + status_code=404, detail=f"Workspace by id {workspace_id} not found" + ) + + +async def create_new_workspace( + payload: CreateWorkspace, organization_id: str, user_uid: str +) -> WorkspaceResponse: + """ + Create a new workspace. + + Args: + payload (CreateWorkspace): The workspace payload. + organization_id (str): The organization id. + user_uid (str): The user uid. + + Returns: + WorkspaceResponse: The created workspace. + """ + + workspace = await db_manager_ee.create_workspace(payload, organization_id, user_uid) + return workspace + + +async def update_workspace( + payload: UpdateWorkspace, workspace_id: str +) -> WorkspaceResponse: + """ + Update a workspace's details. + + Args: + payload (UpdateWorkspace): The data to update the workspace with. + workspace_id (str): The ID of the workspace to update. + + Returns: + WorkspaceResponse: The updated workspace. + + Raises: + HTTPException: If the workspace with the given ID is not found. + """ + + workspace = await get_workspace(workspace_id) + if workspace is not None: + updated_workspace = await db_manager_ee.update_workspace(payload, workspace) + return updated_workspace + raise HTTPException( + status_code=404, detail=f"Workspace by id {workspace_id} not found" + ) + + +async def get_all_workspace_roles() -> List[WorkspaceRole]: + """ + Retrieve all workspace roles. + + Returns: + List[WorkspaceRole]: A list of all workspace roles in the DB. + """ + + workspace_roles_from_db = await db_manager_ee.get_all_workspace_roles() + return workspace_roles_from_db + + +async def get_all_workspace_permissions() -> List[Permission]: + """ + Retrieve all workspace permissions. + + Returns: + List[Permission]: A list of all workspace permissions in the DB. + """ + + workspace_permissions_from_db = await converters.get_all_workspace_permissions() + return workspace_permissions_from_db + + +async def invite_user_to_workspace( + payload: List[InviteRequest], + org_id: str, + project_id: str, + workspace_id: str, + user_uid: str, +) -> JSONResponse: + """ + Invite a user to a workspace. + + Args: + user_uid (str): The user uid. + org_id (str): The ID of the organization that the workspace belongs to. + project_id (str): The ID of the project that belongs to the workspace. + workspace_id (str): The ID of the workspace. + payload (InviteRequest): The payload containing the email address of the user to invite. + + Returns: + JSONResponse: The response containing the invitation details. + + Raises: + HTTPException: If there is an error retrieving the workspace. + """ + + try: + workspace = await get_workspace(workspace_id) + organization = await db_manager_ee.get_organization(org_id) + user_performing_action = await db_manager.get_user(user_uid) + + for payload_invite in payload: + # Check that the user is not inviting themselves + if payload_invite.email == user_performing_action.email: + return JSONResponse( + status_code=400, + content={"error": "You cannot invite yourself to a workspace"}, + ) + + # Check if the user is already a member of the workspace + if await db_manager_ee.check_user_in_workspace_with_email( + payload_invite.email, str(workspace.id) + ): + return JSONResponse( + status_code=400, + content={"error": "User is already a member of the workspace"}, + ) + + # Check if the email address already has a valid, unused invitation for the workspace + existing_invitation, existing_role = await check_existing_invitation( + project_id, payload_invite.email + ) + if not existing_invitation and not existing_role: + # Create a new invitation + invitation = await create_invitation( + payload_invite.roles[0], project_id, payload_invite.email + ) + + # Send the invitation email + send_email = await send_invitation_email( + payload_invite.email, + invitation.token, # type: ignore + project_id, + workspace, + organization, + user_performing_action, + ) + + if not send_email: + return JSONResponse( + {"detail": "Failed to invite user to organization"}, + status_code=400, + ) + else: + return JSONResponse( + status_code=200, + content={ + "message": "Invitation already exists", + }, + ) + + return JSONResponse( + {"message": "Invited users to organization"}, status_code=200 + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +async def resend_user_workspace_invite( + payload: ReseendInviteRequest, + project_id: str, + org_id: str, + workspace_id: str, + user_uid: str, +) -> JSONResponse: + """ + Resend an invitation to a user to a workspace. + + Args: + org_id (str): The ID of the organization that the workspace belongs to. + project_id (str): The ID of the project. + workspace_id (str): The ID of the workspace. + payload (ReseendInviteRequest): The payload containing the email address of the user to invite. + + Returns: + JSONResponse: The response containing the invitation details. + + Raises: + HTTPException: If there is an error retrieving the workspace. + """ + + try: + workspace = await get_workspace(workspace_id) + organization = await db_manager_ee.get_organization(org_id) + user_performing_action = await db_manager.get_user(user_uid) + + # Check if the email address already has a valid, unused invitation for the workspace + existing_invitation, existing_role = await check_existing_invitation( + project_id, payload.email + ) + if existing_invitation: + invitation = existing_invitation + elif existing_role: + # Create a new invitation + invitation = await create_invitation( + existing_role, project_id, payload.email + ) + else: + raise HTTPException( + status_code=404, + detail="No existing invitation found for the user", + ) + + # Send the invitation email + send_email = await send_invitation_email( + payload.email, + invitation.token, + project_id, + workspace, + organization, + user_performing_action, + ) + + if send_email: + return JSONResponse( + {"message": "Invited user to organization"}, status_code=200 + ) + else: + return JSONResponse( + {"detail": "Failed to invite user to organization"}, status_code=400 + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +async def accept_workspace_invitation( + token: str, + project_id: str, + organization: OrganizationDB, + workspace: WorkspaceDB, + user: UserDB, +) -> bool: + """ + Accept an invitation to a workspace. + + Args: + token (str): The invitation token. + project_id (str): The ID of the project. + organization_id (str): The ID of the organization that the workspace belongs to. + workspace_id (str): The ID of the workspace. + user_uid (str): The user uid. + + Returns: + bool: True if the user was successfully added to the workspace, False otherwise + + Raises: + HTTPException: If there is an error retrieving the workspace. + """ + + try: + # Check if the user is already a member of the workspace + if await db_manager_ee.check_user_in_workspace_with_email( + user.email, str(workspace.id) + ): + raise HTTPException( + status_code=409, + detail="User is already a member of the workspace", + ) + + invitation = await check_valid_invitation(project_id, user.email, token) + if invitation is not None: + assert ( + invitation.role is not None + ), "Invitation does not have any workspace role" + await db_manager_ee.add_user_to_workspace_and_org( + organization, workspace, user, project_id, invitation.role + ) + + await db_manager_ee.mark_invitation_as_used( + project_id, str(user.id), invitation + ) + return True + + else: + # Existing invitation is expired + raise Exception("Invitation has expired or does not exist") + except Exception as e: + raise e + + +async def remove_user_from_workspace( + workspace_id: str, + email: str, +) -> WorkspaceResponse: + """ + Remove a user from a workspace. + + Args: + workspace_id (str): The ID of the workspace. + payload (UserRole): The payload containing the user ID and role to remove. + + Returns: + WorkspaceResponse: The updated workspace. + """ + + remove_user = await db_manager_ee.remove_user_from_workspace(workspace_id, email) + return remove_user diff --git a/api/ee/src/tasks/__init__.py b/api/ee/src/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/tasks/evaluations/__init__.py b/api/ee/src/tasks/evaluations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/src/tasks/evaluations/batch.py b/api/ee/src/tasks/evaluations/batch.py new file mode 100644 index 0000000000..cf65107b6a --- /dev/null +++ b/api/ee/src/tasks/evaluations/batch.py @@ -0,0 +1,254 @@ +from typing import Dict, List, Optional +from uuid import UUID +import asyncio +import traceback +from json import dumps + +from celery import shared_task, states, Task + +from fastapi import Request + +from oss.src.utils.helpers import parse_url, get_slug_from_name_and_id +from oss.src.utils.logging import get_module_logger +from oss.src.services.auth_helper import sign_secret_token +from ee.src.services import llm_apps_service +from oss.src.models.shared_models import InvokationResult +from oss.src.services.db_manager import ( + fetch_app_by_id, + fetch_app_variant_by_id, + fetch_app_variant_revision_by_id, + get_deployment_by_id, + get_project_by_id, +) +from oss.src.core.secrets.utils import get_llm_providers_secrets +from ee.src.utils.entitlements import check_entitlements, Counter + +from oss.src.dbs.postgres.queries.dbes import ( + QueryArtifactDBE, + QueryVariantDBE, + QueryRevisionDBE, +) +from oss.src.dbs.postgres.testcases.dbes import ( + TestcaseBlobDBE, +) +from oss.src.dbs.postgres.testsets.dbes import ( + TestsetArtifactDBE, + TestsetVariantDBE, + TestsetRevisionDBE, +) +from oss.src.dbs.postgres.workflows.dbes import ( + WorkflowArtifactDBE, + WorkflowVariantDBE, + WorkflowRevisionDBE, +) + +from oss.src.dbs.postgres.tracing.dao import TracingDAO +from oss.src.dbs.postgres.blobs.dao import BlobsDAO +from oss.src.dbs.postgres.git.dao import GitDAO +from oss.src.dbs.postgres.evaluations.dao import EvaluationsDAO + +from oss.src.core.tracing.service import TracingService +from oss.src.core.queries.service import QueriesService +from oss.src.core.testcases.service import TestcasesService +from oss.src.core.testsets.service import TestsetsService +from oss.src.core.testsets.service import SimpleTestsetsService +from oss.src.core.workflows.service import WorkflowsService +from oss.src.core.evaluators.service import EvaluatorsService +from oss.src.core.evaluators.service import SimpleEvaluatorsService +from oss.src.core.evaluations.service import EvaluationsService +from oss.src.core.annotations.service import AnnotationsService + +# from oss.src.apis.fastapi.tracing.utils import make_hash_id +from oss.src.apis.fastapi.tracing.router import TracingRouter +from oss.src.apis.fastapi.testsets.router import SimpleTestsetsRouter +from oss.src.apis.fastapi.evaluators.router import SimpleEvaluatorsRouter +from oss.src.apis.fastapi.annotations.router import AnnotationsRouter + +from oss.src.core.annotations.types import ( + AnnotationOrigin, + AnnotationKind, + AnnotationChannel, +) +from oss.src.apis.fastapi.annotations.models import ( + AnnotationCreate, + AnnotationCreateRequest, +) + +from oss.src.core.evaluations.types import ( + EvaluationStatus, + EvaluationRun, + EvaluationRunCreate, + EvaluationRunEdit, + EvaluationScenarioCreate, + EvaluationScenarioEdit, + EvaluationResultCreate, + EvaluationMetricsCreate, +) + +from oss.src.core.shared.dtos import Reference +from oss.src.core.tracing.dtos import ( + Filtering, + Windowing, + Formatting, + Format, + Focus, + TracingQuery, +) +from oss.src.core.workflows.dtos import ( + WorkflowServiceData, + WorkflowServiceRequest, + WorkflowServiceResponse, + WorkflowServiceInterface, + WorkflowRevisionData, + WorkflowRevision, + WorkflowVariant, + Workflow, +) + +from oss.src.core.queries.dtos import ( + QueryRevision, + QueryVariant, + Query, +) + +from oss.src.core.workflows.dtos import Tree + +from oss.src.core.evaluations.utils import get_metrics_keys_from_schema + + +log = get_module_logger(__name__) + + +# DBS -------------------------------------------------------------------------- + +tracing_dao = TracingDAO() + +testcases_dao = BlobsDAO( + BlobDBE=TestcaseBlobDBE, +) + +queries_dao = GitDAO( + ArtifactDBE=QueryArtifactDBE, + VariantDBE=QueryVariantDBE, + RevisionDBE=QueryRevisionDBE, +) + +testsets_dao = GitDAO( + ArtifactDBE=TestsetArtifactDBE, + VariantDBE=TestsetVariantDBE, + RevisionDBE=TestsetRevisionDBE, +) + +workflows_dao = GitDAO( + ArtifactDBE=WorkflowArtifactDBE, + VariantDBE=WorkflowVariantDBE, + RevisionDBE=WorkflowRevisionDBE, +) + +evaluations_dao = EvaluationsDAO() + +# CORE ------------------------------------------------------------------------- + +tracing_service = TracingService( + tracing_dao=tracing_dao, +) + +queries_service = QueriesService( + queries_dao=queries_dao, +) + +testcases_service = TestcasesService( + testcases_dao=testcases_dao, +) + +testsets_service = TestsetsService( + testsets_dao=testsets_dao, + testcases_service=testcases_service, +) + +simple_testsets_service = SimpleTestsetsService( + testsets_service=testsets_service, +) + +testsets_service = TestsetsService( + testsets_dao=testsets_dao, + testcases_service=testcases_service, +) + +workflows_service = WorkflowsService( + workflows_dao=workflows_dao, +) + +evaluators_service = EvaluatorsService( + workflows_service=workflows_service, +) + +simple_evaluators_service = SimpleEvaluatorsService( + evaluators_service=evaluators_service, +) + +evaluations_service = EvaluationsService( + evaluations_dao=evaluations_dao, + tracing_service=tracing_service, + queries_service=queries_service, + testsets_service=testsets_service, + evaluators_service=evaluators_service, +) + +# APIS ------------------------------------------------------------------------- + +tracing_router = TracingRouter( + tracing_service=tracing_service, +) + +simple_testsets_router = SimpleTestsetsRouter( + simple_testsets_service=simple_testsets_service, +) # TODO: REMOVE/REPLACE ONCE TRANSFER IS MOVED TO 'core' + +simple_evaluators_router = SimpleEvaluatorsRouter( + simple_evaluators_service=simple_evaluators_service, +) # TODO: REMOVE/REPLACE ONCE TRANSFER IS MOVED TO 'core' + +annotations_service = AnnotationsService( + tracing_router=tracing_router, + evaluators_service=evaluators_service, + simple_evaluators_service=simple_evaluators_service, +) + +annotations_router = AnnotationsRouter( + annotations_service=annotations_service, +) # TODO: REMOVE/REPLACE ONCE ANNOTATE IS MOVED TO 'core' + +# ------------------------------------------------------------------------------ + + +@shared_task( + name="src.tasks.evaluations.batch.evaluate_testsets", + queue="src.tasks.evaluations.batch.evaluate_testsets", + bind=True, +) +def evaluate_testsets( + self, + *, + project_id: UUID, + user_id: UUID, + # + run_id: UUID, +): + pass + + +@shared_task( + name="src.tasks.evaluations.batch.evaluate_queries", + queue="src.tasks.evaluations.batch.evaluate_queries", + bind=True, +) +def evaluate_queries( + self: Task, + *, + project_id: UUID, + user_id: UUID, + # + run_id: UUID, +): + pass diff --git a/api/ee/src/tasks/evaluations/legacy.py b/api/ee/src/tasks/evaluations/legacy.py new file mode 100644 index 0000000000..50d211f713 --- /dev/null +++ b/api/ee/src/tasks/evaluations/legacy.py @@ -0,0 +1,1391 @@ +from typing import Dict, List, Optional +from uuid import UUID +from json import dumps +from asyncio import get_event_loop + +from celery import shared_task, states + +from fastapi import Request + +from oss.src.utils.helpers import parse_url, get_slug_from_name_and_id +from oss.src.utils.logging import get_module_logger +from oss.src.services.auth_helper import sign_secret_token +from ee.src.services import llm_apps_service +from oss.src.models.shared_models import InvokationResult +from oss.src.services.db_manager import ( + fetch_app_by_id, + fetch_app_variant_by_id, + fetch_app_variant_revision_by_id, + fetch_evaluator_config, + get_deployment_by_id, + get_project_by_id, +) +from oss.src.core.secrets.utils import get_llm_providers_secrets +from ee.src.utils.entitlements import check_entitlements, Counter + +from oss.src.dbs.postgres.queries.dbes import ( + QueryArtifactDBE, + QueryVariantDBE, + QueryRevisionDBE, +) +from oss.src.dbs.postgres.testcases.dbes import ( + TestcaseBlobDBE, +) +from oss.src.dbs.postgres.testsets.dbes import ( + TestsetArtifactDBE, + TestsetVariantDBE, + TestsetRevisionDBE, +) +from oss.src.dbs.postgres.workflows.dbes import ( + WorkflowArtifactDBE, + WorkflowVariantDBE, + WorkflowRevisionDBE, +) + +from oss.src.dbs.postgres.tracing.dao import TracingDAO +from oss.src.dbs.postgres.blobs.dao import BlobsDAO +from oss.src.dbs.postgres.git.dao import GitDAO +from oss.src.dbs.postgres.evaluations.dao import EvaluationsDAO + +from oss.src.core.tracing.service import TracingService +from oss.src.core.queries.service import QueriesService +from oss.src.core.testcases.service import TestcasesService +from oss.src.core.testsets.service import TestsetsService, SimpleTestsetsService +from oss.src.core.workflows.service import WorkflowsService +from oss.src.core.evaluators.service import EvaluatorsService +from oss.src.core.evaluators.service import SimpleEvaluatorsService +from oss.src.core.evaluations.service import EvaluationsService +from oss.src.core.annotations.service import AnnotationsService + +from oss.src.apis.fastapi.tracing.utils import make_hash_id +from oss.src.apis.fastapi.tracing.router import TracingRouter +from oss.src.apis.fastapi.testsets.router import SimpleTestsetsRouter +from oss.src.apis.fastapi.evaluators.router import SimpleEvaluatorsRouter +from oss.src.apis.fastapi.annotations.router import AnnotationsRouter + +from oss.src.core.annotations.types import ( + AnnotationOrigin, + AnnotationKind, + AnnotationChannel, +) +from oss.src.apis.fastapi.annotations.models import ( + AnnotationCreate, + AnnotationCreateRequest, +) + +from oss.src.core.evaluations.types import ( + EvaluationStatus, + EvaluationRunDataMappingStep, + EvaluationRunDataMappingColumn, + EvaluationRunDataMapping, + EvaluationRunDataStepInput, + EvaluationRunDataStep, + EvaluationRunData, + EvaluationRunFlags, + EvaluationRun, + EvaluationRunCreate, + EvaluationRunEdit, + EvaluationScenarioCreate, + EvaluationScenarioEdit, + EvaluationResultCreate, + EvaluationMetricsCreate, +) + +from oss.src.core.shared.dtos import Reference +from oss.src.core.workflows.dtos import ( + WorkflowServiceData, + WorkflowServiceRequest, + WorkflowServiceResponse, + WorkflowServiceInterface, + WorkflowRevisionData, + WorkflowRevision, + WorkflowVariant, + Workflow, +) + +from oss.src.core.queries.dtos import ( + QueryRevision, + QueryVariant, + Query, +) + +from oss.src.core.workflows.dtos import Tree + +from oss.src.core.evaluations.utils import get_metrics_keys_from_schema + + +log = get_module_logger(__name__) + + +# DBS -------------------------------------------------------------------------- + +tracing_dao = TracingDAO() + +testcases_dao = BlobsDAO( + BlobDBE=TestcaseBlobDBE, +) + +queries_dao = GitDAO( + ArtifactDBE=QueryArtifactDBE, + VariantDBE=QueryVariantDBE, + RevisionDBE=QueryRevisionDBE, +) + +testsets_dao = GitDAO( + ArtifactDBE=TestsetArtifactDBE, + VariantDBE=TestsetVariantDBE, + RevisionDBE=TestsetRevisionDBE, +) + +workflows_dao = GitDAO( + ArtifactDBE=WorkflowArtifactDBE, + VariantDBE=WorkflowVariantDBE, + RevisionDBE=WorkflowRevisionDBE, +) + +evaluations_dao = EvaluationsDAO() + +# CORE ------------------------------------------------------------------------- + +tracing_service = TracingService( + tracing_dao=tracing_dao, +) + +queries_service = QueriesService( + queries_dao=queries_dao, +) + +testcases_service = TestcasesService( + testcases_dao=testcases_dao, +) + +testsets_service = TestsetsService( + testsets_dao=testsets_dao, + testcases_service=testcases_service, +) + +simple_testsets_service = SimpleTestsetsService( + testsets_service=testsets_service, +) + +workflows_service = WorkflowsService( + workflows_dao=workflows_dao, +) + +evaluators_service = EvaluatorsService( + workflows_service=workflows_service, +) + +simple_evaluators_service = SimpleEvaluatorsService( + evaluators_service=evaluators_service, +) + +evaluations_service = EvaluationsService( + evaluations_dao=evaluations_dao, + tracing_service=tracing_service, + queries_service=queries_service, + testsets_service=testsets_service, + evaluators_service=evaluators_service, +) + +# APIS ------------------------------------------------------------------------- + +tracing_router = TracingRouter( + tracing_service=tracing_service, +) + +simple_testsets_router = SimpleTestsetsRouter( + simple_testsets_service=simple_testsets_service, +) # TODO: REMOVE/REPLACE ONCE TRANSFER IS MOVED TO 'core' + +simple_evaluators_router = SimpleEvaluatorsRouter( + simple_evaluators_service=simple_evaluators_service, +) # TODO: REMOVE/REPLACE ONCE TRANSFER IS MOVED TO 'core' + +annotations_service = AnnotationsService( + tracing_router=tracing_router, + evaluators_service=evaluators_service, + simple_evaluators_service=simple_evaluators_service, +) + +annotations_router = AnnotationsRouter( + annotations_service=annotations_service, +) # TODO: REMOVE/REPLACE ONCE ANNOTATE IS MOVED TO 'core' + +# ------------------------------------------------------------------------------ + + +async def setup_evaluation( + *, + project_id: UUID, + user_id: UUID, + # + name: Optional[str] = None, + description: Optional[str] = None, + # + testset_id: Optional[str] = None, + query_id: Optional[str] = None, + # + revision_id: Optional[str] = None, + # + autoeval_ids: Optional[List[str]] = None, +) -> Optional[EvaluationRun]: + request = Request(scope={"type": "http", "http_version": "1.1", "scheme": "http"}) + request.state.project_id = project_id + request.state.user_id = user_id + + run = None + + # -------------------------------------------------------------------------- + log.info("[SETUP] ", project_id=project_id, user_id=user_id) + log.info("[TESTSET] ", ids=[testset_id]) + log.info("[QUERY] ", ids=[query_id]) + log.info("[INVOCATON] ", ids=[revision_id]) + log.info("[ANNOTATION]", ids=autoeval_ids) + # -------------------------------------------------------------------------- + + try: + # create evaluation run ------------------------------------------------ + runs_create = [ + EvaluationRunCreate( + name=name, + description=description, + # + flags=( + EvaluationRunFlags( + is_closed=None, + is_live=True, + is_active=True, + ) + if query_id + else None + ), + # + status=EvaluationStatus.PENDING, + ) + ] + + runs = await evaluations_service.create_runs( + project_id=project_id, + user_id=user_id, + # + runs=runs_create, + ) + + assert len(runs) == 1, "Failed to create evaluation run." + + run = runs[0] + # ---------------------------------------------------------------------- + + # just-in-time transfer of testset ------------------------------------- + testset_input_steps_keys = list() + + testset_references = dict() + testset = None + + if testset_id: + testset_ref = Reference(id=UUID(testset_id)) + + testset_response = await simple_testsets_router.transfer_simple_testset( + request=request, + testset_id=UUID(testset_id), + ) + + assert ( + testset_response.count != 0 + ), f"Testset with id {testset_id} not found!" + + testset = testset_response.testset + testcases = testset.data.testcases + + testset_references["artifact"] = testset_ref + + testset_input_steps_keys.append( + get_slug_from_name_and_id(testset.name, testset.id) + ) + # ---------------------------------------------------------------------- + + # fetch query ---------------------------------------------------------- + query_input_steps_keys = list() + + query_references = dict() + query_revision = None + + if query_id: + query_ref = Reference(id=UUID(query_id)) + + query = await queries_service.fetch_query( + project_id=project_id, + # + query_ref=query_ref, + ) + + assert query is not None, f"Query with id {query_id} not found!" + + query_references["artifact"] = Reference( + id=query.id, + slug=query.slug, + ) + + query_revision = await queries_service.fetch_query_revision( + project_id=project_id, + # + query_ref=query_ref, + ) + + assert ( + query_revision is not None + ), f"Query revision with id {query_id} not found!" + + query_revision_ref = Reference( + id=query_revision.id, + slug=query_revision.slug, + ) + + query_references["revision"] = query_revision_ref + + query_variant = await queries_service.fetch_query_variant( + project_id=project_id, + query_variant_ref=Reference( + id=query_revision.variant_id, + ), + ) + + assert ( + query_variant is not None + ), f"Query variant with id {query_revision.variant_id} not found!" + + query_variant_ref = Reference( + id=query_variant.id, + slug=query_variant.slug, + ) + + query_references["variant"] = query_variant_ref + + query_input_steps_keys.append(query_revision.slug) + # ---------------------------------------------------------------------- + + # fetch application ---------------------------------------------------- + invocation_steps_keys = list() + + application_references = dict() + + if revision_id: + revision = await fetch_app_variant_revision_by_id(revision_id) + + assert ( + revision is not None + ), f"App revision with id {revision_id} not found!" + + application_references["revision"] = Reference( + id=UUID(str(revision.id)), + ) + + variant = await fetch_app_variant_by_id(str(revision.variant_id)) + + assert ( + variant is not None + ), f"App variant with id {revision.variant_id} not found!" + + application_references["variant"] = Reference( + id=UUID(str(variant.id)), + ) + + app = await fetch_app_by_id(str(variant.app_id)) + + assert app is not None, f"App with id {variant.app_id} not found!" + + application_references["artifact"] = Reference( + id=UUID(str(app.id)), + ) + + deployment = await get_deployment_by_id(str(revision.base.deployment_id)) + + assert ( + deployment is not None + ), f"Deployment with id {revision.base.deployment_id} not found!" + + uri = parse_url(url=deployment.uri) + + assert uri is not None, f"Invalid URI for deployment {deployment.id}!" + + revision_parameters = revision.config_parameters + + assert ( + revision_parameters is not None + ), f"Revision parameters for variant {variant.id} not found!" + + invocation_steps_keys.append( + get_slug_from_name_and_id(app.app_name, revision.id) + ) + # ---------------------------------------------------------------------- + + # fetch evaluators ----------------------------------------------------- + annotation_steps_keys = [] + + if autoeval_ids: + autoeval_configs = [] + + for autoeval_id in autoeval_ids: + autoeval_config = await fetch_evaluator_config(autoeval_id) + + autoeval_configs.append(autoeval_config) + + for autoeval_config in autoeval_configs: + annotation_steps_keys.append( + get_slug_from_name_and_id(autoeval_config.name, autoeval_config.id) + ) + # ---------------------------------------------------------------------- + + # just-in-time transfer of evaluators ---------------------------------- + annotation_metrics_keys = {key: {} for key in annotation_steps_keys} + evaluator_references = dict() + + for jdx, autoeval_id in enumerate(autoeval_ids): + annotation_step_key = annotation_steps_keys[jdx] + + evaluator_response = ( + await simple_evaluators_router.transfer_simple_evaluator( + request=request, + evaluator_id=UUID(autoeval_id), + ) + ) + + evaluator = evaluator_response.evaluator + + assert evaluator is not None, f"Evaluator with id {autoeval_id} not found!" + + evaluator_references[annotation_step_key] = {} + + evaluator_references[annotation_step_key]["artifact"] = Reference( + id=evaluator.id, + slug=evaluator.slug, + ) + + metrics_keys = get_metrics_keys_from_schema( + schema=(evaluator.data.schemas.get("outputs")), + ) + + annotation_metrics_keys[annotation_step_key] = [ + { + "path": metric_key.get("path", "").replace("outputs.", "", 1), + "type": metric_key.get("type", ""), + } + for metric_key in metrics_keys + ] + # ---------------------------------------------------------------------- + + # fetch evaluator workflows -------------------------------------------- + evaluators = dict() + + for annotation_step_key, references in evaluator_references.items(): + evaluators[annotation_step_key] = {} + + workflow_ref = references["artifact"] + + workflow = await workflows_service.fetch_workflow( + project_id=project_id, + # + workflow_ref=workflow_ref, + ) + + evaluators[annotation_step_key]["workflow"] = workflow + + workflow_revision = await workflows_service.fetch_workflow_revision( + project_id=project_id, + # + workflow_ref=workflow_ref, + ) + + assert ( + workflow_revision is not None + ), f"Workflow revision with id {workflow_ref.id} not found!" + + workflow_revision_ref = Reference( + id=workflow_revision.id, + slug=workflow_revision.slug, + ) + + evaluator_references[annotation_step_key][ + "revision" + ] = workflow_revision_ref + + evaluators[annotation_step_key]["revision"] = workflow_revision + + workflow_variant = await workflows_service.fetch_workflow_variant( + project_id=project_id, + workflow_variant_ref=Reference( + id=workflow_revision.variant_id, + ), + ) + + assert ( + workflow_variant is not None + ), f"Workflow variant with id {workflow_revision.variant_id} not found!" + + workflow_variant_ref = Reference( + id=workflow_variant.id, + slug=workflow_variant.slug, + ) + + evaluator_references[annotation_step_key]["variant"] = workflow_variant_ref + + evaluators[annotation_step_key]["variant"] = workflow_variant + + # ---------------------------------------------------------------------- + + # initialize steps/mappings in run ------------------------------------- + testset_input_step = ( + EvaluationRunDataStep( + key=testset_input_steps_keys[0], + type="input", + origin="auto", + references={ + "testset": testset_references["artifact"], + # "testset_variant": + # "testset_revision": + }, + ) + if testset and testset.id + else None + ) + + query_input_step = ( + EvaluationRunDataStep( + key=query_input_steps_keys[0], + type="input", + origin="auto", + references={ + "query": query_references["artifact"], + "query_variant": query_references["variant"], + "query_revision": query_references["revision"], + }, + ) + if query_id + else None + ) + + invocation_step = ( + EvaluationRunDataStep( + key=invocation_steps_keys[0], + type="invocation", + origin="auto", + references={ + "application": application_references["artifact"], + "application_variant": application_references["variant"], + "application_revision": application_references["revision"], + }, + inputs=[ + EvaluationRunDataStepInput( + key=testset_input_steps_keys[0], + ), + ], + ) + if revision_id + else None + ) + + annotation_steps = [ + EvaluationRunDataStep( + key=step_key, + type="annotation", + origin="auto", + references={ + "evaluator": evaluator_references[step_key]["artifact"], + "evaluator_variant": evaluator_references[step_key]["variant"], + "evaluator_revision": evaluator_references[step_key]["revision"], + }, + inputs=( + [ + EvaluationRunDataStepInput( + key=testset_input_steps_keys[0], + ), + EvaluationRunDataStepInput( + key=invocation_steps_keys[0], + ), + ] + if testset_id and revision_id + else [ + EvaluationRunDataStepInput( + key=query_input_steps_keys[0], + ), + ] + ), + ) + for step_key in annotation_steps_keys + ] + + steps: List[EvaluationRunDataStep] = list() + + if testset_id and testset_input_step: + steps.append(testset_input_step) + if query_id and query_input_step: + steps.append(query_input_step) + if revision_id and invocation_step: + steps.append(invocation_step) + + steps.extend(annotation_steps) + + testset_input_mappings = ( + [ + EvaluationRunDataMapping( + column=EvaluationRunDataMappingColumn( + kind="testset", + name=key, + ), + step=EvaluationRunDataMappingStep( + key=testset_input_steps_keys[0], + path=f"data.{key}", + ), + ) + for key in testcases[0].data.keys() + ] + if testset_id + else [] + ) + + query_input_mappings = ( + [ + EvaluationRunDataMapping( + column=EvaluationRunDataMappingColumn( + kind="query", + name="data", + ), + step=EvaluationRunDataMappingStep( + key=query_input_steps_keys[0], + path="attributes.ag.data", + ), + ) + ] + if query_id + else [] + ) + + invocation_mappings = ( + [ + EvaluationRunDataMapping( + column=EvaluationRunDataMappingColumn( + kind="invocation", + name="outputs", + ), + step=EvaluationRunDataMappingStep( + key=step_key, + path="attributes.ag.data.outputs", + ), + ) + for step_key in invocation_steps_keys + ] + if invocation_steps_keys + else [] + ) + + annotation_mappings = [ + EvaluationRunDataMapping( + column=EvaluationRunDataMappingColumn( + kind="annotation", + name=metric_key["path"], + ), + step=EvaluationRunDataMappingStep( + key=step_key, + path=f"attributes.ag.data.outputs{'.' + metric_key['path'] if metric_key['path'] else ''}", + ), + ) + for step_key in annotation_steps_keys + for metric_key in annotation_metrics_keys[step_key] + ] + + mappings: List[EvaluationRunDataMapping] = ( + testset_input_mappings + + query_input_mappings + + invocation_mappings + + annotation_mappings + ) + + run_edit = EvaluationRunEdit( + id=run.id, + # + name=run.name, + description=run.description, + # + flags=run.flags, + tags=run.tags, + meta=run.meta, + # + status=EvaluationStatus.RUNNING, + # + data=EvaluationRunData( + steps=steps, + mappings=mappings, + ), + ) + + run = await evaluations_service.edit_run( + project_id=project_id, + user_id=user_id, + # + run=run_edit, + ) + + assert run, f"Failed to edit evaluation run {run_edit.id}!" + # ---------------------------------------------------------------------- + + log.info("[DONE] ", run_id=run.id) + + except: # pylint: disable=bare-except + if run and run.id: + log.error("[FAIL] ", run_id=run.id, exc_info=True) + + await evaluations_service.delete_run( + project_id=project_id, + # + run_id=run.id, + ) + else: + log.error("[FAIL]", exc_info=True) + + run = None + + return run + + +@shared_task( + name="src.tasks.evaluations.legacy.annotate", + queue="src.tasks.evaluations.legacy.annotate", + bind=True, +) +def annotate( + self, + *, + project_id: UUID, + user_id: UUID, + # + run_id: UUID, + # + testset_id: str, + revision_id: str, + autoeval_ids: Optional[List[str]], + # + run_config: Dict[str, int], +): + """ + Annotates an application revision applied to a testset using auto evaluator(s). + + Args: + self: The task instance. + project_id (str): The ID of the project. + user_id (str): The ID of the user. + run_id (str): The ID of the evaluation run. + testset_id (str): The ID of the testset. + revision_id (str): The ID of the application revision. + autoeval_ids (List[str]): The IDs of the evaluators configurations. + run_config (Dict[str, int]): Configuration for evaluation run. + + Returns: + None + """ + request = Request( + scope={ + "type": "http", + "http_version": "1.1", + "scheme": "http", + } + ) + request.state.project_id = str(project_id) + request.state.user_id = str(user_id) + + loop = get_event_loop() + + run = None + + try: + # ---------------------------------------------------------------------- + log.info("[SCOPE] ", run_id=run_id, project_id=project_id, user_id=user_id) + log.info("[TESTSET] ", run_id=run_id, ids=[testset_id]) + log.info("[INVOCATON] ", run_id=run_id, ids=[revision_id]) + log.info("[ANNOTATION]", run_id=run_id, ids=autoeval_ids) + # ---------------------------------------------------------------------- + + # fetch project -------------------------------------------------------- + project = loop.run_until_complete( + get_project_by_id( + project_id=str(project_id), + ), + ) + # ---------------------------------------------------------------------- + + # fetch secrets -------------------------------------------------------- + secrets = loop.run_until_complete( + get_llm_providers_secrets( + project_id=str(project_id), + ), + ) + # ---------------------------------------------------------------------- + + # prepare credentials -------------------------------------------------- + secret_token = loop.run_until_complete( + sign_secret_token( + user_id=str(user_id), + project_id=str(project_id), + workspace_id=str(project.workspace_id), + organization_id=str(project.organization_id), + ) + ) + + credentials = f"Secret {secret_token}" + # ---------------------------------------------------------------------- + + # fetch run ------------------------------------------------------------ + run = loop.run_until_complete( + evaluations_service.fetch_run( + project_id=project_id, + # + run_id=run_id, + ) + ) + + assert run, f"Evaluation run with id {run_id} not found!" + + assert run.data, f"Evaluation run with id {run_id} has no data!" + + assert run.data.steps, f"Evaluation run with id {run_id} has no steps!" + + steps = run.data.steps + + invocation_steps = [step for step in steps if step.type == "invocation"] + annotation_steps = [step for step in steps if step.type == "annotation"] + + invocation_steps_keys = [step.key for step in invocation_steps] + annotation_steps_keys = [step.key for step in annotation_steps] + + nof_annotations = len(annotation_steps) + # ---------------------------------------------------------------------- + + # fetch testset -------------------------------------------------------- + testset_response = loop.run_until_complete( + simple_testsets_router.fetch_simple_testset( + request=request, + testset_id=testset_id, + ) + ) + + assert testset_response.count != 0, f"Testset with id {testset_id} not found!" + + testset = testset_response.testset + + testcases = testset.data.testcases + testcases_data = [ + {**testcase.data, "id": str(testcase.id)} for testcase in testcases + ] # INEFFICIENT: might want to have testcase_id in testset data (caution with hashing) + nof_testcases = len(testcases) + + testset_step_key = get_slug_from_name_and_id(testset.name, testset.id) + # ---------------------------------------------------------------------- + + # fetch application ---------------------------------------------------- + revision = loop.run_until_complete( + fetch_app_variant_revision_by_id(revision_id), + ) + + assert revision is not None, f"App revision with id {revision_id} not found!" + + variant = loop.run_until_complete( + fetch_app_variant_by_id(str(revision.variant_id)), + ) + + assert ( + variant is not None + ), f"App variant with id {revision.variant_id} not found!" + + app = loop.run_until_complete( + fetch_app_by_id(str(variant.app_id)), + ) + + assert app is not None, f"App with id {variant.app_id} not found!" + + deployment = loop.run_until_complete( + get_deployment_by_id(str(revision.base.deployment_id)), + ) + + assert ( + deployment is not None + ), f"Deployment with id {revision.base.deployment_id} not found!" + + uri = parse_url(url=deployment.uri) + + assert uri is not None, f"Invalid URI for deployment {deployment.id}!" + + revision_parameters = revision.config_parameters + + assert ( + revision_parameters is not None + ), f"Revision parameters for variant {variant.id} not found!" + # ---------------------------------------------------------------------- + + # fetch evaluators ----------------------------------------------------- + evaluator_references = {step.key: step.references for step in annotation_steps} + + evaluators = { + evaluator_key: loop.run_until_complete( + workflows_service.fetch_workflow_revision( + project_id=project_id, + # + workflow_revision_ref=evaluator_refs.get("evaluator_revision"), + ) + ) + for evaluator_key, evaluator_refs in evaluator_references.items() + } + # ---------------------------------------------------------------------- + + # prepare headers ------------------------------------------------------ + headers = {} + if credentials: + headers = {"Authorization": credentials} + headers["ngrok-skip-browser-warning"] = "1" + + openapi_parameters = None + max_recursive_depth = 5 + runtime_prefix = uri + route_path = "" + + while max_recursive_depth > 0 and not openapi_parameters: + try: + openapi_parameters = loop.run_until_complete( + llm_apps_service.get_parameters_from_openapi( + runtime_prefix + "/openapi.json", + route_path, + headers, + ), + ) + except Exception: # pylint: disable=broad-exception-caught + openapi_parameters = None + + if not openapi_parameters: + max_recursive_depth -= 1 + if not runtime_prefix.endswith("/"): + route_path = "/" + runtime_prefix.split("/")[-1] + route_path + runtime_prefix = "/".join(runtime_prefix.split("/")[:-1]) + else: + route_path = "" + runtime_prefix = runtime_prefix[:-1] + + openapi_parameters = loop.run_until_complete( + llm_apps_service.get_parameters_from_openapi( + runtime_prefix + "/openapi.json", + route_path, + headers, + ), + ) + # ---------------------------------------------------------------------- + + # create scenarios ----------------------------------------------------- + scenarios_create = [ + EvaluationScenarioCreate( + run_id=run_id, + # + status=EvaluationStatus.RUNNING, + ) + for _ in range(nof_testcases) + ] + + scenarios = loop.run_until_complete( + evaluations_service.create_scenarios( + project_id=project_id, + user_id=user_id, + # + scenarios=scenarios_create, + ) + ) + + assert ( + len(scenarios) == nof_testcases + ), f"Failed to create evaluation scenarios for run {run_id}!" + # ---------------------------------------------------------------------- + + # create input steps --------------------------------------------------- + steps_create = [ + EvaluationResultCreate( + run_id=run_id, + scenario_id=scenario.id, + step_key=testset_step_key, + # + status=EvaluationStatus.SUCCESS, + # + testcase_id=testcases[idx].id, + ) + for idx, scenario in enumerate(scenarios) + ] + + steps = loop.run_until_complete( + evaluations_service.create_results( + project_id=project_id, + user_id=user_id, + # + results=steps_create, + ) + ) + + assert ( + len(steps) == nof_testcases + ), f"Failed to create evaluation steps for run {run_id}!" + # ---------------------------------------------------------------------- + + # flatten testcases ---------------------------------------------------- + _testcases = [testcase.model_dump(mode="json") for testcase in testcases] + + log.info( + "[BATCH] ", + run_id=run_id, + ids=[testset_id], + count=len(_testcases), + size=len(dumps(_testcases).encode("utf-8")), + ) + # ---------------------------------------------------------------------- + + # invoke application --------------------------------------------------- + invocations: List[InvokationResult] = loop.run_until_complete( + llm_apps_service.batch_invoke( + project_id=str(project_id), + user_id=str(user_id), + testset_data=testcases_data, # type: ignore + parameters=revision_parameters, # type: ignore + uri=uri, + rate_limit_config=run_config, + application_id=str(app.id), # DO NOT REMOVE + references={ + "testset": {"id": testset_id}, + "application": {"id": str(app.id)}, + "application_variant": {"id": str(variant.id)}, + "application_revision": {"id": str(revision.id)}, + }, + ) + ) + # ---------------------------------------------------------------------- + + # create invocation steps ---------------------------------------------- + steps_create = [ + EvaluationResultCreate( + run_id=run_id, + scenario_id=scenario.id, + step_key=invocation_steps_keys[0], + # + status=( + EvaluationStatus.SUCCESS + if not invocations[idx].result.error + else EvaluationStatus.FAILURE + ), + # + trace_id=invocations[idx].trace_id, + error=( + invocations[idx].result.error.model_dump(mode="json") + if invocations[idx].result.error + else None + ), + ) + for idx, scenario in enumerate(scenarios) + ] + + steps = loop.run_until_complete( + evaluations_service.create_results( + project_id=project_id, + user_id=user_id, + # + results=steps_create, + ) + ) + + assert ( + len(steps) == nof_testcases + ), f"Failed to create evaluation steps for run {run_id}!" + # ---------------------------------------------------------------------- + + run_has_errors = 0 + run_status = EvaluationStatus.SUCCESS + + # run evaluators ------------------------------------------------------- + for idx in range(nof_testcases): + scenario = scenarios[idx] + testcase = testcases[idx] + invocation = invocations[idx] + + scenario_has_errors = 0 + scenario_status = EvaluationStatus.SUCCESS + + # skip the iteration if error in the invocation -------------------- + if invocation.result.error: + log.error( + f"There is an error in invocation {invocation.trace_id} so we skip its evaluation" + ) + + scenario_has_errors += 1 + run_has_errors += 1 + scenario_status = EvaluationStatus.ERRORS + run_status = EvaluationStatus.ERRORS + + error = invocation.result.error.model_dump(mode="json") is not None + # ------------------------------------------------------------------ + + # proceed with the evaluation otherwise ---------------------------- + else: + # run the evaluators if no error in the invocation ------------- + for jdx in range(nof_annotations): + annotation_step_key = annotation_steps_keys[jdx] + + step_has_errors = 0 + step_status = EvaluationStatus.SUCCESS + + references = { + **evaluator_references[annotation_step_key], + "testset": {"id": testset_id}, + "testcase": {"id": str(testcase.id)}, + } + links = { + invocation_steps_keys[0]: { + "trace_id": invocation.trace_id, + "span_id": invocation.span_id, + } + } + + # invoke annotation workflow ------------------------------- + workflow_revision = evaluators[annotation_step_key] + + workflows_service_request = WorkflowServiceRequest( + version="2025.07.14", + flags={ + "is_annotation": True, + "inline": True, + }, + tags=None, + meta=None, + data=WorkflowServiceData( + inputs=testcase.data, + # trace= + trace_parameters=revision_parameters, + trace_outputs=invocation.result.value["data"], + tree=( + Tree( + version=invocation.result.value.get("version"), + nodes=invocation.result.value["tree"].get("nodes"), + ) + if "tree" in invocation.result.value + else None + ), + ), + references=references, + links=links, + credentials=credentials, + secrets=secrets, + ) + + workflows_service_response = loop.run_until_complete( + workflows_service.invoke_workflow( + project_id=project_id, + user_id=user_id, + # + request=workflows_service_request, + revision=workflow_revision, + ) + ) + # ---------------------------------------------------------- + + # run evaluator -------------------------------------------- + trace_id = None + error = None + + has_error = workflows_service_response.status.code != 200 + + # if error in evaluator, no annotation, only step ---------- + if has_error: + log.warn( + f"There is an error in annotation {annotation_step_key} for invocation {invocation.trace_id}." + ) + + step_has_errors += 1 + scenario_has_errors += 1 + run_has_errors += 1 + step_status = EvaluationStatus.FAILURE + scenario_status = EvaluationStatus.ERRORS + run_status = EvaluationStatus.ERRORS + + error = workflows_service_response.status.model_dump( + mode="json" + ) + + # ---------------------------------------------------------- + + # else, first annotation, then step ------------------------ + else: + outputs = workflows_service_response.data.outputs or {} + + annotation_create_request = AnnotationCreateRequest( + annotation=AnnotationCreate( + origin=AnnotationOrigin.AUTO, + kind=AnnotationKind.EVAL, + channel=AnnotationChannel.API, # hardcoded + # + data={"outputs": outputs}, + # + references=references, + links=links, + ) + ) + + annotation_response = loop.run_until_complete( + annotations_router.create_annotation( + request=request, + annotation_create_request=annotation_create_request, + ) + ) + + assert ( + annotation_response.count != 0 + ), f"Failed to create annotation for invocation {invocation.trace_id} and evaluator {references.get('evaluator').get('id')}" + + trace_id = annotation_response.annotation.trace_id + # ---------------------------------------------------------- + + steps_create = [ + EvaluationResultCreate( + run_id=run_id, + scenario_id=scenario.id, + step_key=annotation_step_key, + # + status=step_status, + # + trace_id=trace_id, + error=error, + ) + ] + + steps = loop.run_until_complete( + evaluations_service.create_results( + project_id=project_id, + user_id=user_id, + # + results=steps_create, + ) + ) + + assert ( + len(steps) == 1 + ), f"Failed to create evaluation step for scenario with id {scenario.id}!" + # ------------------------------------------------------------------ + + scenario_edit = EvaluationScenarioEdit( + id=scenario.id, + tags=scenario.tags, + meta=scenario.meta, + status=scenario_status, + ) + + scenario = loop.run_until_complete( + evaluations_service.edit_scenario( + project_id=project_id, + user_id=user_id, + # + scenario=scenario_edit, + ) + ) + + assert ( + scenario + ), f"Failed to edit evaluation scenario with id {scenario.id}!" + + if scenario_status != EvaluationStatus.FAILURE: + try: + metrics = loop.run_until_complete( + evaluations_service.refresh_metrics( + project_id=project_id, + user_id=user_id, + # + run_id=run_id, + scenario_id=scenario.id, + ) + ) + + if not metrics: + log.warning( + f"Refreshing metrics failed for {run_id} | {scenario.id}" + ) + + except Exception as e: + log.warning( + f"Refreshing metrics failed for {run_id} | {scenario.id}", + exc_info=True, + ) + # ---------------------------------------------------------------------- + + except Exception as e: # pylint: disable=broad-exception-caught + log.error( + f"An error occurred during evaluation: {e}", + exc_info=True, + ) + + self.update_state(state=states.FAILURE) + + run_status = EvaluationStatus.FAILURE + + if not run: + log.info("[FAIL] ", run_id=run_id, project_id=project_id, user_id=user_id) + + if run_status != EvaluationStatus.FAILURE: + try: + metrics = loop.run_until_complete( + evaluations_service.refresh_metrics( + project_id=project_id, + user_id=user_id, + # + run_id=run_id, + ) + ) + + if not metrics: + log.warning(f"Refreshing metrics failed for {run_id}") + + self.update_state(state=states.FAILURE) + + run_status = EvaluationStatus.FAILURE + + except Exception as e: # pylint: disable=broad-exception-caught + log.warning(f"Refreshing metrics failed for {run_id}", exc_info=True) + + self.update_state(state=states.FAILURE) + + run_status = EvaluationStatus.FAILURE + + # edit evaluation run status ----------------------------------------------- + run_edit = EvaluationRunEdit( + id=run_id, + # + name=run.name, + description=run.description, + # + tags=run.tags, + meta=run.meta, + # + status=run_status, + # + data=run.data, + ) + + loop.run_until_complete( + evaluations_service.edit_run( + project_id=project_id, + user_id=user_id, + # + run=run_edit, + ) + ) + + # edit meters to avoid conting failed evaluations -------------------------- + if run_status == EvaluationStatus.FAILURE: + loop.run_until_complete( + check_entitlements( + organization_id=project.organization_id, + key=Counter.EVALUATIONS, + delta=-1, + ) + ) + + log.info("[DONE] ", run_id=run_id, project_id=project_id, user_id=user_id) + + return diff --git a/api/ee/src/tasks/evaluations/live.py b/api/ee/src/tasks/evaluations/live.py new file mode 100644 index 0000000000..0095206d42 --- /dev/null +++ b/api/ee/src/tasks/evaluations/live.py @@ -0,0 +1,771 @@ +from typing import List, Dict, Any +from uuid import UUID +import asyncio +from datetime import datetime + +from celery import shared_task +from fastapi import Request + +from oss.src.utils.logging import get_module_logger +from oss.src.services.auth_helper import sign_secret_token +from oss.src.services.db_manager import get_project_by_id +from oss.src.core.secrets.utils import get_llm_providers_secrets + +from oss.src.dbs.postgres.queries.dbes import ( + QueryArtifactDBE, + QueryVariantDBE, + QueryRevisionDBE, +) +from oss.src.dbs.postgres.testcases.dbes import ( + TestcaseBlobDBE, +) +from oss.src.dbs.postgres.testsets.dbes import ( + TestsetArtifactDBE, + TestsetVariantDBE, + TestsetRevisionDBE, +) +from oss.src.dbs.postgres.workflows.dbes import ( + WorkflowArtifactDBE, + WorkflowVariantDBE, + WorkflowRevisionDBE, +) + +from oss.src.dbs.postgres.tracing.dao import TracingDAO +from oss.src.dbs.postgres.blobs.dao import BlobsDAO +from oss.src.dbs.postgres.git.dao import GitDAO +from oss.src.dbs.postgres.evaluations.dao import EvaluationsDAO + +from oss.src.core.tracing.service import TracingService +from oss.src.core.queries.service import QueriesService +from oss.src.core.testcases.service import TestcasesService +from oss.src.core.testsets.service import TestsetsService +from oss.src.core.testsets.service import SimpleTestsetsService +from oss.src.core.workflows.service import WorkflowsService +from oss.src.core.evaluators.service import EvaluatorsService +from oss.src.core.evaluators.service import SimpleEvaluatorsService +from oss.src.core.evaluations.service import EvaluationsService +from oss.src.core.annotations.service import AnnotationsService + +# from oss.src.apis.fastapi.tracing.utils import make_hash_id +from oss.src.apis.fastapi.tracing.router import TracingRouter +from oss.src.apis.fastapi.annotations.router import AnnotationsRouter + +from oss.src.core.annotations.types import ( + AnnotationOrigin, + AnnotationKind, + AnnotationChannel, +) +from oss.src.apis.fastapi.annotations.models import ( + AnnotationCreate, + AnnotationCreateRequest, +) + +from oss.src.core.evaluations.types import ( + EvaluationStatus, + EvaluationScenarioCreate, + EvaluationScenarioEdit, + EvaluationResultCreate, +) +from oss.src.core.shared.dtos import ( + Reference, + Link, +) +from oss.src.core.tracing.dtos import ( + Filtering, + Windowing, + Formatting, + Format, + Focus, + TracingQuery, + OTelSpansTree as Trace, + LogicalOperator, + SimpleTraceReferences, +) +from oss.src.core.workflows.dtos import ( + WorkflowServiceData, + WorkflowServiceRequest, +) +from oss.src.core.queries.dtos import ( + QueryRevision, +) +from oss.src.core.evaluators.dtos import ( + EvaluatorRevision, +) + +log = get_module_logger(__name__) + + +# DBS -------------------------------------------------------------------------- + +tracing_dao = TracingDAO() + +testcases_dao = BlobsDAO( + BlobDBE=TestcaseBlobDBE, +) + +queries_dao = GitDAO( + ArtifactDBE=QueryArtifactDBE, + VariantDBE=QueryVariantDBE, + RevisionDBE=QueryRevisionDBE, +) + +testsets_dao = GitDAO( + ArtifactDBE=TestsetArtifactDBE, + VariantDBE=TestsetVariantDBE, + RevisionDBE=TestsetRevisionDBE, +) + +workflows_dao = GitDAO( + ArtifactDBE=WorkflowArtifactDBE, + VariantDBE=WorkflowVariantDBE, + RevisionDBE=WorkflowRevisionDBE, +) + +evaluations_dao = EvaluationsDAO() + +# CORE ------------------------------------------------------------------------- + +tracing_service = TracingService( + tracing_dao=tracing_dao, +) + +queries_service = QueriesService( + queries_dao=queries_dao, +) + +testcases_service = TestcasesService( + testcases_dao=testcases_dao, +) + +testsets_service = TestsetsService( + testsets_dao=testsets_dao, + testcases_service=testcases_service, +) + +simple_testsets_service = SimpleTestsetsService( + testsets_service=testsets_service, +) + +workflows_service = WorkflowsService( + workflows_dao=workflows_dao, +) + +evaluators_service = EvaluatorsService( + workflows_service=workflows_service, +) + +simple_evaluators_service = SimpleEvaluatorsService( + evaluators_service=evaluators_service, +) + +evaluations_service = EvaluationsService( + evaluations_dao=evaluations_dao, + tracing_service=tracing_service, + queries_service=queries_service, + testsets_service=testsets_service, + evaluators_service=evaluators_service, +) + +# APIS ------------------------------------------------------------------------- + +tracing_router = TracingRouter( + tracing_service=tracing_service, +) + +annotations_service = AnnotationsService( + tracing_router=tracing_router, + evaluators_service=evaluators_service, + simple_evaluators_service=simple_evaluators_service, +) + +annotations_router = AnnotationsRouter( + annotations_service=annotations_service, +) # TODO: REMOVE/REPLACE ONCE ANNOTATE IS MOVED TO 'core' + +# ------------------------------------------------------------------------------ + + +@shared_task( + name="src.tasks.evaluations.live.evaluate", + queue="src.tasks.evaluations.live.evaluate", + bind=True, +) +def evaluate( + self, + *, + project_id: UUID, + user_id: UUID, + # + run_id: UUID, + # + newest: datetime, + oldest: datetime, +): + request = Request(scope={"type": "http", "http_version": "1.1", "scheme": "http"}) + + request.state.project_id = str(project_id) + request.state.user_id = str(user_id) + + loop = asyncio.get_event_loop() + + # count in minutes + timestamp = oldest + interval = int((newest - oldest).total_seconds() / 60) + + try: + # ---------------------------------------------------------------------- + log.info( + "[SCOPE] ", + run_id=run_id, + project_id=project_id, + user_id=user_id, + ) + + log.info( + "[RANGE] ", + run_id=run_id, + timestamp=timestamp, + interval=interval, + ) + # ---------------------------------------------------------------------- + + # fetch project -------------------------------------------------------- + project = loop.run_until_complete( + get_project_by_id(project_id=str(project_id)), + ) + # ---------------------------------------------------------------------- + + # fetch provider keys from secrets ------------------------------------- + secrets = loop.run_until_complete( + get_llm_providers_secrets(str(project_id)), + ) + # ---------------------------------------------------------------------- + + # prepare credentials -------------------------------------------------- + secret_token = loop.run_until_complete( + sign_secret_token( + user_id=str(user_id), + project_id=str(project_id), + workspace_id=str(project.workspace_id), + organization_id=str(project.organization_id), + ) + ) + + credentials = f"Secret {secret_token}" + # ---------------------------------------------------------------------- + + # fetch evaluation run ------------------------------------------------- + run = loop.run_until_complete( + evaluations_service.fetch_run( + project_id=project_id, + run_id=run_id, + ) + ) + + assert run, f"Evaluation run with id {run_id} not found!" + + assert run.data, f"Evaluation run with id {run_id} has no data!" + + assert run.data.steps, f"Evaluation run with id {run_id} has no steps!" + + steps = run.data.steps + + input_steps = { + step.key: step for step in steps if step.type == "input" # -------- + } + invocation_steps = { + step.key: step for step in steps if step.type == "invocation" + } + annotation_steps = { + step.key: step for step in steps if step.type == "annotation" + } + + input_steps_keys = list(input_steps.keys()) + invocation_steps_keys = list(invocation_steps.keys()) + annotation_steps_keys = list(annotation_steps.keys()) + + nof_inputs = len(input_steps_keys) + nof_invocations = len(invocation_steps_keys) + nof_annotations = len(annotation_steps_keys) + # ---------------------------------------------------------------------- + + # initialize query variables ------------------------------------------- + query_revision_refs: Dict[str, Reference] = dict() + # + query_revisions: Dict[str, QueryRevision] = dict() + query_references: Dict[str, Dict[str, Reference]] = dict() + # + query_traces: Dict[str, Dict[str, Trace]] = dict() + # ---------------------------------------------------------------------- + + # initialize evaluator variables --------------------------------------- + evaluator_revision_refs: Dict[str, Reference] = dict() + # + evaluator_revisions: Dict[str, EvaluatorRevision] = dict() + evaluator_references: Dict[str, Dict[str, Reference]] = dict() + # ---------------------------------------------------------------------- + + # get query steps references ------------------------------------------- + for input_step_key in input_steps_keys: + query_refs = input_steps[input_step_key].references + query_revision_ref = query_refs.get("query_revision") + + if query_revision_ref: + query_revision_refs[input_step_key] = query_revision_ref + + # ---------------------------------------------------------------------- + + # get evaluator steps references --------------------------------------- + for annotation_step_key in annotation_steps_keys: + evaluator_refs = annotation_steps[annotation_step_key].references + evaluator_revision_ref = evaluator_refs.get("evaluator_revision") + + if evaluator_revision_ref: + evaluator_revision_refs[annotation_step_key] = evaluator_revision_ref + # ---------------------------------------------------------------------- + + # fetch query revisions ------------------------------------------------ + for ( + query_step_key, + query_revision_ref, + ) in query_revision_refs.items(): + query_revision = loop.run_until_complete( + queries_service.fetch_query_revision( + project_id=project_id, + # + query_revision_ref=query_revision_ref, + ) + ) + + if ( + not query_revision + or not query_revision.id + or not query_revision.slug + or not query_revision.data + ): + log.warn( + f"Query revision with ref {query_revision_ref.model_dump(mode='json')} not found!" + ) + continue + + query_step = input_steps[query_step_key] + + query_revisions[query_step_key] = query_revision + query_references[query_step_key] = query_step.references + # ---------------------------------------------------------------------- + + # fetch evaluator revisions -------------------------------------------- + for ( + evaluator_step_key, + evaluator_revision_ref, + ) in evaluator_revision_refs.items(): + evaluator_revision = loop.run_until_complete( + evaluators_service.fetch_evaluator_revision( + project_id=project_id, + # + evaluator_revision_ref=evaluator_revision_ref, + ) + ) + + if ( + not evaluator_revision + or not evaluator_revision.id + or not evaluator_revision.slug + or not evaluator_revision.data + ): + log.warn( + f"Evaluator revision with ref {evaluator_revision_ref.model_dump(mode='json')} not found!" + ) + continue + + evaluator_step = annotation_steps[evaluator_step_key] + + evaluator_revisions[evaluator_step_key] = evaluator_revision + evaluator_references[evaluator_step_key] = evaluator_step.references + # ---------------------------------------------------------------------- + + # run query revisions -------------------------------------------------- + for query_step_key, query_revision in query_revisions.items(): + formatting = Formatting( + focus=Focus.TRACE, + format=Format.AGENTA, + ) + filtering = Filtering( + operator=LogicalOperator.AND, + conditions=list(), + ) + windowing = Windowing( + oldest=oldest, + newest=newest, + next=None, + limit=None, + order="ascending", + interval=None, + rate=None, + ) + + if query_revision.data: + if query_revision.data.filtering: + filtering = query_revision.data.filtering + + if query_revision.data.windowing: + windowing.rate = query_revision.data.windowing.rate + + query = TracingQuery( + formatting=formatting, + filtering=filtering, + windowing=windowing, + ) + + tracing_response = loop.run_until_complete( + tracing_router.query_spans( + request=request, + # + query=query, + ) + ) + + nof_traces = tracing_response.count + + log.info( + "[TRACES] ", + run_id=run_id, + count=nof_traces, + ) + + query_traces[query_step_key] = tracing_response.traces or dict() + # ---------------------------------------------------------------------- + + # run online evaluation ------------------------------------------------ + for query_step_key in query_traces.keys(): + if not query_traces[query_step_key].keys(): + continue + + # create scenarios ------------------------------------------------- + + nof_traces = len(query_traces[query_step_key].keys()) + + scenarios_create = [ + EvaluationScenarioCreate( + run_id=run_id, + timestamp=timestamp, + interval=interval, + # + status=EvaluationStatus.RUNNING, + ) + for _ in range(nof_traces) + ] + + scenarios = loop.run_until_complete( + evaluations_service.create_scenarios( + project_id=project_id, + user_id=user_id, + # + scenarios=scenarios_create, + ) + ) + + if len(scenarios) != nof_traces: + log.error( + "[LIVE] Could not create evaluation scenarios", + run_id=run_id, + ) + continue + # ------------------------------------------------------------------ + + # create query steps ----------------------------------------------- + query_trace_ids = list(query_traces[query_step_key].keys()) + scenario_ids = [scenario.id for scenario in scenarios if scenario.id] + + results_create = [ + EvaluationResultCreate( + run_id=run_id, + scenario_id=scenario_id, + step_key=query_step_key, + repeat_idx=1, + timestamp=timestamp, + interval=interval, + # + status=EvaluationStatus.SUCCESS, + # + trace_id=query_trace_id, + ) + for scenario_id, query_trace_id in zip(scenario_ids, query_trace_ids) + ] + + results = loop.run_until_complete( + evaluations_service.create_results( + project_id=project_id, + user_id=user_id, + # + results=results_create, + ) + ) + + assert ( + len(results) == nof_traces + ), f"Failed to create evaluation results for run {run_id}!" + # ------------------------------------------------------------------ + + scenario_has_errors: Dict[int, int] = dict() + scenario_status: Dict[int, EvaluationStatus] = dict() + + # iterate over query traces ---------------------------------------- + for idx, trace in enumerate(query_traces[query_step_key].values()): + scenario_has_errors[idx] = 0 + scenario_status[idx] = EvaluationStatus.SUCCESS + + scenario = scenarios[idx] + scenario_id = scenario_ids[idx] + query_trace_id = query_trace_ids[idx] + + if not isinstance(trace.spans, dict): + log.warn( + f"Trace with id {query_trace_id} has no root spans", + run_id=run_id, + ) + scenario_has_errors[idx] += 1 + scenario_status[idx] = EvaluationStatus.ERRORS + continue + + root_span = list(trace.spans.values())[0] + + if isinstance(root_span, list): + log.warn( + f"More than one root span for trace with id {query_trace_id}", + run_id=run_id, + ) + scenario_has_errors[idx] += 1 + scenario_status[idx] = EvaluationStatus.ERRORS + continue + + query_span_id = root_span.span_id + + log.info( + "[TRACE] ", + run_id=run_id, + trace_id=query_trace_id, + ) + + # run evaluator revisions -------------------------------------- + for ( + evaluator_step_key, + evaluator_revision, + ) in evaluator_revisions.items(): + step_has_errors = 0 + step_status = EvaluationStatus.SUCCESS + + references: dict = evaluator_references[evaluator_step_key] + links: dict = dict( + query_step_key=Link( + trace_id=query_trace_id, + span_id=query_span_id, + ) + ) + + parameters: dict = ( + evaluator_revision.data.parameters or {} + if evaluator_revision.data + else {} + ) + inputs: dict = {} + outputs: Any = None + + trace_attributes: dict = root_span.attributes or {} + trace_ag_attributes: dict = trace_attributes.get("ag", {}) + trace_data: dict = trace_ag_attributes.get("data", {}) + trace_parameters: dict = trace_data.get("parameters", {}) + trace_inputs: dict = trace_data.get("inputs", {}) + trace_outputs: Any = trace_data.get("outputs") + + workflow_service_data = WorkflowServiceData( + # + parameters=parameters, + inputs=inputs, + # + trace_parameters=trace_parameters, + trace_inputs=trace_inputs, + trace_outputs=trace_outputs, + # + trace=trace, + ) + + workflow_service_request = WorkflowServiceRequest( + version="2025.07.14", + # + flags={ + "is_annotation": True, + "inline": True, + }, + tags=None, + meta=None, + # + data=workflow_service_data, + # + references=references, + links=links, + # + credentials=credentials, + secrets=secrets, + ) + + workflow_revision = evaluator_revision + + workflows_service_response = loop.run_until_complete( + workflows_service.invoke_workflow( + project_id=project_id, + user_id=user_id, + # + request=workflow_service_request, + revision=workflow_revision, + ) + ) + + evaluator_trace_id = None + error = None + + has_error = workflows_service_response.status.code != 200 + + # if error in evaluator, no annotation, only step ---------- + if has_error: + log.warn( + f"There is an error in evaluator {evaluator_step_key} for query {query_trace_id}." + ) + + step_has_errors += 1 + step_status = EvaluationStatus.FAILURE + scenario_has_errors[idx] += 1 + scenario_status[idx] = EvaluationStatus.ERRORS + + error = workflows_service_response.status.model_dump( + mode="json", + exclude_none=True, + ) + # ---------------------------------------------------------- + + # else, first annotation, then step ------------------------ + else: + outputs = ( + workflows_service_response.data.outputs + if workflows_service_response.data + else None + ) + + annotation_create_request = AnnotationCreateRequest( + annotation=AnnotationCreate( + origin=AnnotationOrigin.AUTO, + kind=AnnotationKind.EVAL, + channel=AnnotationChannel.API, + # + data={"outputs": outputs}, + # + references=SimpleTraceReferences(**references), + links=links, + ) + ) + + annotation_response = loop.run_until_complete( + annotations_router.create_annotation( + request=request, + annotation_create_request=annotation_create_request, + ) + ) + + if ( + not annotation_response.count + or not annotation_response.annotation + ): + log.warn( + f"Failed to create annotation for query {query_trace_id} and evaluator {evaluator_revision.id}" + ) + step_has_errors += 1 + step_status = EvaluationStatus.FAILURE + scenario_has_errors[idx] += 1 + scenario_status[idx] = EvaluationStatus.ERRORS + continue + + evaluator_trace_id = annotation_response.annotation.trace_id + # ---------------------------------------------------------- + + results_create = [ + EvaluationResultCreate( + run_id=run_id, + scenario_id=scenario_id, + step_key=evaluator_step_key, + repeat_idx=1, + timestamp=timestamp, + interval=interval, + # + status=step_status, + # + trace_id=evaluator_trace_id, + error=error, + ) + ] + + results = loop.run_until_complete( + evaluations_service.create_results( + project_id=project_id, + user_id=user_id, + # + results=results_create, + ) + ) + + assert ( + len(results) == 1 + ), f"Failed to create evaluation result for scenario with id {scenario.id}!" + # -------------------------------------------------------------- + + scenario_edit = EvaluationScenarioEdit( + id=scenario.id, + tags=scenario.tags, + meta=scenario.meta, + status=scenario_status[idx], + ) + + scenario = loop.run_until_complete( + evaluations_service.edit_scenario( + project_id=project_id, + user_id=user_id, + # + scenario=scenario_edit, + ) + ) + + if not scenario or not scenario.id: + log.error( + f"Failed to update evaluation scenario with id {scenario_id}!", + run_id=run_id, + ) + + loop.run_until_complete( + evaluations_service.refresh_metrics( + project_id=project_id, + user_id=user_id, + # + run_id=run_id, + scenario_id=scenario_id, + ) + ) + # ------------------------------------------------------------------ + + loop.run_until_complete( + evaluations_service.refresh_metrics( + project_id=project_id, + user_id=user_id, + # + run_id=run_id, + timestamp=timestamp, + interval=interval, + ) + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e, exc_info=True) + + log.info( + "[DONE] ", + run_id=run_id, + ) + + return diff --git a/api/ee/src/utils/entitlements.py b/api/ee/src/utils/entitlements.py new file mode 100644 index 0000000000..13360aad77 --- /dev/null +++ b/api/ee/src/utils/entitlements.py @@ -0,0 +1,169 @@ +from typing import Union, Optional, Callable +from uuid import UUID + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.caching import get_cache, set_cache + +log = get_module_logger(__name__) + +from fastapi.responses import JSONResponse + +from ee.src.core.subscriptions.service import SubscriptionsService +from ee.src.core.entitlements.types import ( + Tracker, + Flag, + Counter, + Gauge, + Plan, + ENTITLEMENTS, +) +from ee.src.core.meters.service import MetersService +from ee.src.core.meters.types import MeterDTO +from ee.src.dbs.postgres.meters.dao import MetersDAO +from ee.src.dbs.postgres.subscriptions.dao import SubscriptionsDAO + +meters_service = MetersService( + meters_dao=MetersDAO(), +) + +subscriptions_service = SubscriptionsService( + subscriptions_dao=SubscriptionsDAO(), + meters_service=meters_service, +) + + +class EntitlementsException(Exception): + pass + + +NOT_ENTITLED_RESPONSE: Callable[ + [Tracker], JSONResponse +] = lambda tracker=None: JSONResponse( + status_code=403, + content={ + "detail": ( + "You have reached your monthly quota limit. Please upgrade your plan to continue." + if tracker == Tracker.COUNTERS + else ( + "You have reached your quota limit. Please upgrade your plan to continue." + if tracker == Tracker.GAUGES + else ( + "You do not have access to this feature. Please upgrade your plan to continue." + if tracker == Tracker.FLAGS + else "You do not have access to this feature." + ) + ) + ), + }, +) + + +async def check_entitlements( + organization_id: UUID, + key: Union[Flag, Counter, Gauge], + delta: Optional[int] = None, +) -> tuple[bool, Optional[MeterDTO], Optional[Callable]]: + flag = None + try: + flag = Flag(key) + except ValueError: + pass + + counter = None + try: + counter = Counter(key) + except ValueError: + pass + + gauge = None + try: + gauge = Gauge(key) + except ValueError: + pass + + if flag is None and counter is None and gauge is None: + raise EntitlementsException(f"Invalid key [{key}]") + + cache_key = { + "organization_id": organization_id, + } + + subscription_data = await get_cache( + namespace="entitlements:subscription", + key=cache_key, + ) + + if subscription_data is None: + subscription = await subscriptions_service.read(organization_id=organization_id) + + if not subscription: + raise EntitlementsException( + f"No subscription found for organization [{organization_id}]" + ) + + subscription_data = { + "plan": subscription.plan.value, + "anchor": subscription.anchor, + } + + await set_cache( + namespace="entitlements:subscription", + key=cache_key, + value=subscription_data, + ) + + plan = Plan(subscription_data.get("plan")) + anchor = subscription_data.get("anchor") + + if plan not in ENTITLEMENTS: + raise EntitlementsException(f"Missing plan [{plan}] in entitlements") + + if flag: + if flag not in ENTITLEMENTS[plan][Tracker.FLAGS]: + raise EntitlementsException(f"Invalid flag: {flag} for plan [{plan}]") + + check = ENTITLEMENTS[plan][Tracker.FLAGS][flag] + + if flag.name != "RBAC": + # TODO: remove this line + log.info( + f"adjusting: {organization_id} | | {'allow' if check else 'deny '} | {flag.name}" + ) + + return check is True, None, None + + quota = None + + if counter: + if counter not in ENTITLEMENTS[plan][Tracker.COUNTERS]: + raise EntitlementsException(f"Invalid counter: {counter} for plan [{plan}]") + + quota = ENTITLEMENTS[plan][Tracker.COUNTERS][counter] + + if gauge: + if gauge not in ENTITLEMENTS[plan][Tracker.GAUGES]: + raise EntitlementsException(f"Invalid gauge: {gauge} for plan [{plan}]") + + quota = ENTITLEMENTS[plan][Tracker.GAUGES][gauge] + + if not quota: + raise EntitlementsException(f"No quota found for key [{key}] in plan [{plan}]") + + meter = MeterDTO( + organization_id=organization_id, + key=key, + delta=delta, + ) + + check, meter, _ = await meters_service.adjust( + meter=meter, + quota=quota, + anchor=anchor, + ) + + # TODO: remove this line + log.info( + f"adjusting: {organization_id} | {(('0' if (meter.month != 0 and meter.month < 10) else '') + str(meter.month)) if meter.month != 0 else ' '}.{meter.year if meter.year else ' '} | {'allow' if check else 'deny '} | {meter.key}: {meter.value-meter.synced} [{meter.value}]" + ) + + return check is True, meter, _ diff --git a/api/ee/src/utils/permissions.py b/api/ee/src/utils/permissions.py new file mode 100644 index 0000000000..312bcb05b6 --- /dev/null +++ b/api/ee/src/utils/permissions.py @@ -0,0 +1,304 @@ +from typing import Dict, List, Union, Optional + +from fastapi import HTTPException +from fastapi.responses import JSONResponse + +from oss.src.utils.logging import get_module_logger +from oss.src.utils.caching import get_cache, set_cache + +from ee.src.models.db_models import ( + OrganizationDB, + WorkspaceDB, + Permission, + WorkspaceRole, + ProjectDB, +) +from oss.src.services import db_manager +from ee.src.services import db_manager_ee +from ee.src.utils.entitlements import check_entitlements, Flag +from ee.src.services.selectors import get_user_org_and_workspace_id + + +log = get_module_logger(__name__) + +FORBIDDEN_EXCEPTION = HTTPException( + status_code=403, + detail="You do not have access to perform this action. Please contact your organization admin.", +) + + +async def check_user_org_access( + kwargs: dict, organization_id: str, check_owner=False +) -> bool: + if check_owner: # Check that the user is the owner of the organization + user = await db_manager.get_user_with_id(user_id=kwargs["id"]) + organization = await db_manager_ee.get_organization(organization_id) + if not organization: + log.error("Organization not found") + raise Exception("Organization not found") + return organization.owner == str(user.id) # type: ignore + else: + user_organizations: List = kwargs["organization_ids"] + user_exists_in_organizations = organization_id in user_organizations + return user_exists_in_organizations + + +async def check_user_access_to_workspace( + user_org_workspace_data: Dict[str, Union[str, list]], + workspace: WorkspaceDB, + organization: OrganizationDB, +) -> bool: + """ + Check if a user has access to a specific workspace and the workspace organization. + + Args: + user_org_workspace_data (Dict[str, Union[str, list]]): User-specific information. + workspace (WorkspaceDB): The workspace to check. + organization (OrganizationDB): The organization to check. + + Returns: + bool: True if the user has access, False otherwise. + + Raises: + ValueError: If the workspace does not belong to the organization. + """ + + workspace_organization_id = str(workspace.organization_id) + if ( + workspace is None + or organization is None + or workspace_organization_id != str(organization.id) + ): + raise ValueError("Workspace does not belong to the provided organization") + + # Check that the user belongs to the organization + has_organization_access = await check_user_org_access( + user_org_workspace_data, workspace_organization_id + ) + if not has_organization_access: + log.error("User does not belong and have access to the organization") + return False + + # Check that the user belongs to the workspace + user_id = user_org_workspace_data.get("id") + if user_id is None: + log.error("User ID is missing in user_org_workspace_data") + return False + + workspace_members = workspace.get_all_members() + if user_id not in workspace_members: + log.error("User does not belong to the workspace") + return False + + # Check that the workspace is in the user's workspaces + has_access_to_workspace = any( + str(workspace.id) == workspace_id + for workspace_id in user_org_workspace_data["workspace_ids"] + ) + return has_access_to_workspace + + +async def check_action_access( + user_uid: str, + project_id: str = None, + permission: Permission = None, + role: str = None, +) -> bool: + """ + Check if a user belongs to a workspace and has a certain permission. + + Args: + user_id (str): The user's ID. + object_id (str): The ID of the object to check. + type (str): The type of the object to check. + permission (Permission): The permission to check. + role (str): The role to check. + + Returns: + bool: True if the user belongs to the workspace and has the specified permission, False otherwise. + """ + + if permission is None and role is None: + raise Exception("Either permission or role must be provided") + elif permission is not None and role is not None: + raise Exception("Only one of permission or role must be provided") + + cache_key = { + "permission": permission.value if permission else None, + "role": role, + } + + has_permission = await get_cache( + project_id=project_id, + user_id=user_uid, + namespace="check_action_access", + key=cache_key, + ) + + if has_permission is not None: + return has_permission + + user_org_workspace_data: dict = await get_user_org_and_workspace_id(user_uid) + has_permission = await check_rbac_permission( + user_org_workspace_data=user_org_workspace_data, + project_id=project_id, + role=role, + permission=permission, + ) + + await set_cache( + project_id=project_id, + user_id=user_uid, + namespace="check_action_access", + key=cache_key, + value=has_permission, + ) + + return has_permission + + +# async def check_apikey_action_access( +# api_key: str, user_id: str, permission: Permission +# ): +# """ +# Check if an api key belongs to a user for a workspace and has the right permission. + +# Args: +# api_key (str): The api key +# user_id (str): The user (owner) ID of the api_key +# permission (Permission): The permission to check for. +# """ + +# api_key_prefix = api_key.split(".")[0] +# api_key_db = await db_manager.get_user_api_key_by_prefix( +# api_key_prefix=api_key_prefix, user_id=user_id +# ) +# if api_key_db is None: +# raise HTTPException( +# 404, {"message": f"API Key with prefix {api_key_prefix} not found"} +# ) + +# project_db = await db_manager.get_project_by_id( +# project_id=str(api_key_db.project_id) +# ) +# if project_db is None: +# raise HTTPException( +# 404, +# {"message": f"Project with ID {str(api_key_db.workspace_id)} not found"}, +# ) + +# has_access = await check_project_has_role_or_permission( +# project_db, str(api_key_db.created_by_id), None, permission +# ) +# if not has_access: +# raise HTTPException( +# 403, +# { +# "message": "You do not have access to perform this action. Please contact your organization admin." +# }, +# ) + + +async def check_rbac_permission( + user_org_workspace_data: Dict[str, Union[str, list]], + project_id: str = None, + permission: Permission = None, + role: str = None, +) -> bool: + """ + Check if a user belongs to a workspace and has a certain permission. + + Args: + user_org_workspace_data (Dict[str, Union[str, list]]): User-specific information containing the id, uid, list of user organization and list of user workspace. + project_id (str): The ID of the project. + permission (Permission): The permission to check for. + role (str): The role to check for. + + Returns: + bool: True if the user belongs to the workspace and has the specified permission, False otherwise. + """ + + assert ( + project_id is not None + ), "Project_ID is required to check object-level permissions" + + # Assert that either permission or role is provided, but not both + assert (permission is not None) or ( + role is not None + ), "Either 'permission' or 'role' must be provided, but neither is provided" + assert not ( + (permission is not None) and (role is not None) + ), "'permission' and 'role' cannot both be provided at the same time" + + if project_id is not None: + project = await db_manager.get_project_by_id(project_id) + if project is None: + raise Exception("Project not found") + + workspace = await db_manager.get_workspace(str(project.workspace_id)) + organization = await db_manager_ee.get_organization( + str(project.organization_id) + ) + + workspace_has_access = await check_user_access_to_workspace( + user_org_workspace_data=user_org_workspace_data, + workspace=workspace, + organization=organization, + ) + if not workspace_has_access: + log.error("User does not have access to the workspace") + return False + + user_id = user_org_workspace_data["id"] + assert isinstance(user_id, str), "User ID must be a string" + has_access = await check_project_has_role_or_permission( + project, user_id, role, permission + ) + return has_access + + +async def check_project_has_role_or_permission( + # organization_id: str, + project: ProjectDB, + user_id: str, + role: Optional[str] = None, + permission: Optional[str] = None, +): + """Check if a user has the provided role or permission in a project. + + Args: + project (ProjectDB): The project to check if the user has permissions to + user_id (str): The ID of the user + role (Optional[str], optional): The role to check for. Defaults to None. + permission (Optional[str], optional): The permission to check for. Defaults to None. + """ + + check, _, _ = await check_entitlements( + organization_id=project.organization_id, + key=Flag.RBAC, + ) + + if not check: + return True + + assert ( + role is not None or permission is not None + ), "Either role or permission must be provided" + + project_members = await db_manager_ee.get_project_members( + project_id=str(project.id) + ) + if project.is_owner(user_id, project_members): + return True + + if role is not None: + if role not in list(WorkspaceRole): + raise Exception("Invalid role specified") + return project.has_role(user_id, role, project_members) + + if permission is not None: + if permission not in list(Permission): + raise Exception("Invalid permission specified") + return project.has_permission(user_id, permission, project_members) + + return False diff --git a/api/ee/tests/__init__.py b/api/ee/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/tests/manual/billing.http b/api/ee/tests/manual/billing.http new file mode 100644 index 0000000000..6158dac23f --- /dev/null +++ b/api/ee/tests/manual/billing.http @@ -0,0 +1,52 @@ + +@host = http://localhost +@base_url = {{host}}/api/billing +@api_key = xxx.xxx +### + +# @name open_portal +POST {{base_url}}/stripe/portals/ +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name open_checkout +POST {{base_url}}/stripe/checkouts/?plan=cloud_v0_pro&success_url=http://localhost/ +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name fetch_plans +GET {{base_url}}/plans +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name switch_plans +POST {{base_url}}/plans/switch?plan=cloud_v0_pro +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name fetch_subscription +GET {{base_url}}/subscription +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name cancel_subscription +POST {{base_url}}/subscription/cancel +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name fetch_usage +GET {{base_url}}/usage +Content-Type: application/json +Authorization: ApiKey {{api_key}} diff --git a/api/ee/tests/manual/evaluations/live.http b/api/ee/tests/manual/evaluations/live.http new file mode 100644 index 0000000000..6a43280046 --- /dev/null +++ b/api/ee/tests/manual/evaluations/live.http @@ -0,0 +1,131 @@ +@auth_key = {{$dotenv.AGENTA_AUTH_KEY}} || change-me +@api_url = {{$dotenv AGENTA_API_URL}} +@api_key = {{$dotenv AGENTA_API_KEY}} + + +### +# @name create_account +POST {{api_url}}/admin/account +Content-Type: application/json +Authorization: Access {{auth_key}} + +### +@user_id = {{create_account.response.body.user.id}} +# @authorization = {{create_account.response.body.scopes[0].credentials}} +@authorization = ApiKey {{api_key}} + +### +# @name list_queries +POST {{api_url}}/preview/simple/queries/query +Content-Type: application/json +Authorization: {{authorization}} + +{} + +### +# @name create_query +POST {{api_url}}/preview/simple/queries/ +Content-Type: application/json +Authorization: {{authorization}} + +{ + "query": { + "slug": "{{$guid}}", + "name": "Test Query", + "description": "This is a test query", + "tags": { + "my_key": "my_value" + }, + "data": { + "filtering": { + "conditions": [ + { + "field": "attributes", + "key": "ag.type.trace", + "operator": "is", + "value": "invocation" + } + ] + } + } + } +} + +### +# @name fetch_query_revision +POST {{api_url}}/preview/queries/revisions/retrieve +Content-Type: application/json +Authorization: {{authorization}} + +{ + "query_ref": { + "id": "{{create_query.response.body.query.id}}" + } +} + +### +# @name list_evaluators +POST {{api_url}}/preview/simple/evaluators/query +Content-Type: application/json +Authorization: {{authorization}} + +{} + +### +# @name fetch_evaluator_revision +POST {{api_url}}/preview/evaluators/revisions/retrieve +Content-Type: application/json +Authorization: {{authorization}} + +{ + "evaluator_ref": { + "id": "{{list_evaluators.response.body.evaluators[2].id}}" + } +} + +### +# @name list_evaluations +POST {{api_url}}/preview/simple/evaluations/query +Content-Type: application/json +Authorization: {{authorization}} + +{} + +### +# @name create_evaluation +POST {{api_url}}/preview/simple/evaluations/ +Content-Type: application/json +Authorization: {{authorization}} + +{ + "evaluation": { + "name": "Test JIT Evaluation", + "description": "This is a test jit evaluation", + "tags": { + "my_key": "my_value" + }, + "flags": { + "is_live": true + }, + "data": { + "query_steps": [ + "{{fetch_query_revision.response.body.query_revision.id}}" + ], + "evaluator_steps": [ + "{{fetch_evaluator_revision.response.body.evaluator_revision.evaluator_id}}" + ] + } + } +} + +### +# @name stop_evaluation +POST {{api_url}}/preview/simple/evaluations/{{create_evaluation.response.body.evaluation.id}}/stop +Content-Type: application/json +Authorization: {{authorization}} + +### +# @name start_evaluation +POST {{api_url}}/preview/simple/evaluations/{{create_evaluation.response.body.evaluation.id}}/start +Content-Type: application/json +Authorization: {{authorization}} diff --git a/api/ee/tests/manual/evaluations/sdk/client.py b/api/ee/tests/manual/evaluations/sdk/client.py new file mode 100644 index 0000000000..c930eee323 --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/client.py @@ -0,0 +1,32 @@ +from os import getenv + +import requests + +BASE_TIMEOUT = 10 + +AGENTA_API_KEY = getenv("AGENTA_API_KEY") +AGENTA_API_URL = getenv("AGENTA_API_URL") + + +def authed_api(): + """ + Preconfigured requests for authenticated endpoints (supports all methods). + """ + + api_url = AGENTA_API_URL + credentials = f"ApiKey {AGENTA_API_KEY}" + + def _request(method: str, endpoint: str, **kwargs): + url = f"{api_url}{endpoint}" + headers = kwargs.pop("headers", {}) + headers.setdefault("Authorization", credentials) + + return requests.request( + method=method, + url=url, + headers=headers, + timeout=BASE_TIMEOUT, + **kwargs, + ) + + return _request diff --git a/api/ee/tests/manual/evaluations/sdk/definitions.py b/api/ee/tests/manual/evaluations/sdk/definitions.py new file mode 100644 index 0000000000..4768515ef3 --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/definitions.py @@ -0,0 +1,1818 @@ +from enum import Enum +from uuid import UUID, uuid4 +from re import match +from datetime import datetime +from typing import Dict, List, Optional, Union, Literal, Callable, Any, TypeAliasType + +from pydantic import BaseModel, field_validator, Field + +# oss.src.core.shared.dtos ----------------------------------------------------- + +from typing import Optional, Dict, List, Union, Literal +from uuid import UUID +from datetime import datetime +from re import match + +from pydantic import BaseModel, field_validator + +from typing_extensions import TypeAliasType + + +BoolJson: TypeAliasType = TypeAliasType( # type: ignore + "BoolJson", + Union[bool, Dict[str, "BoolJson"]], # type: ignore +) + +StringJson: TypeAliasType = TypeAliasType( # type: ignore + "StringJson", + Union[str, Dict[str, "StringJson"]], # type: ignore +) + +FullJson: TypeAliasType = TypeAliasType( # type: ignore + "FullJson", + Union[str, int, float, bool, None, Dict[str, "FullJson"], List["FullJson"]], # type: ignore +) + +NumericJson: TypeAliasType = TypeAliasType( # type: ignore + "NumericJson", + Union[int, float, Dict[str, "NumericJson"]], # type: ignore +) + +NoListJson: TypeAliasType = TypeAliasType( # type: ignore + "NoListJson", + Union[str, int, float, bool, None, Dict[str, "NoListJson"]], # type: ignore +) + +Json = Dict[str, FullJson] # type: ignore + +Data = Dict[str, FullJson] # type: ignore + +Flags = Dict[str, bool | str] + +Tags = Dict[str, NoListJson] # type: ignore + +Meta = Dict[str, FullJson] # type: ignore + +Hashes = Dict[str, StringJson] # type: ignore + + +class Metadata(BaseModel): + flags: Optional[Flags] = None # type: ignore + meta: Optional[Meta] = None # type: ignore + tags: Optional[Tags] = None # type: ignore + + +class Windowing(BaseModel): + # RANGE + newest: Optional[datetime] = None + oldest: Optional[datetime] = None + # TOKEN + next: Optional[UUID] = None + # LIMIT + limit: Optional[int] = None + # ORDER + order: Optional[Literal["ascending", "descending"]] = None + # SAMPLES + rate: Optional[float] = None + # BUCKETS + interval: Optional[int] = None + + @field_validator("rate") + def check_rate(cls, v): + if v is not None and (v < 0.0 or v > 1.0): + raise ValueError("Sampling rate must be between 0.0 and 1.0.") + return v + + @field_validator("interval") + def check_interval(cls, v): + if v is not None and v <= 0: + raise ValueError("Bucket interval must be a positive integer.") + return v + + +class Lifecycle(BaseModel): + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + deleted_at: Optional[datetime] = None + + created_by_id: Optional[UUID] = None + updated_by_id: Optional[UUID] = None + deleted_by_id: Optional[UUID] = None + + +class TraceID(BaseModel): + trace_id: Optional[str] = None + + +class SpanID(BaseModel): + span_id: Optional[str] = None + + +class Identifier(BaseModel): + id: Optional[UUID] = None + + +class Slug(BaseModel): + slug: Optional[str] = None + + @field_validator("slug") + def check_url_safety(cls, v): + if v is not None: + if not match(r"^[a-zA-Z0-9_-]+$", v): + raise ValueError("slug must be URL-safe.") + return v + + +class Version(BaseModel): + version: Optional[str] = None + + +class Header(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + + +class Commit(BaseModel): + author: Optional[UUID] = None + date: Optional[datetime] = None + message: Optional[str] = None + + +class Reference(Identifier, Slug, Version): + pass + + +class Link(TraceID, SpanID): + pass + + +def sync_alias(primary: str, alias: str, instance: BaseModel) -> None: + primary_val = getattr(instance, primary) + alias_val = getattr(instance, alias) + if primary_val and alias_val is None: + object.__setattr__(instance, alias, primary_val) + elif alias_val and primary_val is None: + object.__setattr__(instance, primary, alias_val) + + +class AliasConfig(BaseModel): + model_config = { + "populate_by_name": True, + "from_attributes": True, + } + + +Metrics = Dict[str, NumericJson] # type: ignore + + +class LegacyLifecycleDTO(BaseModel): + created_at: Optional[str] = None + updated_at: Optional[str] = None + updated_by_id: Optional[str] = None + # DEPRECATING + updated_by: Optional[str] = None # email + + +class Status(BaseModel): + code: Optional[int] = 500 + type: Optional[str] = None + message: Optional[str] = "An unexpected error occurred. Please try again later." + stacktrace: Optional[str] = None + + +Mappings = Dict[str, str] + +Schema = Dict[str, FullJson] # type: ignore + +# ------------------------------------------------------------------------------ + +# oss.src.core.git.dtos -------------------------------------------------------- + +from typing import Optional, List +from uuid import UUID + +from pydantic import BaseModel + + +# artifacts -------------------------------------------------------------------- + + +class Artifact(Identifier, Slug, Lifecycle, Header, Metadata): + pass + + +class ArtifactCreate(Slug, Header, Metadata): + pass + + +class ArtifactEdit(Identifier, Header, Metadata): + pass + + +class ArtifactQuery(Metadata): + pass + + +# variants --------------------------------------------------------------------- + + +class Variant(Identifier, Slug, Lifecycle, Header, Metadata): + artifact_id: Optional[UUID] = None + + +class VariantCreate(Slug, Header, Metadata): + artifact_id: Optional[UUID] = None + + +class VariantEdit(Identifier, Header, Metadata): + pass + + +class VariantQuery(Metadata): + pass + + +# revisions -------------------------------------------------------------------- + + +class Revision(Identifier, Slug, Version, Lifecycle, Header, Metadata, Commit): + data: Optional[Data] = None + + artifact_id: Optional[UUID] = None + variant_id: Optional[UUID] = None + + +class RevisionCreate(Slug, Header, Metadata): + artifact_id: Optional[UUID] = None + variant_id: Optional[UUID] = None + + +class RevisionEdit(Identifier, Header, Metadata): + pass + + +class RevisionQuery(Metadata): + authors: Optional[List[UUID]] = None + + +class RevisionCommit(Slug, Header, Metadata): + data: Optional[Data] = None + + message: Optional[str] = None + + artifact_id: Optional[UUID] = None + variant_id: Optional[UUID] = None + + +class RevisionsLog(BaseModel): + artifact_id: Optional[UUID] = None + variant_id: Optional[UUID] = None + revision_id: Optional[UUID] = None + + depth: Optional[int] = None + + +# forks ------------------------------------------------------------------------ + + +class RevisionFork(Slug, Header, Metadata): + data: Optional[Data] = None + + message: Optional[str] = None + + +class VariantFork(Slug, Header, Metadata): + pass + + +class ArtifactFork(RevisionsLog): + variant: Optional[VariantFork] = None + revision: Optional[RevisionFork] = None + + +# ------------------------------------------------------------------------------ + + +Origin = Literal["custom", "human", "auto"] +# Target = Union[List[UUID], Dict[UUID, Origin], List[Callable]] +Target = Union[ + List[List[Dict[str, Any]]], # testcases_data + List[Callable], # workflow_handlers + List[UUID], # entity_ids + Dict[UUID, Origin], # entity_ids with origins +] + + +# oss.src.core.evaluations.types + + +class EvaluationStatus(str, Enum): + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILURE = "failure" + ERRORS = "errors" + CANCELLED = "cancelled" + + +class EvaluationRunFlags(BaseModel): + is_closed: Optional[bool] = None # Indicates if the run is modifiable + is_live: Optional[bool] = None # Indicates if the run has live queries + is_active: Optional[bool] = None # Indicates if the run is currently active + + +class SimpleEvaluationFlags(EvaluationRunFlags): + pass + + +SimpleEvaluationStatus = EvaluationStatus + + +class SimpleEvaluationData(BaseModel): + status: Optional[SimpleEvaluationStatus] = None + + query_steps: Optional[Target] = None + testset_steps: Optional[Target] = None + application_steps: Optional[Target] = None + evaluator_steps: Optional[Target] = None + + repeats: Optional[int] = None + + +class EvaluationRun(BaseModel): + id: UUID + + +class EvaluationScenario(BaseModel): + id: UUID + + run_id: UUID + + +class EvaluationResult(BaseModel): + id: UUID + + run_id: UUID + scenario_id: UUID + step_key: str + + testcase_id: Optional[UUID] = None + trace_id: Optional[UUID] = None + error: Optional[dict] = None + + flags: Optional[Dict[str, Any]] = None + tags: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + + +class EvaluationMetrics(Identifier, Lifecycle): + flags: Optional[Dict[str, Any]] = None + tags: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + + status: Optional[EvaluationStatus] = None + + timestamp: Optional[datetime] = None + interval: Optional[int] = None + + data: Optional[Data] = None + + scenario_id: Optional[UUID] = None + + run_id: UUID + + +# oss.src.core.tracing.dtos + +import random +import string +from enum import Enum +from datetime import datetime, timezone +from typing import List, Dict, Any, Union, Optional, Literal +from uuid import UUID + +from pydantic import BaseModel, model_validator, Field + + +class TraceType(Enum): + INVOCATION = "invocation" + ANNOTATION = "annotation" + # + UNKNOWN = "unknown" + + +class SpanType(Enum): + AGENT = "agent" + CHAIN = "chain" + WORKFLOW = "workflow" + TASK = "task" + TOOL = "tool" + EMBEDDING = "embedding" + QUERY = "query" + LLM = "llm" + COMPLETION = "completion" + CHAT = "chat" + RERANK = "rerank" + # + UNKNOWN = "unknown" + + +class AgMetricEntryAttributes(BaseModel): + # cumulative: 'cum' can't be used though + cumulative: Optional[Metrics] = None + # incremental 'inc' could be used, since 'unit' may be confusing + incremental: Optional[Metrics] = None + + model_config = {"ser_json_exclude_none": True} + + +class AgMetricsAttributes(BaseModel): + duration: Optional[AgMetricEntryAttributes] = None + errors: Optional[AgMetricEntryAttributes] = None + tokens: Optional[AgMetricEntryAttributes] = None + costs: Optional[AgMetricEntryAttributes] = None + + model_config = {"ser_json_exclude_none": True} + + +class AgTypeAttributes(BaseModel): + trace: Optional[TraceType] = TraceType.INVOCATION + span: Optional[SpanType] = SpanType.TASK + + +class AgDataAttributes(BaseModel): + parameters: Optional[Dict[str, Any]] = None + inputs: Optional[Dict[str, Any]] = None + outputs: Optional[Any] = None + internals: Optional[Dict[str, Any]] = None + + model_config = {"ser_json_exclude_none": True} + + +class AgAttributes(BaseModel): + type: AgTypeAttributes = Field(default_factory=AgTypeAttributes) + data: AgDataAttributes = Field(default_factory=AgDataAttributes) + + metrics: Optional[AgMetricsAttributes] = None + flags: Optional[Flags] = None # type: ignore + tags: Optional[Tags] = None # type: ignore + meta: Optional[Meta] = None # type: ignore + exception: Optional[Data] = None # type: ignore + references: Optional[Dict[str, "OTelReference"]] = None + unsupported: Optional[Data] = None # type: ignore + + model_config = {"ser_json_exclude_none": True} + + +## --- SUB-ENTITIES --- ## + + +class OTelStatusCode(Enum): + STATUS_CODE_UNSET = "STATUS_CODE_UNSET" + STATUS_CODE_OK = "STATUS_CODE_OK" + STATUS_CODE_ERROR = "STATUS_CODE_ERROR" + + +class OTelSpanKind(Enum): + SPAN_KIND_UNSPECIFIED = "SPAN_KIND_UNSPECIFIED" + SPAN_KIND_INTERNAL = "SPAN_KIND_INTERNAL" + SPAN_KIND_SERVER = "SPAN_KIND_SERVER" + SPAN_KIND_CLIENT = "SPAN_KIND_CLIENT" + SPAN_KIND_PRODUCER = "SPAN_KIND_PRODUCER" + SPAN_KIND_CONSUMER = "SPAN_KIND_CONSUMER" + + +OTelAttributes = Json # type: ignore +OTelMetrics = Metrics # type: ignore +OTelTags = Tags # type: ignore + +Attributes = OTelAttributes # type: ignore + + +class OTelEvent(BaseModel): + name: str + timestamp: Union[datetime, int] + + attributes: Optional[OTelAttributes] = None + + +OTelEvents = List[OTelEvent] + + +class OTelHash(Identifier): + attributes: Optional[OTelAttributes] = None + + +OTelHashes = List[OTelHash] + + +class OTelLink(TraceID, SpanID): + attributes: Optional[OTelAttributes] = None + + +OTelLinks = List[OTelLink] + + +class OTelReference(Reference): + attributes: Optional[OTelAttributes] = None + + +OTelReferences = List[OTelReference] + + +class OTelSpansTree(BaseModel): + spans: Optional["OTelNestedSpans"] = None + + +OTelSpansTrees = List[OTelSpansTree] + + +class OTelFlatSpan(Lifecycle): + trace_id: str + span_id: str + parent_id: Optional[str] = None + + trace_type: Optional[TraceType] = None + span_type: Optional[SpanType] = None + + span_kind: Optional[OTelSpanKind] = None + span_name: Optional[str] = None + + start_time: Optional[Union[datetime, int]] = None + end_time: Optional[Union[datetime, int]] = None + + status_code: Optional[OTelStatusCode] = None + status_message: Optional[str] = None + + attributes: Optional[OTelAttributes] = None + references: Optional[OTelReferences] = None + links: Optional[OTelLinks] = None + hashes: Optional[OTelHashes] = None + + exception: Optional[Data] = None # type: ignore + + events: Optional[OTelEvents] = None + + @model_validator(mode="after") + def set_defaults(self): + if self.trace_type is None: + self.trace_type = TraceType.INVOCATION + if self.span_type is None: + self.span_type = SpanType.TASK + if self.span_kind is None: + self.span_kind = OTelSpanKind.SPAN_KIND_UNSPECIFIED + if self.status_code is None: + self.status_code = OTelStatusCode.STATUS_CODE_UNSET + if self.end_time is None and self.start_time is not None: + self.end_time = self.start_time + if self.start_time is None and self.end_time is not None: + self.start_time = self.end_time + if self.start_time is None and self.end_time is None: + now = datetime.now(timezone.utc) + self.start_time = now + self.end_time = now + if self.span_name is None: + self.span_name = "".join( + random.choices(string.ascii_letters + string.digits, k=8) + ) + return self + + +class OTelSpan(OTelFlatSpan, OTelSpansTree): + pass + + +OTelFlatSpans = List[OTelFlatSpan] +OTelNestedSpans = Dict[str, Union[OTelSpan, List[OTelSpan]]] +OTelTraceTree = Dict[str, OTelSpansTree] +OTelTraceTrees = List[OTelTraceTree] +OTelSpans = List[OTelSpan] + + +class Fields(str, Enum): + TRACE_ID = "trace_id" + SPAN_ID = "span_id" + PARENT_ID = "parent_id" + SPAN_NAME = "span_name" + SPAN_KIND = "span_kind" + START_TIME = "start_time" + END_TIME = "end_time" + STATUS_CODE = "status_code" + STATUS_MESSAGE = "status_message" + ATTRIBUTES = "attributes" + EVENTS = "events" + LINKS = "links" + REFERENCES = "references" + CREATED_AT = "created_at" + UPDATED_AT = "updated_at" + DELETED_AT = "deleted_at" + CREATED_BY_ID = "created_by_id" + UPDATED_BY_ID = "updated_by_id" + DELETED_BY_ID = "deleted_by_id" + CONTENT = "content" + + +class LogicalOperator(str, Enum): + AND = "and" + OR = "or" + NOT = "not" + NAND = "nand" + NOR = "nor" + + +class ComparisonOperator(str, Enum): + IS = "is" + IS_NOT = "is_not" + + +class NumericOperator(str, Enum): + EQ = "eq" + NEQ = "neq" + GT = "gt" + LT = "lt" + GTE = "gte" + LTE = "lte" + BETWEEN = "btwn" + + +class StringOperator(str, Enum): + STARTSWITH = "startswith" + ENDSWITH = "endswith" + CONTAINS = "contains" + MATCHES = "matches" + LIKE = "like" + + +class DictOperator(str, Enum): + HAS = "has" + HAS_NOT = "has_not" + + +class ListOperator(str, Enum): + IN = "in" + NOT_IN = "not_in" + + +class ExistenceOperator(str, Enum): + EXISTS = "exists" + NOT_EXISTS = "not_exists" + + +class TextOptions(BaseModel): + case_sensitive: Optional[bool] = False + exact_match: Optional[bool] = False + + +class ListOptions(BaseModel): + all: Optional[bool] = False + + +class Condition(BaseModel): + field: str + key: Optional[str] = None + value: Optional[Union[str, int, float, bool, list, dict]] = None + operator: Optional[ + Union[ + ComparisonOperator, + NumericOperator, + StringOperator, + ListOperator, + DictOperator, + ExistenceOperator, + ] + ] = ComparisonOperator.IS + options: Optional[Union[TextOptions, ListOptions]] = None + + +class Filtering(BaseModel): + operator: Optional[LogicalOperator] = LogicalOperator.AND + conditions: List[Union[Condition, "Filtering"]] = list() + + +class Focus(str, Enum): + TRACE = "trace" + SPAN = "span" + + +class Format(str, Enum): + AGENTA = "agenta" + OPENTELEMETRY = "opentelemetry" + + +class Formatting(BaseModel): + focus: Optional[Focus] = Focus.SPAN + format: Optional[Format] = Format.AGENTA + + +class TracingQuery(BaseModel): + formatting: Optional[Formatting] = None + windowing: Optional[Windowing] = None + filtering: Optional[Filtering] = None + + +_C_OPS = list(ComparisonOperator) +_N_OPS = list(NumericOperator) +_S_OPS = list(StringOperator) +_L_OPS = list(ListOperator) +_D_OPS = list(DictOperator) +_E_OPS = list(ExistenceOperator) + + +class FilteringException(Exception): + pass + + +class Analytics(BaseModel): + count: Optional[int] = 0 + duration: Optional[float] = 0.0 + costs: Optional[float] = 0.0 + tokens: Optional[float] = 0.0 + + def plus(self, other: "Analytics") -> "Analytics": + self.count += other.count + self.duration += other.duration + self.costs += other.costs + self.tokens += other.tokens + + return self + + +class Bucket(BaseModel): + timestamp: datetime + interval: int + total: Analytics + errors: Analytics + + +Trace = OTelSpansTree + +# oss.src.core.observability.dtos + +from enum import Enum +from uuid import UUID +from datetime import datetime +from typing import List, Dict, Any, Union, Optional + +from pydantic import BaseModel + + +## --- SUB-ENTITIES --- ## + + +class RootDTO(BaseModel): + id: UUID + + +class TreeType(Enum): + INVOCATION = "invocation" + ANNOTATION = "annotation" + # + UNKNOWN = "unknown" + + +class TreeDTO(BaseModel): + id: UUID + type: Optional[TreeType] = None + + +class NodeType(Enum): + # --- VARIANTS --- # + ## SPAN_KIND_SERVER + AGENT = "agent" + WORKFLOW = "workflow" + CHAIN = "chain" + ## SPAN_KIND_INTERNAL + TASK = "task" + ## SPAN_KIND_CLIENT + TOOL = "tool" + EMBEDDING = "embedding" + QUERY = "query" + COMPLETION = "completion" # LEGACY + CHAT = "chat" + RERANK = "rerank" + # --- VARIANTS --- # + + +class NodeDTO(BaseModel): + id: UUID + name: str + type: Optional[NodeType] = None + + +class ParentDTO(BaseModel): + id: UUID + + +class TimeDTO(BaseModel): + start: datetime + end: datetime + + +class StatusCode(Enum): + UNSET = "UNSET" + OK = "OK" + ERROR = "ERROR" + + +class StatusDTO(BaseModel): + code: StatusCode + message: Optional[str] = None + + class Config: + use_enum_values = True + + +Attributes = Dict[str, Any] + + +class ExceptionDTO(BaseModel): + timestamp: datetime + type: str + message: Optional[str] = None + stacktrace: Optional[str] = None + attributes: Optional[Attributes] = None + + class Config: + json_encoders = { + UUID: lambda v: str(v), # pylint: disable=unnecessary-lambda + datetime: lambda dt: dt.isoformat(), + } + + +Data = Dict[str, Any] +Metrics = Dict[str, Any] +Meta = Dict[str, Any] +Refs = Dict[str, Any] + + +class LinkDTO(BaseModel): + type: TreeType # Yes, this is correct + id: UUID # node_id, this is correct + tree_id: Optional[UUID] = None + + class Config: + use_enum_values = True + json_encoders = { + UUID: lambda v: str(v), # pylint: disable=unnecessary-lambda + } + + +class OTelSpanKind(Enum): + SPAN_KIND_UNSPECIFIED = "SPAN_KIND_UNSPECIFIED" + # INTERNAL + SPAN_KIND_INTERNAL = "SPAN_KIND_INTERNAL" + # SYNCHRONOUS + SPAN_KIND_SERVER = "SPAN_KIND_SERVER" + SPAN_KIND_CLIENT = "SPAN_KIND_CLIENT" + # ASYNCHRONOUS + SPAN_KIND_PRODUCER = "SPAN_KIND_PRODUCER" + SPAN_KIND_CONSUMER = "SPAN_KIND_CONSUMER" + + +class OTelStatusCode(Enum): + STATUS_CODE_OK = "STATUS_CODE_OK" + STATUS_CODE_ERROR = "STATUS_CODE_ERROR" + STATUS_CODE_UNSET = "STATUS_CODE_UNSET" + + +class OTelContextDTO(BaseModel): + trace_id: str + span_id: str + + +class OTelEventDTO(BaseModel): + name: str + timestamp: str + + attributes: Optional[Attributes] = None + + +class OTelLinkDTO(BaseModel): + context: OTelContextDTO + + attributes: Optional[Attributes] = None + + +class OTelExtraDTO(BaseModel): + kind: Optional[str] = None + + attributes: Optional[Attributes] = None + events: Optional[List[OTelEventDTO]] = None + links: Optional[List[OTelLinkDTO]] = None + + +## --- ENTITIES --- ## + + +class SpanDTO(BaseModel): + trace_id: str + span_id: str + + lifecycle: Optional[LegacyLifecycleDTO] = None + + root: RootDTO + tree: TreeDTO + node: NodeDTO + + parent: Optional[ParentDTO] = None + + time: TimeDTO + status: StatusDTO + + exception: Optional[ExceptionDTO] = None + + data: Optional[Data] = None + metrics: Optional[Metrics] = None + meta: Optional[Meta] = None + refs: Optional[Refs] = None + + links: Optional[List[LinkDTO]] = None + + otel: Optional[OTelExtraDTO] = None + + nodes: Optional[Dict[str, Union["SpanDTO", List["SpanDTO"]]]] = None + + model_config = { + "json_encoders": { + UUID: lambda v: str(v), + datetime: lambda dt: dt.isoformat(), + }, + } + + def encode(self, data: Any) -> Any: + if isinstance(data, dict): + return {k: self.encode(v) for k, v in data.items()} + elif isinstance(data, list): + return [self.encode(item) for item in data] + for type_, encoder in self.model_config["json_encoders"].items(): # type: ignore + if isinstance(data, type_): + return encoder(data) + return data + + def model_dump(self, *args, **kwargs) -> dict: + return self.encode( + super().model_dump( + *args, + **kwargs, + exclude_none=True, + ) + ) + + +class OTelSpanDTO(BaseModel): + context: OTelContextDTO + + name: str + kind: OTelSpanKind = OTelSpanKind.SPAN_KIND_UNSPECIFIED + + start_time: datetime + end_time: datetime + + status_code: OTelStatusCode = OTelStatusCode.STATUS_CODE_UNSET + status_message: Optional[str] = None + + attributes: Optional[Attributes] = None + events: Optional[List[OTelEventDTO]] = None + + parent: Optional[OTelContextDTO] = None + links: Optional[List[OTelLinkDTO]] = None + + +# oss.src.apis.fastapi.observability.models + +from typing import List, Optional +from datetime import datetime + + +class AgentaNodeDTO(SpanDTO): + pass + + +class Tree(BaseModel): + version: str + nodes: List[AgentaNodeDTO] + + +# oss.src.core.blobs.dtos + + +class Blob(Identifier, Lifecycle): + flags: Optional[Flags] = None # type: ignore + tags: Optional[Tags] = None # type: ignore + meta: Optional[Meta] = None # type: ignore + + data: Optional[Data] = None # type: ignore + + set_id: Optional[UUID] = None + + +# oss.src.core.testcases.dtos +# oss.src.core.testsets.dtos + + +class TestsetIdAlias(AliasConfig): + testset_id: Optional[UUID] = None + set_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="testset_id", + ) + + +class TestsetVariantIdAlias(AliasConfig): + testset_variant_id: Optional[UUID] = None + variant_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="testset_variant_id", + ) + + +class Testcase(Blob, TestsetIdAlias): + def model_post_init(self, __context) -> None: + sync_alias("testset_id", "set_id", self) + + +class TestsetFlags(BaseModel): + has_testcases: Optional[bool] = None + has_traces: Optional[bool] = None + + +class TestsetRevisionData(BaseModel): + testcase_ids: Optional[List[UUID]] = None + testcases: Optional[List[Testcase]] = None + + +class SimpleTestset( + Identifier, + Slug, + Lifecycle, + Header, +): + flags: Optional[TestsetFlags] = None + tags: Optional[Tags] = None # type: ignore + meta: Optional[Meta] = None # type: ignore + + data: Optional[TestsetRevisionData] = None + + +class Testset(Artifact): + flags: Optional[TestsetFlags] = None # type: ignore + + +class TestsetRevision( + Revision, + TestsetIdAlias, + TestsetVariantIdAlias, +): + flags: Optional[TestsetFlags] = None # type: ignore + + data: Optional[TestsetRevisionData] = None # type: ignore + + def model_post_init(self, __context) -> None: + sync_alias("testset_id", "artifact_id", self) + sync_alias("testset_variant_id", "variant_id", self) + + +class SimpleTestsetCreate(Slug, Header): + tags: Optional[Tags] = None # type: ignore + meta: Optional[Meta] = None # type: ignore + data: Optional[TestsetRevisionData] = None + + +class SimpleTestsetEdit( + Identifier, + Header, +): + # flags: Optional[TestsetFlags] = None + tags: Optional[Tags] = None # type: ignore + meta: Optional[Meta] = None # type: ignore + + data: Optional[TestsetRevisionData] = None + + +class TestsetResponse(BaseModel): + count: int = 0 + testset: Optional[Testset] = None + + +class TestsetRevisionResponse(BaseModel): + count: int = 0 + testset_revision: Optional[TestsetRevision] = None + + +class SimpleTestsetResponse(BaseModel): + count: int = 0 + testset: Optional[SimpleTestset] = None + + +# oss.src.core.workflows.dtos +from typing import Optional, Dict, Any +from uuid import UUID, uuid4 +from urllib.parse import urlparse + +from pydantic import ( + BaseModel, + Field, + model_validator, + ValidationError, +) + +from jsonschema import ( + Draft202012Validator, + Draft201909Validator, + Draft7Validator, + Draft4Validator, + Draft6Validator, +) +from jsonschema.exceptions import SchemaError + +# aliases ---------------------------------------------------------------------- + + +class WorkflowIdAlias(AliasConfig): + workflow_id: Optional[UUID] = None + artifact_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="workflow_id", + ) + + +class WorkflowVariantIdAlias(AliasConfig): + workflow_variant_id: Optional[UUID] = None + variant_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="workflow_variant_id", + ) + + +class WorkflowRevisionIdAlias(AliasConfig): + workflow_revision_id: Optional[UUID] = None + revision_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="workflow_revision_id", + ) + + +# globals ---------------------------------------------------------------------- + + +class WorkflowFlags(BaseModel): + is_custom: Optional[bool] = None + is_evaluator: Optional[bool] = None + is_human: Optional[bool] = None + + +# workflows -------------------------------------------------------------------- + + +class Workflow(Artifact): + flags: Optional[WorkflowFlags] = None + + +class WorkflowCreate(ArtifactCreate): + flags: Optional[WorkflowFlags] = None + + +class WorkflowEdit(ArtifactEdit): + flags: Optional[WorkflowFlags] = None + + +# workflow variants ------------------------------------------------------------ + + +class WorkflowVariant( + Variant, + WorkflowIdAlias, +): + flags: Optional[WorkflowFlags] = None + + def model_post_init(self, __context) -> None: + sync_alias("workflow_id", "artifact_id", self) + + +class WorkflowVariantCreate( + VariantCreate, + WorkflowIdAlias, +): + flags: Optional[WorkflowFlags] = None + + def model_post_init(self, __context) -> None: + sync_alias("workflow_id", "artifact_id", self) + + +class WorkflowVariantEdit(VariantEdit): + flags: Optional[WorkflowFlags] = None + + +class WorkflowVariantQuery(VariantQuery): + flags: Optional[WorkflowFlags] = None + + +# workflow revisions ----------------------------------------------------------- + + +class WorkflowServiceVersion(BaseModel): + version: Optional[str] = None + + +class WorkflowServiceInterface(WorkflowServiceVersion): + uri: Optional[str] = None # str (Enum) w/ validation + url: Optional[str] = None # str w/ validation + headers: Optional[Dict[str, Reference | str]] = None # either hardcoded or a secret + handler: Optional[Callable] = None + + schemas: Optional[Dict[str, Schema]] = None # json-schema instead of pydantic + mappings: Optional[Mappings] = None # used in the workflow interface + + +class WorkflowServiceConfiguration(WorkflowServiceInterface): + script: Optional[str] = None # str w/ validation + parameters: Optional[Data] = None # configuration values + + +class WorkflowRevisionData(WorkflowServiceConfiguration): + # LEGACY FIELDS + service: Optional[dict] = None # url, schema, kind, etc + configuration: Optional[dict] = None # parameters, variables, etc + + @model_validator(mode="after") + def validate_all(self) -> "WorkflowRevisionData": + errors = [] + + if self.service and self.service.get("agenta") and self.service.get("format"): + _format = self.service.get("format") # pylint: disable=redefined-builtin + + try: + validator_class = self._get_validator_class_from_schema(_format) # type: ignore + validator_class.check_schema(_format) # type: ignore + except SchemaError as e: + errors.append( + { + "loc": ("format",), + "msg": f"Invalid JSON Schema: {e.message}", + "type": "value_error", + "ctx": {"error": str(e)}, + "input": _format, + } + ) + + if self.service and self.service.get("agenta") and self.service.get("url"): + url = self.service.get("url") + + if not self._is_valid_http_url(url): + errors.append( + { + "loc": ("url",), + "msg": "Invalid HTTP(S) URL", + "type": "value_error.url", + "ctx": {"error": "Invalid URL format"}, + "input": url, + } + ) + + if errors: + raise ValidationError.from_exception_data( + self.__class__.__name__, + errors, + ) + + return self + + @staticmethod + def _get_validator_class_from_schema(schema: dict): + """Detect JSON Schema draft from $schema or fallback to 2020-12.""" + schema_uri = schema.get( + "$schema", "https://json-schema.org/draft/2020-12/schema" + ) + + if "2020-12" in schema_uri: + return Draft202012Validator + elif "2019-09" in schema_uri: + return Draft201909Validator + elif "draft-07" in schema_uri: + return Draft7Validator + elif "draft-06" in schema_uri: + return Draft6Validator + elif "draft-04" in schema_uri: + return Draft4Validator + else: + # fallback default if unknown $schema + return Draft202012Validator + + @staticmethod + def _is_valid_http_url(url: str) -> bool: + parsed = urlparse(url) + return parsed.scheme in ("http", "https") and bool(parsed.netloc) + + +class WorkflowRevision( + Revision, + WorkflowIdAlias, + WorkflowVariantIdAlias, +): + flags: Optional[WorkflowFlags] = None + + data: Optional[WorkflowRevisionData] = None + + def model_post_init(self, __context) -> None: + sync_alias("workflow_id", "artifact_id", self) + sync_alias("workflow_variant_id", "variant_id", self) + + +class WorkflowRevisionCreate( + RevisionCreate, + WorkflowIdAlias, + WorkflowVariantIdAlias, +): + flags: Optional[WorkflowFlags] = None + + def model_post_init(self, __context) -> None: + sync_alias("workflow_id", "artifact_id", self) + sync_alias("workflow_variant_id", "variant_id", self) + + +class WorkflowRevisionEdit(RevisionEdit): + flags: Optional[WorkflowFlags] = None + + +class WorkflowRevisionQuery(RevisionQuery): + flags: Optional[WorkflowFlags] = None + + +class WorkflowRevisionCommit( + RevisionCommit, + WorkflowIdAlias, + WorkflowVariantIdAlias, +): + flags: Optional[WorkflowFlags] = None + + data: Optional[WorkflowRevisionData] = None + + def model_post_init(self, __context) -> None: + sync_alias("workflow_id", "artifact_id", self) + sync_alias("workflow_variant_id", "variant_id", self) + + +class WorkflowRevisionsLog( + RevisionsLog, + WorkflowIdAlias, + WorkflowVariantIdAlias, + WorkflowRevisionIdAlias, +): + def model_post_init(self, __context) -> None: + sync_alias("workflow_id", "artifact_id", self) + sync_alias("workflow_variant_id", "variant_id", self) + sync_alias("workflow_revision_id", "revision_id", self) + + +# forks ------------------------------------------------------------------------ + + +class WorkflowRevisionFork(RevisionFork): + flags: Optional[WorkflowFlags] = None + + data: Optional[WorkflowRevisionData] = None + + +class WorkflowRevisionForkAlias(AliasConfig): + workflow_revision: Optional[WorkflowRevisionFork] = None + + revision: Optional[RevisionFork] = Field( + default=None, + exclude=True, + alias="workflow_revision", + ) + + +class WorkflowVariantFork(VariantFork): + flags: Optional[WorkflowFlags] = None + + +class WorkflowVariantForkAlias(AliasConfig): + workflow_variant: Optional[WorkflowVariantFork] = None + + variant: Optional[VariantFork] = Field( + default=None, + exclude=True, + alias="workflow_variant", + ) + + +class WorkflowFork( + ArtifactFork, + WorkflowIdAlias, + WorkflowVariantIdAlias, + WorkflowVariantForkAlias, + WorkflowRevisionIdAlias, + WorkflowRevisionForkAlias, +): + def model_post_init(self, __context) -> None: + sync_alias("workflow_id", "artifact_id", self) + sync_alias("workflow_variant_id", "variant_id", self) + sync_alias("workflow_variant", "variant", self) + sync_alias("workflow_revision_id", "revision_id", self) + sync_alias("workflow_revision", "revision", self) + + +# workflow services ------------------------------------------------------------ + + +class WorkflowServiceData(BaseModel): + parameters: Optional[Data] = None + inputs: Optional[Data] = None + outputs: Optional[Data | str] = None + # + trace_parameters: Optional[Data] = None + trace_inputs: Optional[Data] = None + trace_outputs: Optional[Data | str] = None + # + trace: Optional[Trace] = None + # LEGACY -- used for workflow execution traces + tree: Optional[Tree] = None + + +class WorkflowServiceRequest(Version, Metadata): + tags: Optional[Tags] = None + meta: Optional[Meta] = None + + data: Optional[WorkflowServiceData] = None + + references: Optional[Dict[str, Reference]] = None + links: Optional[Dict[str, Link]] = None + + credentials: Optional[str] = None # Fix typing + secrets: Optional[Dict[str, Any]] = None # Fix typing + + +class WorkflowServiceResponse(Identifier, Version): + data: Optional[WorkflowServiceData] = None + + links: Optional[Dict[str, Link]] = None + + trace_id: Optional[str] = None + + status: Status = Status() + + def __init__(self, **data): + super().__init__(**data) + + self.id = uuid4() if not self.id else self.id + self.version = "2025.07.14" if not self.version else self.version + + +class SuccessStatus(Status): + code: int = 200 + + +class HandlerNotFoundStatus(Status): + code: int = 501 + type: str = "https://docs.agenta.ai/errors#v1:uri:handler-not-found" + + def __init__(self, uri: Optional[str] = None): + super().__init__() + self.message = f"The handler at '{uri}' is not implemented or not available." + + +class RevisionDataNotFoundStatus(Status): + code: int = 404 + type: str = "https://docs.agenta.ai/errors#v1:uri:revision-data-not-found" + + def __init__(self, uri: Optional[str] = None): + super().__init__() + self.message = f"The revision data at '{uri}' could not be found." + + +class RequestDataNotFoundStatus(Status): + code: int = 404 + type: str = "https://docs.agenta.ai/errors#v1:uri:request-data-not-found" + + def __init__(self, uri: Optional[str] = None): + super().__init__() + self.message = f"The request data at '{uri}' could not be found." + + +ERRORS_BASE_URL = "https://docs.agenta.ai/errors" + + +class ErrorStatus(Exception): + code: int + type: str + message: str + stacktrace: Optional[str] = None + + def __init__( + self, + code: int, + type: str, + message: str, + stacktrace: Optional[str] = None, + ): + super().__init__() + self.code = code + self.type = type + self.message = message + self.stacktrace = stacktrace + + def __str__(self): + return f"[EVAL] {self.code} - {self.message} ({self.type})" + ( + f"\nStacktrace: {self.stacktrace}" if self.stacktrace else "" + ) + + def __repr__(self): + return f"ErrorStatus(code={self.code}, type='{self.type}', message='{self.message}')" + + +# ------------------------------------------------------------------------------ + + +class EvaluatorRevision(BaseModel): + id: Optional[UUID] = None + slug: Optional[str] = None + version: Optional[str] = None + + data: Optional[WorkflowRevisionData] = None + + +class ApplicationServiceRequest(WorkflowServiceRequest): + pass + + +class ApplicationServiceResponse(WorkflowServiceResponse): + pass + + +class EvaluatorServiceRequest(WorkflowServiceRequest): + pass + + +class EvaluatorServiceResponse(WorkflowServiceResponse): + pass + + +# oss.src.core.evaluators.dtos + + +class EvaluatorIdAlias(AliasConfig): + evaluator_id: Optional[UUID] = None + workflow_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="evaluator_id", + ) + + +class EvaluatorVariantIdAlias(AliasConfig): + evaluator_variant_id: Optional[UUID] = None + workflow_variant_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="evaluator_variant_id", + ) + + +class EvaluatorRevisionData(WorkflowRevisionData): + pass + + +class EvaluatorFlags(WorkflowFlags): + def __init__(self, **data): + data["is_evaluator"] = True + + super().__init__(**data) + + +class SimpleEvaluatorFlags(EvaluatorFlags): + pass + + +class SimpleEvaluatorData(EvaluatorRevisionData): + pass + + +class Evaluator(Workflow): + flags: Optional[EvaluatorFlags] = None + + +class SimpleEvaluatorRevision( + WorkflowRevision, + EvaluatorIdAlias, + EvaluatorVariantIdAlias, +): + flags: Optional[EvaluatorFlags] = None + + data: Optional[EvaluatorRevisionData] = None + + +class SimpleEvaluator(Identifier, Slug, Lifecycle, Header, Metadata): + flags: Optional[SimpleEvaluatorFlags] = None + + data: Optional[SimpleEvaluatorData] = None + + +class SimpleEvaluatorCreate(Slug, Header, Metadata): + flags: Optional[SimpleEvaluatorFlags] = None + + data: Optional[SimpleEvaluatorData] = None + + +class SimpleEvaluatorEdit(Identifier, Header, Metadata): + flags: Optional[SimpleEvaluatorFlags] = None + + data: Optional[SimpleEvaluatorData] = None + + +class SimpleEvaluatorResponse(BaseModel): + count: int = 0 + evaluator: Optional[SimpleEvaluator] = None + + +class EvaluatorRevisionResponse(BaseModel): + count: int = 0 + evaluator_revision: Optional[EvaluatorRevision] = None + + +# oss.src.core.applications.dtos + +# aliases ---------------------------------------------------------------------- + + +class ApplicationIdAlias(AliasConfig): + application_id: Optional[UUID] = None + workflow_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="application_id", + ) + + +class ApplicationVariantIdAlias(AliasConfig): + application_variant_id: Optional[UUID] = None + workflow_variant_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="application_variant_id", + ) + + +class ApplicationRevisionIdAlias(AliasConfig): + application_revision_id: Optional[UUID] = None + workflow_revision_id: Optional[UUID] = Field( + default=None, + exclude=True, + alias="application_revision_id", + ) + + +# globals ---------------------------------------------------------------------- + + +class ApplicationFlags(WorkflowFlags): + def __init__(self, **data): + data["is_evaluator"] = True + + super().__init__(**data) + + +# applications ------------------------------------------------------------------- + + +class Application(Workflow): + flags: Optional[ApplicationFlags] = None + + +class ApplicationCreate(WorkflowCreate): + flags: Optional[ApplicationFlags] = None + + +class ApplicationEdit(WorkflowEdit): + flags: Optional[ApplicationFlags] = None + + +# application variants ----------------------------------------------------------- + + +class ApplicationVariant( + WorkflowVariant, + ApplicationIdAlias, +): + flags: Optional[ApplicationFlags] = None + + def model_post_init(self, __context) -> None: + sync_alias("application_id", "workflow_id", self) + + +class ApplicationVariantCreate( + WorkflowVariantCreate, + ApplicationIdAlias, +): + flags: Optional[ApplicationFlags] = None + + def model_post_init(self, __context) -> None: + sync_alias("application_id", "workflow_id", self) + + +class ApplicationVariantEdit(WorkflowVariantEdit): + flags: Optional[ApplicationFlags] = None + + +# application revisions ----------------------------------------------------- + + +class ApplicationRevisionData(WorkflowRevisionData): + pass + + +class ApplicationRevision( + WorkflowRevision, + ApplicationIdAlias, + ApplicationVariantIdAlias, +): + flags: Optional[ApplicationFlags] = None + + data: Optional[ApplicationRevisionData] = None + + def model_post_init(self, __context) -> None: + sync_alias("application_id", "workflow_id", self) + sync_alias("application_variant_id", "workflow_variant_id", self) + + +class ApplicationRevisionCreate( + WorkflowRevisionCreate, + ApplicationIdAlias, + ApplicationVariantIdAlias, +): + flags: Optional[ApplicationFlags] = None + + def model_post_init(self, __context) -> None: + sync_alias("application_id", "workflow_id", self) + sync_alias("application_variant_id", "workflow_variant_id", self) + + +class ApplicationRevisionEdit(WorkflowRevisionEdit): + flags: Optional[ApplicationFlags] = None + + +class ApplicationRevisionCommit( + WorkflowRevisionCommit, + ApplicationIdAlias, + ApplicationVariantIdAlias, +): + flags: Optional[ApplicationFlags] = None + + data: Optional[ApplicationRevisionData] = None + + def model_post_init(self, __context) -> None: + sync_alias("application_id", "workflow_id", self) + sync_alias("application_variant_id", "workflow_variant_id", self) + + +class ApplicationRevisionResponse(BaseModel): + count: int = 0 + application_revision: Optional[ApplicationRevision] = None + + +class ApplicationRevisionsResponse(BaseModel): + count: int = 0 + application_revisions: List[ApplicationRevision] = [] + + +# simple applications ------------------------------------------------------------ + + +class LegacyApplicationFlags(WorkflowFlags): + pass + + +class LegacyApplicationData(WorkflowRevisionData): + pass + + +class LegacyApplication(Identifier, Slug, Lifecycle, Header, Metadata): + flags: Optional[LegacyApplicationFlags] = None + + data: Optional[LegacyApplicationData] = None + + +class LegacyApplicationCreate(Slug, Header, Metadata): + flags: Optional[LegacyApplicationFlags] = None + + data: Optional[LegacyApplicationData] = None + + +class LegacyApplicationEdit(Identifier, Header, Metadata): + flags: Optional[LegacyApplicationFlags] = None + + data: Optional[LegacyApplicationData] = None + + +class LegacyApplicationResponse(BaseModel): + count: int = 0 + application: Optional[LegacyApplication] = None + + +# end of oss.src.core.applications.dtos diff --git a/api/ee/tests/manual/evaluations/sdk/entities.py b/api/ee/tests/manual/evaluations/sdk/entities.py new file mode 100644 index 0000000000..12c714db95 --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/entities.py @@ -0,0 +1,447 @@ +import asyncio +from typing import List, Dict, Any, Callable, Optional +from uuid import uuid4, UUID + +from definitions import ( + Testcase, + TestsetRevisionData, + TestsetRevision, + ApplicationRevision, + EvaluatorRevision, + # + SimpleTestsetCreate, + SimpleTestsetEdit, + # + SimpleTestsetResponse, + TestsetRevisionResponse, + # + Evaluator, + # + SimpleEvaluatorData, + SimpleEvaluatorCreate, + SimpleEvaluatorEdit, + # + EvaluatorRevisionData, + SimpleEvaluatorResponse, + EvaluatorRevisionResponse, + # + ApplicationRevisionResponse, + # + LegacyApplicationData, + LegacyApplicationCreate, + LegacyApplicationEdit, + # + LegacyApplicationResponse, +) +from services import ( + REGISTRY, + register_handler, + retrieve_handler, +) + +from client import authed_api + + +client = authed_api() + +APPLICATION_REVISION_ID = uuid4() +APPLICATION_REVISION = ApplicationRevision( + id=APPLICATION_REVISION_ID, + slug=str(APPLICATION_REVISION_ID)[-12:], + version="0", +) + +EVALUATOR_REVISION_ID = uuid4() +EVALUATOR_REVISION = EvaluatorRevision( + id=EVALUATOR_REVISION_ID, + slug=str(EVALUATOR_REVISION_ID)[-12:], + version="0", +) + + +async def _retrieve_testset( + testset_id: Optional[UUID] = None, + testset_revision_id: Optional[UUID] = None, +) -> Optional[TestsetRevision]: + response = client( + method="POST", + endpoint="/preview/testsets/revisions/retrieve", + params={ + "testset_id": testset_id, + "testset_revision_id": testset_revision_id, + }, + ) + + response.raise_for_status() + + testset_revision_response = TestsetRevisionResponse(**response.json()) + + testset_revision = testset_revision_response.testset_revision + + return testset_revision + + +async def retrieve_testset( + testset_revision_id: Optional[UUID] = None, +) -> Optional[TestsetRevision]: + response = await _retrieve_testset( + testset_revision_id=testset_revision_id, + ) + + return response + + +async def upsert_testset( + testcases_data: List[Dict[str, Any]], + # + testset_revision_id: Optional[UUID] = None, + # + testset_id: Optional[UUID] = None, + testset_name: Optional[str] = None, + testset_description: Optional[str] = None, +) -> Optional[UUID]: + testset_revision_data = TestsetRevisionData( + testcases=[ + Testcase( + data=testcase_data, + ) + for testcase_data in testcases_data + ] + ) + + retrieve_response = None + + if testset_revision_id: + retrieve_response = await _retrieve_testset( + testset_revision_id=testset_revision_id, + ) + elif testset_id: + retrieve_response = await _retrieve_testset( + testset_id=testset_id, + ) + + if retrieve_response and retrieve_response.id: + testset_edit_request = SimpleTestsetEdit( + id=testset_id, + name=testset_name, + description=testset_description, + data=testset_revision_data, + ) + + response = client( + method="PUT", + endpoint=f"/preview/simple/testsets/{testset_id}", + json={ + "testset": testset_edit_request.model_dump( + mode="json", + exclude_none=True, + ) + }, + ) + + try: + response.raise_for_status() + except Exception as e: + print(f"[ERROR]: Failed to update testset: {e}") + return None + + else: + testset_create_request = SimpleTestsetCreate( + name=testset_name, + description=testset_description, + slug=uuid4().hex, + data=testset_revision_data, + ) + + response = client( + method="POST", + endpoint="/preview/simple/testsets/", + json={ + "testset": testset_create_request.model_dump( + mode="json", + exclude_none=True, + ) + }, + ) + + try: + response.raise_for_status() + except Exception as e: + print(f"[ERROR]: Failed to create testset: {e}") + return None + + testset_response = SimpleTestsetResponse(**response.json()) + + testset = testset_response.testset + + if not testset or not testset.id: + return None + + testset_revision = await _retrieve_testset( + testset_id=testset.id, + ) + + if not testset_revision or not testset_revision.id: + return None + + return testset_revision.id + + +async def _retrieve_application( + application_id: Optional[UUID] = None, + application_revision_id: Optional[UUID] = None, +) -> Optional[ApplicationRevision]: + response = client( + method="POST", + endpoint=f"/preview/legacy/applications/revisions/retrieve", + params={ + "application_id": application_id, + "application_revision_id": application_revision_id, + }, + ) + response.raise_for_status() + + application_revision_response = ApplicationRevisionResponse(**response.json()) + + application_revision = application_revision_response.application_revision + + if not application_revision or not application_revision.id: + return None + + if not application_revision.data or not application_revision.data.uri: + return None + + application_revision.data.handler = retrieve_handler(application_revision.data.uri) + + return application_revision + + +async def retrieve_application( + application_revision_id: Optional[UUID] = None, +) -> Optional[ApplicationRevision]: + response = await _retrieve_application( + application_revision_id=application_revision_id, + ) + + return response + + +async def upsert_application( + application_handler: Callable, + application_script: Optional[str] = None, + application_parameters: Optional[Dict[str, Any]] = None, + # + application_revision_id: Optional[UUID] = None, + # + application_id: Optional[UUID] = None, + application_name: Optional[str] = None, + application_description: Optional[str] = None, +) -> Optional[UUID]: + legacy_application_data = LegacyApplicationData( + uri=register_handler(application_handler), + script=application_script, + parameters=application_parameters, + ) + + retrieve_response = None + + if application_revision_id: + retrieve_response = await _retrieve_application( + application_revision_id=application_revision_id, + ) + elif application_id: + retrieve_response = await _retrieve_application( + application_id=application_id, + ) + + if retrieve_response and retrieve_response.id: + application_edit_request = LegacyApplicationEdit( + id=application_id, + name=application_name, + description=application_description, + data=legacy_application_data, + ) + + response = client( + method="PUT", + endpoint=f"/preview/legacy/applications/{application_id}", + json={ + "application": application_edit_request.model_dump( + mode="json", + exclude_none=True, + ) + }, + ) + + try: + response.raise_for_status() + except Exception as e: + print("[ERROR]: Failed to update application:", e) + return None + + else: + application_create_request = LegacyApplicationCreate( + name=application_name, + description=application_description, + slug=uuid4().hex, + data=legacy_application_data, + ) + + response = client( + method="POST", + endpoint="/preview/legacy/applications/", + json={ + "application": application_create_request.model_dump( + mode="json", + exclude_none=True, + ) + }, + ) + + try: + response.raise_for_status() + except Exception as e: + print("[ERROR]: Failed to create application:", e) + return None + + application_response = LegacyApplicationResponse(**response.json()) + + application = application_response.application + + if not application or not application.id: + return None + + application_revision = await _retrieve_application( + application_id=application.id, + ) + + if not application_revision or not application_revision.id: + return None + + return application_revision.id + + +async def _retrieve_evaluator( + evaluator_id: Optional[UUID] = None, + evaluator_revision_id: Optional[UUID] = None, +) -> Optional[EvaluatorRevision]: + response = client( + method="POST", + endpoint=f"/preview/evaluators/revisions/retrieve", + params={ + "evaluator_id": evaluator_id, + "evaluator_revision_id": evaluator_revision_id, + }, + ) + response.raise_for_status() + + evaluator_revision_response = EvaluatorRevisionResponse(**response.json()) + + evaluator_revision = evaluator_revision_response.evaluator_revision + + return evaluator_revision + + +async def retrieve_evaluator( + evaluator_revision_id: Optional[UUID] = None, +) -> Optional[EvaluatorRevision]: + response = await _retrieve_evaluator( + evaluator_revision_id=evaluator_revision_id, + ) + + return response + + +async def upsert_evaluator( + evaluator_handler: Callable, + evaluator_script: Optional[str] = None, + evaluator_parameters: Optional[Dict[str, Any]] = None, + # + evaluator_revision_id: Optional[UUID] = None, + # + evaluator_id: Optional[UUID] = None, + evaluator_name: Optional[str] = None, + evaluator_description: Optional[str] = None, +) -> Optional[UUID]: + simple_evaluator_data = SimpleEvaluatorData( + uri=register_handler(evaluator_handler), + script=evaluator_script, + parameters=evaluator_parameters, + ) + + retrieve_response = None + + if evaluator_revision_id: + retrieve_response = await _retrieve_evaluator( + evaluator_revision_id=evaluator_revision_id, + ) + elif evaluator_id: + retrieve_response = await _retrieve_evaluator( + evaluator_id=evaluator_id, + ) + + if retrieve_response and retrieve_response.id: + evaluator_edit_request = SimpleEvaluatorEdit( + id=evaluator_id, + name=evaluator_name, + description=evaluator_description, + data=simple_evaluator_data, + ) + + response = client( + method="PUT", + endpoint=f"/preview/simple/evaluators/{evaluator_id}", + json={ + "evaluator": evaluator_edit_request.model_dump( + mode="json", + exclude_none=True, + ) + }, + ) + + try: + response.raise_for_status() + except Exception as e: + print("[ERROR]: Failed to update evaluator:", e) + return None + + else: + evaluator_create_request = SimpleEvaluatorCreate( + name=evaluator_name, + description=evaluator_description, + slug=uuid4().hex, + data=simple_evaluator_data, + ) + + response = client( + method="POST", + endpoint="/preview/simple/evaluators/", + json={ + "evaluator": evaluator_create_request.model_dump( + mode="json", + exclude_none=True, + ) + }, + ) + + try: + response.raise_for_status() + except Exception as e: + print("[ERROR]: Failed to create evaluator:", e) + return None + + evaluator_response = SimpleEvaluatorResponse(**response.json()) + + evaluator = evaluator_response.evaluator + + if not evaluator or not evaluator.id: + return None + + evaluator_revision = await _retrieve_evaluator( + evaluator_id=evaluator.id, + ) + + if not evaluator_revision or not evaluator_revision.id: + return None + + return evaluator_revision.id diff --git a/api/ee/tests/manual/evaluations/sdk/evaluate.py b/api/ee/tests/manual/evaluations/sdk/evaluate.py new file mode 100644 index 0000000000..e312474144 --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/evaluate.py @@ -0,0 +1,340 @@ +from typing import Dict, List +from uuid import UUID +from copy import deepcopy + +from definitions import ( + Origin, + Link, + Reference, + SimpleEvaluationFlags, + SimpleEvaluationStatus, + SimpleEvaluationData, + TestsetRevision, + ApplicationRevision, + EvaluatorRevision, + WorkflowServiceData, + ApplicationServiceRequest, + ApplicationServiceResponse, + EvaluatorServiceRequest, + EvaluatorServiceResponse, +) +from evaluations import ( + create_run, + add_scenario, + log_result, + compute_metrics, + get_slug_from_name_and_id, +) + +# from mock_entities import ( +# upsert_testset, +# retrieve_testset, +# upsert_application, +# retrieve_application, +# upsert_evaluator, +# retrieve_evaluator, +# ) + +from entities import ( + upsert_testset, + retrieve_testset, + upsert_application, + retrieve_application, + upsert_evaluator, + retrieve_evaluator, +) + +from services import ( + invoke_application, + invoke_evaluator, +) + +EvaluateSpecs = SimpleEvaluationData + + +# @debug +async def evaluate( + data: SimpleEvaluationData, +): + data = deepcopy(data) + + if data.testset_steps: + if isinstance(data.testset_steps, list): + testset_steps: Dict[str, Origin] = {} + + if all( + isinstance(testset_revision_id, UUID) + for testset_revision_id in data.testset_steps + ): + for testset_revision_id in data.testset_steps: + if isinstance(testset_revision_id, UUID): + testset_steps[str(testset_revision_id)] = "custom" + + elif all( + isinstance(testcases_data, List) + for testcases_data in data.testset_steps + ): + for testcases_data in data.testset_steps: + if isinstance(testcases_data, List): + if all(isinstance(step, Dict) for step in testcases_data): + testset_revision_id = await upsert_testset( + testcases_data=testcases_data, + ) + testset_steps[str(testset_revision_id)] = "custom" + + data.testset_steps = testset_steps + + if not data.testset_steps or not isinstance(data.testset_steps, dict): + print("[failure] missing or invalid testset steps") + return None + + if data.application_steps: + if isinstance(data.application_steps, list): + application_steps: Dict[str, Origin] = {} + + if all( + isinstance(application_revision_id, UUID) + for application_revision_id in data.application_steps + ): + for application_revision_id in data.application_steps: + if isinstance(application_revision_id, UUID): + application_steps[str(application_revision_id)] = "custom" + + elif all( + callable(application_handler) + for application_handler in data.application_steps + ): + for application_handler in data.application_steps: + if callable(application_handler): + application_revision_id = await upsert_application( + application_handler=application_handler, + ) + application_steps[str(application_revision_id)] = "custom" + + data.application_steps = application_steps + + if not data.application_steps or not isinstance(data.application_steps, dict): + print("[failure] missing or invalid application steps") + return None + + if data.evaluator_steps: + if isinstance(data.evaluator_steps, list): + evaluator_steps: Dict[str, Origin] = {} + + if all( + isinstance(evaluator_revision_id, UUID) + for evaluator_revision_id in data.evaluator_steps + ): + for evaluator_revision_id in data.evaluator_steps: + if isinstance(evaluator_revision_id, UUID): + evaluator_steps[str(evaluator_revision_id)] = "custom" + + elif all( + callable(evaluator_handler) + for evaluator_handler in data.evaluator_steps + ): + for evaluator_handler in data.evaluator_steps: + if callable(evaluator_handler): + evaluator_revision_id = await upsert_evaluator( + evaluator_handler=evaluator_handler, + ) + evaluator_steps[str(evaluator_revision_id)] = "custom" + + data.evaluator_steps = evaluator_steps + + if not data.evaluator_steps or not isinstance(data.evaluator_steps, dict): + print("[failure] missing or invalid evaluator steps") + return None + + testsets: Dict[UUID, TestsetRevision] = {} + for testset_revision_id, origin in data.testset_steps.items(): + testset_revision = await retrieve_testset( + testset_revision_id=testset_revision_id, + ) + + if not testset_revision: + continue + + testsets[testset_revision_id] = testset_revision + + applications: Dict[UUID, ApplicationRevision] = {} + for application_revision_id, origin in data.application_steps.items(): + application_revision = await retrieve_application( + application_revision_id=application_revision_id, + ) + + if not application_revision: + continue + + applications[application_revision_id] = application_revision + + evaluators: Dict[UUID, EvaluatorRevision] = {} + for evaluator_revision_id, origin in data.evaluator_steps.items(): + evaluator_revision = await retrieve_evaluator( + evaluator_revision_id=evaluator_revision_id, + ) + + if not evaluator_revision: + continue + + evaluators[evaluator_revision_id] = evaluator_revision + + run = await create_run( + testset_steps=data.testset_steps, + application_steps=data.application_steps, + evaluator_steps=data.evaluator_steps, + ) + + if not run.id: + print("[failure] could not create evaluation") + return None + + scenarios = list() + + for testset_revision_id, testset_revision in testsets.items(): + if not testset_revision.data or not testset_revision.data.testcases: + continue + + testcases = testset_revision.data.testcases + + print() + print(f"From testset_id={str(testset_revision.testset_id)}") + + for testcase in testcases: + print(f"Evaluating testcase_id={str(testcase.id)}") + scenario = await add_scenario( + run_id=run.id, + ) + + results = dict() + + result = await log_result( + run_id=run.id, + scenario_id=scenario.id, + step_key="testset-" + testset_revision.slug, # type: ignore + testcase_id=testcase.id, + ) + + results[testset_revision.slug] = result + + for application_revision_id, application_revision in applications.items(): + if not application_revision or not application_revision.data: + print("Missing or invalid application revision") + continue + + application_request = ApplicationServiceRequest( + data=WorkflowServiceData( + parameters=application_revision.data.parameters, + inputs=testcase.data, + ), + references=dict( + testset_revision=Reference( + id=testset_revision.id, + slug=testset_revision.slug, + version=testset_revision.version, + ), + application_revision=Reference( + id=application_revision.id, + slug=application_revision.slug, + version=application_revision.version, + ), + ), + ) + + application_response = await invoke_application( + request=application_request, + revision=application_revision, + ) + + if ( + not application_response + or not application_response.data + or not application_response.trace_id + ): + print("Missing or invalid application response") + continue + + trace_id = application_response.trace_id + + if not application_revision.id or not application_revision.name: + print("Missing application revision ID or name") + continue + + application_slug = get_slug_from_name_and_id( + name=application_revision.name, + id=application_revision.id, + ) + + result = await log_result( + run_id=run.id, + scenario_id=scenario.id, + step_key="application-" + application_slug, # type: ignore + trace_id=trace_id, + ) + + results[application_slug] = result + + for evaluator_revision_id, evaluator_revision in evaluators.items(): + if not evaluator_revision or not evaluator_revision.data: + print("Missing or invalid evaluator revision") + continue + + evaluator_request = EvaluatorServiceRequest( + data=WorkflowServiceData( + parameters=evaluator_revision.data.parameters, + inputs=testcase.data, + # + trace_outputs=application_response.data.outputs, + trace=application_response.data.trace, + ), + references=dict( + testset_revision=Reference( + id=testset_revision.id, + slug=testset_revision.slug, + version=testset_revision.version, + ), + evaluator_revision=Reference( + id=evaluator_revision.id, + slug=evaluator_revision.slug, + version=evaluator_revision.version, + ), + ), + links=application_response.links, + ) + + evaluator_response = await invoke_evaluator( + request=evaluator_request, + revision=evaluator_revision, + ) + + if not evaluator_response or not evaluator_response.data: + print("Missing or invalid evaluator response") + continue + + trace_id = evaluator_response.trace_id + + result = await log_result( + run_id=run.id, + scenario_id=scenario.id, + step_key="evaluator-" + evaluator_revision.slug, # type: ignore + trace_id=trace_id, + ) + + results[evaluator_revision.slug] = result + + scenarios.append( + { + "scenario": scenario, + "results": results, + }, + ) + + metrics = await compute_metrics( + run_id=run.id, + ) + + return dict( + run=run, + scenarios=scenarios, + metrics=metrics, + ) diff --git a/api/ee/tests/manual/evaluations/sdk/evaluations.py b/api/ee/tests/manual/evaluations/sdk/evaluations.py new file mode 100644 index 0000000000..70720dc583 --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/evaluations.py @@ -0,0 +1,208 @@ +from typing import Optional, Dict, Any +from uuid import uuid4, UUID + +import unicodedata +import re + +from definitions import ( + EvaluationRun, + EvaluationScenario, + EvaluationResult, + EvaluationMetrics, + Origin, + Target, +) + +from client import authed_api + + +client = authed_api() + + +async def create_run( + *, + flags: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + meta: Optional[Dict[str, Any]] = None, + # + query_steps: Optional[Target] = None, + testset_steps: Optional[Target] = None, + application_steps: Optional[Target] = None, + evaluator_steps: Optional[Target] = None, + repeats: Optional[int] = None, +) -> EvaluationRun: + payload = dict( + evaluation=dict( + flags=flags, + tags=tags, + meta=meta, + # + data=dict( + status="running", + query_steps=query_steps, + testset_steps=testset_steps, + application_steps=application_steps, + evaluator_steps=evaluator_steps, + repeats=repeats, + ), + ) + ) + + response = client( + method="POST", + endpoint=f"/preview/simple/evaluations/", + json=payload, + ) + + try: + response.raise_for_status() + except: + print(response.text) + raise + + response = response.json() + + run = EvaluationRun(id=UUID(response["evaluation"]["id"])) + + return run + + +async def add_scenario( + *, + flags: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + meta: Optional[Dict[str, Any]] = None, + # + run_id: UUID, +) -> EvaluationScenario: + payload = dict( + scenarios=[ + dict( + flags=flags, + tags=tags, + meta=meta, + # + run_id=str(run_id), + ) + ] + ) + + response = client( + method="POST", + endpoint=f"/preview/evaluations/scenarios/", + json=payload, + ) + + try: + response.raise_for_status() + except: + print(response.text) + raise + + response = response.json() + + scenario = EvaluationScenario(**response["scenarios"][0]) + + return scenario + + +async def log_result( + *, + flags: Optional[Dict[str, Any]] = None, + tags: Optional[Dict[str, Any]] = None, + meta: Optional[Dict[str, Any]] = None, + # + testcase_id: Optional[UUID] = None, + trace_id: Optional[str] = None, + error: Optional[dict] = None, + # + # timestamp: datetime, + # repeat_idx: str, + step_key: str, + run_id: UUID, + scenario_id: UUID, +) -> EvaluationResult: + payload = dict( + results=[ + dict( + flags=flags, + tags=tags, + meta=meta, + # + testcase_id=str(testcase_id) if testcase_id else None, + trace_id=trace_id, + error=error, + # + # interval=interval, + # timestamp=timestamp, + # repeat_idx=repeat_idx, + step_key=step_key, + run_id=str(run_id), + scenario_id=str(scenario_id), + ) + ] + ) + + response = client( + method="POST", + endpoint=f"/preview/evaluations/results/", + json=payload, + ) + + try: + response.raise_for_status() + except: + print(response.text) + raise + + response = response.json() + + result = EvaluationResult(**response["results"][0]) + + return result + + +async def compute_metrics( + run_id: UUID, +) -> EvaluationMetrics: + payload = dict( + run_id=str(run_id), + ) + + response = client( + method="POST", + endpoint=f"/preview/evaluations/metrics/refresh", + params=payload, + ) + + try: + response.raise_for_status() + except: + print(response.text) + raise + + response = response.json() + + metrics = EvaluationMetrics(**response["metrics"][0]) + + return metrics + + +def get_slug_from_name_and_id( + name: str, + id: UUID, # pylint: disable=redefined-builtin +) -> str: + # Normalize Unicode (e.g., é → e) + name = unicodedata.normalize("NFKD", name) + # Remove non-ASCII characters + name = name.encode("ascii", "ignore").decode("ascii") + # Lowercase and remove non-word characters except hyphens and spaces + name = re.sub(r"[^\w\s-]", "", name.lower()) + # Replace any sequence of hyphens or whitespace with a single hyphen + name = re.sub(r"[-\s]+", "-", name) + # Trim leading/trailing hyphens + name = name.strip("-") + # Last 12 characters of the ID + slug = f"{name}-{id.hex[-12:]}" + + return slug.lower() diff --git a/api/ee/tests/manual/evaluations/sdk/loop.py b/api/ee/tests/manual/evaluations/sdk/loop.py new file mode 100644 index 0000000000..9e166b5fde --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/loop.py @@ -0,0 +1,97 @@ +import asyncio +import random +import json + +from evaluate import ( + evaluate, + EvaluateSpecs, +) +from definitions import ( + ApplicationRevision, + ApplicationServiceRequest, + EvaluatorRevision, + EvaluatorServiceRequest, +) + + +dataset = [ + {"country": "Germany", "capital": "Berlin"}, + {"country": "France", "capital": "Paris"}, + {"country": "Spain", "capital": "Madrid"}, + {"country": "Italy", "capital": "Rome"}, +] + + +async def my_application( + revision: ApplicationRevision, + request: ApplicationServiceRequest, + **kwargs, +): + inputs: dict = request.data.inputs # type:ignore + chance = random.choice([True, False]) + outputs = { + "capital": (inputs.get("capital") if chance else "Aloha"), + } + + return outputs + + +async def my_evaluator( + revision: EvaluatorRevision, + request: EvaluatorServiceRequest, + **kwargs, +): + inputs: dict = request.data.inputs # type:ignore + trace_outputs: dict = request.data.trace_outputs # type:ignore + outputs = { + "success": trace_outputs.get("capital") == inputs.get("capital"), + } + + return outputs + + +async def run_evaluation(): + specs = EvaluateSpecs( + testset_steps=[dataset], + application_steps=[my_application], + evaluator_steps=[my_evaluator], + ) + + eval = await evaluate(specs) + + return eval + + +# export AGENTA_API_URL=http://localhost/api +# export AGENTA_API_KEY=xxxxxxxx + +if __name__ == "__main__": + eval = asyncio.run(run_evaluation()) + + if not eval: + exit(1) + + print() + print("Displaying evaluation") + print(f"run_id={eval['run'].id}") # type:ignore + + for scenario in eval["scenarios"]: + print(" " f"scenario_id={scenario['scenario'].id}") # type:ignore + for step_key, result in scenario["results"].items(): # type:ignore + if result.testcase_id: + print( + " " + f"step_key={str(step_key).ljust(32)}, testcase_id={result.testcase_id}", + ) + elif result.trace_id: + print( + " " + f"step_key={str(step_key).ljust(32)}, trace_id={result.trace_id}", + ) + else: + print( + " " + f"step_key={str(step_key).ljust(32)}, error={result.error}", + ) + + print(f"metrics={json.dumps(eval['metrics'].data, indent=4)}") # type:ignore diff --git a/api/ee/tests/manual/evaluations/sdk/mock_entities.py b/api/ee/tests/manual/evaluations/sdk/mock_entities.py new file mode 100644 index 0000000000..8d1d9e5ab4 --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/mock_entities.py @@ -0,0 +1,90 @@ +from typing import List, Dict, Any, Callable +from uuid import uuid4, UUID + +from definitions import ( + Testcase, + TestsetRevisionData, + TestsetRevision, + ApplicationRevision, + ApplicationRevisionData, + EvaluatorRevision, + WorkflowRevisionData, +) + +from services import register_handler + +TESTSET_REVISION_ID = uuid4() +TESTSET_REVISION = TestsetRevision( + id=TESTSET_REVISION_ID, + slug=str(TESTSET_REVISION_ID)[-12:], + data=TestsetRevisionData( + testcases=[ + Testcase( + id=uuid4(), + data={"country": "Germany", "capital": "Berlin"}, + ), + Testcase( + id=uuid4(), + data={"country": "France", "capital": "Paris"}, + ), + ] + ), +) + +APPLICATION_REVISION_ID = uuid4() +APPLICATION_REVISION = ApplicationRevision( + id=APPLICATION_REVISION_ID, + slug=str(APPLICATION_REVISION_ID)[-12:], + version="0", + data=ApplicationRevisionData(), +) + +EVALUATOR_REVISION_ID = uuid4() +EVALUATOR_REVISION = EvaluatorRevision( + id=EVALUATOR_REVISION_ID, + slug=str(EVALUATOR_REVISION_ID)[-12:], + version="0", + data=WorkflowRevisionData(), +) + +MOCK_URI = None + + +async def upsert_testset( + testcases_data: List[Dict[str, Any]], +) -> UUID: + return TESTSET_REVISION_ID + + +async def retrieve_testset( + testset_revision_id: UUID, +) -> TestsetRevision: + return TESTSET_REVISION + + +async def upsert_application( + application_handler: Callable, +) -> UUID: + global MOCK_URI + MOCK_URI = register_handler(application_handler) + return APPLICATION_REVISION_ID + + +async def retrieve_application( + application_revision_id: UUID, +) -> ApplicationRevision: + application_revision = APPLICATION_REVISION + application_revision.data.uri = MOCK_URI + return application_revision + + +async def upsert_evaluator( + evaluator_handler: Callable, +) -> UUID: + return EVALUATOR_REVISION_ID + + +async def retrieve_evaluator( + evaluator_revision_id: UUID, +) -> EvaluatorRevision: + return EVALUATOR_REVISION diff --git a/api/ee/tests/manual/evaluations/sdk/services.py b/api/ee/tests/manual/evaluations/sdk/services.py new file mode 100644 index 0000000000..fee8836401 --- /dev/null +++ b/api/ee/tests/manual/evaluations/sdk/services.py @@ -0,0 +1,375 @@ +from typing import Callable, Dict, Optional, Any +from uuid import uuid4, UUID + +from definitions import ( + Status, + WorkflowServiceData, + ApplicationRevision, + ApplicationServiceRequest, + ApplicationServiceResponse, + EvaluatorRevision, + EvaluatorServiceRequest, + EvaluatorServiceResponse, + SuccessStatus, + HandlerNotFoundStatus, + ErrorStatus, + RevisionDataNotFoundStatus, + RequestDataNotFoundStatus, + Link, +) + +from client import authed_api + + +client = authed_api() + + +REGISTRY: Dict[str, Dict[str, Dict[str, Dict[str, Callable]]]] = dict( + user=dict( + custom=dict(), + ), +) + + +def register_handler(fn: Callable) -> str: + global REGISTRY + + key = f"{fn.__module__}.{fn.__name__}" + + if not REGISTRY["user"]["custom"].get(key): + REGISTRY["user"]["custom"][key] = dict() + + REGISTRY["user"]["custom"][key]["latest"] = fn + + uri = f"user:custom:{key}:latest" + + return uri + + +def retrieve_handler(uri: Optional[str] = None) -> Optional[Callable]: + if not uri: + return None + + parts = uri.split(":") + + return REGISTRY[parts[0]][parts[1]].get(parts[2], {}).get(parts[3], None) + + +async def invoke_application( + *, + request: ApplicationServiceRequest, + revision: ApplicationRevision, +) -> ApplicationServiceResponse: + try: + if not revision.data: + return ApplicationServiceResponse( + status=RevisionDataNotFoundStatus(), + ) + + if not request.data: + return ApplicationServiceResponse( + status=RequestDataNotFoundStatus(), + ) + + handler = retrieve_handler(revision.data.uri) + + if not handler: + return ApplicationServiceResponse( + status=HandlerNotFoundStatus( + uri=revision.data.uri, + ), + ) + + outputs = await handler( + revision=revision, + request=request, + # + parameters=revision.data.parameters, + inputs=request.data.inputs, + # + trace_parameters=request.data.trace_parameters, + trace_inputs=request.data.trace_inputs, + trace_outputs=request.data.trace_outputs, + # + trace=request.data.trace, + tree=request.data.tree, + ) + + data = dict( + parameters=revision.data.parameters, + inputs=request.data.inputs, + outputs=outputs, + ) + + references = ( + { + k: ref.model_dump( + mode="json", + ) + for k, ref in request.references.items() + } + if request.references + else None + ) + + links = ( + { + k: ref.model_dump( + mode="json", + ) + for k, ref in request.links.items() + } + if request.links + else None + ) + + link = None + + try: + link = await _invocations_create( + tags=request.tags, + meta=request.meta, + # + data=data, + # + references=references, + links=links, + ) + except Exception as ex: + print(ex) + + response = ApplicationServiceResponse( + status=SuccessStatus(message=""), + data=WorkflowServiceData( + outputs=outputs, + ), + trace_id=link.trace_id if link else None, + links=({revision.slug or uuid4().hex: link} if link else {}), + ) + + return response + + except ErrorStatus as error: + return ApplicationServiceResponse( + status=Status( + code=error.code, + type=error.type, + message=error.message, + stacktrace=error.stacktrace, + ), + ) + + except Exception as ex: + return ApplicationServiceResponse( + status=Status( + code=500, + message=str(ex), + ), + ) + + +async def invoke_evaluator( + *, + request: EvaluatorServiceRequest, + revision: EvaluatorRevision, +) -> EvaluatorServiceResponse: + try: + if not revision.data: + return EvaluatorServiceResponse( + status=RevisionDataNotFoundStatus(), + ) + + if not request.data: + return EvaluatorServiceResponse( + status=RequestDataNotFoundStatus(), + ) + + handler = retrieve_handler(revision.data.uri) + + if not handler: + return EvaluatorServiceResponse( + status=HandlerNotFoundStatus( + uri=revision.data.uri, + ), + ) + + outputs = await handler( + revision=revision, + request=request, + # + parameters=revision.data.parameters, + inputs=request.data.inputs, + # + trace_parameters=request.data.trace_parameters, + trace_inputs=request.data.trace_inputs, + trace_outputs=request.data.trace_outputs, + # + trace=request.data.trace, + tree=request.data.tree, + ) + + data = dict( + parameters=revision.data.parameters, + inputs=request.data.inputs, + outputs=outputs, + ) + + references = ( + { + k: ref.model_dump( + mode="json", + ) + for k, ref in request.references.items() + } + if request.references + else None + ) + + links = ( + { + k: ref.model_dump( + mode="json", + ) + for k, ref in request.links.items() + } + if request.links + else None + ) + + link = None + + try: + link = await _annotations_create( + tags=request.tags, + meta=request.meta, + # + data=data, + # + references=references, + links=links, + ) + except Exception as ex: + print(ex) + + response = EvaluatorServiceResponse( + status=SuccessStatus(message=""), + data=WorkflowServiceData( + outputs=outputs, + ), + trace_id=link.trace_id if link else None, + links=({revision.slug or uuid4().hex: link} if link else {}), + ) + + return response + + except ErrorStatus as error: + return EvaluatorServiceResponse( + status=Status( + code=error.code, + type=error.type, + message=error.message, + stacktrace=error.stacktrace, + ), + ) + + except Exception as ex: + return EvaluatorServiceResponse( + status=Status( + code=500, + message=str(ex), + ), + ) + + +async def _invocations_create( + tags: Optional[Dict[str, Any]] = None, + meta: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + references: Optional[Dict[str, Any]] = None, + links: Optional[Dict[str, Any]] = None, +) -> Optional[Link]: + response = client( + method="POST", + endpoint=f"/preview/invocations/", + json=dict( + invocation=dict( + origin="custom", + kind="eval", + channel="api", + data=data, + tags=tags, + meta=meta, + references=references, + links=links, + ) + ), + ) + + try: + response.raise_for_status() + except: + print(response.text) + raise + + response = response.json() + + trace_id = response.get("invocation", {}).get("trace_id", None) + span_id = response.get("invocation", {}).get("span_id", None) + + link = ( + Link( + trace_id=trace_id, + span_id=span_id, + ) + if trace_id and span_id + else None + ) + + return link + + +async def _annotations_create( + tags: Optional[Dict[str, Any]] = None, + meta: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + references: Optional[Dict[str, Any]] = None, + links: Optional[Dict[str, Any]] = None, +) -> Optional[Link]: + response = client( + method="POST", + endpoint=f"/preview/annotations/", + json=dict( + annotation=dict( + origin="custom", + kind="eval", + channel="api", + data=data, + tags=tags, + meta=meta, + references=references, + links=links, + ) + ), + ) + + try: + response.raise_for_status() + except: + print(response.text) + raise + + response = response.json() + + trace_id = response.get("annotation", {}).get("trace_id", None) + span_id = response.get("annotation", {}).get("span_id", None) + + link = ( + Link( + trace_id=trace_id, + span_id=span_id, + ) + if trace_id and span_id + else None + ) + + return link diff --git a/api/ee/tests/manual/evaluators/human-evaluator.http b/api/ee/tests/manual/evaluators/human-evaluator.http new file mode 100644 index 0000000000..8c02962cf8 --- /dev/null +++ b/api/ee/tests/manual/evaluators/human-evaluator.http @@ -0,0 +1,73 @@ + +@host = http://localhost +@base_url = {{host}}/api/human-evaluators +@api_key = xxxxxx.xxxxxxxxxxxxxxxx +### + +# @name add_human_evaluator +POST {{base_url}}/ +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +{ + "slug": "my-human-evaluator", + "header": {"name": "a/b accuracy", "description": "this is a description"}, + "revision": { + "kind": "HUMAN_EVALUATOR", + "body": { + "data": {"metrics": ["accuracy"], "notes": "Evaluator for accuracy"}, + "tags": ["human", "evaluation"] + }, + "commit": { + "parent_id": null, + "message": "Initial commit", + "author": "01964312-ad5a-7bb1-b21e-4f055c9f988b", + "date": "2025-04-18T12:25:59.609Z" + } + } +} + +### + +# @name fetch_human_evaluator +POST {{base_url}}/{{add_human_evaluator.response.body.variant_ref.id}} +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name edit_human_evaluator +PUT {{base_url}}/{{add_human_evaluator.response.body.variant_ref.id}} +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +{ + "slug": "my-human-evaluator-updated-another-another", + "body": { + "data": {"metrics": ["accuracy"], "notes": "Evaluator for accuracy"}, + "tags": ["human", "evaluation"] + }, + "commit": { + "parent_id": null, + "message": "Second commit", + "author": "01964312-ad5a-7bb1-b21e-4f055c9f988b", + "date": "2025-04-18T13:10:55.658Z" + } +} + +### + +# @name query_human_evaluators +GET {{base_url}}/query?revision_id={{add_human_evaluator.response.body.id}}&depth=1 +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + +# @name delete_human_evaluator +DELETE {{base_url}}/{{add_human_evaluator.response.body.variant_ref.id}} +Content-Type: application/json +Authorization: ApiKey {{api_key}} + +### + diff --git a/api/ee/tests/pytest/__init__.py b/api/ee/tests/pytest/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/ee/tests/requirements.txt b/api/ee/tests/requirements.txt new file mode 100644 index 0000000000..510e3b3b6f --- /dev/null +++ b/api/ee/tests/requirements.txt @@ -0,0 +1 @@ +-r ../../oss/tests/requirements.txt \ No newline at end of file diff --git a/api/oss/tests/manual/tracing/windowing.http b/api/oss/tests/manual/tracing/windowing.http index 5956e7d6a2..cad4ae83ad 100644 --- a/api/oss/tests/manual/tracing/windowing.http +++ b/api/oss/tests/manual/tracing/windowing.http @@ -1,6 +1,6 @@ @host = http://localhost @base_url = {{host}}/api/preview/tracing -@api_key = UGZaImq8.a94d2c99eab827b1cd27678358016a61f2e92c2cdea8f33b1cf3cc2afb7065e8 +@api_key = ... ### diff --git a/api/pyproject.toml b/api/pyproject.toml index 4f58361a5b..76fc53afd4 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "api" -version = "0.57.2" +version = "0.58.0" description = "Agenta API" authors = [ { name = "Mahmoud Mabrouk", email = "mahmoud@agenta.ai" }, diff --git a/docs/docs/prompt-engineering/playground/02-adding-custom-providers.mdx b/docs/docs/prompt-engineering/playground/02-adding-custom-providers.mdx index c2b3121621..898cefceb1 100644 --- a/docs/docs/prompt-engineering/playground/02-adding-custom-providers.mdx +++ b/docs/docs/prompt-engineering/playground/02-adding-custom-providers.mdx @@ -78,7 +78,7 @@ To add Azure OpenAI models, you'll need the following information: ### Configuration Example ```plaintext -API Key: c98d7a8s7d6a5s4d3a2s1d... +API Key: xxxxxxxxxx API Version: 2023-05-15 API base url: Use here your endpoint URL (e.g., https://accountnameinstance.openai.azure.com Deployment Name: Use here the deployment name in Azure (e.g., gpt-4-turbo) @@ -103,7 +103,7 @@ Refer to these tutorials for detailed instructions: ```plaintext Access Key ID: xxxxxxxxxx -Secret Access Key: xxxxxxxxxxxxxxxxxxxxxxx +Secret Access Key: xxxxxxxxxx Region: (e.g eu-central-1) Model name: (e.g anthropic.claude-3-sonnet-20240229-v1:0) ``` diff --git a/ee/LICENSE b/ee/LICENSE new file mode 100644 index 0000000000..ae7a2f38f4 --- /dev/null +++ b/ee/LICENSE @@ -0,0 +1,37 @@ +Agenta Enterprise License (the “Enterprise License”) +Copyright (c) 2023–2025 +Agentatech UG (haftungsbeschränkt), doing business as “Agenta” (“Agenta”) + +With regard to the Agenta Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have agreed to, +and are in compliance with, the Agenta Subscription Terms of Service, available +at https://agenta.ai/terms (the “Enterprise Terms”), or other +agreement governing the use of the Software, as agreed by you and Agenta, +and otherwise have a valid Agenta Enterprise License. + +Subject to the foregoing sentence, you are free to modify this Software and +publish patches to the Software. You agree that Agenta and/or its licensors +(as applicable) retain all right, title and interest in and to all such +modifications and/or patches, and all such modifications and/or patches may +only be used, copied, modified, displayed, distributed, or otherwise exploited +with a valid Agenta Enterprise License. Notwithstanding the foregoing, you may +copy and modify the Software for development and testing purposes, without +requiring a subscription. You agree that Agenta and/or its licensors (as +applicable) retain all right, title and interest in and to all such +modifications. You are not granted any other rights beyond what is expressly +stated herein. Subject to the foregoing, it is forbidden to copy, merge, +publish, distribute, sublicense, and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Agenta Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/hooks/setup.sh b/hooks/setup.sh new file mode 100755 index 0000000000..dfa7669995 --- /dev/null +++ b/hooks/setup.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +set -euo pipefail + +echo "🔧 Setting up Git hooks with pre-commit + gitleaks..." + +# --- check dependencies --- +if ! command -v python3 >/dev/null 2>&1; then + echo "❌ Python3 is required but not installed." + exit 1 +fi +if ! command -v pip3 >/dev/null 2>&1; then + echo "❌ pip3 is required but not installed." + exit 1 +fi + +# --- install pre-commit globally if missing --- +if ! command -v pre-commit >/dev/null 2>&1; then + echo "📦 Installing pre-commit..." + pip3 install pre-commit +fi + +# --- install gitleaks globally if missing --- +if ! command -v gitleaks >/dev/null 2>&1; then + echo "📦 Installing gitleaks..." + if command -v brew >/dev/null 2>&1; then + brew install gitleaks + else + # fallback: go install (requires Go installed) + go install github.com/gitleaks/gitleaks/v8@latest + export PATH="$PATH:$(go env GOPATH)/bin" + fi +fi + +# --- install hooks into .git/hooks/ --- +echo "⚙️ Installing pre-commit hooks..." +pre-commit install --install-hooks +pre-commit install --hook-type pre-push + +# --- one-time full scans --- +echo "🔍 Running one-time gitleaks scans..." + +gitleaks --config .gitleaks.toml --exit-code 1 --verbose detect --no-git --source . || { + echo "❌ Gitleaks detected potential secrets in the working directory." + exit 1 +} + +echo "✅ Setup complete! Hooks installed and initial scan passed. You are safe to commit." diff --git a/hosting/docker-compose/ee/.dockerignore b/hosting/docker-compose/ee/.dockerignore new file mode 100644 index 0000000000..3a6d70aca2 --- /dev/null +++ b/hosting/docker-compose/ee/.dockerignore @@ -0,0 +1,7 @@ +node_modules +.git +docker/ +db.schema +tests/ +poetry.lock +db.schema \ No newline at end of file diff --git a/hosting/docker-compose/ee/LICENSE b/hosting/docker-compose/ee/LICENSE new file mode 100644 index 0000000000..ae7a2f38f4 --- /dev/null +++ b/hosting/docker-compose/ee/LICENSE @@ -0,0 +1,37 @@ +Agenta Enterprise License (the “Enterprise License”) +Copyright (c) 2023–2025 +Agentatech UG (haftungsbeschränkt), doing business as “Agenta” (“Agenta”) + +With regard to the Agenta Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have agreed to, +and are in compliance with, the Agenta Subscription Terms of Service, available +at https://agenta.ai/terms (the “Enterprise Terms”), or other +agreement governing the use of the Software, as agreed by you and Agenta, +and otherwise have a valid Agenta Enterprise License. + +Subject to the foregoing sentence, you are free to modify this Software and +publish patches to the Software. You agree that Agenta and/or its licensors +(as applicable) retain all right, title and interest in and to all such +modifications and/or patches, and all such modifications and/or patches may +only be used, copied, modified, displayed, distributed, or otherwise exploited +with a valid Agenta Enterprise License. Notwithstanding the foregoing, you may +copy and modify the Software for development and testing purposes, without +requiring a subscription. You agree that Agenta and/or its licensors (as +applicable) retain all right, title and interest in and to all such +modifications. You are not granted any other rights beyond what is expressly +stated herein. Subject to the foregoing, it is forbidden to copy, merge, +publish, distribute, sublicense, and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Agenta Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/hosting/docker-compose/ee/docker-compose.dev.yml b/hosting/docker-compose/ee/docker-compose.dev.yml new file mode 100644 index 0000000000..8d540070d9 --- /dev/null +++ b/hosting/docker-compose/ee/docker-compose.dev.yml @@ -0,0 +1,372 @@ +name: agenta-ee-dev + +services: + .api: + image: agenta-ee-dev-api:latest + build: + context: ../../../api + dockerfile: ee/docker/Dockerfile.dev + command: ["true"] # exits immediately + + .web: + image: agenta-ee-dev-web:latest + build: + context: ../../../web + dockerfile: ee/docker/Dockerfile.dev + command: ["true"] # exits immediately + + web: + profiles: + - with-web + + image: agenta-ee-dev-web:latest + + volumes: + - ../../../web/ee/src:/app/ee/src + - ../../../web/ee/public:/app/ee/public + - ../../../web/oss/src:/app/oss/src + - ../../../web/oss/public:/app/oss/public + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + ports: + - "3000:3000" + + restart: always + + networks: + - agenta-network + labels: + - "traefik.http.routers.agenta-web.rule= PathPrefix(`/`)" + - "traefik.http.routers.agenta-web.entrypoints=web" + - "traefik.http.services.agenta-web.loadbalancer.server.port=3000" + + command: sh -c "pnpm dev-ee" + + api: + image: agenta-ee-dev-api:latest + + volumes: + - ../../../api:/app + # - ../../../sdk:/sdk + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + labels: + - "traefik.http.routers.api.rule=PathPrefix(`/api/`)" + - "traefik.http.routers.api.entrypoints=web" + - "traefik.http.middlewares.api-strip.stripprefix.prefixes=/api" + - "traefik.http.middlewares.api-strip.stripprefix.forceslash=true" + - "traefik.http.routers.api.middlewares=api-strip" + - "traefik.http.services.api.loadbalancer.server.port=8000" + - "traefik.http.routers.api.service=api" + + restart: always + + networks: + - agenta-network + extra_hosts: + - "host.docker.internal:host-gateway" + + command: + [ + "uvicorn", + "entrypoint:app", + "--host", + "0.0.0.0", + "--port", + "8000", + "--reload", + "--root-path", + "/api", + ] + + depends_on: + postgres: + condition: service_healthy + alembic: + condition: service_completed_successfully + supertokens: + condition: service_healthy + + worker: + image: agenta-ee-dev-api:latest + + volumes: + - ../../../api:/app + # - ../../../sdk:/sdk + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + depends_on: + - postgres + - rabbitmq + - redis + + extra_hosts: + - "host.docker.internal:host-gateway" + + restart: always + + networks: + - agenta-network + + command: > + watchmedo auto-restart --directory=/app/ --pattern=*.py --recursive -- celery -A entrypoint.celery_app worker --concurrency=1 --max-tasks-per-child=1 --prefetch-multiplier=1 --loglevel=DEBUG + + cron: + image: agenta-ee-dev-api:latest + + volumes: + - ../../../api/ee/src/crons/meters.sh:/meters.sh + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + depends_on: + - postgres + - api + + extra_hosts: + - "host.docker.internal:host-gateway" + + restart: always + + networks: + - agenta-network + + command: cron -f + + alembic: + image: agenta-ee-dev-api:latest + + volumes: + - ../../../api:/app + # - ../../../sdk:/sdk + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + depends_on: + postgres: + condition: service_healthy + networks: + - agenta-network + + command: sh -c "python -m ee.databases.postgres.migrations.runner" + + completion: + build: + context: ../../../services/completion + dockerfile: oss/docker/Dockerfile.dev + + volumes: + - ../../../services/completion:/app + - ../../../sdk:/sdk + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + extra_hosts: + - "host.docker.internal:host-gateway" + labels: + - "traefik.http.routers.completion.rule=PathPrefix(`/services/completion/`)" + - "traefik.http.routers.completion.entrypoints=web" + - "traefik.http.middlewares.completion-strip.stripprefix.prefixes=/services/completion" + - "traefik.http.middlewares.completion-strip.stripprefix.forceslash=true" + - "traefik.http.routers.completion.middlewares=completion-strip" + - "traefik.http.services.completion.loadbalancer.server.port=80" + - "traefik.http.routers.completion.service=completion" + + restart: always + + networks: + - agenta-network + + command: ["python", "oss/src/main.py"] + + chat: + build: + context: ../../../services/chat + dockerfile: oss/docker/Dockerfile.dev + + volumes: + - ../../../services/chat:/app + - ../../../sdk:/sdk + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + extra_hosts: + - "host.docker.internal:host-gateway" + labels: + - "traefik.http.routers.chat.rule=PathPrefix(`/services/chat/`)" + - "traefik.http.routers.chat.entrypoints=web" + - "traefik.http.middlewares.chat-strip.stripprefix.prefixes=/services/chat" + - "traefik.http.middlewares.chat-strip.stripprefix.forceslash=true" + - "traefik.http.routers.chat.middlewares=chat-strip" + - "traefik.http.services.chat.loadbalancer.server.port=80" + - "traefik.http.routers.chat.service=chat" + + restart: always + + networks: + - agenta-network + + command: ["python", "oss/src/main.py"] + + postgres: + image: postgres:16.2 + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + ports: + - "5432:5432" + + restart: always + + networks: + - agenta-network + volumes: + - postgres-data:/var/lib/postgresql/data/ + - ../../../api/ee/databases/postgres/init-db-ee.sql:/docker-entrypoint-initdb.d/init-db.sql + healthcheck: + test: ["CMD-SHELL", "pg_isready -U username -d agenta_ee_core"] + interval: 10s + timeout: 5s + retries: 5 + + rabbitmq: + image: rabbitmq:3-management + + ports: + - "5672:5672" + - "15672:15672" + volumes: + - rabbitmq-data:/var/lib/rabbitmq + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + restart: always + + networks: + - agenta-network + + redis: + image: redis:latest + + restart: always + + networks: + - agenta-network + volumes: + - redis-data:/data + + cache: + image: redis:latest + + command: > + redis-server + --port 6378 + --appendonly no + --appendfsync no + --save "" + --maxmemory 128mb + --maxmemory-policy allkeys-lru + + ports: + - "6378:6378" + + volumes: + - cache-data:/data + + restart: always + + networks: + - agenta-network + + labels: + - "traefik.enable=false" + + healthcheck: + test: ["CMD", "redis-cli", "-p", "6378", "ping"] + interval: 10s + timeout: 5s + retries: 5 + start_period: 5s + + traefik: + image: traefik:v2.10 + + command: --api.dashboard=true --api.insecure=true --providers.docker --entrypoints.web.address=:${AGENTA_PORT:-80} + volumes: + - /var/run/docker.sock:/var/run/docker.sock + ports: + - "${AGENTA_PORT:-80}:${AGENTA_PORT:-80}" + - "8080:8080" + - "443:443" + + restart: always + + networks: + - agenta-network + + supertokens: + image: registry.supertokens.io/supertokens/supertokens-postgresql + + depends_on: + postgres: + condition: service_healthy + alembic: + condition: service_completed_successfully + + ports: + - 3567:3567 + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + environment: + POSTGRESQL_CONNECTION_URI: ${POSTGRES_URI_SUPERTOKENS} + + restart: always + + networks: + - agenta-network + + healthcheck: + test: > + bash -c 'exec 3<>/dev/tcp/127.0.0.1/3567 && echo -e "GET /hello HTTP/1.1\r\nhost: 127.0.0.1:3567\r\nConnection: close\r\n\r\n" >&3 && cat <&3 | grep "Hello"' + interval: 10s + timeout: 5s + retries: 5 + + stripe: + image: stripe/stripe-cli:latest + + command: [ + listen, + --forward-to, + http://api:8000/billing/stripe/events/, + --events, + "customer.subscription.created,customer.subscription.deleted,invoice.updated,invoice.upcoming,invoice.payment_failed,invoice.payment_succeeded" + ] + + env_file: + - ${ENV_FILE:-./.env.ee.dev} + + restart: always + + networks: + - agenta-network + +networks: + agenta-network: + +volumes: + postgres-data: + rabbitmq-data: + redis-data: + cache-data: + nextjs_cache: diff --git a/hosting/docker-compose/ee/env.ee.dev.example b/hosting/docker-compose/ee/env.ee.dev.example new file mode 100644 index 0000000000..c42666965b --- /dev/null +++ b/hosting/docker-compose/ee/env.ee.dev.example @@ -0,0 +1,91 @@ +# First-party (required) +AGENTA_LICENSE=ee +AGENTA_STAGE=dev +AGENTA_PROVIDER=local +AGENTA_WEB_URL=http://localhost +AGENTA_API_URL=http://localhost/api +AGENTA_SERVICES_URL=http://localhost/services +AGENTA_AUTH_KEY=change-me +AGENTA_CRYPT_KEY=change-me +AGENTA_API_IMAGE_NAME=agenta-api +AGENTA_API_IMAGE_TAG=latest +AGENTA_WEB_IMAGE_NAME=agenta-web +AGENTA_WEB_IMAGE_TAG=latest + +# First-party (registry & service) +DOCKER_NETWORK_MODE=bridge +POSTGRES_USERNAME=username +POSTGRES_PASSWORD=password + +# First-party (optional) +AGENTA_AUTO_MIGRATIONS=true +AGENTA_PRICING= +AGENTA_DEMOS= +AGENTA_RUNTIME_PREFIX= +AGENTA_API_INTERNAL_URL= +AGENTA_LITELLM_MOCK= +POSTGRES_USERNAME_ADMIN= +POSTGRES_PASSWORD_ADMIN= +AGENTA_SERVICE_MIDDLEWARE_CACHE_ENABLED=true +AGENTA_OTLP_MAX_BATCH_BYTES=10485760 + +# Third-party (required) +TRAEFIK_DOMAIN= +TRAEFIK_PROTOCOL= +TRAEFIK_PORT= + +REDIS_URL=redis://redis:6379/0 +RABBITMQ_DEFAULT_PASS=guest +RABBITMQ_DEFAULT_USER=guest + +CELERY_BROKER_URL=amqp://guest@rabbitmq// +CELERY_RESULT_BACKEND=redis://redis:6379/0 + +POSTGRES_URI_SUPERTOKENS="postgresql://username:password@postgres:5432/agenta_ee_supertokens" +POSTGRES_URI_CORE="postgresql+asyncpg://username:password@postgres:5432/agenta_ee_core" +POSTGRES_URI_TRACING="postgresql+asyncpg://username:password@postgres:5432/agenta_ee_tracing" + +ALEMBIC_CFG_PATH_CORE=/app/ee/databases/postgres/migrations/core/alembic.ini +ALEMBIC_CFG_PATH_TRACING=/app/ee/databases/postgres/migrations/tracing/alembic.ini + +SUPERTOKENS_CONNECTION_URI=http://supertokens:3567 + +# Third-party (optional) +AWS_ECR_URL= +AWS_RDS_SECRET= + +POSTHOG_API_KEY=phc_3urGRy5TL1HhaHnRYL0JSHxJxigRVackhphHtozUmdp + +GITHUB_OAUTH_CLIENT_ID= +GITHUB_OAUTH_CLIENT_SECRET= +GOOGLE_OAUTH_CLIENT_ID= +GOOGLE_OAUTH_CLIENT_SECRET= + +SUPERTOKENS_API_KEY=replace-me + +NEW_RELIC_LICENSE_KEY= +NRIA_LICENSE_KEY= + +LOOPS_API_KEY= + +SENDGRID_API_KEY= + +CRISP_WEBSITE_ID= + +STRIPE_API_KEY= +STRIPE_WEBHOOK_SECRET= +STRIPE_TARGET= + +# Third-party - LLM (optional) +ALEPHALPHA_API_KEY= +ANTHROPIC_API_KEY= +ANYSCALE_API_KEY= +COHERE_API_KEY= +DEEPINFRA_API_KEY= +GEMINI_API_KEY= +GROQ_API_KEY= +MISTRAL_API_KEY= +OPENAI_API_KEY= +OPENROUTER_API_KEY= +PERPLEXITYAI_API_KEY= +TOGETHERAI_API_KEY= diff --git a/hosting/docker-compose/ee/env.ee.gh.example b/hosting/docker-compose/ee/env.ee.gh.example new file mode 100644 index 0000000000..5cba09c18c --- /dev/null +++ b/hosting/docker-compose/ee/env.ee.gh.example @@ -0,0 +1,80 @@ +# First-party (required) +AGENTA_LICENSE=ee +AGENTA_STAGE=dev +AGENTA_PROVIDER=local +AGENTA_API_URL=http://localhost/api +AGENTA_WEB_URL=http://localhost +AGENTA_SERVICES_URL=http://localhost/services +AGENTA_AUTH_KEY=change-me +AGENTA_CRYPT_KEY=change-me + +# First-party (registry & service) +DOCKER_NETWORK_MODE=bridge +POSTGRES_PASSWORD=password +POSTGRES_USERNAME=username + +# First-party (optional) +AGENTA_AUTO_MIGRATIONS=true +AGENTA_PRICING= +AGENTA_DEMOS= +AGENTA_RUNTIME_PREFIX= +AGENTA_API_INTERNAL_URL= +AGENTA_SERVICE_MIDDLEWARE_CACHE_ENABLED=true +AGENTA_OTLP_MAX_BATCH_BYTES=10485760 + +# Third-party (required) +TRAEFIK_DOMAIN= +TRAEFIK_PROTOCOL= +TRAEFIK_PORT= + +REDIS_URL=redis://redis:6379/0 +RABBITMQ_DEFAULT_PASS=guest +RABBITMQ_DEFAULT_USER=guest + +CELERY_BROKER_URL=amqp://guest@rabbitmq// +CELERY_RESULT_BACKEND=redis://redis:6379/0 + +POSTGRES_URI_SUPERTOKENS="postgresql://username:password@postgres:5432/agenta_ee_supertokens" +POSTGRES_URI_CORE="postgresql+asyncpg://username:password@postgres:5432/agenta_ee_core" +POSTGRES_URI_TRACING="postgresql+asyncpg://username:password@postgres:5432/agenta_ee_tracing" + +ALEMBIC_CFG_PATH_CORE=/app/ee/databases/postgres/migrations/core/alembic.ini +ALEMBIC_CFG_PATH_TRACING=/app/ee/databases/postgres/migrations/tracing/alembic.ini + +SUPERTOKENS_API_KEY=replace-me +SUPERTOKENS_CONNECTION_URI=http://supertokens:3567 + +# Third-party (optional) +POSTHOG_API_KEY=phc_3urGRy5TL1HhaHnRYL0JSHxJxigRVackhphHtozUmdp + +GITHUB_OAUTH_CLIENT_ID= +GITHUB_OAUTH_CLIENT_SECRET= + +GOOGLE_OAUTH_CLIENT_ID= +GOOGLE_OAUTH_CLIENT_SECRET= + +NEW_RELIC_LICENSE_KEY= +NRIA_LICENSE_KEY= + +LOOPS_API_KEY= + +SENDGRID_API_KEY= + +CRISP_WEBSITE_ID= + +STRIPE_API_KEY= +STRIPE_WEBHOOK_SECRET= + +# Third-party - LLM (optional) +ALEPHALPHA_API_KEY= +ANTHROPIC_API_KEY= +ANYSCALE_API_KEY= +COHERE_API_KEY= +DEEPINFRA_API_KEY= +GEMINI_API_KEY= +GROQ_API_KEY= +MISTRAL_API_KEY= +OPENAI_API_KEY= +OPENROUTER_API_KEY= +PERPLEXITYAI_API_KEY= +TOGETHERAI_API_KEY= \ No newline at end of file diff --git a/hosting/docker-compose/oss/.env.oss.dev.example b/hosting/docker-compose/oss/env.oss.dev.example similarity index 100% rename from hosting/docker-compose/oss/.env.oss.dev.example rename to hosting/docker-compose/oss/env.oss.dev.example diff --git a/hosting/docker-compose/oss/.env.oss.gh.example b/hosting/docker-compose/oss/env.oss.gh.example similarity index 100% rename from hosting/docker-compose/oss/.env.oss.gh.example rename to hosting/docker-compose/oss/env.oss.gh.example diff --git a/hosting/aws/agenta_instance.tf b/hosting/old/aws/agenta_instance.tf similarity index 100% rename from hosting/aws/agenta_instance.tf rename to hosting/old/aws/agenta_instance.tf diff --git a/hosting/aws/agenta_instance_sg.tf b/hosting/old/aws/agenta_instance_sg.tf similarity index 100% rename from hosting/aws/agenta_instance_sg.tf rename to hosting/old/aws/agenta_instance_sg.tf diff --git a/hosting/aws/instance-setup.sh b/hosting/old/aws/instance-setup.sh similarity index 100% rename from hosting/aws/instance-setup.sh rename to hosting/old/aws/instance-setup.sh diff --git a/hosting/aws/main.tf b/hosting/old/aws/main.tf similarity index 100% rename from hosting/aws/main.tf rename to hosting/old/aws/main.tf diff --git a/hosting/gcp/agenta-instance.tf b/hosting/old/gcp/agenta-instance.tf similarity index 100% rename from hosting/gcp/agenta-instance.tf rename to hosting/old/gcp/agenta-instance.tf diff --git a/hosting/old/gcp/credentials.json b/hosting/old/gcp/credentials.json new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hosting/gcp/main.tf b/hosting/old/gcp/main.tf similarity index 100% rename from hosting/gcp/main.tf rename to hosting/old/gcp/main.tf diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 323d0515f4..6fb0d58a0a 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agenta" -version = "0.57.2" +version = "0.58.0" description = "The SDK for agenta is an open-source LLMOps platform." readme = "README.md" authors = [ diff --git a/sdk/tests/legacy/baggage/config.toml b/sdk/tests/legacy/baggage/config.toml index f32346649b..d5a5f01895 100644 --- a/sdk/tests/legacy/baggage/config.toml +++ b/sdk/tests/legacy/baggage/config.toml @@ -1,4 +1,4 @@ app_name = "baggage" app_id = "0193b67a-b673-7919-85c2-0b5b0a2183d3" backend_host = "http://localhost" -api_key = "XELnjVve.c1f177c87250b603cf1ed2a69ebdfc1cec3124776058e7afcbba93890c515e74" +api_key = "XELnjVve.xxxx" diff --git a/sdk/tests/legacy/debugging/simple-app/config.toml b/sdk/tests/legacy/debugging/simple-app/config.toml index 389b22a2bf..7c2a204758 100644 --- a/sdk/tests/legacy/debugging/simple-app/config.toml +++ b/sdk/tests/legacy/debugging/simple-app/config.toml @@ -1,6 +1,6 @@ app_name = "asdf" app_id = "0193bbaa-4f2b-7510-9170-9bdf95249ca0" backend_host = "https://cloud.agenta.ai" -api_key = "dWdKluoL.fc56608c5e0ce7b262e9e9a795b6a5e9371200c573cafbd975ebb6b4368b6032" +api_key = "dWdKluoL.xxxx" variants = [] variant_ids = [] diff --git a/services/chat/ee/LICENSE b/services/chat/ee/LICENSE new file mode 100644 index 0000000000..ae7a2f38f4 --- /dev/null +++ b/services/chat/ee/LICENSE @@ -0,0 +1,37 @@ +Agenta Enterprise License (the “Enterprise License”) +Copyright (c) 2023–2025 +Agentatech UG (haftungsbeschränkt), doing business as “Agenta” (“Agenta”) + +With regard to the Agenta Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have agreed to, +and are in compliance with, the Agenta Subscription Terms of Service, available +at https://agenta.ai/terms (the “Enterprise Terms”), or other +agreement governing the use of the Software, as agreed by you and Agenta, +and otherwise have a valid Agenta Enterprise License. + +Subject to the foregoing sentence, you are free to modify this Software and +publish patches to the Software. You agree that Agenta and/or its licensors +(as applicable) retain all right, title and interest in and to all such +modifications and/or patches, and all such modifications and/or patches may +only be used, copied, modified, displayed, distributed, or otherwise exploited +with a valid Agenta Enterprise License. Notwithstanding the foregoing, you may +copy and modify the Software for development and testing purposes, without +requiring a subscription. You agree that Agenta and/or its licensors (as +applicable) retain all right, title and interest in and to all such +modifications. You are not granted any other rights beyond what is expressly +stated herein. Subject to the foregoing, it is forbidden to copy, merge, +publish, distribute, sublicense, and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Agenta Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/services/chat/ee/__init__.py b/services/chat/ee/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/services/chat/ee/docker/Dockerfile.gh b/services/chat/ee/docker/Dockerfile.gh new file mode 100644 index 0000000000..7e2351a555 --- /dev/null +++ b/services/chat/ee/docker/Dockerfile.gh @@ -0,0 +1,18 @@ +FROM python:3.10-slim + +ARG ROOT_PATH=/ +ENV ROOT_PATH=${ROOT_PATH} + +WORKDIR /app/ + +RUN pip install --upgrade pip + +COPY ./requirements.txt /app/requirements.txt + +RUN pip install -r requirements.txt + +COPY ./oss /app/oss/ + +ENV PYTHONPATH=/sdk:$PYTHONPATH + +EXPOSE 80 diff --git a/services/completion/ee/LICENSE b/services/completion/ee/LICENSE new file mode 100644 index 0000000000..ae7a2f38f4 --- /dev/null +++ b/services/completion/ee/LICENSE @@ -0,0 +1,37 @@ +Agenta Enterprise License (the “Enterprise License”) +Copyright (c) 2023–2025 +Agentatech UG (haftungsbeschränkt), doing business as “Agenta” (“Agenta”) + +With regard to the Agenta Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have agreed to, +and are in compliance with, the Agenta Subscription Terms of Service, available +at https://agenta.ai/terms (the “Enterprise Terms”), or other +agreement governing the use of the Software, as agreed by you and Agenta, +and otherwise have a valid Agenta Enterprise License. + +Subject to the foregoing sentence, you are free to modify this Software and +publish patches to the Software. You agree that Agenta and/or its licensors +(as applicable) retain all right, title and interest in and to all such +modifications and/or patches, and all such modifications and/or patches may +only be used, copied, modified, displayed, distributed, or otherwise exploited +with a valid Agenta Enterprise License. Notwithstanding the foregoing, you may +copy and modify the Software for development and testing purposes, without +requiring a subscription. You agree that Agenta and/or its licensors (as +applicable) retain all right, title and interest in and to all such +modifications. You are not granted any other rights beyond what is expressly +stated herein. Subject to the foregoing, it is forbidden to copy, merge, +publish, distribute, sublicense, and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Agenta Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/services/completion/ee/__init__.py b/services/completion/ee/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/services/completion/ee/docker/Dockerfile.gh b/services/completion/ee/docker/Dockerfile.gh new file mode 100644 index 0000000000..7e2351a555 --- /dev/null +++ b/services/completion/ee/docker/Dockerfile.gh @@ -0,0 +1,18 @@ +FROM python:3.10-slim + +ARG ROOT_PATH=/ +ENV ROOT_PATH=${ROOT_PATH} + +WORKDIR /app/ + +RUN pip install --upgrade pip + +COPY ./requirements.txt /app/requirements.txt + +RUN pip install -r requirements.txt + +COPY ./oss /app/oss/ + +ENV PYTHONPATH=/sdk:$PYTHONPATH + +EXPOSE 80 diff --git a/web/ee/.gitignore b/web/ee/.gitignore new file mode 100644 index 0000000000..6d61ed9526 --- /dev/null +++ b/web/ee/.gitignore @@ -0,0 +1,37 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env*.local + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + + diff --git a/web/ee/LICENSE b/web/ee/LICENSE new file mode 100644 index 0000000000..ae7a2f38f4 --- /dev/null +++ b/web/ee/LICENSE @@ -0,0 +1,37 @@ +Agenta Enterprise License (the “Enterprise License”) +Copyright (c) 2023–2025 +Agentatech UG (haftungsbeschränkt), doing business as “Agenta” (“Agenta”) + +With regard to the Agenta Software: + +This software and associated documentation files (the "Software") may only be +used in production, if you (and any entity that you represent) have agreed to, +and are in compliance with, the Agenta Subscription Terms of Service, available +at https://agenta.ai/terms (the “Enterprise Terms”), or other +agreement governing the use of the Software, as agreed by you and Agenta, +and otherwise have a valid Agenta Enterprise License. + +Subject to the foregoing sentence, you are free to modify this Software and +publish patches to the Software. You agree that Agenta and/or its licensors +(as applicable) retain all right, title and interest in and to all such +modifications and/or patches, and all such modifications and/or patches may +only be used, copied, modified, displayed, distributed, or otherwise exploited +with a valid Agenta Enterprise License. Notwithstanding the foregoing, you may +copy and modify the Software for development and testing purposes, without +requiring a subscription. You agree that Agenta and/or its licensors (as +applicable) retain all right, title and interest in and to all such +modifications. You are not granted any other rights beyond what is expressly +stated herein. Subject to the foregoing, it is forbidden to copy, merge, +publish, distribute, sublicense, and/or sell the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +For all third party components incorporated into the Agenta Software, those +components are licensed under the original license provided by the owner of the +applicable component. diff --git a/web/ee/docker/Dockerfile.dev b/web/ee/docker/Dockerfile.dev new file mode 100644 index 0000000000..719331462b --- /dev/null +++ b/web/ee/docker/Dockerfile.dev @@ -0,0 +1,43 @@ +FROM node:20.18-slim + +ENV TURBO_TELEMETRY_DISABLED=1 + +WORKDIR /app + +# Install jq for JSON parsing +RUN apt-get update && apt-get install -y jq + +# Install dependencies based on the preferred package manager +COPY package.json yarn.lock* package-lock.json* pnpm-lock.yaml* .npmrc* ./ + +# Extract PNPM version and install it +RUN PNPM_VERSION=$(cat package.json | jq -r '.packageManager | split("@")[1]') && \ + npm install -g pnpm@${PNPM_VERSION} + +COPY ee/package.json ./ee/yarn.lock* ./ee/package-lock.json* ./ee/pnpm-lock.yaml* ./ee/.npmrc* ./ee/ +COPY oss/package.json ./oss/yarn.lock* ./oss/package-lock.json* ./oss/pnpm-lock.yaml* ./oss/.npmrc* ./oss/ +COPY ./pnpm-workspace.yaml ./turbo.json ./ + +COPY ./entrypoint.sh /app/entrypoint.sh + +RUN pnpm i + +COPY ee/src ./ee/src +COPY ee/public ./ee/public +COPY oss/src ./oss/src +COPY oss/public ./oss/public +COPY tsconfig.json . +COPY ee/tsconfig.json ./ee +COPY oss/tsconfig.json ./oss + +COPY ee/postcss.config.mjs ./ee/postcss.config.mjs +COPY oss/postcss.config.mjs ./oss/postcss.config.mjs + +COPY ee/next.config.ts ./ee/next.config.ts +COPY oss/next.config.ts ./oss/next.config.ts + +COPY ee/tailwind.config.ts ./ee/tailwind.config.ts +COPY oss/tailwind.config.ts ./oss/tailwind.config.ts + +ENTRYPOINT ["./entrypoint.sh"] +EXPOSE 3000 \ No newline at end of file diff --git a/web/ee/docker/Dockerfile.gh b/web/ee/docker/Dockerfile.gh new file mode 100644 index 0000000000..c362aa886f --- /dev/null +++ b/web/ee/docker/Dockerfile.gh @@ -0,0 +1,43 @@ +FROM node:20.18.0-slim AS base + +ENV TURBO_TELEMETRY_DISABLED=1 + +ENV PNPM_HOME="/pnpm" +ENV PATH="$PNPM_HOME:$PATH" + +RUN apt-get update && apt-get install -y jq + +COPY . . +RUN PNPM_VERSION=$(cat package.json | jq -r '.packageManager | split("@")[1]') && \ + npm install -g pnpm@${PNPM_VERSION} + +RUN pnpm add -g turbo +RUN turbo prune @agenta/ee --docker + +# BUILDER --------------------------------------------------------------------- + +FROM base AS builder + +WORKDIR /app + +COPY --from=base ./out/json/ . +COPY ./.husky /app/.husky + +RUN --mount=type=cache,id=pnpm,target=/pnpm/store yes | pnpm install --frozen-lockfile --filter=@agenta/ee +COPY --from=base /out/full/ . + +RUN npx next telemetry disable + +RUN pnpm turbo run build --filter=@agenta/ee + +# RUNNER ---------------------------------------------------------------------- + +FROM base AS runner + +WORKDIR /app + +COPY --from=builder /app/ee/.next/standalone /app +COPY ../entrypoint.sh /app/entrypoint.sh + +ENTRYPOINT ["/app/entrypoint.sh"] +EXPOSE 3000 diff --git a/web/ee/next.config.ts b/web/ee/next.config.ts new file mode 100644 index 0000000000..7de0509f9c --- /dev/null +++ b/web/ee/next.config.ts @@ -0,0 +1,73 @@ +import path from "path" +import {createRequire} from "module" + +import ossConfig from "@agenta/oss/next.config" + +const require = createRequire(import.meta.url) +const reduxToolkitCjsEntry = path.join( + path.dirname(require.resolve("@reduxjs/toolkit/package.json")), + "dist/cjs/index.js", +) + +const config = { + ...ossConfig, + outputFileTracingRoot: path.resolve(__dirname, ".."), + turbopack: { + // root: path.resolve(__dirname, ".."), + resolveAlias: { + "@/oss/*": ["@/agenta-oss-common/*"], + }, + }, + experimental: { + optimizePackageImports: ["@agenta/oss"], + }, + transpilePackages: ["jotai-devtools"], + typescript: { + ignoreBuildErrors: true, + }, + webpack: (webpackConfig: any, options: any) => { + const baseConfig = + typeof ossConfig.webpack === "function" + ? ossConfig.webpack(webpackConfig, options) + : webpackConfig + + baseConfig.resolve ??= {} + baseConfig.resolve.alias = { + ...(baseConfig.resolve.alias ?? {}), + "@reduxjs/toolkit": reduxToolkitCjsEntry, + } + + return baseConfig + }, + async redirects() { + return [ + { + source: "/apps", + destination: "/w", + permanent: true, + }, + { + source: "/apps/:app_id", + destination: "/w", + permanent: true, + }, + { + source: "/apps/:app_id/:path*", + destination: "/w", + permanent: true, + }, + { + source: "/", + destination: "/w", + permanent: true, + }, + { + source: "/:workspace_id/apps/:app_id", + destination: "/:workspace_id/apps/:app_id/overview/", + permanent: true, + }, + ] + }, +} + +export default config diff --git a/web/ee/package.json b/web/ee/package.json new file mode 100644 index 0000000000..e98d5b2bed --- /dev/null +++ b/web/ee/package.json @@ -0,0 +1,94 @@ +{ + "name": "@agenta/ee", + "version": "0.58.0", + "private": true, + "engines": { + "node": ">=18" + }, + "scripts": { + "dev": "next dev --turbopack", + "dev:local": "ENV_FILE=.local.env next dev", + "dev:turbo": "ENV_FILE=.local.env next dev --turbo", + "build": "next build && cp -r public/. ./.next/standalone/ee/public && cp -r .next/static/. ./.next/standalone/ee/.next/static", + "start": "next start", + "lint": "next lint", + "lint-fix": "next lint --fix", + "format": "prettier --check .", + "format-fix": "prettier --write .", + "types:check": "tsc", + "types:watch": "tsc -w" + }, + "dependencies": { + "@ag-grid-community/client-side-row-model": "^32.3.4", + "@ag-grid-community/core": "^32.3.4", + "@ag-grid-community/csv-export": "^32.3.4", + "@ag-grid-community/react": "^32.3.4", + "@ag-grid-community/styles": "^32.3.4", + "@agenta/oss": "workspace:../oss", + "@ant-design/colors": "^7.2.0", + "@ant-design/cssinjs": "^1.22.1", + "@ant-design/icons": "^5.5.2", + "@ant-design/v5-patch-for-react-19": "^1.0.3", + "@lexical/code-shiki": "^0.35.0", + "@monaco-editor/react": "^4.7.0-rc.0", + "@phosphor-icons/react": "^2.1.10", + "@tanstack/query-core": "^5.87.1", + "@tanstack/react-query": "^5.87.1", + "@tremor/react": "^3.18.7", + "@types/js-yaml": "^4.0.9", + "@types/lodash": "^4.17.18", + "@types/react": "^19.0.10", + "@types/react-dom": "^19.0.4", + "@types/react-resizable": "^3.0.7", + "@types/react-window": "^1.8.8", + "@types/recharts": "^2.0.1", + "@types/uuid": "^10.0.0", + "antd": "^5.26.1", + "autoprefixer": "10.4.20", + "axios": "^1.12.2", + "classnames": "^2.3.2", + "clsx": "^2.1.1", + "crisp-sdk-web": "^1.0.25", + "dayjs": "^1.11.10", + "dotenv": "^16.5.0", + "fast-deep-equal": "^3.1.3", + "immer": "^10.1.1", + "jotai": "^2.13.1", + "jotai-devtools": "^0.12.0", + "jotai-eager": "^0.2.3", + "jotai-immer": "^0.4.1", + "jotai-tanstack-query": "^0.11.0", + "js-yaml": "^4.1.0", + "jsonrepair": "^3.13.0", + "lodash": "^4.17.21", + "postcss": "^8.5.6", + "postcss-antd-fixes": "^0.2.0", + "posthog-js": "^1.223.3", + "rc-virtual-list": "^3.19.1", + "react": "^19.0.0", + "react-dom": "^19.0.0", + "react-jss": "^10.10.0", + "react-resizable": "^3.0.5", + "react-window": "^1.8.11", + "recharts": "^3.1.0", + "shiki": "^3.12.2", + "supertokens-auth-react": "^0.47.0", + "supertokens-node": "^21.0.0", + "swc-loader": "^0.2.6", + "swr": "^2.3.0", + "tailwindcss": "^3.4.4", + "typescript": "5.8.3", + "use-animation-frame": "^0.2.1", + "usehooks-ts": "^3.1.0", + "uuid": "^11.1.0" + }, + "devDependencies": { + "@agenta/web-tests": "workspace:../tests", + "@swc-jotai/debug-label": "^0.2.0", + "@swc-jotai/react-refresh": "^0.3.0", + "@types/node": "^20.8.10", + "@types/prismjs": "^1.26.5", + "node-mocks-http": "^1.17.2", + "tailwind-scrollbar": "^3" + } +} diff --git a/web/ee/postcss.config.mjs b/web/ee/postcss.config.mjs new file mode 100644 index 0000000000..d286a2562d --- /dev/null +++ b/web/ee/postcss.config.mjs @@ -0,0 +1,3 @@ +import ossConfig from "@agenta/oss/postcss.config.mjs" + +export default ossConfig diff --git a/web/ee/public/assets/On-boarding.png b/web/ee/public/assets/On-boarding.png new file mode 100644 index 0000000000..00ec79f653 Binary files /dev/null and b/web/ee/public/assets/On-boarding.png differ diff --git a/web/ee/public/assets/On-boarding.webp b/web/ee/public/assets/On-boarding.webp new file mode 100644 index 0000000000..2562fd4adc Binary files /dev/null and b/web/ee/public/assets/On-boarding.webp differ diff --git a/web/ee/public/assets/dark-complete-transparent-CROPPED.png b/web/ee/public/assets/dark-complete-transparent-CROPPED.png new file mode 100644 index 0000000000..7d134ac59a Binary files /dev/null and b/web/ee/public/assets/dark-complete-transparent-CROPPED.png differ diff --git a/web/ee/public/assets/dark-complete-transparent_white_logo.png b/web/ee/public/assets/dark-complete-transparent_white_logo.png new file mode 100644 index 0000000000..8685bbf981 Binary files /dev/null and b/web/ee/public/assets/dark-complete-transparent_white_logo.png differ diff --git a/web/ee/public/assets/dark-logo.svg b/web/ee/public/assets/dark-logo.svg new file mode 100644 index 0000000000..6cb8ef3330 --- /dev/null +++ b/web/ee/public/assets/dark-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/ee/public/assets/fallback.png b/web/ee/public/assets/fallback.png new file mode 100644 index 0000000000..d3cbad0379 Binary files /dev/null and b/web/ee/public/assets/fallback.png differ diff --git a/web/ee/public/assets/favicon.ico b/web/ee/public/assets/favicon.ico new file mode 100644 index 0000000000..4dc8619b1d Binary files /dev/null and b/web/ee/public/assets/favicon.ico differ diff --git a/web/ee/public/assets/light-complete-transparent-CROPPED.png b/web/ee/public/assets/light-complete-transparent-CROPPED.png new file mode 100644 index 0000000000..6be2e99e08 Binary files /dev/null and b/web/ee/public/assets/light-complete-transparent-CROPPED.png differ diff --git a/web/ee/public/assets/light-logo.svg b/web/ee/public/assets/light-logo.svg new file mode 100644 index 0000000000..9c795f8e88 --- /dev/null +++ b/web/ee/public/assets/light-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/ee/public/assets/not-found.png b/web/ee/public/assets/not-found.png new file mode 100644 index 0000000000..f4048f6573 Binary files /dev/null and b/web/ee/public/assets/not-found.png differ diff --git a/web/ee/public/assets/onboard-page-grids.svg b/web/ee/public/assets/onboard-page-grids.svg new file mode 100644 index 0000000000..85990df21d --- /dev/null +++ b/web/ee/public/assets/onboard-page-grids.svg @@ -0,0 +1,81 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/ee/public/assets/rag-demo-app.webp b/web/ee/public/assets/rag-demo-app.webp new file mode 100644 index 0000000000..77ded11bee Binary files /dev/null and b/web/ee/public/assets/rag-demo-app.webp differ diff --git a/web/ee/src/components/Banners/BillingPlanBanner/FreePlanBanner.tsx b/web/ee/src/components/Banners/BillingPlanBanner/FreePlanBanner.tsx new file mode 100644 index 0000000000..02ca075130 --- /dev/null +++ b/web/ee/src/components/Banners/BillingPlanBanner/FreePlanBanner.tsx @@ -0,0 +1,29 @@ +import {memo} from "react" + +import {Button, Typography} from "antd" +import {useRouter} from "next/router" + +import useURL from "@/oss/hooks/useURL" + +const FreePlanBanner = () => { + const router = useRouter() + const {projectURL} = useURL() + + return ( +
+ Free Plan + + Create unlimited applications & run unlimited evaluations. Upgrade today and get + more out of Agenta.{" "} + + +
+ ) +} + +export default memo(FreePlanBanner) diff --git a/web/ee/src/components/Banners/BillingPlanBanner/FreeTrialBanner.tsx b/web/ee/src/components/Banners/BillingPlanBanner/FreeTrialBanner.tsx new file mode 100644 index 0000000000..3fa77e8671 --- /dev/null +++ b/web/ee/src/components/Banners/BillingPlanBanner/FreeTrialBanner.tsx @@ -0,0 +1,33 @@ +import {Button, Typography} from "antd" +import {useRouter} from "next/router" +import useURL from "@/oss/hooks/useURL" + +import {SubscriptionType} from "@/oss/services/billing/types" + +import SubscriptionPlanDetails from "@/agenta-oss-common/components/pages/settings/Billing/Modals/PricingModal/assets/SubscriptionPlanDetails" + +const FreeTrialBanner = ({subscription}: {subscription: SubscriptionType}) => { + const router = useRouter() + const {projectURL} = useURL() + + return ( +
+ + + + + Create unlimited applications & run unlimited evaluations. Upgrade today to keep pro + plan features. + + +
+ ) +} + +export default FreeTrialBanner diff --git a/web/ee/src/components/DeleteEvaluationModal/DeleteEvaluationModal.tsx b/web/ee/src/components/DeleteEvaluationModal/DeleteEvaluationModal.tsx new file mode 100644 index 0000000000..3631ee59eb --- /dev/null +++ b/web/ee/src/components/DeleteEvaluationModal/DeleteEvaluationModal.tsx @@ -0,0 +1,59 @@ +import EnhancedModal from "@agenta/oss/src/components/EnhancedUIs/Modal" +import {DeleteOutlined} from "@ant-design/icons" +import {Typography} from "antd" + +import {DeleteEvaluationModalProps} from "./types" + +const DeleteEvaluationModal = ({ + evaluationType, + isMultiple = false, + ...props +}: DeleteEvaluationModalProps) => { + return ( + , type: "primary"}} + centered + zIndex={2000} + > +
+ + Are you sure you want to delete? + + +
+ + {isMultiple + ? `The selected ${evaluationType.split("|").length} evaluations will be permanently deleted.` + : `A deleted ${evaluationType} cannot be restored.`} + + +
+ + {isMultiple + ? "You are about to delete the following evaluations:" + : "You are about to delete:"} + + + {isMultiple + ? evaluationType.split(" | ").map((item, index) => ( +
+ • {item.trim()} +
+ )) + : evaluationType} +
+
+
+
+
+ ) +} + +export default DeleteEvaluationModal diff --git a/web/ee/src/components/DeleteEvaluationModal/types.ts b/web/ee/src/components/DeleteEvaluationModal/types.ts new file mode 100644 index 0000000000..7acded39ee --- /dev/null +++ b/web/ee/src/components/DeleteEvaluationModal/types.ts @@ -0,0 +1,6 @@ +import type {ModalProps} from "antd" + +export interface DeleteEvaluationModalProps extends ModalProps { + evaluationType: string + isMultiple?: boolean +} diff --git a/web/ee/src/components/DeploymentHistory/DeploymentHistory.tsx b/web/ee/src/components/DeploymentHistory/DeploymentHistory.tsx new file mode 100644 index 0000000000..d596e2bc42 --- /dev/null +++ b/web/ee/src/components/DeploymentHistory/DeploymentHistory.tsx @@ -0,0 +1,347 @@ +import {useCallback, useEffect, useRef, useState} from "react" + +import {Button, Card, Divider, Space, Typography, notification} from "antd" +import dayjs from "dayjs" +import duration from "dayjs/plugin/duration" +import relativeTime from "dayjs/plugin/relativeTime" +import debounce from "lodash/debounce" +import {createUseStyles} from "react-jss" + +import {useAppTheme} from "@/oss/components/Layout/ThemeContextProvider" +import ResultComponent from "@/oss/components/ResultComponent/ResultComponent" +import {Environment, JSSTheme} from "@/oss/lib/Types" +import { + createRevertDeploymentRevision, + fetchAllDeploymentRevisions, +} from "@/oss/services/deploymentVersioning/api" + +import {DeploymentRevisionConfig, DeploymentRevisions} from "../../lib/types_ee" + +dayjs.extend(relativeTime) +dayjs.extend(duration) + +interface DeploymentHistoryProps { + selectedEnvironment: Environment +} + +const {Text} = Typography + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + display: "flex", + gap: 20, + }, + historyItemsContainer: { + flex: 0.2, + backgroundColor: theme.isDark ? "#333" : "#FFFFFF", + border: theme.isDark ? "" : "1px solid #f0f0f0", + overflowY: "scroll", + padding: 10, + borderRadius: 10, + minWidth: 300, + height: "600px", + }, + historyItems: { + display: "flex", + flexDirection: "column", + padding: "10px 20px", + margin: "20px 0", + borderRadius: 10, + cursor: "pointer", + }, + promptHistoryCard: { + margin: "30px", + }, + promptHistoryInfo: { + flex: 0.8, + backgroundColor: theme.isDark ? "#333" : "#FFFFFF", + border: theme.isDark ? "" : "1px solid #f0f0f0", + padding: 20, + borderRadius: 10, + }, + promptHistoryInfoHeader: { + display: "flex", + alignItems: "center", + justifyContent: "space-between", + "& h1": { + fontSize: 32, + }, + }, + emptyContainer: { + display: "flex", + alignItems: "center", + justifyContent: "center", + margin: "30px auto", + fontSize: 20, + fontWeight: "bold", + }, + divider: { + margin: "10px 0", + }, + historyItemsTitle: { + fontSize: 14, + "& span": { + color: theme.isDark ? "#f1f5f8" : "#656d76", + }, + }, + noParams: { + color: theme.colorTextDescription, + textAlign: "center", + marginTop: theme.marginLG, + }, + loadingContainer: { + display: "grid", + placeItems: "center", + height: "100%", + }, +})) + +const DeploymentHistory: React.FC = ({selectedEnvironment}) => { + const {appTheme} = useAppTheme() + const classes = useStyles() + const [activeItem, setActiveItem] = useState(0) + const [isLoading, setIsLoading] = useState(false) + const [isReverting, setIsReverted] = useState(false) + const [showDeployment, setShowDeployment] = useState() + const [deploymentRevisionId, setDeploymentRevisionId] = useState("") + const [deploymentRevisions, setDeploymentRevisions] = useState() + const [showDeploymentLoading, setShowDeploymentLoading] = useState(false) + const {current} = useRef<{id: string; revision: string | undefined}>({ + id: "", + revision: "", + }) + + useEffect(() => { + current.revision = deploymentRevisions?.revisions[activeItem].deployed_app_variant_revision + }, [activeItem]) + + const fetchData = async () => { + setIsLoading(true) + try { + const data = await fetchAllDeploymentRevisions( + selectedEnvironment?.app_id, + selectedEnvironment?.name, + ) + setDeploymentRevisions(data) + current.id = data.deployed_app_variant_revision_id || "" + } catch (error) { + setIsLoading(false) + } finally { + setIsLoading(false) + } + } + + const handleRevert = useCallback(async (deploymentRevisionId: string) => { + setIsReverted(true) + try { + const response = await createRevertDeploymentRevision(deploymentRevisionId) + notification.success({ + message: "Environment Revision", + description: response?.data, + duration: 3, + }) + await fetchData() + } catch (err) { + console.error(err) + } finally { + setIsReverted(false) + } + }, []) + + useEffect(() => { + fetchData() + }, [selectedEnvironment.app_id, selectedEnvironment.name]) + + useEffect(() => { + const fetch = async () => { + try { + setShowDeploymentLoading(true) + if (deploymentRevisions && deploymentRevisions.revisions.length) { + setActiveItem(deploymentRevisions.revisions.length - 1) + + const mod = await import("@/oss/services/deploymentVersioning/api") + const fetchAllDeploymentRevisionConfig = mod?.fetchAllDeploymentRevisionConfig + if (!mod || !fetchAllDeploymentRevisionConfig) return + + const revisionConfig = await fetchAllDeploymentRevisionConfig( + deploymentRevisions.revisions[deploymentRevisions.revisions.length - 1].id, + ) + setShowDeployment(revisionConfig) + } + } catch (error) { + console.error(error) + } finally { + setShowDeploymentLoading(false) + } + } + + fetch() + }, [deploymentRevisions]) + + const handleShowDeployments = async (revision: number, index: number) => { + const findRevision = deploymentRevisions?.revisions.find( + (deploymentRevision) => deploymentRevision.revision === revision, + ) + + if (!findRevision) return + + setActiveItem(index) + setDeploymentRevisionId(findRevision.id) + + try { + setShowDeploymentLoading(true) + const mod = await import("@/oss/services/deploymentVersioning/api") + const fetchAllDeploymentRevisionConfig = mod?.fetchAllDeploymentRevisionConfig + if (!mod || !fetchAllDeploymentRevisionConfig) return + + const revisionConfig = await fetchAllDeploymentRevisionConfig(findRevision.id) + + setShowDeployment(revisionConfig) + } catch (error) { + console.error(error) + } finally { + setShowDeploymentLoading(false) + } + } + + const debouncedHandleShowDeployments = debounce(handleShowDeployments, 300) + + return ( + <> + {isLoading ? ( + + ) : deploymentRevisions?.revisions?.length ? ( +
+
+ {deploymentRevisions?.revisions + ?.map((item, index) => ( +
+ debouncedHandleShowDeployments(item.revision, index) + } + > + + + Revision v{item.revision} + + + + {dayjs(item.created_at).fromNow()} + + + + + {deploymentRevisions.deployed_app_variant_revision_id === + item.deployed_app_variant_revision && ( + + In production... + + )} + + + + +
+ Modified By: + {item.modified_by} +
+
+
+ )) + .reverse()} +
+ +
+
+

Information

+ + {deploymentRevisions.revisions.length > 1 && ( + + )} +
+ + {showDeploymentLoading ? ( +
+ +
+ ) : ( + <> + {showDeployment?.parameters && + Object.keys(showDeployment?.parameters).length ? ( + + + <> + {Object.entries(showDeployment.parameters).map( + ([key, value], index) => { + return ( + <> +
+ + {key}:{" "} + + {Array.isArray(value) + ? JSON.stringify(value) + : typeof value === "boolean" + ? `${value}` + : value} +
+ + ) + }, + )} + +
+
+ ) : ( +
No parameters to display
+ )} + + )} +
+
+ ) : ( +
You have no saved prompts
+ )} + + ) +} + +export default DeploymentHistory diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/AutoEvalRunSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/AutoEvalRunSkeleton.tsx new file mode 100644 index 0000000000..bdb1451fc2 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/AutoEvalRunSkeleton.tsx @@ -0,0 +1,28 @@ +import {memo} from "react" + +import {useRouter} from "next/router" + +import EvalRunHeaderSkeleton from "../components/EvalRunHeader/assets/EvalRunHeaderSkeleton" +import EvalRunOverviewViewerSkeleton from "../components/EvalRunOverviewViewer/assets/EvalRunOverviewViewerSkeleton" +import EvalRunPromptConfigViewerSkeleton from "../components/EvalRunPromptConfigViewer/assets/EvalRunPromptConfigViewerSkeleton" +import EvalRunTestCaseViewerSkeleton from "../components/EvalRunTestCaseViewer/assets/EvalRunTestCaseViewerSkeleton" + +const AutoEvalRunSkeleton = () => { + const router = useRouter() + const viewType = router.query.view as string + + return ( +
+ + {viewType === "test-cases" ? ( + + ) : viewType === "prompt" ? ( + + ) : ( + + )} +
+ ) +} + +export default memo(AutoEvalRunSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/EvalNameTag.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/EvalNameTag.tsx new file mode 100644 index 0000000000..c67397dbce --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/EvalNameTag.tsx @@ -0,0 +1,270 @@ +import {useCallback, useMemo} from "react" + +import {Star, XCircle} from "@phosphor-icons/react" +import {Button, Popover, PopoverProps, Tag, TagProps, Tooltip} from "antd" +import clsx from "clsx" +import {useAtom} from "jotai" +import {useRouter} from "next/router" + +import TooltipWithCopyAction from "@/oss/components/TooltipWithCopyAction" +import UserAvatarTag from "@/oss/components/ui/UserAvatarTag" +import {EnrichedEvaluationRun} from "@/oss/lib/hooks/usePreviewEvaluations/types" + +import {urlStateAtom} from "../../state/urlState" + +import TagWithLink from "./TagWithLink" +import VariantTag from "./VariantTag" +import { + combineAppNameWithLabel, + deriveVariantAppName, + deriveVariantLabelParts, + getVariantDisplayMetadata, + normalizeId, + prettifyVariantLabel, +} from "./variantUtils" + +interface EvalNameTagProps extends TagProps { + run: EnrichedEvaluationRun + showClose?: boolean + showPin?: boolean + isBaseEval?: boolean + onlyShowBasePin?: boolean + popoverProps?: PopoverProps + allowVariantNavigation?: boolean +} +const EvalNameTag = ({ + run, + showClose = false, + showPin = false, + isBaseEval = false, + onlyShowBasePin = false, + className, + popoverProps, + allowVariantNavigation = true, + ...props +}: EvalNameTagProps) => { + const router = useRouter() + const normalizedRouteAppId = useMemo( + () => normalizeId(router.query.app_id as string | undefined), + [router.query.app_id], + ) + const [urlState, setUrlState] = useAtom(urlStateAtom) + + const onClose = useCallback( + async (runId: string) => { + const compareRunIds = urlState.compare || [] + const updatedRuns = compareRunIds.filter((id) => id !== runId) + + await router.replace( + { + pathname: router.pathname, + query: {...router.query, compare: updatedRuns}, + }, + undefined, + {shallow: true}, + ) + + setUrlState((draft) => { + draft.compare = updatedRuns.length > 0 ? updatedRuns : undefined + }) + }, + [urlState, router, setUrlState], + ) + + const onPin = useCallback(async () => { + const currentBaseId = router.query.evaluation_id?.toString() + const compareRunIds = urlState.compare || [] + const targetId = run.id + + if (!currentBaseId || targetId === currentBaseId) return + const targetIndex = compareRunIds.indexOf(targetId) + if (targetIndex === -1) return + + const updatedCompare = [...compareRunIds] + updatedCompare[targetIndex] = currentBaseId + + await router.replace( + { + pathname: router.pathname, + query: { + ...router.query, + evaluation_id: targetId, + compare: updatedCompare, + }, + }, + undefined, + {shallow: true}, + ) + setUrlState((draft) => { + draft.compare = updatedCompare + }) + }, [urlState, router, run?.id, setUrlState]) + + return ( + +
+ {run?.name} +
+ {showPin && ( + +
+
+
+
+ ID + + + {run?.id.split("-")[run?.id.split("-").length - 1]} + + +
+
+ Testset + +
+
+ Variant + {run?.variants && run?.variants.length > 0 ? ( + (() => { + const variant: any = run?.variants[0] + const summary = getVariantDisplayMetadata(variant) + const {label: formattedLabel, revision: labelRevision} = + deriveVariantLabelParts({ + variant, + displayLabel: summary.label, + }) + const resolvedAppName = + deriveVariantAppName({ + variant, + fallbackAppName: + run?.appName || + (run as any)?.app_name || + (run as any)?.app?.name, + }) ?? run?.appName + + const prettyLabel = combineAppNameWithLabel( + resolvedAppName, + prettifyVariantLabel(formattedLabel) ?? formattedLabel, + ) + + const candidateRevisionId = + summary.revisionId || + normalizeId(variant?.id) || + normalizeId(variant?.variantId) + const candidateAppId = normalizeId( + variant?.appId || + (variant as any)?.app_id || + run?.appId || + (run as any)?.app_id, + ) + + const resolvedAppId = candidateAppId || normalizedRouteAppId + const blockedByRuntime = + Boolean(normalizedRouteAppId) && + resolvedAppId === normalizedRouteAppId && + summary.hasRuntime === false + + const canNavigate = + allowVariantNavigation && + Boolean(candidateRevisionId && resolvedAppId) && + summary.isHealthy !== false && + !blockedByRuntime + + return ( + + ) + })() + ) : ( + + Not available + + )} +
+
+ Created on + {run?.createdAt} +
+ {!!run?.createdBy?.user?.username && ( +
+ Created by + +
+ )} +
+ + } + > + + {showPin && ( + + + + )} + {run?.name} + {showClose && !isBaseEval && ( + + onClose(run.id)} + /> + + )} + +
+ ) +} + +export default EvalNameTag diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/TagWithLink.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/TagWithLink.tsx new file mode 100644 index 0000000000..254c78476f --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/TagWithLink.tsx @@ -0,0 +1,34 @@ +import {ArrowSquareOut} from "@phosphor-icons/react" +import {Tag, TagProps} from "antd" +import clsx from "clsx" +import {useRouter} from "next/router" + +interface TagWithLinkProps extends TagProps { + name: string + href: string + showIcon?: boolean +} +const TagWithLink = ({name, href, className, showIcon = true, ...props}: TagWithLinkProps) => { + const router = useRouter() + return ( + router.push(href)} + {...props} + > + {name} + {showIcon && ( + + )} + + ) +} + +export default TagWithLink diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/VariantTag.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/VariantTag.tsx new file mode 100644 index 0000000000..c8a841a305 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/VariantTag.tsx @@ -0,0 +1,262 @@ +import {useMemo} from "react" + +import {ArrowSquareOut} from "@phosphor-icons/react" +import {useQueryClient} from "@tanstack/react-query" +import {Skeleton, Tag} from "antd" +import clsx from "clsx" +import {useRouter} from "next/router" +import {useSetAtom} from "jotai" + +import useURL from "@/oss/hooks/useURL" +import {buildRevisionsQueryParam} from "@/oss/lib/helpers/url" +import type {EnrichedEvaluationRun} from "@/oss/lib/hooks/usePreviewEvaluations/types" +import {recentAppIdAtom, routerAppIdAtom} from "@/oss/state/app" + +import { + combineAppNameWithLabel, + deriveVariantAppName, + deriveVariantLabelParts, + getVariantDisplayMetadata, + normalizeId, + normalizeLabel, +} from "./variantUtils" + +interface VariantTagProps { + variantName?: string + revision?: number | string + id?: string | null + className?: string + isLoading?: boolean + disabled?: boolean + isDeleted?: boolean + enrichedRun?: EnrichedEvaluationRun + variant?: any +} + +const VariantTag = ({ + variantName, + revision, + id, + className, + isLoading, + disabled = false, + isDeleted = false, + enrichedRun, + variant, +}: VariantTagProps) => { + const router = useRouter() + const queryClient = useQueryClient() + const setRouterAppId = useSetAtom(routerAppIdAtom) + const setRecentAppId = useSetAtom(recentAppIdAtom) + const routeAppId = normalizeId(router.query.app_id as string | undefined) + const {baseAppURL} = useURL() + + const variantsFromRun = useMemo(() => { + if (enrichedRun?.variants && Array.isArray(enrichedRun.variants)) { + return enrichedRun.variants as any[] + } + return [] + }, [enrichedRun]) + + const normalizedTargetId = useMemo(() => normalizeId(id), [id]) + const normalizedTargetName = useMemo(() => normalizeLabel(variantName), [variantName]) + + const variantFromRun = useMemo(() => { + if (!variantsFromRun.length) return undefined + + const match = variantsFromRun.find((candidate: any) => { + const candidateIds = [ + normalizeId(candidate?._revisionId), + normalizeId(candidate?.id), + normalizeId(candidate?.variantId), + normalizeId(candidate?.revisionId), + ].filter(Boolean) as string[] + + if (normalizedTargetId && candidateIds.includes(normalizedTargetId)) { + return true + } + + if (normalizedTargetName) { + const candidateNames = [ + normalizeLabel(candidate?.variantName), + normalizeLabel(candidate?.configName), + normalizeLabel(candidate?.name), + normalizeLabel(candidate?.variantId), + ].filter(Boolean) as string[] + if (candidateNames.includes(normalizedTargetName)) { + return true + } + } + + return false + }) + + return match ?? variantsFromRun[0] + }, [variantsFromRun, normalizedTargetId, normalizedTargetName]) + + const resolvedVariant = useMemo(() => { + if (variant) { + if (variantFromRun) { + return { + ...variantFromRun, + ...variant, + } + } + return variant + } + return variantFromRun + }, [variant, variantFromRun]) + + if (isLoading) { + return + } + + const baseLabel = + normalizeLabel(variantName) ?? + normalizeLabel(resolvedVariant?.variantName) ?? + normalizeLabel(resolvedVariant?.name) ?? + "Variant unavailable" + + const display = useMemo( + () => + getVariantDisplayMetadata(resolvedVariant, { + fallbackLabel: normalizedTargetName ?? baseLabel, + fallbackRevisionId: normalizedTargetId, + requireRuntime: false, + }), + [resolvedVariant, normalizedTargetName, baseLabel, normalizedTargetId], + ) + + const {label: preferredLabel, revision: labelRevision} = useMemo( + () => + deriveVariantLabelParts({ + variant: resolvedVariant, + displayLabel: display.label ?? baseLabel, + }), + [resolvedVariant, display.label, baseLabel], + ) + + const variantAppName = useMemo( + () => + deriveVariantAppName({ + variant: resolvedVariant, + fallbackAppName: + (resolvedVariant as any)?.appName ?? + (resolvedVariant as any)?.application?.name ?? + enrichedRun?.appName ?? + (enrichedRun as any)?.app_name ?? + (enrichedRun as any)?.app?.name, + }), + [resolvedVariant, enrichedRun], + ) + + const variantAppId = useMemo( + () => + normalizeId( + (resolvedVariant as any)?.appId ?? + (resolvedVariant as any)?.app_id ?? + (resolvedVariant as any)?.application?.id ?? + (resolvedVariant as any)?.application_id ?? + (resolvedVariant as any)?.application_ref?.id ?? + (resolvedVariant as any)?.applicationRef?.id, + ), + [resolvedVariant], + ) + + const runAppId = useMemo( + () => + normalizeId( + (enrichedRun as any)?.appId ?? + (enrichedRun as any)?.app_id ?? + (enrichedRun as any)?.app?.id ?? + (enrichedRun as any)?.application?.id, + ), + [enrichedRun], + ) + + const targetAppId = variantAppId || runAppId || routeAppId + const resolvedLabel = isDeleted + ? "Variant deleted" + : combineAppNameWithLabel(variantAppName, preferredLabel) + + const derivedRevisionId = display.revisionId + const selectedRevisionId = derivedRevisionId || normalizedTargetId + + const derivedRevision = useMemo(() => { + if (revision !== undefined && revision !== null && revision !== "") { + return revision + } + const candidate: any = resolvedVariant + const fromVariant = + candidate?.revision ?? + candidate?.revisionLabel ?? + candidate?.version ?? + candidate?._revision ?? + undefined + if ( + fromVariant !== undefined && + fromVariant !== null && + String(fromVariant).toString().trim() !== "" + ) { + return fromVariant + } + return labelRevision ?? "" + }, [resolvedVariant, revision, labelRevision]) + + const hasValidRevision = Boolean(selectedRevisionId || labelRevision) + const isRouteAppContext = Boolean(routeAppId) && targetAppId === routeAppId + const blockedByRuntime = isRouteAppContext && display.hasRuntime === false + + const canNavigate = + !isDeleted && + Boolean(targetAppId) && + hasValidRevision && + display.isHealthy !== false && + !blockedByRuntime + const effectiveDisabled = Boolean(disabled) || isDeleted || !canNavigate + + const hasRevision = + derivedRevision !== undefined && + derivedRevision !== null && + String(derivedRevision).toString().trim() !== "" + + return ( + { + if (effectiveDisabled || !selectedRevisionId || !targetAppId) return + setRouterAppId(targetAppId) + setRecentAppId(targetAppId) + + queryClient.removeQueries({queryKey: ["variants"]}) + queryClient.removeQueries({queryKey: ["appSpec"]}) + queryClient.removeQueries({queryKey: ["variantRevisions"]}) + + await router.push({ + pathname: `${baseAppURL}/${targetAppId}/playground`, + query: { + revisions: buildRevisionsQueryParam([selectedRevisionId]), + }, + }) + }} + > + + {resolvedLabel} + {hasRevision ? ` v${derivedRevision}` : ""} + + {!effectiveDisabled && ( + + )} + + ) +} + +export default VariantTag diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/types.ts b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/types.ts new file mode 100644 index 0000000000..b2ab78e971 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/types.ts @@ -0,0 +1,7 @@ +export interface AutoEvalRunDetailsProps { + name: string + description: string + id: string + isLoading: boolean +} +export type ViewOptionsType = "overview" | "test-cases" diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/utils.ts b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/utils.ts new file mode 100644 index 0000000000..4808772970 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/utils.ts @@ -0,0 +1,52 @@ +import {canonicalizeMetricKey, getMetricDisplayName} from "@/oss/lib/metricUtils" + +export const formatMetricName = (name: string) => { + const canonical = canonicalizeMetricKey(name) + + // Prefer rich labels for well-known invocation metrics + if (canonical.startsWith("attributes.ag.metrics.")) { + return getMetricDisplayName(canonical) + } + + if (canonical.startsWith("attributes.ag.")) { + const tail = canonical.split(".").pop() ?? canonical + return tail + .replace(/[_.-]/g, " ") + .replace(/\s+/g, " ") + .trim() + .replace(/\b\w/g, (c) => c.toUpperCase()) + } + + const formattedName = canonical + .replace(/[_.]/g, " ") + .replace(/([A-Z])/g, " $1") + .trim() + .toLocaleLowerCase() + + if (formattedName === "duration") return "Latency" + if (formattedName.includes("cost")) return "Cost" + return formattedName +} + +export const EVAL_TAG_COLOR = { + 1: "blue", + 2: "orange", + 3: "purple", + 4: "cyan", + 5: "lime", +} +export const EVAL_BG_COLOR = { + 1: "rgba(230, 244, 255, 0.5)", + 2: "rgba(255, 242, 232, 0.5)", + 3: "rgba(249, 240, 255, 0.5)", + 4: "rgba(230, 255, 251, 0.5)", + 5: "rgba(255, 255, 230, 0.5)", +} + +export const EVAL_COLOR = { + 1: "rgba(145, 202, 255, 1)", + 2: "rgba(255, 187, 150, 1)", + 3: "rgba(211, 173, 247, 1)", + 4: "rgba(135, 232, 222, 1)", + 5: "rgba(200, 240, 150, 1)", +} diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/variantUtils.ts b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/variantUtils.ts new file mode 100644 index 0000000000..4531649740 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/assets/variantUtils.ts @@ -0,0 +1,170 @@ +export const normalizeId = (value: unknown): string | undefined => { + if (value === undefined || value === null) return undefined + const stringValue = String(value) + if ( + stringValue.trim() === "" || + stringValue === "undefined" || + stringValue === "null" || + stringValue === "[object Object]" || + stringValue === "NaN" + ) { + return undefined + } + return stringValue +} + +export const normalizeLabel = (value: unknown): string | undefined => { + if (typeof value !== "string") return undefined + const trimmed = value.trim() + return trimmed.length > 0 ? trimmed : undefined +} + +export interface VariantDisplayOptions { + fallbackLabel?: string + fallbackRevisionId?: string + /** When true (default), navigation requires a runtime endpoint */ + requireRuntime?: boolean +} + +export interface VariantDisplayMetadata { + label: string + revisionId: string + isHealthy: boolean + hasRuntime: boolean + canNavigate: boolean +} + +export const getVariantDisplayMetadata = ( + variant: any, + {fallbackLabel, fallbackRevisionId, requireRuntime = true}: VariantDisplayOptions = {}, +): VariantDisplayMetadata => { + const label = + normalizeLabel(variant?.variantName) ?? + normalizeLabel(variant?.configName) ?? + normalizeLabel(variant?.name) ?? + normalizeLabel(variant?.variantId) ?? + normalizeLabel(fallbackLabel) ?? + "Variant unavailable" + + const revisionId = + normalizeId(variant?._revisionId) ?? + normalizeId(variant?.id) ?? + normalizeId(variant?.variantId) ?? + normalizeId(variant?.revisionId) ?? + normalizeId(fallbackRevisionId) ?? + "" + + const hasRuntime = Boolean( + variant?.uri || + variant?.uriObject?.runtimePrefix || + variant?.runtime?.uri || + variant?.runtime?.runtimePrefix, + ) + const isHealthy = variant?.appStatus !== false + + const canNavigate = Boolean(revisionId) && isHealthy && (requireRuntime ? hasRuntime : true) + + return { + label, + revisionId, + isHealthy, + hasRuntime, + canNavigate, + } +} + +const HEX_SEGMENT_REGEX = /^[0-9a-f]{8,}$/i + +export const prettifyVariantLabel = (label?: string): string | undefined => { + if (!label) return label + const parts = label.split("-") + if (parts.length <= 1) { + return label + } + + const last = parts[parts.length - 1] + if (HEX_SEGMENT_REGEX.test(last)) { + return parts.slice(0, -1).join("-") + } + + return label +} + +export const deriveVariantLabelParts = ({ + variant, + displayLabel, +}: { + variant?: any + displayLabel?: string +}): {label: string; revision?: string} => { + const normalizedVariantLabel = + normalizeLabel(variant?.variantName) ?? + normalizeLabel(variant?.configName) ?? + normalizeLabel(variant?.name) ?? + undefined + + const normalizedVariantId = normalizeLabel(variant?.variantId) + + const rawLabel = normalizedVariantLabel ?? normalizedVariantId ?? displayLabel ?? "Variant" + const trimmed = prettifyVariantLabel(rawLabel) ?? rawLabel + + const primaryRevision = + variant?.revision ?? + variant?.revisionLabel ?? + variant?.version ?? + variant?._revision ?? + undefined + + if ( + primaryRevision !== undefined && + primaryRevision !== null && + String(primaryRevision).toString().trim() !== "" + ) { + return {label: trimmed, revision: String(primaryRevision)} + } + + const segments = trimmed.split("-") + if (segments.length > 1) { + const last = segments[segments.length - 1] + if (/^\d+$/.test(last)) { + const base = segments.slice(0, -1).join("-") || segments.join("-") + return {label: base, revision: last} + } + } + + return {label: trimmed, revision: undefined} +} + +export const deriveVariantAppName = ({ + variant, + fallbackAppName, +}: { + variant?: any + fallbackAppName?: string +}): string | undefined => { + return ( + normalizeLabel(variant?.appName) ?? + normalizeLabel(variant?.application?.name) ?? + normalizeLabel(variant?.application?.appName) ?? + normalizeLabel(variant?.application_ref?.name) ?? + normalizeLabel(variant?.applicationRef?.name) ?? + normalizeLabel(fallbackAppName) + ) +} + +export const combineAppNameWithLabel = (appName: string | undefined, label?: string): string => { + const normalizedLabel = label?.trim() + const normalizedApp = normalizeLabel(appName) + + if (!normalizedLabel || normalizedLabel.length === 0) { + return normalizedApp ?? "Variant unavailable" + } + + if (!normalizedApp) { + return normalizedLabel + } + + return normalizedLabel.toLowerCase().startsWith(normalizedApp.toLowerCase()) + ? normalizedLabel + : `${normalizedApp} ${normalizedLabel}` +} diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunCompareMenu/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunCompareMenu/index.tsx new file mode 100644 index 0000000000..9e5e22182d --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunCompareMenu/index.tsx @@ -0,0 +1,269 @@ +import {memo, useCallback, useEffect, useMemo, useRef, useState} from "react" + +import {Check, Plus} from "@phosphor-icons/react" +import {Button, ButtonProps, Input, Popover, PopoverProps, Typography, Tag, message} from "antd" +import clsx from "clsx" +import {useAtom, useAtomValue} from "jotai" +import {useRouter} from "next/router" +import {useLocalStorage} from "usehooks-ts" + +import {useRunId} from "@/oss/contexts/RunIdContext" +import useFocusInput from "@/oss/hooks/useFocusInput" +import {EvaluationType} from "@/oss/lib/enums" +import {evaluationRunStateFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import usePreviewEvaluations from "@/oss/lib/hooks/usePreviewEvaluations" +import {EnrichedEvaluationRun} from "@/oss/lib/hooks/usePreviewEvaluations/types" + +import {urlStateAtom} from "../../../state/urlState" + +const filters = ["all", "success", "failed"] +const failedFilters = ["errors", "error", "failed", "failure"] + +const EvalRunCompareMenu = ({ + popoverProps, + buttonProps, +}: { + popoverProps?: PopoverProps + buttonProps?: ButtonProps +}) => { + const [isMenuOpen, setIsMenuOpen] = useState(false) + const [searchTerm, setSearchTerm] = useState("") + const [filter, setFilter] = useLocalStorage("eval-compare-popup-filter", "") + const {inputRef} = useFocusInput({isOpen: isMenuOpen}) + const router = useRouter() + const runId = useRunId() + // Use ref to track previous compareRunIds to avoid infinite loops + const prevCompareRunIdsRef = useRef([]) + + // atoms + const evaluation = useAtomValue(evaluationRunStateFamily(runId!)) + const [urlState, setUrlState] = useAtom(urlStateAtom) + const enrichedRun = evaluation?.enrichedRun + const compareRunIds = urlState.compare || [] + + const derivedAppId = useMemo(() => { + return enrichedRun?.appId ?? enrichedRun?.variants?.[0]?.appId ?? undefined + }, [enrichedRun]) + + const {runs: projectRuns} = usePreviewEvaluations({ + skip: false, + types: [EvaluationType.auto_exact_match], + appId: "", + }) + + const {runs: appRuns} = usePreviewEvaluations({ + skip: false, + types: [EvaluationType.auto_exact_match], + appId: derivedAppId, + }) + + const runs = (projectRuns.length ? projectRuns : appRuns) as EnrichedEvaluationRun[] + + // Track compare ids locally to avoid redundant work; do not overwrite urlState + useEffect(() => { + const prevIds = prevCompareRunIdsRef.current + const currentIds = compareRunIds + const isDifferent = + prevIds.length !== currentIds.length || + prevIds.some((id, index) => id !== currentIds[index]) + + if (isDifferent) { + prevCompareRunIdsRef.current = [...compareRunIds] + } + }, [compareRunIds]) + + const resolveTestsetIds = useCallback((run?: EnrichedEvaluationRun | null) => { + if (!run) return new Set() + const ids = new Set() + ;(run.testsets ?? []).forEach((testset) => { + if (testset?.id) ids.add(testset.id) + }) + ;(run.data?.steps ?? []).forEach((step) => { + const id = step?.references?.testset?.id + if (id) ids.add(id) + }) + return ids + }, []) + + const evaluations = useMemo(() => { + const baseIds = resolveTestsetIds(enrichedRun) + const baseIdList = Array.from(baseIds) + + const matchedTestsetEvals = runs.filter((run) => { + if (!baseIds.size) return false + const runIds = resolveTestsetIds(run) + return baseIdList.some((id) => runIds.has(id)) + }) + + const evals = matchedTestsetEvals.filter((run) => run?.id !== enrichedRun?.id) + + const autoEvals = evals?.filter((run) => + run?.data?.steps.every( + (step) => step?.type !== "annotation" || step?.origin === "auto", + ), + ) + + return autoEvals + }, [runs, enrichedRun, resolveTestsetIds]) + + const filteredEvals = useMemo(() => { + if (searchTerm.trim().length > 0) { + return evaluations.filter((e) => + e?.name.toLowerCase().includes(searchTerm.toLowerCase()), + ) + } + + if (filter === "success") { + return evaluations.filter((e) => e.status === filter) + } + + if (filter === "failed") { + return evaluations.filter((e) => failedFilters.includes(e.status)) + } + + return evaluations + }, [searchTerm, evaluations, filter]) + + const onMutateRun = useCallback( + async (runId: string) => { + if (compareRunIds.includes(runId)) { + const updatedRuns = compareRunIds.filter((id) => id !== runId) + await router.replace( + { + pathname: router.pathname, + query: {...router.query, compare: updatedRuns}, + }, + undefined, + {shallow: true}, + ) + setUrlState((draft) => { + draft.compare = updatedRuns.length > 0 ? updatedRuns : undefined + }) + } else { + if (compareRunIds.length === 4) { + message.info("You can only compare up to 5 runs") + return + } + await router.replace( + { + pathname: router.pathname, + query: {...router.query, compare: [...compareRunIds, runId]}, + }, + undefined, + {shallow: true}, + ) + setUrlState((draft) => { + draft.compare = [...compareRunIds, runId] + }) + } + }, + [compareRunIds], + ) + + return ( + +
+
+ + Add evaluations using testset: + + + {enrichedRun?.testsets?.[0]?.name} + +
+ + setSearchTerm(e.target.value)} + /> +
+ +
+ Filters: + + {filters.map((f) => ( + + ))} +
+ + {filteredEvals?.length > 0 ? ( +
+ {filteredEvals?.map((evaluation) => ( +
onMutateRun(evaluation.id)} + > +
+ + {evaluation.name} + + +
+ + {evaluation.variants?.[0]?.variantName || "-"} + + v{evaluation.variants?.[0]?.revision || "0"} + + + {compareRunIds?.includes(evaluation.id) ? ( + + ) : null} +
+
+
+ + {evaluation.description || "No description"} + + + {evaluation.createdAt} + +
+
+ ))} +
+ ) : ( +
+ No evaluations found +
+ )} + + } + {...popoverProps} + > + +
+ ) +} + +export default memo(EvalRunCompareMenu) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/assets/RunOutput.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/assets/RunOutput.tsx new file mode 100644 index 0000000000..9c4151f6eb --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/assets/RunOutput.tsx @@ -0,0 +1,60 @@ +import SimpleSharedEditor from "@/oss/components/EditorViews/SimpleSharedEditor" +import {useInvocationResult} from "@/oss/lib/hooks/useInvocationResult" +import clsx from "clsx" + +const RunOutput = ({ + runId, + scenarioId, + stepKey, + showComparisons, +}: { + runId: string + scenarioId?: string + stepKey?: string + showComparisons?: boolean +}) => { + const { + value, + messageNodes: nodes, + hasError: err, + } = useInvocationResult({ + scenarioId, + stepKey, + editorType: "simple", + viewType: "single", + runId, + }) + return ( +
+ {nodes ? ( + nodes + ) : ( + {}} + initialValue={ + !!value && typeof value !== "string" ? JSON.stringify(value) : value + } + headerName="Output" + editorType="borderless" + state="readOnly" + disabled + readOnly + editorClassName="!text-xs" + error={err} + placeholder="N/A" + className="!w-[97.5%]" + /> + )} +
+ ) +} + +export default RunOutput diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/assets/RunTraceHeader.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/assets/RunTraceHeader.tsx new file mode 100644 index 0000000000..d37ba99532 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/assets/RunTraceHeader.tsx @@ -0,0 +1,79 @@ +import {EVAL_TAG_COLOR} from "@/oss/components/EvalRunDetails/AutoEvalRun/assets/utils" +import {useRunId} from "@/oss/contexts/RunIdContext" +import { + evalAtomStore, + evaluationRunStateFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {useInvocationResult} from "@/oss/lib/hooks/useInvocationResult" +import clsx from "clsx" +import {useAtomValue} from "jotai" +import {memo} from "react" +import dynamic from "next/dynamic" +import EvalNameTag from "@/oss/components/EvalRunDetails/AutoEvalRun/assets/EvalNameTag" + +const GenerationResultUtils = dynamic( + () => + import( + "@/oss/components/Playground/Components/PlaygroundGenerations/assets/GenerationResultUtils" + ), + {ssr: false}, +) + +const RunTraceHeader = ({ + runId: rId, + scenarioId: scId, + stepKey, + anchorId, + showComparisons, +}: { + runId: string + scenarioId?: string + stepKey?: string + anchorId?: string + showComparisons?: boolean +}) => { + const baseRunId = useRunId() + const store = evalAtomStore() + const state = useAtomValue(evaluationRunStateFamily(rId), {store}) + const enriched = state?.enrichedRun + const {trace: runTrace} = useInvocationResult({ + scenarioId: scId, + stepKey: stepKey, + editorType: "simple", + viewType: "single", + runId: rId, + }) + + return ( +
+ {enriched ? ( + + ) : ( +
+ )} + {runTrace ? ( + + ) : ( +
+ )} +
+ ) +} + +export default memo(RunTraceHeader) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/index.tsx new file mode 100644 index 0000000000..772d0ea186 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerContent/index.tsx @@ -0,0 +1,905 @@ +import {useCallback, useEffect, useMemo, useRef, useState} from "react" + +import SimpleSharedEditor from "@agenta/oss/src/components/EditorViews/SimpleSharedEditor" +import VirtualizedSharedEditors from "@agenta/oss/src/components/EditorViews/VirtualizedSharedEditors" +import {Collapse, CollapseProps, Tag, Tooltip} from "antd" +import clsx from "clsx" +import {useAtomValue} from "jotai" +import {loadable} from "jotai/utils" +import {useRouter} from "next/router" + +import {renderChatMessages} from "@/oss/components/EvalRunDetails/assets/renderChatMessages" +import {STATUS_COLOR} from "@/oss/components/EvalRunDetails/components/EvalRunScenarioStatusTag/assets" +import {titleCase} from "@/oss/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/flatDataSourceBuilder" +import {comparisonRunsStepsAtom} from "@/oss/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useExpandableComparisonDataSource" +import {focusScenarioAtom} from "@/oss/components/EvalRunDetails/state/focusScenarioAtom" +import {urlStateAtom} from "@/oss/components/EvalRunDetails/state/urlState" +import {formatMetricValue} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils" +import {getStatusLabel} from "@/oss/lib/constants/statusLabels" +import { + evalAtomStore, + scenarioStepFamily, + evaluationRunStateFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {runScopedMetricDataFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import {useInvocationResult} from "@/oss/lib/hooks/useInvocationResult" +import {EvaluationStatus} from "@/oss/lib/Types" +import {useAppState} from "@/oss/state/appState" + +import FocusDrawerContentSkeleton from "../Skeletons/FocusDrawerContentSkeleton" + +import RunOutput from "./assets/RunOutput" +import RunTraceHeader from "./assets/RunTraceHeader" + +const failureRunTypes = [EvaluationStatus.FAILED, EvaluationStatus.FAILURE, EvaluationStatus.ERROR] +const EMPTY_COMPARISON_RUN_IDS: string[] = [] + +const FocusDrawerContent = () => { + const appState = useAppState() + const store = evalAtomStore() + const router = useRouter() + + const [windowHight, setWindowHight] = useState(0) + const [activeKeys, setActiveKeys] = useState<(string | number)[]>([ + "input", + "output", + "evaluators", + ]) + + // atoms + const focus = useAtomValue(focusScenarioAtom) + const urlState = useAtomValue(urlStateAtom) + const scenarioId = focus?.focusScenarioId as string + const runId = focus?.focusRunId as string + const rawCompareRunIds = Array.isArray(urlState?.compare) ? urlState.compare : [] + const compareRunIdsKey = rawCompareRunIds.join("|") + const evaluationRunData = useAtomValue(evaluationRunStateFamily(runId!)) + const comparisonRunIds = useMemo(() => { + if (!rawCompareRunIds.length) return EMPTY_COMPARISON_RUN_IDS + return rawCompareRunIds.slice() + }, [compareRunIdsKey]) + const rawBaseRunId = useMemo(() => { + const routerValue = router.query?.evaluation_id + if (Array.isArray(routerValue)) { + const firstRouterId = routerValue[0] + if (firstRouterId) return firstRouterId + } else if (typeof routerValue === "string" && routerValue.length > 0) { + return routerValue + } + + const appStateValue = appState.query?.evaluation_id + if (Array.isArray(appStateValue)) { + return appStateValue[0] ?? null + } + + return typeof appStateValue === "string" && appStateValue.length > 0 ? appStateValue : null + }, [appState.query?.evaluation_id, router.query?.evaluation_id]) + + const isBaseRun = useMemo(() => { + if (evaluationRunData?.isBase !== undefined) { + return Boolean(evaluationRunData.isBase) + } + return rawBaseRunId ? runId === rawBaseRunId : false + }, [evaluationRunData?.isBase, rawBaseRunId, runId]) + + const baseRunId = useMemo(() => { + if (evaluationRunData?.isBase) return runId + if (rawBaseRunId && typeof rawBaseRunId === "string") return rawBaseRunId + return runId + }, [evaluationRunData?.isBase, rawBaseRunId, runId]) + + const comparisonRunsStepsAtomInstance = useMemo( + () => comparisonRunsStepsAtom(comparisonRunIds), + [comparisonRunIds], + ) + const comparisonRunsSteps = useAtomValue(comparisonRunsStepsAtomInstance, {store}) + // // Derive whether to show comparison mode + const showComparisons = useMemo( + () => Boolean(isBaseRun && comparisonRunIds.length > 0), + [isBaseRun, comparisonRunIds], + ) + const stepLoadable = useAtomValue( + loadable( + scenarioStepFamily({ + runId: runId!, + scenarioId: scenarioId!, + }), + ), + ) + + const enricedRun = evaluationRunData?.enrichedRun + const invocationStep = useMemo(() => stepLoadable.data?.invocationSteps?.[0], [stepLoadable]) + const { + trace, + value: outputValue, + messageNodes, + hasError, + } = useInvocationResult({ + scenarioId: invocationStep?.scenarioId, + stepKey: invocationStep?.stepKey, + editorType: "simple", + viewType: "single", + runId, + }) + + const entries = useMemo(() => { + const inputSteps = stepLoadable.data?.inputSteps + + if (stepLoadable.state !== "hasData" || !inputSteps) return [] + const out: {k: string; v: unknown}[] = [] + inputSteps.forEach((inputCol) => { + let _inputs = {} + try { + const {testcase_dedup_id, ...rest} = inputCol.testcase.data + _inputs = {...rest} + } catch (e) { + const rawInputs = (inputCol && (inputCol as any).inputs) || {} + const {testcase_dedup_id, ...rest} = rawInputs as Record + _inputs = {...rest} + } + Object.entries(_inputs || {})?.forEach(([k, v]) => out.push({k: titleCase(k), v})) + }) + return out + }, [stepLoadable]) + + // Base testcase id to match comparison scenarios by content + const baseTestcaseId = useMemo(() => { + const inputSteps = stepLoadable.data?.inputSteps + const id = inputSteps?.[0]?.testcaseId + return id + }, [stepLoadable]) + + // Map of comparison runId -> matched scenarioId (by testcaseId) + const matchedComparisonScenarios = useMemo(() => { + if (!showComparisons || !baseTestcaseId) return [] as {runId: string; scenarioId?: string}[] + return comparisonRunIds.map((compRunId) => { + const compMap = + comparisonRunsSteps && typeof comparisonRunsSteps === "object" + ? ((comparisonRunsSteps as Record)[compRunId] as any) || {} + : {} + let matchedScenarioId: string | undefined + for (const [scId, testcaseIds] of Object.entries(compMap)) { + const first = Array.isArray(testcaseIds) ? testcaseIds[0] : undefined + if (first && first === baseTestcaseId) { + matchedScenarioId = scId + break + } + } + return {runId: compRunId, scenarioId: matchedScenarioId} + }) + }, [showComparisons, baseTestcaseId, comparisonRunsSteps, comparisonRunIds]) + + const evaluatorMetrics = useMemo(() => { + const evaluators = enricedRun?.evaluators + return evaluators?.map((evaluator) => ({ + name: evaluator.name, + metrics: evaluator.metrics, + slug: evaluator.slug, + })) + }, [enricedRun]) + + const openAndScrollTo = useCallback((key: string) => { + // Ensure the related section is expanded when navigating via hash + setActiveKeys((prev) => { + const next = new Set(prev) + next.add(key) + if (key === "output" || key.startsWith("output-")) next.add("output") + return Array.from(next) + }) + + // wait for Collapse to render/expand, then scroll + const tryScroll = (attempt = 0) => { + const el = document.getElementById(`section-${key}`) + // element is visible when offsetParent is not null (after expand) + if (el && el.offsetParent !== null) { + el.scrollIntoView({behavior: "smooth", block: "start", inline: "nearest"}) + } else if (attempt < 10) { + requestAnimationFrame(() => tryScroll(attempt + 1)) + } + } + requestAnimationFrame(() => tryScroll()) + }, []) + + const handleCollapseChange = useCallback((keys: string[]) => { + // Check if any dropdown is open by looking for the dropdown menu with the 'open' class + // This is for improving micro interactions + const openSelects = document.querySelectorAll( + ".ant-select-dropdown:not(.ant-select-dropdown-hidden)", + ) + const openDropdowns = document.querySelectorAll(".ant-dropdown:not(.ant-dropdown-hidden)") + if (openSelects.length > 0 || openDropdowns.length > 0) { + return + } + setActiveKeys(keys) + }, []) + + // TODO remove this from here and create a function or something to also use in somewhere else + const getErrorStep = useCallback( + (metricKey: string, scenarioId: string) => { + if (stepLoadable.state === "loading") return null + const [evalSlug, key] = metricKey.split(".") + if (!key) return null // if does not have key that means it's not an evaluator metric + const _step = stepLoadable.data?.steps?.find((s) => s.stepKey === evalSlug) + + if (!_step) { + const invocationStep = stepLoadable.data?.invocationSteps?.find( + (s) => s.scenarioId === scenarioId, + ) + + if (failureRunTypes.includes(invocationStep?.status)) { + return { + status: invocationStep?.status, + error: invocationStep?.error?.stacktrace || invocationStep?.error?.message, + } + } + return null + } + + if (failureRunTypes.includes(_step?.status)) { + return { + status: _step?.status, + error: _step?.error?.stacktrace || _step?.error?.message, + } + } + + return null + }, + [stepLoadable], + ) + + useEffect(() => { + setWindowHight(window.innerHeight) + }, [stepLoadable]) + + useEffect(() => { + const evaluatorSlug = enricedRun?.evaluators?.map((evaluator) => evaluator.slug) ?? [] + if (!evaluatorSlug.length) return + + setActiveKeys((prev) => { + const next = new Set(prev) + let changed = false + + evaluatorSlug.forEach((slug) => { + if (!next.has(slug)) { + next.add(slug) + changed = true + } + }) + + return changed ? Array.from(next) : prev + }) + }, [enricedRun]) + + useEffect(() => { + const hash = appState.asPath?.split("#")[1]?.trim() + if (!hash) return + openAndScrollTo(hash) + }, [appState.asPath, openAndScrollTo]) + + // Sync horizontal scroll between the Collapse header (trace) and content box (output) + const isSyncingScroll = useRef(false) + useEffect(() => { + if (!showComparisons) return + + const traceEl = document.querySelector( + ".trace-scroll-container .ant-collapse-header", + ) as HTMLDivElement | null + const outputEl = document.querySelector( + ".output-scroll-container .ant-collapse-content-box", + ) as HTMLDivElement | null + const evalEl = document.querySelector( + ".evaluator-scroll-container .ant-collapse-content-box", + ) as HTMLDivElement | null + + if (!traceEl || !outputEl) return + + const sync = (from: HTMLDivElement) => { + const left = from.scrollLeft + if (outputEl && from !== outputEl) outputEl.scrollLeft = left + if (traceEl && from !== traceEl) traceEl.scrollLeft = left + if (evalEl && from !== evalEl) evalEl.scrollLeft = left + } + + const onTraceScroll = (e: any) => { + if (isSyncingScroll.current) return + isSyncingScroll.current = true + sync(e.currentTarget as HTMLDivElement) + requestAnimationFrame(() => (isSyncingScroll.current = false)) + } + const onOutputScroll = (e: any) => { + if (isSyncingScroll.current) return + isSyncingScroll.current = true + sync(e.currentTarget as HTMLDivElement) + requestAnimationFrame(() => (isSyncingScroll.current = false)) + } + const onEvalScroll = (e: any) => { + if (isSyncingScroll.current) return + isSyncingScroll.current = true + sync(e.currentTarget as HTMLDivElement) + requestAnimationFrame(() => (isSyncingScroll.current = false)) + } + + traceEl.addEventListener("scroll", onTraceScroll) + outputEl.addEventListener("scroll", onOutputScroll) + evalEl?.addEventListener("scroll", onEvalScroll) + + return () => { + traceEl.removeEventListener("scroll", onTraceScroll) + outputEl.removeEventListener("scroll", onOutputScroll) + evalEl?.removeEventListener("scroll", onEvalScroll) + } + }, [showComparisons, activeKeys]) + + const items: CollapseProps["items"] = useMemo(() => { + if (stepLoadable.state !== "hasData" || !scenarioId) return [] + + return [ + { + key: "input", + className: "!rounded-none [&_.ant-collapse-header]:!py-2", + label: ( + + Inputs + + ), + children: ( +
+ { + // Detect chat-shaped JSON like in CellComponents.tsx + let isChat = false + if (typeof entry.v === "string") { + try { + const parsed = JSON.parse(entry.v) + isChat = + Array.isArray(parsed) && + parsed.every((m: any) => "role" in m && "content" in m) + } catch { + /* ignore */ + } + } + + if (isChat) { + const nodes = renderChatMessages({ + keyPrefix: `${scenarioId}-${entry.k}`, + rawJson: entry.v as string, + view: "single", + editorType: "simple", + }) + return ( +
+ {nodes} +
+ ) + } + + return ( + {}} + headerName={entry.k} + initialValue={String(entry.v)} + editorType="borderless" + state="readOnly" + placeholder="N/A" + disabled + readOnly + editorClassName="!text-xs" + className="!w-[97.5%]" + editorProps={{enableResize: true}} + /> + ) + }} + /> +
+ ), + }, + { + key: "trace", + className: + "trace-scroll-container !rounded-none !px-0 [&_.ant-collapse-header]:!px-0 [&_.ant-collapse-header]:overflow-x-auto [&_.ant-collapse-header]:scroll-mr-2 sticky -top-[13px] z-10 bg-white [&_.ant-collapse-header::-webkit-scrollbar]:!w-0 [&_.ant-collapse-header::-webkit-scrollbar]:!h-0", + collapsible: "disabled", + disabled: true, + showArrow: false, + label: ( +
+ {showComparisons ? ( + <> + + {matchedComparisonScenarios.map( + ({runId: rId, scenarioId: scId}) => ( + + ), + )} + + ) : ( + + )} +
+ ), + }, + { + key: "output", + label: Outputs, + className: clsx([ + "output-scroll-container", + "!rounded-none !px-0 [&_.ant-collapse-header]:!py-2 [&_.ant-collapse-content-box]:overflow-x-auto [&_.ant-collapse-content-box]:scroll-mr-2 [&_.ant-collapse-content-box::-webkit-scrollbar]:!w-0 [&_.ant-collapse-content-box::-webkit-scrollbar]:!h-0", + {"[&_.ant-collapse-content-box]:!px-1": showComparisons}, + ]), + children: showComparisons ? ( +
+ + {matchedComparisonScenarios.map(({runId: rId, scenarioId: scId}) => ( + + ))} +
+ ) : ( +
+ {messageNodes ? ( + messageNodes + ) : ( + {}} + initialValue={ + !!outputValue && typeof outputValue !== "string" + ? JSON.stringify(outputValue) + : outputValue + } + headerName="Output" + editorType="borderless" + state="readOnly" + disabled + readOnly + editorClassName="!text-xs" + error={hasError} + placeholder="N/A" + className="!w-[97.5%]" + /> + )} +
+ ), + }, + ...(showComparisons + ? [ + { + key: "evaluators", + label: null, + disabled: true, + showArrow: false, + className: + "evaluator-scroll-container !rounded-none [&_.ant-collapse-header]:!hidden [&_.ant-collapse-content-box]:overflow-x-auto [&_.ant-collapse-content-box]:!px-0 [&_.ant-collapse-content-box::-webkit-scrollbar]:!w-0 [&_.ant-collapse-content-box::-webkit-scrollbar]:!h-0", + children: (() => { + const runs = [ + {runId: baseRunId, scenarioId}, + ...matchedComparisonScenarios.map((m) => ({ + runId: m.runId, + scenarioId: m.scenarioId, + })), + ] + + // Helper: collect evaluator list for a run + const getRunEvaluators = (rId: string) => { + const rState = evalAtomStore().get(evaluationRunStateFamily(rId)) + const evaluators = rState?.enrichedRun?.evaluators || [] + return Array.isArray(evaluators) + ? evaluators + : (Object.values(evaluators) as any[]) + } + + // Build ordered set of evaluator slugs (base run first, then others) + const slugOrder = new Set() + const slugName: Record = {} + runs.forEach(({runId: rId}) => { + const list = getRunEvaluators(rId) + list.forEach((ev: any) => { + slugOrder.add(ev.slug) + if (!slugName[ev.slug]) slugName[ev.slug] = ev.name || ev.slug + }) + }) + + // Renders the value UI for a single metric in a single run + const renderMetricCell = ( + rId: string, + scId: string | undefined, + evaluatorSlug: string, + metricName: string, + ) => { + if (!scId) { + return ( + + N/A + + ) + } + + const metricData = evalAtomStore().get( + runScopedMetricDataFamily({ + runId: rId, + scenarioId: scId, + metricKey: `${evaluatorSlug}.${metricName}`, + stepSlug: invocationStep?.stepkey, + }), + ) + + // Run-scoped error fallback + let errorStep: any = null + const stepLoadableR = evalAtomStore().get( + loadable(scenarioStepFamily({runId: rId, scenarioId: scId})), + ) as any + if (stepLoadableR?.state === "hasData") { + const _step = stepLoadableR?.data?.steps?.find( + (s: any) => s.stepkey === evaluatorSlug, + ) + if (failureRunTypes.includes(_step?.status)) { + errorStep = { + status: _step?.status, + error: + _step?.error?.stacktrace || _step?.error?.message, + } + } else { + const inv = stepLoadableR?.data?.invocationSteps?.find( + (s: any) => s.scenarioId === scId, + ) + if (failureRunTypes.includes(inv?.status)) { + errorStep = { + status: inv?.status, + error: + inv?.error?.stacktrace || inv?.error?.message, + } + } + } + } + + if (errorStep?.status || errorStep?.error) { + return ( + + + {getStatusLabel(errorStep?.status)} + + + ) + } + + let value: any + if ( + metricData?.value?.frequency && + metricData?.value?.frequency?.length > 0 + ) { + const mostFrequent = metricData?.value?.frequency?.reduce( + (max: any, current: any) => + current.count > max.count ? current : max, + ).value + value = String(mostFrequent) + } else { + const prim = Object.values(metricData?.value || {}).find( + (v) => typeof v === "number" || typeof v === "string", + ) + value = + prim !== undefined + ? prim + : JSON.stringify(metricData?.value) + } + + const formatted = formatMetricValue(metricName, value || "") + + const isLongText = + typeof formatted === "string" && + (formatted.length > 180 || /\n/.test(formatted)) + + if ( + formatted === undefined || + formatted === null || + formatted === "" + ) { + return ( + + N/A + + ) + } + + return isLongText ? ( + {}} + initialValue={String(formatted)} + editorType="borderless" + state="readOnly" + disabled + readOnly + editorClassName="!text-xs" + placeholder="N/A" + className="!w-[97.5%]" + /> + ) : ( + + {String(formatted)} + + ) + } + + // Build the vertical list of evaluators with per-run metric columns + const orderedSlugs = Array.from(slugOrder) + + return ( +
+ {orderedSlugs.map((slug) => { + // Figure out which runs used this evaluator + const usedBy = new Set( + runs + .filter(({runId: rId, scenarioId: scId}) => { + if (!scId) return false + const list = getRunEvaluators(rId) + return list.some((e: any) => e.slug === slug) + }) + .map((r) => r.runId), + ) + + if (usedBy.size === 0) return null + + // Union of metric keys across participating runs only + const metricKeyOrder = new Set() + runs.forEach(({runId: rId}) => { + if (!usedBy.has(rId)) return + const list = getRunEvaluators(rId) + const ev = list.find((e: any) => e.slug === slug) + Object.keys(ev?.metrics || {}).forEach((k) => + metricKeyOrder.add(k), + ) + }) + + const keys = Array.from(metricKeyOrder) + const displayName = slugName[slug] || slug + + return ( +
+
+
+ {displayName} +
+ {runs.slice(1).map((_, idx) => ( +
+ ))} +
+
+
+ {runs.map( + ({runId: rId, scenarioId: scId}) => { + const hasThis = usedBy.has(rId) + return ( +
+ {hasThis ? ( + keys.map((metricName) => ( +
+ + {metricName} + + {renderMetricCell( + rId, + scId, + slug, + metricName, + )} +
+ )) + ) : ( + // Support structure to preserve column spacing +
+ )} +
+ ) + }, + )} +
+
+ ) + })} +
+ ) + })(), + }, + ] + : (evaluatorMetrics || []).map((evaluator, idx) => { + const metrics = evaluator.metrics + const isFirst = idx === 0 + const prevSlug = evaluatorMetrics?.[idx - 1]?.slug + const isPrevOpen = !!(prevSlug && activeKeys.includes(prevSlug)) + + if (!evaluator) return null + return { + key: evaluator.slug, + label: ( + + {evaluator.name} + + ), + className: clsx( + "[&_.ant-collapse-header]:border-0 [&_.ant-collapse-header]:border-solid [&_.ant-collapse-header]:border-gray-200", + "[&_.ant-collapse-header]:!rounded-none [&_.ant-collapse-header]:!py-[9px]", + "[&_.ant-collapse-header]:border-b", + { + // Top border for first item or when previous evaluator is open + "[&_.ant-collapse-header]:border-t": isFirst || isPrevOpen, + }, + ), + children: Object.keys(metrics || {})?.map((metricKey) => { + const metricData = evalAtomStore().get( + runScopedMetricDataFamily({ + runId: runId!, + scenarioId: scenarioId!, + metricKey: `${evaluator.slug}.${metricKey}`, + stepSlug: invocationStep?.stepkey, + }), + ) + + const errorStep = + !metricData?.distInfo || hasError + ? getErrorStep(`${evaluator.slug}.${metricKey}`, scenarioId) + : null + + let value + if ( + metricData?.value?.frequency && + metricData?.value?.frequency?.length > 0 + ) { + const mostFrequent = metricData?.value?.frequency?.reduce( + (max, current) => (current.count > max.count ? current : max), + ).value + value = String(mostFrequent) + } else { + const prim = Object.values(metricData?.value || {}).find( + (v) => typeof v === "number" || typeof v === "string", + ) + value = + prim !== undefined ? prim : JSON.stringify(metricData?.value) + } + + const formatted = formatMetricValue(metricKey, value || "") + + return ( +
+ {metricKey} + {errorStep?.status || errorStep?.error ? ( + + + {getStatusLabel(errorStep?.status)} + + + ) : ( + + {typeof formatted === "object" || + formatted === undefined || + formatted === null + ? "N/A" + : String(formatted)} + + )} +
+ ) + }), + } + })), + ] + }, [ + entries, + stepLoadable.state, + windowHight, + outputValue, + trace, + enricedRun?.name, + scenarioId, + activeKeys, + messageNodes, + hasError, + comparisonRunIds, + showComparisons, + matchedComparisonScenarios, + baseRunId, + invocationStep?.stepkey, + ]) + + if (stepLoadable.state !== "hasData" || !enricedRun) { + return + } + + return ( +
+ +
+ ) +} + +export default FocusDrawerContent diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerHeader/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerHeader/index.tsx new file mode 100644 index 0000000000..1538cce667 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerHeader/index.tsx @@ -0,0 +1,142 @@ +import {useCallback, useMemo, useState} from "react" + +import {CaretDown, CaretUp, Check, Copy} from "@phosphor-icons/react" +import {Button, Tag} from "antd" +import {useAtomValue} from "jotai" +import {loadable} from "jotai/utils" + +import EvalRunScenarioNavigator from "@/oss/components/EvalRunDetails/components/EvalRunScenarioNavigator" +import {focusScenarioAtom} from "@/oss/components/EvalRunDetails/state/focusScenarioAtom" +import TooltipWithCopyAction from "@/oss/components/TooltipWithCopyAction" +import { + scenariosFamily, + scenarioStepFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {useAppNavigation} from "@/oss/state/appState" + +import FocusDrawerHeaderSkeleton from "../Skeletons/FocusDrawerHeaderSkeleton" + +const FocusDrawerHeader = () => { + const [isCopy, setIsCopy] = useState(false) + const focus = useAtomValue(focusScenarioAtom) + const navigation = useAppNavigation() + + const runId = focus?.focusRunId as string + const focusScenarioId = focus?.focusScenarioId as string + + const handleScenarioChange = useCallback( + (nextScenarioId: string) => { + navigation.patchQuery( + { + focusScenarioId: nextScenarioId, + focusRunId: runId, + }, + {shallow: true}, + ) + }, + [navigation, runId], + ) + + const stepLoadable = useAtomValue( + loadable( + scenarioStepFamily({ + runId, + scenarioId: focusScenarioId, + }), + ), + ) + const scenarios = useAtomValue(scenariosFamily(runId)) ?? [] + + const selectedScenario = useMemo(() => { + return scenarios.find((s) => s.id === focusScenarioId) + }, [scenarios, focusScenarioId]) + + const loadPrevVariant = useCallback(() => { + if (!selectedScenario) return + const prevIndex = selectedScenario.scenarioIndex - 2 + if (prevIndex < 0) return + const prevScenario = scenarios[prevIndex] + if (!prevScenario) return + handleScenarioChange(prevScenario.id) + }, [handleScenarioChange, selectedScenario, scenarios]) + + const loadNextVariant = useCallback(() => { + if (!selectedScenario) return + const nextIndex = selectedScenario.scenarioIndex || 1 + const nextScenario = scenarios[nextIndex] + if (!nextScenario) return + handleScenarioChange(nextScenario.id) + }, [handleScenarioChange, selectedScenario, scenarios]) + + const isDisablePrev = useMemo(() => selectedScenario?.scenarioIndex === 1, [selectedScenario]) + const isDisableNext = useMemo( + () => selectedScenario?.scenarioIndex === scenarios.length, + [selectedScenario, scenarios], + ) + + if (stepLoadable.state === "loading") { + return + } + + return ( +
+
+
+
+ handleScenarioChange(id), + classNames: {popup: {root: "!p-0 !min-w-[180px]"}}, + }} + showOnlySelect + /> + {stepLoadable.state === "hasData" && + stepLoadable.data?.inputSteps?.map((input, index) => ( + + { + setIsCopy(true) + setTimeout(() => { + setIsCopy(false) + }, 1500) + }} + > + {input?.testcaseId}{" "} + {isCopy ? : } + + + ))} +
+
+ ) +} + +export default FocusDrawerHeader diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerSidePanel/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerSidePanel/index.tsx new file mode 100644 index 0000000000..852964bc06 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/FocusDrawerSidePanel/index.tsx @@ -0,0 +1,164 @@ +import {Key, useCallback, useMemo} from "react" + +import {TreeStructure, Download, Sparkle, Speedometer} from "@phosphor-icons/react" +import {Tree, TreeDataNode} from "antd" +import deepEqual from "fast-deep-equal" +import {atom} from "jotai" +import {useAtomValue} from "jotai" +import {atomFamily} from "jotai/utils" +import {useRouter} from "next/router" + +import {focusScenarioAtom} from "@/oss/components/EvalRunDetails/state/focusScenarioAtom" +import {urlStateAtom} from "@/oss/components/EvalRunDetails/state/urlState" +import {evaluationRunStateFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import FocusDrawerSidePanelSkeleton from "../Skeletons/FocusDrawerSidePanelSkeleton" + +// Helper atom to read multiple run states given a list of runIds +const evaluationsRunFamily = atomFamily( + (runIds: string[]) => + atom((get) => { + return runIds.map((runId) => get(evaluationRunStateFamily(runId))) + }), + deepEqual, +) + +const FocusDrawerSidePanel = () => { + const router = useRouter() + const urlState = useAtomValue(urlStateAtom) + const focus = useAtomValue(focusScenarioAtom) + const compareRunIds = (urlState?.compare || []) as string[] + const focusRunId = focus?.focusRunId! + const focusRunState = useAtomValue(evaluationRunStateFamily(focusRunId)) + const baseRunId = useMemo(() => { + if (focusRunState?.isBase) return focusRunId + const routerValue = router.query?.evaluation_id + if (Array.isArray(routerValue)) { + return routerValue[0] ?? focusRunId + } + if (typeof routerValue === "string" && routerValue.length > 0) { + return routerValue + } + return focusRunId + }, [focusRunId, focusRunState?.isBase, router.query?.evaluation_id]) + const isComparison = Array.isArray(compareRunIds) && compareRunIds.length > 0 + const isBaseRun = focusRunState?.isBase ?? focusRunId === baseRunId + + // Read base run and all comparison run states + const runIds = useMemo(() => { + if (!isComparison) return [baseRunId] + if (!isBaseRun && isComparison) return [focusRunId] + + return [baseRunId, ...compareRunIds] + }, [baseRunId, compareRunIds, focusRunId, isBaseRun, isComparison]) + + const runs = useAtomValue(evaluationsRunFamily(runIds)) + + const baseEvaluation = useMemo( + () => runs.find((r) => r?.enrichedRun?.id === baseRunId), + [runs, baseRunId], + ) + const baseEvaluators = useMemo( + () => baseEvaluation?.enrichedRun?.evaluators || [], + [baseEvaluation], + ) + + // Build deduped evaluator list across all runs when in comparison mode + const dedupedEvaluators = useMemo(() => { + if (isBaseRun && !isComparison) return baseEvaluators + + const map = new Map() + runs?.forEach((r) => { + r?.enrichedRun?.evaluators?.forEach((e) => { + if (!map.has(e.slug)) map.set(e.slug, {slug: e.slug, name: e.name}) + }) + }) + return Array.from(map.values()) + }, [isComparison, runs, baseEvaluators, isBaseRun]) + + // Output children: evaluation names (base + comparisons) when in comparison mode + const outputChildren: TreeDataNode[] = useMemo(() => { + if (!isComparison || (!isBaseRun && isComparison)) return [] + return runs + .map((r) => r?.enrichedRun) + .filter(Boolean) + .map((enriched) => ({ + title: enriched!.name, + key: `output-${enriched!.id}`, + icon: , + })) as TreeDataNode[] + }, [isComparison, runs, isBaseRun]) + + const treeData: TreeDataNode[] = useMemo(() => { + if (!focusRunId) return [] + return [ + { + title: "Evaluation", + key: "evaluation", + icon: , + children: [ + { + title: "Input", + key: "input", + icon: , + }, + { + title: "Output", + key: "output", + icon: , + children: outputChildren, + }, + { + title: "Evaluator", + key: "evaluator", + icon: , + children: + dedupedEvaluators?.map((e) => ({ + title: e.name, + key: e.slug, + icon: , + })) || [], + }, + ], + }, + ] + }, [dedupedEvaluators, outputChildren, focusRunId]) + + const onSelect = useCallback( + async (selectedKeys: Key[]) => { + if (selectedKeys.length > 0) { + const key = selectedKeys[0].toString() + + await router.replace( + { + pathname: router.pathname, + query: router.query, + hash: key, + }, + undefined, + {scroll: false, shallow: true}, + ) + } + }, + [router], + ) + + if (!runs.length) { + return + } + + return ( +
+ +
+ ) +} + +export default FocusDrawerSidePanel diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerContentSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerContentSkeleton.tsx new file mode 100644 index 0000000000..e79155bbc1 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerContentSkeleton.tsx @@ -0,0 +1,33 @@ +import {memo} from "react" + +import {Skeleton} from "antd" + +const FocusDrawerContentSkeleton = () => { + return ( +
+
+ + +
+ + + + +
+ +
+ + + +
+
+
+ + +
+ +
+ ) +} + +export default memo(FocusDrawerContentSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerHeaderSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerHeaderSkeleton.tsx new file mode 100644 index 0000000000..20912d5e76 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerHeaderSkeleton.tsx @@ -0,0 +1,16 @@ +import {memo} from "react" + +import {Skeleton} from "antd" + +const FocusDrawerHeaderSkeleton = () => { + return ( +
+ + + + +
+ ) +} + +export default memo(FocusDrawerHeaderSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerSidePanelSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerSidePanelSkeleton.tsx new file mode 100644 index 0000000000..1c813e50aa --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/assets/Skeletons/FocusDrawerSidePanelSkeleton.tsx @@ -0,0 +1,15 @@ +import {memo} from "react" + +import {Skeleton} from "antd" + +const FocusDrawerSidePanelSkeleton = () => { + return ( +
+ {Array.from({length: 8}).map((_, idx) => ( + + ))} +
+ ) +} + +export default memo(FocusDrawerSidePanelSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/index.tsx new file mode 100644 index 0000000000..187de30274 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunFocusDrawer/index.tsx @@ -0,0 +1,68 @@ +import {memo, useCallback, useMemo} from "react" + +import {useAtomValue, useSetAtom} from "jotai" +import dynamic from "next/dynamic" + +import { + closeFocusDrawerAtom, + focusScenarioAtom, + isFocusDrawerOpenAtom, + resetFocusDrawerAtom, +} from "@/oss/components/EvalRunDetails/state/focusScenarioAtom" +import GenericDrawer from "@/oss/components/GenericDrawer" +import {RunIdProvider} from "@/oss/contexts/RunIdContext" +import {clearFocusDrawerQueryParams} from "@/oss/state/url/focusDrawer" + +const FocusDrawerHeader = dynamic(() => import("./assets/FocusDrawerHeader"), {ssr: false}) +const FocusDrawerContent = dynamic(() => import("./assets/FocusDrawerContent"), {ssr: false}) +const FocusDrawerSidePanel = dynamic(() => import("./assets/FocusDrawerSidePanel"), {ssr: false}) + +const EvalRunFocusDrawer = () => { + const isOpen = useAtomValue(isFocusDrawerOpenAtom) + const focus = useAtomValue(focusScenarioAtom) + const closeDrawer = useSetAtom(closeFocusDrawerAtom) + const resetDrawer = useSetAtom(resetFocusDrawerAtom) + + const focusRunId = focus?.focusRunId ?? null + + const handleClose = useCallback(() => { + closeDrawer(null) + }, [closeDrawer]) + + const handleAfterOpenChange = useCallback( + (nextOpen: boolean) => { + if (!nextOpen) { + resetDrawer(null) + clearFocusDrawerQueryParams() + } + }, + [resetDrawer], + ) + + const shouldRenderContent = useMemo( + () => Boolean(focusRunId && focus?.focusScenarioId), + [focusRunId, focus?.focusScenarioId], + ) + + if (!focusRunId) { + return null + } + + return ( + + : null} + mainContent={shouldRenderContent ? : null} + sideContent={shouldRenderContent ? : null} + className="[&_.ant-drawer-body]:p-0" + sideContentDefaultSize={200} + /> + + ) +} + +export default memo(EvalRunFocusDrawer) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunHeader/assets/EvalRunHeaderSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunHeader/assets/EvalRunHeaderSkeleton.tsx new file mode 100644 index 0000000000..1ae2c1dea1 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunHeader/assets/EvalRunHeaderSkeleton.tsx @@ -0,0 +1,20 @@ +import {memo} from "react" + +import {Skeleton} from "antd" +import clsx from "clsx" + +const EvalRunHeaderSkeleton = ({className}: {className?: string}) => { + return ( +
+ + +
+ ) +} + +export default memo(EvalRunHeaderSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunHeader/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunHeader/index.tsx new file mode 100644 index 0000000000..20dca8c02c --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunHeader/index.tsx @@ -0,0 +1,46 @@ +import {memo} from "react" + +import clsx from "clsx" +import {useAtomValue} from "jotai" + +import {useRunId} from "@/oss/contexts/RunIdContext" + +import EvalRunScenariosViewSelector from "../../../components/EvalRunScenariosViewSelector" +import {runViewTypeAtom, urlStateAtom} from "../../../state/urlState" +import EvalRunCompareMenu from "../EvalRunCompareMenu" +import EvalRunSelectedEvaluations from "../EvalRunSelectedEvaluations" + +const EvalRunHeader = ({className, name, id}: {className?: string; name: string; id: string}) => { + const viewType = useAtomValue(runViewTypeAtom) + const urlState = useAtomValue(urlStateAtom) + const baseRunId = useRunId() + return ( +
+ + +
+
+ {urlState.compare?.length > 0 && ( + + )} +
+ + +
+
+ ) +} + +export default memo(EvalRunHeader) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunOverviewViewer/assets/EvalRunOverviewViewerSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunOverviewViewer/assets/EvalRunOverviewViewerSkeleton.tsx new file mode 100644 index 0000000000..cb6953b18d --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunOverviewViewer/assets/EvalRunOverviewViewerSkeleton.tsx @@ -0,0 +1,25 @@ +import {memo} from "react" + +import EvalRunScoreTableSkeleton from "../../EvalRunScoreTable/assets/EvalRunScoreTableSkeleton" +import EvaluatorMetricsChartSkeleton from "../../EvaluatorMetricsChart/assets/EvaluatorMetricsChartSkeleton" + +const EvalRunOverviewViewerSkeleton = () => { + return ( + <> +
+ +
+ +
+ {Array.from({length: 3}).map((_, index) => ( + + ))} +
+ + ) +} + +export default memo(EvalRunOverviewViewerSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunOverviewViewer/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunOverviewViewer/index.tsx new file mode 100644 index 0000000000..2b8e6131c7 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunOverviewViewer/index.tsx @@ -0,0 +1,209 @@ +import {memo, useMemo} from "react" + +import deepEqual from "fast-deep-equal" +import {useAtomValue} from "jotai" +import {atom} from "jotai" +import {atomFamily} from "jotai/utils" + +import {useRunId} from "@/oss/contexts/RunIdContext" +import { + evaluationEvaluatorsFamily, + loadingStateAtom, + loadingStateFamily, + evaluationRunStateFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {runMetricStatsFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import {canonicalizeMetricKey, getMetricValueWithAliases} from "@/oss/lib/metricUtils" + +import {urlStateAtom} from "../../../state/urlState" +import {formatMetricName} from "../../assets/utils" +import {EVAL_COLOR} from "../../assets/utils" +import EvalRunScoreTable from "../EvalRunScoreTable" +import EvaluatorMetricsChart from "../EvaluatorMetricsChart" + +import EvalRunOverviewViewerSkeleton from "./assets/EvalRunOverviewViewerSkeleton" + +// Only evaluator metrics (slug-prefixed) should render in overview charts; skip invocation metrics. +const INVOCATION_METRIC_PREFIX = "attributes.ag." + +// Lightweight readers (mirrors what ScoreTable does) to fetch multiple runs' state/metrics +const runsStateFamily = atomFamily( + (runIds: string[]) => atom((get) => runIds.map((id) => get(evaluationRunStateFamily(id)))), + deepEqual, +) +const runsMetricsFamily = atomFamily( + (runIds: string[]) => + atom((get) => runIds.map((id) => ({id, metrics: get(runMetricStatsFamily({runId: id}))}))), + deepEqual, +) + +const EvalRunOverviewViewer = () => { + const runId = useRunId() + const urlState = useAtomValue(urlStateAtom) + const compareRunIds = urlState.compare + const isCompare = !!compareRunIds?.length + + const metrics = useAtomValue(runMetricStatsFamily({runId})) + const evaluators = useAtomValue(evaluationEvaluatorsFamily(runId)) + const loadingState = useAtomValue(loadingStateAtom) + const loadingStateFamilyData = useAtomValue(loadingStateFamily(runId)) + const allRunIds = useMemo( + () => [runId!, ...(compareRunIds || []).filter((id) => id && id !== runId)], + [runId, compareRunIds], + ) + const runs = useAtomValue(runsStateFamily(allRunIds)) + const metricsByRun = useAtomValue(runsMetricsFamily(allRunIds)) + + const evaluatorsBySlug = useMemo(() => { + const map: Record = {} + runs.forEach((r) => { + r?.enrichedRun?.evaluators?.forEach((ev: any) => { + if (ev?.slug && !map[ev.slug]) { + map[ev.slug] = ev + } + }) + }) + evaluators?.forEach((ev) => { + if (ev?.slug && !map[ev.slug]) { + map[ev.slug] = ev + } + }) + return map + }, [runs, evaluators]) + + const combinedMetricEntries = useMemo(() => { + const entries: { + fullKey: string + evaluatorSlug: string + metricKey: string + metric: Record + }[] = [] + const seen = new Set() + + const pushEntry = (rawKey: string, source: Record) => { + const canonical = canonicalizeMetricKey(rawKey) + if (canonical.startsWith(INVOCATION_METRIC_PREFIX)) return + if (!canonical.includes(".")) return + if (seen.has(canonical)) return + + const metric = + (getMetricValueWithAliases(source, canonical) as Record) || + (source?.[rawKey] as Record) + if (!metric) return + + const [slug, ...rest] = canonical.split(".") + const metricKey = rest.join(".") || slug + + entries.push({fullKey: canonical, evaluatorSlug: slug, metricKey, metric}) + seen.add(canonical) + } + + const baseMetrics = (metrics || {}) as Record + Object.keys(baseMetrics).forEach((fullKey) => { + pushEntry(fullKey, baseMetrics) + }) + + metricsByRun.forEach(({metrics: runMetrics}) => { + const scoped = (runMetrics || {}) as Record + Object.keys(scoped).forEach((fullKey) => { + pushEntry(fullKey, scoped) + }) + }) + + return entries + }, [metrics, metricsByRun]) + + const evalById = useMemo(() => { + const map: Record = {} + runs.forEach((r) => (map[r.enrichedRun?.id || r.id] = r)) + return map + }, [runs]) + + const metricsLookup = useMemo(() => { + const map: Record> = {} + metricsByRun.forEach(({id, metrics}) => { + const source = (metrics || {}) as Record + const normalized: Record = {...source} + Object.keys(source || {}).forEach((rawKey) => { + const canonical = canonicalizeMetricKey(rawKey) + if (canonical !== rawKey && normalized[canonical] === undefined) { + normalized[canonical] = source[rawKey] + } + }) + map[id] = normalized + }) + return map + }, [metricsByRun]) + + if (loadingState.isLoadingMetrics || loadingStateFamilyData.isLoadingMetrics) { + return + } + return ( + <> +
+ +
+ +
+ {combinedMetricEntries.map(({fullKey, metric, evaluatorSlug, metricKey}, idx) => { + if (!metric || !Object.keys(metric || {}).length) return null + + // Build comparison rows for this evaluator metric + const rowsWithMeta = isCompare + ? allRunIds.map((id, i) => { + const state = evalById[id] + const compareIdx = state?.compareIndex || i + 1 + const stats = metricsLookup[id] || {} + const m: any = getMetricValueWithAliases(stats, fullKey) + const hasMetric = !!m + let y = 0 + if (hasMetric) { + if (Array.isArray(m?.unique)) { + const trueEntry = (m?.frequency || m?.rank || [])?.find( + (f: any) => f?.value === true, + ) + const total = m?.count ?? 0 + y = total ? ((trueEntry?.count ?? 0) / total) * 100 : 0 + } else if (typeof m?.mean === "number") { + y = m.mean + } + } + return { + id, + x: state?.enrichedRun?.name || `Eval ${compareIdx}`, + y, + hasMetric, + color: (EVAL_COLOR as any)[compareIdx] || "#3B82F6", + } + }) + : undefined + + const averageRows = rowsWithMeta + ?.filter((r) => r.hasMetric) + .map(({x, y, color}) => ({x, y, color})) + const summaryRows = rowsWithMeta?.map(({x, y, color}) => ({ + x, + y, + color, + })) + + return ( + + ) + })} +
+ + ) +} + +export default memo(EvalRunOverviewViewer) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/assets/EvalRunPromptConfigViewerSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/assets/EvalRunPromptConfigViewerSkeleton.tsx new file mode 100644 index 0000000000..02f6b9f69a --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/assets/EvalRunPromptConfigViewerSkeleton.tsx @@ -0,0 +1,42 @@ +import {memo} from "react" + +import {Skeleton} from "antd" +import clsx from "clsx" + +const EvalRunPromptConfigViewerSkeleton = ({className}: {className?: string}) => { + return ( +
+
+
+
+ + +
+ +
+ + +
+
+ ) +} + +export default memo(EvalRunPromptConfigViewerSkeleton) + +export const PromptConfigCardSkeleton = memo(() => { + return ( +
+ {/* Header */} +
+ + +
+ + {/* Prompt section */} +
+ + +
+
+ ) +}) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/assets/PromptConfigCard.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/assets/PromptConfigCard.tsx new file mode 100644 index 0000000000..0602e5d8ce --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/assets/PromptConfigCard.tsx @@ -0,0 +1,633 @@ +import {memo, useEffect, useMemo, useRef, useState} from "react" + +import {Empty, Skeleton, Tag, Typography} from "antd" +import clsx from "clsx" +import {atom, getDefaultStore, useAtomValue} from "jotai" +import {atomFamily} from "jotai/utils" +import dynamic from "next/dynamic" +import {useRouter} from "next/router" + +import {PromptsSourceProvider} from "@/oss/components/Playground/context/PromptsSource" +import {EnrichedEvaluationRun} from "@/oss/lib/hooks/usePreviewEvaluations/types" +import type {EnhancedObjectConfig} from "@/oss/lib/shared/variant/genericTransformer/types" +import {fetchOpenApiSchemaJson} from "@/oss/lib/shared/variant/transformer" +import { + deriveCustomPropertiesFromSpec, + derivePromptsFromSpec, +} from "@/oss/lib/shared/variant/transformer/transformer" +import type {AgentaConfigPrompt} from "@/oss/lib/shared/variant/transformer/types" +import {projectScopedVariantsAtom} from "@/oss/state/projectVariantConfig" +import { + appSchemaAtom, + appUriInfoAtom, + getEnhancedRevisionById, +} from "@/oss/state/variant/atoms/fetcher" + +import EvalNameTag from "../../../assets/EvalNameTag" +import {EVAL_TAG_COLOR} from "../../../assets/utils" +import VariantTag from "../../../assets/VariantTag" +import { + combineAppNameWithLabel, + deriveVariantAppName, + deriveVariantLabelParts, + getVariantDisplayMetadata, + normalizeId, + prettifyVariantLabel, +} from "../../../assets/variantUtils" + +import {PromptConfigCardSkeleton} from "./EvalRunPromptConfigViewerSkeleton" + +const PlaygroundVariantConfigPrompt = dynamic( + () => import("@/oss/components/Playground/Components/PlaygroundVariantConfigPrompt"), + {ssr: false, loading: () => }, +) +const PlaygroundVariantCustomProperties = dynamic( + () => import("@/oss/components/Playground/Components/PlaygroundVariantCustomProperties"), + {ssr: false, loading: () => }, +) + +type ParametersShape = Record | null | undefined + +type PromptNode = EnhancedObjectConfig + +const deriveFromParametersSnapshot = (parameters: ParametersShape) => { + const ag = (parameters as any)?.ag_config ?? (parameters as any) ?? {} + const fallbackPrompts = Object.entries(ag) + .map(([name, cfg]: [string, any]) => { + if (!cfg || typeof cfg !== "object") return null + const messages = (cfg as any).messages + const llm_config = (cfg as any).llm_config || (cfg as any).llmConfig + if (!messages && !llm_config) return null + return { + __name: name, + messages, + llm_config, + } + }) + .filter(Boolean) as PromptNode[] + + return {prompts: fallbackPrompts, customProps: {}} +} + +const mergeParametersWithSnapshot = ( + baseParameters: ParametersShape, + snapshot: ParametersShape, +): ParametersShape => { + if (!snapshot || typeof snapshot !== "object") { + return baseParameters ?? undefined + } + + const base = baseParameters && typeof baseParameters === "object" ? baseParameters : {} + const merged: Record = { + ...base, + ...snapshot, + } + + const baseAgConfig = + (base as any)?.ag_config ?? (base as any)?.agConfig ?? (base as any)?.parameters?.ag_config + const snapshotAgConfig = (snapshot as any)?.ag_config ?? (snapshot as any)?.agConfig + + if (snapshotAgConfig && typeof snapshotAgConfig === "object") { + const mergedAg = { + ...(baseAgConfig && typeof baseAgConfig === "object" ? baseAgConfig : {}), + ...snapshotAgConfig, + } + merged.ag_config = mergedAg + merged.agConfig = mergedAg + } else if (baseAgConfig && typeof baseAgConfig === "object") { + merged.ag_config = baseAgConfig + merged.agConfig = baseAgConfig + } + + return merged +} + +interface DeriveParams { + variantId: string + parameters: ParametersShape +} + +// Single source atom family that derives prompts and custom props +const derivedPromptsAtomFamily = atomFamily(({variantId, parameters}: DeriveParams) => + atom((get) => { + const normalizedVariantId = typeof variantId === "string" ? variantId.trim() : "" + + if (!normalizedVariantId) { + return deriveFromParametersSnapshot(parameters) + } + + const rev = getEnhancedRevisionById(get.bind(get) as any, normalizedVariantId) + + if (rev) { + try { + const spec = get(appSchemaAtom) + const routePath = get(appUriInfoAtom)?.routePath + + if (spec) { + const mergedParameters = mergeParametersWithSnapshot( + (rev as any).parameters, + parameters, + ) + const mergedVariant = { + ...(rev as any), + parameters: mergedParameters ?? (rev as any).parameters, + } + + const derivedPrompts = derivePromptsFromSpec( + mergedVariant as any, + spec as any, + routePath, + ) as PromptNode[] + const derivedCustomProps = deriveCustomPropertiesFromSpec( + mergedVariant as any, + spec as any, + routePath, + ) as Record + + if (Array.isArray(derivedPrompts)) { + return {prompts: derivedPrompts, customProps: derivedCustomProps} + } + } + } catch (error) { + if (process.env.NODE_ENV !== "production") { + console.warn("[PromptConfig] Failed to derive prompts from spec", error) + } + } + } + + return deriveFromParametersSnapshot(parameters) + }), +) + +const PromptContentSkeleton = memo(({description}: {description: string}) => { + return ( +
+ + +
+ {description} +
+
+ ) +}) + +const PromptConfigCard = ({ + variantId, + evaluation, + isComparison, + compareIndex, + isFirstPrompt, + isMiddlePrompt, + isLastPrompt, + totalRuns, +}: { + variantId: string + evaluation: EnrichedEvaluationRun + isComparison: boolean + compareIndex: number + isFirstPrompt: boolean + isMiddlePrompt: boolean + isLastPrompt: boolean + totalRuns: number +}) => { + const router = useRouter() + const normalizedVariantId = useMemo(() => (variantId ? String(variantId) : ""), [variantId]) + const jotaiStore = useMemo(() => getDefaultStore(), []) + const projectScopedVariants = useAtomValue(projectScopedVariantsAtom) + + const [fallbackPrompts, setFallbackPrompts] = useState([]) + const [fallbackCustomProps, setFallbackCustomProps] = useState>({}) + const [fallbackTrigger, setFallbackTrigger] = useState(0) + const fallbackAttemptsRef = useRef(0) + + const variants = evaluation?.variants ?? [] + const selectedVariant = useMemo(() => { + if (!variants.length) return undefined + if (!normalizedVariantId) return variants[0] + + return ( + variants.find((variant) => { + const candidateIds = [ + (variant as any)?._revisionId, + (variant as any)?.id, + variant?.variantId, + ] + return candidateIds.some( + (candidate) => + candidate !== undefined && String(candidate) === normalizedVariantId, + ) + }) || undefined + ) + }, [variants, normalizedVariantId]) + + const projectScopedVariant = useMemo(() => { + if (!normalizedVariantId) return undefined + const scoped = projectScopedVariants?.revisionMap?.[normalizedVariantId] + return scoped && scoped.length > 0 ? scoped[0] : undefined + }, [normalizedVariantId, projectScopedVariants]) + + useEffect(() => { + setFallbackPrompts([]) + setFallbackCustomProps({}) + fallbackAttemptsRef.current = 0 + setFallbackTrigger(0) + }, [normalizedVariantId]) + + const variantForDisplay = selectedVariant ?? projectScopedVariant + + const fallbackVariantSource = useMemo(() => { + if (projectScopedVariant?.uri) return projectScopedVariant + if (selectedVariant?.uri) return selectedVariant + return projectScopedVariant ?? selectedVariant ?? null + }, [projectScopedVariant, selectedVariant]) + + const variantDisplay = useMemo( + () => + getVariantDisplayMetadata(variantForDisplay, { + fallbackLabel: normalizedVariantId || undefined, + fallbackRevisionId: normalizedVariantId || undefined, + requireRuntime: false, + }), + [variantForDisplay, normalizedVariantId], + ) + + const {label: formattedVariantLabel} = useMemo( + () => + deriveVariantLabelParts({ + variant: variantForDisplay, + displayLabel: variantDisplay.label, + }), + [variantForDisplay, variantDisplay.label], + ) + + const variantAppName = useMemo( + () => + deriveVariantAppName({ + variant: variantForDisplay, + fallbackAppName: + (evaluation as any)?.appName ?? + (evaluation as any)?.app_name ?? + (evaluation as any)?.app?.name ?? + undefined, + }), + [variantForDisplay, evaluation], + ) + + const variantLabel = combineAppNameWithLabel( + variantAppName, + prettifyVariantLabel(formattedVariantLabel) ?? formattedVariantLabel, + ) + + const revisionId = variantDisplay.revisionId || normalizedVariantId || "" + + const variantAppId = useMemo( + () => + normalizeId( + (variantForDisplay as any)?.appId ?? + (variantForDisplay as any)?.app_id ?? + (variantForDisplay as any)?.application?.id ?? + (variantForDisplay as any)?.application_id ?? + (variantForDisplay as any)?.application_ref?.id ?? + (variantForDisplay as any)?.applicationRef?.id, + ), + [variantForDisplay], + ) + + const evaluationAppId = useMemo( + () => + normalizeId( + (evaluation as any)?.appId ?? + (evaluation as any)?.app_id ?? + (evaluation as any)?.app?.id ?? + (evaluation as any)?.application?.id, + ), + [evaluation], + ) + + const normalizedRouteAppId = useMemo( + () => normalizeId(router.query?.app_id as string | undefined), + [router.query?.app_id], + ) + + const navigableAppId = variantAppId || evaluationAppId || normalizedRouteAppId + const isRouteAppContext = + Boolean(normalizedRouteAppId) && navigableAppId === normalizedRouteAppId + const blockedByRuntime = isRouteAppContext && variantDisplay.hasRuntime === false + + const canNavigateToVariant = Boolean( + revisionId && navigableAppId && variantDisplay.isHealthy !== false && !blockedByRuntime, + ) + + const parameters = useMemo(() => { + const map = (evaluation as any)?.parametersByRevisionId as + | Record + | undefined + + if (map) { + const candidateIds = [ + normalizedVariantId, + String((selectedVariant as any)?._revisionId ?? ""), + String((selectedVariant as any)?.id ?? ""), + String(selectedVariant?.variantId ?? ""), + ].filter( + (id) => + !!id && + id !== "undefined" && + id !== "null" && + id !== "[object Object]" && + id !== "NaN", + ) + + for (const id of candidateIds) { + if (map[id]) { + return map[id] + } + } + } + + const projectScopedParams = (projectScopedVariant as any)?.configParams + + return ( + (selectedVariant as any)?.parameters ?? + (selectedVariant as any)?.configParams ?? + projectScopedParams ?? + undefined + ) + }, [evaluation, normalizedVariantId, selectedVariant, projectScopedVariant]) + + const deriveParams = useMemo( + () => ({variantId: normalizedVariantId, parameters}), + [normalizedVariantId, parameters], + ) + + const {prompts, customProps} = useAtomValue(derivedPromptsAtomFamily(deriveParams), { + store: jotaiStore, + }) + + const basePrompts = prompts ?? [] + const promptsList = basePrompts.length ? basePrompts : fallbackPrompts + + const combinedCustomProps = useMemo(() => { + if (customProps && Object.keys(customProps).length > 0) return customProps + return fallbackCustomProps + }, [customProps, fallbackCustomProps]) + + const baseCustomPropsHasContent = useMemo(() => { + if (!customProps) return false + return Object.values(customProps).some((value) => { + if (value === null || value === undefined) return false + if (Array.isArray(value)) return value.length > 0 + if (typeof value === "object") return Object.keys(value).length > 0 + if (typeof value === "string") return value.trim().length > 0 + return true + }) + }, [customProps]) + + const combinedCustomPropsHasContent = useMemo(() => { + if (!combinedCustomProps) return false + return Object.values(combinedCustomProps).some((value) => { + if (value === null || value === undefined) return false + if (Array.isArray(value)) return value.length > 0 + if (typeof value === "object") return Object.keys(value).length > 0 + if (typeof value === "string") return value.trim().length > 0 + return true + }) + }, [combinedCustomProps]) + + const hasPrompts = promptsList.length > 0 + const hasContent = hasPrompts || combinedCustomPropsHasContent + const hasVariantsInRun = + (evaluation?.variants?.length ?? 0) > 0 || Boolean(projectScopedVariant) + const isVariantSelectable = Boolean(normalizedVariantId && variantForDisplay) + const showSkeleton = Boolean( + !variantForDisplay && normalizedVariantId && hasVariantsInRun && !parameters, + ) + const showPrompts = isVariantSelectable && hasContent + const emptyDescription = !isVariantSelectable + ? "Prompt configuration is unavailable because the source application or variant is no longer accessible." + : hasContent + ? "Prompt configuration isn't available because the original application is no longer accessible." + : "This evaluation does not include any prompt configuration data." + + const promptsMap = useMemo(() => { + if (!normalizedVariantId) return {} + return {[normalizedVariantId]: promptsList as PromptNode[] | undefined} + }, [normalizedVariantId, promptsList]) + + const fallbackCustomPropsPopulated = useMemo( + () => Object.keys(fallbackCustomProps).length > 0, + [fallbackCustomProps], + ) + + const shouldAttemptFallback = useMemo(() => { + if (!normalizedVariantId) return false + if (!fallbackVariantSource?.uri) return false + if (basePrompts.length > 0 || baseCustomPropsHasContent) return false + if (fallbackPrompts.length > 0 || fallbackCustomPropsPopulated) return false + return true + }, [ + normalizedVariantId, + fallbackVariantSource, + basePrompts.length, + baseCustomPropsHasContent, + fallbackPrompts.length, + fallbackCustomPropsPopulated, + ]) + + useEffect(() => { + if (!shouldAttemptFallback) return + + let isCancelled = false + let retryTimeout: ReturnType | undefined + + const snapshot = + (parameters && Object.keys(parameters as any).length > 0 + ? parameters + : (fallbackVariantSource as any)?.configParams) ?? {} + + const run = async () => { + try { + const {schema} = await fetchOpenApiSchemaJson(fallbackVariantSource!.uri as string) + if (!schema) { + throw new Error("Missing OpenAPI schema") + } + + const mergedParameters = mergeParametersWithSnapshot( + (fallbackVariantSource as any)?.parameters, + snapshot, + ) + + const fallbackVariant = { + ...fallbackVariantSource, + parameters: mergedParameters ?? snapshot, + } + + const derivedPrompts = derivePromptsFromSpec( + fallbackVariant as any, + schema as any, + ) as PromptNode[] + const derivedCustomProps = deriveCustomPropertiesFromSpec( + fallbackVariant as any, + schema as any, + ) as Record + + if (isCancelled) return + + fallbackAttemptsRef.current = 0 + setFallbackPrompts(Array.isArray(derivedPrompts) ? derivedPrompts : []) + setFallbackCustomProps(derivedCustomProps ?? {}) + + if (process.env.NODE_ENV !== "production" && typeof window !== "undefined") { + console.info("[PromptConfigCard] Fallback prompts derived", { + runId: evaluation?.id, + variantId: normalizedVariantId, + promptCount: derivedPrompts?.length ?? 0, + customPropsCount: Object.keys(derivedCustomProps ?? {}).length, + }) + } + } catch (error: any) { + if (isCancelled) return + const attempt = fallbackAttemptsRef.current + 1 + fallbackAttemptsRef.current = attempt + if (attempt <= 3) { + if (process.env.NODE_ENV !== "production" && typeof window !== "undefined") { + console.warn("[PromptConfigCard] Fallback prompt fetch failed, retrying", { + runId: evaluation?.id, + variantId: normalizedVariantId, + attempt, + error, + }) + } + retryTimeout = setTimeout(() => { + setFallbackTrigger((prev) => prev + 1) + }, 500 * attempt) + } else if (process.env.NODE_ENV !== "production" && typeof window !== "undefined") { + console.error("[PromptConfigCard] Fallback prompt fetch failed", { + runId: evaluation?.id, + variantId: normalizedVariantId, + attempt, + error, + }) + } + } + } + + run() + + return () => { + isCancelled = true + if (retryTimeout) clearTimeout(retryTimeout) + } + }, [ + shouldAttemptFallback, + fallbackTrigger, + normalizedVariantId, + fallbackVariantSource, + evaluation?.id, + parameters, + ]) + + const usingFallbackPrompts = basePrompts.length === 0 && fallbackPrompts.length > 0 + const usingFallbackCustomProps = !baseCustomPropsHasContent && fallbackCustomPropsPopulated + const parametersSource = + usingFallbackPrompts || usingFallbackCustomProps + ? "project-fallback" + : selectedVariant + ? "run" + : projectScopedVariant + ? "project-scoped" + : "none" + + if ( + process.env.NODE_ENV !== "production" && + typeof window !== "undefined" && + normalizedVariantId + ) { + console.info("[PromptConfigCard] Render", { + runId: evaluation?.id, + variantId: normalizedVariantId, + hasSelectedVariant: Boolean(selectedVariant), + usingProjectFallback: Boolean(!selectedVariant && projectScopedVariant), + hasPrompts, + hasCustomProps: combinedCustomPropsHasContent, + showPrompts, + parametersSource, + usingFallbackPrompts, + usingFallbackCustomProps, + }) + } + + return ( +
2}, + {"!rounded-r-none": isComparison && isFirstPrompt}, + {"!rounded-none": isComparison && isMiddlePrompt}, + {"!rounded-l-none": isComparison && isLastPrompt}, + ])} + > +
+
+ + {variantForDisplay ? ( + + ) : ( + + Variant unavailable + + )} +
+
+ + {showSkeleton ? ( + + ) : showPrompts ? ( + +
+ {promptsList.map((prompt) => ( + + ))} + +
+
+ ) : ( +
+ + {emptyDescription} + + } + image={Empty.PRESENTED_IMAGE_SIMPLE} + /> +
+ )} +
+ ) +} + +export default memo(PromptConfigCard) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/index.tsx new file mode 100644 index 0000000000..866d83eb3a --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunPromptConfigViewer/index.tsx @@ -0,0 +1,152 @@ +import {memo, useEffect, useMemo, useRef} from "react" + +import clsx from "clsx" +import deepEqual from "fast-deep-equal" +import {atom, useAtomValue, useSetAtom} from "jotai" +import {atomFamily} from "jotai/utils" + +import {useRunId} from "@/oss/contexts/RunIdContext" +import {evaluationRunStateFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import { + clearProjectVariantReferencesAtom, + prefetchProjectVariantConfigs, + setProjectVariantReferencesAtom, +} from "@/oss/state/projectVariantConfig" +import {projectIdAtom} from "@/oss/state/project/selectors/project" + +import {urlStateAtom} from "../../../state/urlState" +import {collectProjectVariantReferences} from "../../../../../lib/hooks/usePreviewEvaluations/projectVariantConfigs" + +import PromptConfigCard from "./assets/PromptConfigCard" + +// Helper atom to read multiple run states given a list of runIds +const evaluationsRunFamily = atomFamily( + (runIds: string[]) => + atom((get) => { + return runIds.map((runId) => get(evaluationRunStateFamily(runId))) + }), + deepEqual, +) + +const EvalRunPromptConfigViewer = () => { + const runId = useRunId() + const urlState = useAtomValue(urlStateAtom) + const compareRunIds = urlState?.compare + + // Read base run and all comparison run states + const runIds = useMemo(() => { + if (!compareRunIds?.length) return [runId!] + return [runId!, ...compareRunIds] + }, [runId, compareRunIds]) + + const runs = useAtomValue(evaluationsRunFamily(runIds)) + const renderableRuns = useMemo( + () => runs?.filter((run) => Boolean(run?.enrichedRun)) ?? [], + [runs], + ) + const projectId = useAtomValue(projectIdAtom) + const setProjectVariantReferences = useSetAtom(setProjectVariantReferencesAtom) + const clearProjectVariantReferences = useSetAtom(clearProjectVariantReferencesAtom) + + const projectVariantReferences = useMemo(() => { + if (!projectId || !renderableRuns.length) return [] + const enrichedRuns = renderableRuns + .map((run) => run.enrichedRun) + .filter((run): run is NonNullable => Boolean(run)) + return collectProjectVariantReferences(enrichedRuns, projectId) + }, [projectId, renderableRuns]) + const referencesSetRef = useRef(false) + + useEffect(() => { + if (process.env.NODE_ENV !== "production" && typeof window !== "undefined") { + console.info("[EvalRunPromptConfigViewer] Renderable runs", { + total: runs?.length ?? 0, + renderable: renderableRuns.length, + runIds, + enrichedRunIds: renderableRuns.map((r) => r.enrichedRun?.id), + }) + } + }, [runIds, runs, renderableRuns]) + + useEffect(() => { + if (!projectId || projectVariantReferences.length === 0) { + if (process.env.NODE_ENV !== "production" && typeof window !== "undefined") { + console.info("[EvalRunPromptConfigViewer] No project variant references derived", { + projectId, + renderableRuns: renderableRuns.length, + }) + } + return + } + setProjectVariantReferences(projectVariantReferences) + prefetchProjectVariantConfigs(projectVariantReferences) + referencesSetRef.current = true + if (process.env.NODE_ENV !== "production" && typeof window !== "undefined") { + console.info("[EvalRunPromptConfigViewer] Prefetch project variants", { + projectId, + referenceCount: projectVariantReferences.length, + references: projectVariantReferences, + }) + } + }, [ + projectId, + projectVariantReferences, + setProjectVariantReferences, + prefetchProjectVariantConfigs, + ]) + + useEffect( + () => () => { + if (referencesSetRef.current) { + clearProjectVariantReferences() + referencesSetRef.current = false + if (process.env.NODE_ENV !== "production" && typeof window !== "undefined") { + console.info("[EvalRunPromptConfigViewer] Cleared project variant references") + } + } + }, + [clearProjectVariantReferences], + ) + + return ( +
0}])}> + {renderableRuns.map((run, idx) => { + const enriched = run.enrichedRun! + const variants = Array.isArray(enriched?.variants) ? enriched.variants : [] + + const primaryVariant = + variants.find((variant) => { + const revisionId = + (variant as any)?._revisionId ?? + (variant as any)?.id ?? + variant?.variantId + return Boolean(revisionId) + }) ?? variants[0] + + const variantRevisionId = + (primaryVariant as any)?._revisionId ?? + (primaryVariant as any)?.id ?? + primaryVariant?.variantId ?? + "" + + const reactKey = variantRevisionId || `${enriched.id || "run"}-${idx}` + + return ( + 0} + compareIndex={run.compareIndex || 1} + isFirstPrompt={idx === 0} + isMiddlePrompt={idx > 0 && idx < renderableRuns.length - 1} + isLastPrompt={idx === renderableRuns.length - 1} + totalRuns={renderableRuns.length} + /> + ) + })} +
+ ) +} + +export default memo(EvalRunPromptConfigViewer) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/EvalRunScoreTableSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/EvalRunScoreTableSkeleton.tsx new file mode 100644 index 0000000000..ee6c0d540b --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/EvalRunScoreTableSkeleton.tsx @@ -0,0 +1,21 @@ +import {memo} from "react" + +import {Skeleton} from "antd" +import clsx from "clsx" + +const EvalRunScoreTableSkeleton = ({className}: {className?: string}) => { + return ( +
+
+ + +
+
+ + +
+
+ ) +} + +export default memo(EvalRunScoreTableSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/TraceMetrics.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/TraceMetrics.tsx new file mode 100644 index 0000000000..06a6cb9d4a --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/TraceMetrics.tsx @@ -0,0 +1,49 @@ +import {memo} from "react" + +import {Timer, Coins, PlusCircle} from "@phosphor-icons/react" +import {Space, Tooltip} from "antd" + +import {formatCurrency, formatLatency, formatTokenUsage} from "@/oss/lib/helpers/formatters" + +const TraceMetrics = ({latency, cost, tokens}: {latency: number; cost: number; tokens: number}) => { + return ( +
+ + +
+ + {formatLatency(latency)} +
+
+ + +
+ + {formatCurrency(cost)} +
+
+ + +
+ + {formatTokenUsage(tokens)} +
+
+
+
+ ) +} + +export default memo(TraceMetrics) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/constants.ts b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/constants.ts new file mode 100644 index 0000000000..f4947321b6 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/assets/constants.ts @@ -0,0 +1,17 @@ +import {ColumnType} from "antd/es/table" + +export const FIXED_COLUMNS: ColumnType[] = [ + { + title: "Metric", + dataIndex: "title", + key: "title", + minWidth: 120, + fixed: "left", + }, + { + title: "Label", + dataIndex: "label", + key: "label", + minWidth: 120, + }, +] diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/index.tsx new file mode 100644 index 0000000000..74490dd563 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunScoreTable/index.tsx @@ -0,0 +1,510 @@ +import {isValidElement, cloneElement, memo, useCallback, useMemo} from "react" + +import {Table, Typography} from "antd" +import clsx from "clsx" +import deepEqual from "fast-deep-equal" +import {atom, useAtomValue} from "jotai" +import {atomFamily} from "jotai/utils" +import dynamic from "next/dynamic" + +import {formatColumnTitle} from "@/oss/components/Filters/EditColumns/assets/helper" +import {formatMetricValue} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils" +import {useRunId} from "@/oss/contexts/RunIdContext" +import useURL from "@/oss/hooks/useURL" +import {formatLatency} from "@/oss/lib/helpers/formatters" +import {evaluationRunStateFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {runMetricStatsFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import { + BasicStats, + canonicalizeMetricKey, + getMetricDisplayName, + getMetricValueWithAliases, +} from "@/oss/lib/metricUtils" + +import RenameEvalButton from "../../../HumanEvalRun/components/Modals/RenameEvalModal/assets/RenameEvalButton" +import {urlStateAtom} from "../../../state/urlState" +import EvalNameTag from "../../assets/EvalNameTag" +import TagWithLink from "../../assets/TagWithLink" +import {EVAL_TAG_COLOR, EVAL_COLOR, EVAL_BG_COLOR} from "../../assets/utils" +import {formatMetricName} from "../../assets/utils" +import VariantTag from "../../assets/VariantTag" +import {getVariantDisplayMetadata} from "../../assets/variantUtils" +import type {EvaluatorMetricsSpiderChartProps} from "../EvaluatorMetircsSpiderChart/types" + +const EvaluatorMetricsSpiderChart = dynamic( + () => import("../EvaluatorMetircsSpiderChart"), + {ssr: false}, +) + +// Atom helpers to read multiple runs' state/metrics in one go +const runsStateFamily = atomFamily( + (runIds: string[]) => atom((get) => runIds.map((id) => get(evaluationRunStateFamily(id)))), + deepEqual, +) +const runsMetricsFamily = atomFamily( + (runIds: string[]) => + atom((get) => runIds.map((id) => ({id, metrics: get(runMetricStatsFamily({runId: id}))}))), + deepEqual, +) + +const INVOCATION_METRIC_KEYS = [ + "attributes.ag.metrics.costs.cumulative.total", + "attributes.ag.metrics.duration.cumulative", + "attributes.ag.metrics.tokens.cumulative.total", + "attributes.ag.metrics.errors.cumulative", +] as const + +const INVOCATION_METRIC_SET = new Set(INVOCATION_METRIC_KEYS) + +const COST_METRIC_KEY = INVOCATION_METRIC_KEYS[0] +const DURATION_METRIC_KEY = INVOCATION_METRIC_KEYS[1] +const TOKEN_METRIC_KEY = INVOCATION_METRIC_KEYS[2] +const ERRORS_METRIC_KEY = INVOCATION_METRIC_KEYS[3] + +const INVOCATION_METRIC_COLUMNS: Array<{key: string; label: string}> = [ + {key: COST_METRIC_KEY, label: "Cost (Total)"}, + {key: DURATION_METRIC_KEY, label: "Duration (Total)"}, + {key: TOKEN_METRIC_KEY, label: "Total tokens"}, + {key: ERRORS_METRIC_KEY, label: "Errors"}, +] + +const EvalRunScoreTable = ({className}: {className?: string}) => { + const baseRunId = useRunId() + const {projectURL} = useURL() + const urlState = useAtomValue(urlStateAtom) + const compareRunIds = (urlState?.compare || []).filter((id) => id && id !== baseRunId) + const allRunIds = useMemo(() => [baseRunId!, ...compareRunIds], [baseRunId, compareRunIds]) + + const isComparison = compareRunIds.length > 0 + + // Fetch all runs and their metrics + const runs = useAtomValue(runsStateFamily(allRunIds)) + const metricsByRun = useAtomValue(runsMetricsFamily(allRunIds)) + + // Convenience lookup maps + const evalById = useMemo(() => { + const map: Record = {} + runs.forEach((r) => (map[r.enrichedRun?.id || r.id] = r)) + return map + }, [runs]) + + const metricsLookup = useMemo(() => { + const map: Record> = {} + metricsByRun.forEach(({id, metrics}) => { + const source = (metrics || {}) as Record + const normalized: Record = {...(source as any)} + Object.keys(source || {}).forEach((rawKey) => { + const canonical = canonicalizeMetricKey(rawKey) + if (canonical !== rawKey && normalized[canonical] === undefined) { + normalized[canonical] = source[rawKey] + } + }) + map[id] = normalized + }) + return map + }, [metricsByRun]) + + const getFrequencyData = useCallback((metric: any, returnPercentage = true) => { + const trueEntry = (metric as any)?.frequency?.find((f: any) => f?.value === true) + const total = (metric as any)?.count ?? 0 + return returnPercentage + ? `${(((trueEntry?.count ?? 0) / total) * 100).toFixed(2)}%` + : ((trueEntry?.count ?? 0) / total) * 100 + }, []) + + const chartMetrics = useMemo(() => { + // Build union of evaluator metrics across all runs, then add invocation metrics per rules. + interface Axis { + name: string + maxScore: number + type: "binary" | "numeric" + value?: number + [k: string]: any + _key: string + } + + const axesByKey: Record = {} + + // 1) Union evaluator metrics from all runs + allRunIds.forEach((runId, runIdx) => { + const stats = metricsLookup[runId] || {} + const evaluators = evalById[runId]?.enrichedRun?.evaluators + const processed = new Set() + + Object.keys(stats).forEach((rawKey) => { + const canonicalKey = canonicalizeMetricKey(rawKey) + if (processed.has(canonicalKey)) return + processed.add(canonicalKey) + + if (INVOCATION_METRIC_SET.has(canonicalKey)) return + if (!canonicalKey.includes(".")) return + + const metric = getMetricValueWithAliases(stats, canonicalKey) + if (!metric) return + + const [evalSlug, ...metricParts] = canonicalKey.split(".") + const metricRemainder = metricParts.join(".") + const evaluator = evaluators?.find((e: any) => e.slug === evalSlug) + if (!evaluator) return + + const axisKey = canonicalKey + const isBinary = Array.isArray((metric as any)?.frequency) + const displayMetricName = metricRemainder + ? formatMetricName(metricRemainder) + : formatMetricName(canonicalKey) + + if (!axesByKey[axisKey]) { + axesByKey[axisKey] = { + name: `${evaluator?.name ?? evalSlug} - ${displayMetricName}`, + maxScore: isBinary ? 100 : (metric as any)?.max || 100, + type: isBinary ? "binary" : "numeric", + _key: axisKey, + } + } else if (!isBinary) { + const mx = (metric as any)?.max + if (typeof mx === "number") { + axesByKey[axisKey].maxScore = Math.max(axesByKey[axisKey].maxScore, mx) + } + } + + const seriesKey = runIdx === 0 ? "value" : `value-${runIdx + 1}` + const v = isBinary ? getFrequencyData(metric, false) : ((metric as any)?.mean ?? 0) + axesByKey[axisKey][seriesKey] = v + }) + }) + + let axes: Axis[] = Object.values(axesByKey) + + // 2) Invocation metrics only when evaluator metrics are fewer than 3 (based on union) + const evaluatorCount = axes.length + const addInvocationAxis = (metricKey: string, label?: string) => { + const axis: Axis = { + name: label ?? getMetricDisplayName(metricKey), + maxScore: 0, + type: "numeric", + _key: metricKey, + } + allRunIds.forEach((runId, runIdx) => { + const stats = metricsLookup[runId] || {} + const metric = getMetricValueWithAliases(stats, metricKey) as BasicStats | any + const seriesKey = runIdx === 0 ? "value" : `value-${runIdx + 1}` + axis[seriesKey] = metric?.mean ?? 0 + const mx = metric?.max + if (typeof mx === "number") axis.maxScore = Math.max(axis.maxScore, mx) + }) + axes.push(axis) + } + + if (evaluatorCount < 3) { + if (evaluatorCount === 2) { + addInvocationAxis(COST_METRIC_KEY, "Invocation costs") + } else if (evaluatorCount <= 1) { + addInvocationAxis(DURATION_METRIC_KEY, "Invocation duration") + addInvocationAxis(COST_METRIC_KEY, "Invocation costs") + } + } + + // 3) Ensure all series keys exist for each axis + if (axes.length > 0) { + allRunIds.forEach((_, runIdx) => { + const seriesKey = runIdx === 0 ? "value" : `value-${runIdx + 1}` + axes.forEach((a) => { + if (a[seriesKey] === undefined) a[seriesKey] = 0 + }) + }) + } + + return axes.map(({_key, ...rest}) => rest) + }, [allRunIds, evalById, metricsLookup, getFrequencyData]) + + const dataSource = useMemo(() => { + // Build union of all metric keys across runs + const metricKeys = new Set() + allRunIds.forEach((id) => { + const m = metricsLookup[id] || {} + Object.keys(m).forEach((k) => metricKeys.add(canonicalizeMetricKey(k))) + }) + + // const baseEval = evalById[baseRunId!]?.enrichedRun + const rows: any[] = [] + + // Test Sets row + const testsetRow: any = {key: "testsets", title: "Test Sets", values: {}} + allRunIds.forEach((id) => { + if (baseRunId !== id) return + const enr = evalById[id]?.enrichedRun + const tags = (enr?.testsets || []).map((t: any) => ( + + )) + testsetRow.values[id] = tags.length ? tags[0] : "" + }) + rows.push(testsetRow) + + // Evaluations row + const evalsRow: any = {key: "evaluations", title: "Evaluations", values: {}} + allRunIds.forEach((id) => { + const state = evalById[id] + const enr = state?.enrichedRun + const color = EVAL_TAG_COLOR?.[state?.compareIndex || 1] + // evalsRow.values[id] = enr ? : "" + evalsRow.values[id] = enr ? ( +
+ + + + +
+ ) : ( + "" + ) + }) + rows.push(evalsRow) + + // date row + const dateRow: any = {key: "date", title: "Created at", values: {}} + allRunIds.forEach((id) => { + const state = evalById[id] + const enr = state?.enrichedRun + dateRow.values[id] = enr?.createdAt + }) + rows.push(dateRow) + + // Variants row + const variantsRow: any = {key: "variants", title: "Variants", values: {}} + allRunIds.forEach((id) => { + const enr = evalById[id]?.enrichedRun + const v = enr?.variants?.[0] as any + if (!v) { + variantsRow.values[id] =
N/A
+ return + } + const summary = getVariantDisplayMetadata(v) + variantsRow.values[id] = ( + + ) + }) + rows.push(variantsRow) + + // Metric rows (generic + evaluator) + const pushMetricRow = (key: string, labelNode: any) => { + const row: any = {key, title: labelNode, values: {}} + allRunIds.forEach((id) => { + const metric = getMetricValueWithAliases(metricsLookup[id] || {}, key) as + | BasicStats + | any + let value: any + + if (metric && (metric as any)?.mean !== undefined) { + const meanValue = (metric as any).mean + value = + key === DURATION_METRIC_KEY + ? formatLatency(meanValue) + : formatMetricValue(key, meanValue) + } else if ( + metric && + Array.isArray((metric as any)?.unique) && + typeof (metric as any)?.unique?.[0] === "boolean" + ) { + value = getFrequencyData(metric) + } + + row.values[id] = + value === undefined || value === null || value === "" ? ( +
+ ) : ( + value + ) + }) + rows.push(row) + } + + INVOCATION_METRIC_COLUMNS.forEach(({key: canonicalKey, label}) => { + const baseMetric = getMetricValueWithAliases( + metricsLookup[baseRunId!] || {}, + canonicalKey, + ) as any + const hasMean = baseMetric && (baseMetric as any)?.mean !== undefined + const titleNode = ( +
+ {label} + {hasMean && (mean)} +
+ ) + pushMetricRow(canonicalKey, titleNode) + }) + + // Evaluator metrics grouped by evaluator slug + const allEvaluatorEntries: {slug: string; metricKey: string; fullKey: string}[] = [] + Array.from(metricKeys) + .filter((k) => !INVOCATION_METRIC_SET.has(k) && k.includes(".")) + .forEach((fullKey) => { + const [slug, ...restParts] = fullKey.split(".") + const metricKey = restParts.join(".") || slug + allEvaluatorEntries.push({slug, metricKey, fullKey}) + }) + + // Maintain stable order by slug then metricKey + allEvaluatorEntries + .sort((a, b) => + a.slug === b.slug + ? a.metricKey.localeCompare(b.metricKey) + : a.slug.localeCompare(b.slug), + ) + .forEach(({slug, metricKey, fullKey}) => { + const state = evalById[baseRunId!] + const evaluator = state?.enrichedRun?.evaluators?.find((e: any) => e.slug === slug) + const baseMetric = getMetricValueWithAliases( + metricsLookup[baseRunId!] || {}, + fullKey, + ) as any + const [, ...restParts] = fullKey.split(".") + const metricPath = restParts.length ? restParts.join(".") : metricKey + const labelSegment = metricPath.split(".").pop() || metricPath + const displayMetricName = formatColumnTitle(labelSegment) + const titleNode = ( +
+ + {evaluator?.name ?? formatColumnTitle(slug)} + +
+ {displayMetricName} + {/* Show (mean) if base has mean */} + {baseMetric && (baseMetric as any)?.mean !== undefined && ( + (mean) + )} +
+
+ ) + pushMetricRow(fullKey, titleNode) + }) + + return rows + }, [allRunIds, baseRunId, evalById, getFrequencyData, metricsLookup, runs]) + return ( +
+
+ Evaluator Scores Overview + + Average evaluator score across evaluations + +
+ +
+
+ { + // First column is the label/title + const cols: any[] = [ + { + title: "Metric", + dataIndex: "title", + key: "title", + minWidth: 120, + fixed: "left", + }, + ] + + // One value column per run (base + comparisons) + allRunIds.forEach((id, idx) => { + const state = evalById[id] + const compareIdx = state?.compareIndex || idx + 1 + cols.push({ + title: idx === 0 ? "Label" : `Label_${idx + 1}`, + key: `label_${id}`, + render: (_: any, record: any) => { + // Merge "Test Sets" row across all run columns + if (record?.key === "testsets") { + if (id === allRunIds[0]) { + return { + children: + record?.values?.[baseRunId] ?? + record?.values?.[id] ?? + "", + props: {colSpan: allRunIds.length}, + } + } + return {children: null, props: {colSpan: 0}} + } + const value = record?.values?.[id] + if (!value) return "-" + if (record?.key !== "evaluations") return value + + const runState = evalById[id] + const enriched = runState?.enrichedRun + const firstVariant: any = enriched?.variants?.[0] + const summary = getVariantDisplayMetadata(firstVariant) + + if (isValidElement(value)) { + return cloneElement(value as any, { + allowVariantNavigation: summary.canNavigate, + variantName: summary.label, + id: summary.revisionId || undefined, + }) + } + + return summary.label + }, + minWidth: 120, + onCell: (record: any) => ({ + style: + isComparison && record?.key !== "testsets" + ? {background: (EVAL_BG_COLOR as any)[compareIdx]} + : undefined, + }), + }) + }) + + return cols + }, [allRunIds, baseRunId, isComparison, evalById])} + pagination={false} + showHeader={false} + bordered + scroll={{x: "max-content"}} + rowKey={(r) => r.key} + /> + + { + return allRunIds.map((id, idx) => { + const state = evalById[id] + const compareIdx = state?.compareIndex || idx + 1 + return { + key: idx === 0 ? "value" : `value-${idx + 1}`, + color: (EVAL_COLOR as any)[compareIdx] || "#3B82F6", + name: state?.enrichedRun?.name || `Eval ${compareIdx}`, + } + }) + }, [allRunIds, evalById])} + /> + + + ) +} + +export default memo(EvalRunScoreTable) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunSelectedEvaluations/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunSelectedEvaluations/index.tsx new file mode 100644 index 0000000000..12f4a6527a --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunSelectedEvaluations/index.tsx @@ -0,0 +1,73 @@ +import {memo, useMemo} from "react" + +import deepEqual from "fast-deep-equal" +import {atom, useAtomValue} from "jotai" +import {atomFamily} from "jotai/utils" + +import { + evalAtomStore, + evaluationRunStateFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import EvalNameTag from "../../assets/EvalNameTag" +import {EVAL_TAG_COLOR} from "../../assets/utils" + +const comparisonRunsAtom = atomFamily( + (runIds: string[]) => + atom((get) => { + return runIds.map((runId) => { + const state = get(evaluationRunStateFamily(runId)) + return { + runId, + run: state?.enrichedRun, + compareIndex: state?.compareIndex, + isBase: state?.isBase, + isComparison: state?.isComparison, + } + }) + }), + deepEqual, +) +const EvalRunSelectedEvaluations = ({runIds, baseRunId}: {runIds: string[]; baseRunId: string}) => { + // Build a stable, de-duplicated list so transient states (during swaps) don't render duplicates + const uniqueIds = useMemo(() => { + const list = [baseRunId, ...runIds] + const seen = new Set() + return list.filter((id) => { + if (!id || seen.has(id)) return false + seen.add(id) + return true + }) + }, [baseRunId, runIds.join(",")]) + + const runs = useAtomValue(comparisonRunsAtom(uniqueIds), {store: evalAtomStore()}) + + return ( +
+ Evaluations: +
+
+ {runs + ?.filter((r) => Boolean(r?.run)) + .map((r) => { + const idx = r?.compareIndex ?? (r?.isBase ? 1 : undefined) + const color = idx ? (EVAL_TAG_COLOR as any)[idx] : undefined + return ( + + ) + })} +
+ {/*
*/} +
+
+ ) +} + +export default memo(EvalRunSelectedEvaluations) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewUtilityOptions/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewUtilityOptions/index.tsx new file mode 100644 index 0000000000..35f0777096 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewUtilityOptions/index.tsx @@ -0,0 +1,332 @@ +import {Dispatch, memo, SetStateAction, useCallback, useMemo, useState} from "react" + +import {message} from "antd" +import {ColumnsType} from "antd/es/table" +import {useAtomValue} from "jotai" +import {useRouter} from "next/router" + +import EditColumns from "@/oss/components/Filters/EditColumns" +import {useRunId} from "@/oss/contexts/RunIdContext" +import {convertToStringOrJson} from "@/oss/lib/helpers/utils" +import { + evalAtomStore, + evaluationEvaluatorsFamily, + evaluationRunStateFamily, + scenarioIdsFamily, + scenarioStepFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {scenarioMetricsMapFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" + +import EvalRunScenarioNavigator from "../../../components/EvalRunScenarioNavigator" +import SaveDataButton from "../../../components/SaveDataModal/assets/SaveDataButton" +import useExpandableComparisonDataSource from "../../../components/VirtualizedScenarioTable/hooks/useExpandableComparisonDataSource" +import {metricsFromEvaluatorsFamily} from "../../../components/VirtualizedScenarioTable/hooks/useTableDataSource" +import {urlStateAtom} from "../../../state/urlState" + +const EMPTY_ROWS: any[] = [] + +interface ScenarioCsvRow { + scenarioId?: string + record: Record +} + +const EvalRunTestCaseViewUtilityOptions = ({ + columns, + setEditColumns, +}: { + columns: ColumnsType + setEditColumns: Dispatch> +}) => { + const runId = useRunId() + const router = useRouter() + // states for select dropdown + const [rows, setRows] = useState(EMPTY_ROWS) + const evaluation = useAtomValue(evaluationRunStateFamily(runId)) + const urlState = useAtomValue(urlStateAtom) + // Determine runs to include: base + comparisons (unique, exclude base duplicates) + const compareRunIds = useMemo( + () => (urlState?.compare || []).filter(Boolean) as string[], + [urlState?.compare], + ) + const hasComparisons = compareRunIds.length > 0 + const allRunIds = Array.from(new Set([runId, ...compareRunIds.filter((id) => id !== runId)])) + const selectedScenarioId = router.query.scrollTo as string + + const {rawColumns: comparisonRawColumns} = useExpandableComparisonDataSource({ + baseRunId: runId, + comparisonRunIds: compareRunIds, + }) + + const csvDataFormat = useCallback(async () => { + const store = evalAtomStore() + + // Helper: build rows for a single run + const buildRowsForRun = async (rId: string): Promise => { + // 1) Scenario IDs and evaluator info for this run + const ids = store.get(scenarioIdsFamily(rId)) + const evaluatorsRaw = store.get(evaluationEvaluatorsFamily(rId)) || [] + const evaluatorList: any[] = Array.isArray(evaluatorsRaw) + ? (evaluatorsRaw as any[]) + : Object.values(evaluatorsRaw as any) + const evaluatorSlugs = evaluatorList.map((e: any) => e.slug) + + // 2) Resolve steps and metrics for this run + const [scenarioMetricsMap, ...allScenarios] = await Promise.all([ + store.get(scenarioMetricsMapFamily(rId)), + ...ids.map((id) => store.get(scenarioStepFamily({runId: rId, scenarioId: id}))), + ]) + + // Evaluation name for this run (for column 'name' when comparing) + const runState = store.get(evaluationRunStateFamily(rId)) + const evalName = runState?.enrichedRun?.name + + // 3) Build CSV-friendly rows for this run + const rowsForRun: ScenarioCsvRow[] = [] + + allScenarios.forEach((scenario) => { + if (!scenario) return + const sid = scenario.steps?.[0]?.scenarioId + const scenarioId = sid ? String(sid) : undefined + + const primaryInput = scenario.inputSteps?.find((s: any) => s.inputs) || {} + const {inputs = {}, groundTruth = {}, status: inputStatus} = primaryInput as any + + const record: Record = {} + + // When in comparison mode, include evaluation name + if (hasComparisons && evalName) { + record.name = evalName + } + + // 1. Add input + if (!Object.keys(groundTruth).length) { + Object.entries(primaryInput.testcase?.data || {}).forEach(([k, v]) => { + if (k === "testcase_dedup_id") return + record[`input.${k}`] = convertToStringOrJson(v) + }) + } else { + Object.entries(inputs || {}).forEach(([k, v]) => { + record[`input.${k}`] = convertToStringOrJson(v) + }) + } + + // 2. Add output + // Extract model output from the first invocation step that contains a trace + const invWithTrace = scenario.invocationSteps?.find((inv: any) => inv.trace) + + if (!invWithTrace) { + const invWithErr = scenario.invocationSteps?.find((inv: any) => inv.error) + if (invWithErr) { + record.output = convertToStringOrJson( + invWithErr.error?.stacktrace || invWithErr.error, + ) + } + } + + if (invWithTrace) { + const traceObj = invWithTrace?.trace + let traceOutput: any + if (Array.isArray(traceObj?.nodes)) { + traceOutput = traceObj.nodes[0]?.data?.outputs + } else if (Array.isArray(traceObj?.trees)) { + traceOutput = traceObj.trees[0]?.nodes?.[0]?.data?.outputs + } + + if (traceOutput) { + record.output = convertToStringOrJson(traceOutput) + } + } + + // 3. Add status + if (!invWithTrace) { + const _invWithTrace = scenario.invocationSteps?.find((inv: any) => inv.error) + record.status = _invWithTrace?.status ?? "unknown" + } else { + record.status = invWithTrace?.status ?? "unknown" + } + + // 4. Add annotation and metrics/errors + const annSteps = scenario.steps.filter((step) => + evaluatorSlugs.includes(step.stepKey), + ) + const steps = annSteps.length + ? annSteps + : scenario.invocationSteps?.filter((inv: any) => inv.error) + const annotation = scenarioMetricsMap?.[sid] + + // Prefill metric columns so compare-eval metrics are visible even if values missing yet + const evalMetricsDefs = store.get(metricsFromEvaluatorsFamily(rId)) as any + if (evalMetricsDefs && typeof evalMetricsDefs === "object") { + Object.entries(evalMetricsDefs).forEach(([slug, defs]: [string, any[]]) => { + const evaluator = evaluatorList?.find((e) => e?.slug === slug) + if (!Array.isArray(defs)) return + defs.forEach((metricDef) => { + Object.keys(metricDef || {}) + .filter((k) => k !== "evaluatorSlug") + .forEach((metricName) => { + const key = `${evaluator?.name || slug}.${metricName}` + if (!(key in record)) record[key] = "" + }) + }) + }) + } + + if (steps?.some((step) => step.error) || invWithTrace?.error) { + const evalMetrics = store.get(metricsFromEvaluatorsFamily(rId)) + steps.forEach((step) => { + if (!step.error) return null + + const errorMessage = + step.error.stacktrace || step?.error?.message || step.error + Object.entries(evalMetrics || {}).forEach(([k, v]) => { + if (Array.isArray(v)) { + v.forEach((metric) => { + const evaluator = (evaluatorList as any[])?.find( + (e) => e?.slug === metric?.evaluatorSlug, + ) + const {evaluatorSlug, ...rest} = metric + + Object.keys(rest || {}).forEach((metricKey) => { + if (evaluator) { + record[`${evaluator?.name}.${metricKey}`] = + convertToStringOrJson(errorMessage) + } else { + record[`${metric?.evaluatorSlug}.${metricKey}`] = + convertToStringOrJson(errorMessage) + } + }) + }) + } + }) + }) + } + + if (annotation) { + Object.entries(annotation || {}).forEach(([k, v]) => { + if (!k.includes(".")) return + const [evalSlug, metricName] = k.split(".") + if (["error", "errors"].includes(metricName)) return + const evaluator = (evaluatorList as any[])?.find( + (e) => e?.slug === evalSlug, + ) + + if ((v as any).mean) { + record[`${evaluator?.name}.${metricName}`] = (v as any)?.mean + } else if ((v as any).unique) { + const mostFrequent = (v as any).frequency.reduce( + (max: any, current: any) => + current.count > max.count ? current : max, + ).value + record[`${evaluator?.name}.${metricName}`] = String(mostFrequent) + } else if (v && typeof v !== "object") { + record[`${evaluator?.name}.${metricName}`] = + typeof v === "number" + ? String(v).includes(".") + ? (v as number).toFixed(3) + : v + : convertToStringOrJson(v) + } + }) + } + rowsForRun.push({record, scenarioId}) + }) + + return rowsForRun + } + + // Build data across all runs + const rowsByRun = new Map() + const lookupByRun = new Map>() + + for (const rId of allRunIds) { + const rows = await buildRowsForRun(rId) + rowsByRun.set(rId, rows) + + const scenarioLookup = new Map() + rows.forEach((row) => { + if (row && row.scenarioId) { + scenarioLookup.set(row.scenarioId, row) + } + }) + lookupByRun.set(rId, scenarioLookup) + } + + if (!hasComparisons) { + const baseRows = rowsByRun.get(runId) || [] + return baseRows.map(({record}) => record) + } + + const orderedResults: Record[] = [] + const baseRows = rowsByRun.get(runId) || [] + const uniqueCompareRunIds = Array.from( + new Set(compareRunIds.filter((id) => id && id !== runId)), + ) + + baseRows.forEach((baseRow, index) => { + if (!baseRow) return + orderedResults.push(baseRow.record) + + uniqueCompareRunIds.forEach((compareId) => { + const compareRows = rowsByRun.get(compareId) || [] + if (!compareRows.length) return + + const scenarioLookup = lookupByRun.get(compareId) + const matchedRow = + (baseRow.scenarioId && scenarioLookup?.get(baseRow.scenarioId)) || + compareRows[index] + + if (matchedRow) { + orderedResults.push(matchedRow.record) + } + }) + }) + + return orderedResults + }, [runId, evalAtomStore, allRunIds, compareRunIds, hasComparisons]) + + const onClickSaveData = useCallback(async () => { + try { + const data = await csvDataFormat() + setRows(data) + } catch (error) { + message.error("Failed to export results") + } + }, [csvDataFormat]) + + return ( +
+
+ Go to test case: + +
+
+ + { + setEditColumns(keys) + }} + /> +
+
+ ) +} + +export default memo(EvalRunTestCaseViewUtilityOptions) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewer/assets/EvalRunTestCaseViewerSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewer/assets/EvalRunTestCaseViewerSkeleton.tsx new file mode 100644 index 0000000000..de4dc79c9f --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewer/assets/EvalRunTestCaseViewerSkeleton.tsx @@ -0,0 +1,77 @@ +import {memo} from "react" + +import {Skeleton} from "antd" + +export const EvalRunTestCaseTableSkeleton = memo( + ({rows = 8, cols = 5, rowHight = 60}: {rows?: number; cols?: number; rowHight?: number}) => { + return ( +
+
+ + + {Array.from({length: cols}).map((_, colIndex) => ( + + ))} + + + + {Array.from({length: rows}).map((_, rowIndex) => ( + + {Array.from({length: cols}).map((_, colIndex) => ( + + ))} + + ))} + +
+ +
+ +
+
+ ) + }, +) + +const EvalRunTestCaseViewerSkeleton = ({ + rows = 8, + cols = 5, + rowHight = 60, +}: { + rows?: number + cols?: number + rowHight?: number +}) => { + return ( +
+
+ +
+ + +
+
+ + +
+ ) +} + +export default memo(EvalRunTestCaseViewerSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewer/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewer/index.tsx new file mode 100644 index 0000000000..524724c9ea --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvalRunTestCaseViewer/index.tsx @@ -0,0 +1,31 @@ +import {memo} from "react" + +import VirtualizedScenarioTable from "@/oss/components/EvalRunDetails/components/VirtualizedScenarioTable" +import useTableDataSource from "@/oss/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useTableDataSource" + +import EvalRunTestCaseViewUtilityOptions from "../EvalRunTestCaseViewUtilityOptions" + +import EvalRunTestCaseViewerSkeleton from "./assets/EvalRunTestCaseViewerSkeleton" + +const EvalRunTestCaseViewer = () => { + const {antColumns, isLoadingSteps, setEditColumns, rawColumns} = useTableDataSource() + + if (isLoadingSteps) { + return + } + + return ( +
+ + +
+ +
+
+ ) +} + +export default memo(EvalRunTestCaseViewer) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetircsSpiderChart/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetircsSpiderChart/index.tsx new file mode 100644 index 0000000000..870d59a3b0 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetircsSpiderChart/index.tsx @@ -0,0 +1,223 @@ +import {memo, useMemo} from "react" + +import {Typography} from "antd" +import clsx from "clsx" +import { + PolarAngleAxis, + PolarGrid, + PolarRadiusAxis, + Radar, + RadarChart, + ResponsiveContainer, + Tooltip, +} from "recharts" + +import {format3Sig} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils" +import {formatCurrency, formatLatency} from "@/oss/lib/helpers/formatters" + +import {EVAL_COLOR} from "../../assets/utils" + +import {EvaluatorMetricsSpiderChartProps, MetricData, SeriesMeta} from "./types" + +const EvaluatorMetricsSpiderChart = ({ + className, + metrics = [], + maxScore = 100, + series = [{key: "value", color: EVAL_COLOR[1], name: "Eval 1"}], +}: EvaluatorMetricsSpiderChartProps) => { + // Build chart data with per-axis normalization to 0-100 so + // each axis can have its own maxScore while sharing a single radius scale. + const chartData: MetricData[] = useMemo(() => { + return metrics.map((m) => { + const axisMax = + typeof m.maxScore === "number" && isFinite(m.maxScore) && m.maxScore > 0 + ? m.maxScore + : maxScore + + const baseRaw = typeof m.value === "number" && isFinite(m.value) ? m.value : 0 + const baseNorm = Math.max(0, Math.min(100, (baseRaw / axisMax) * 100)) + + const obj: MetricData = { + subject: m.name, + value: baseNorm, + rawValue: baseRaw, + maxScore: axisMax, + type: m.type, + } + + // Add normalized values for additional series using same axis max + series.forEach((s) => { + const key = s.key + if (key === "value") return // already set + const raw = typeof m[key] === "number" && isFinite(m[key]) ? m[key] : 0 + const norm = Math.max(0, Math.min(100, (raw / axisMax) * 100)) + ;(obj as any)[key] = norm + }) + + return obj + }) + }, [metrics, maxScore, series]) + + if (metrics.length === 0) { + return ( +
+ No metrics available +
+ ) + } + + const LABEL_OFFSET = 12 // distance outside web + const NUDGE = 5 // small outward nudge + const RAD = Math.PI / 180 + + return ( +
+ + + + { + const {cx, cy, radius, payload, index} = props + const label = (payload?.value ?? "") as string + + const angle = Number(payload?.coordinate ?? 0) + const r = (radius ?? 0) + LABEL_OFFSET + + const x = cx + r * Math.cos(-angle * RAD) + const y = cy + r * Math.sin(-angle * RAD) + + const cos = Math.cos(-angle * RAD) + const sin = Math.sin(-angle * RAD) + + const textAnchor = + Math.abs(cos) < 0.1 ? "middle" : cos > 0 ? "start" : "end" + + const nudgeX = cos * NUDGE + const nudgeY = sin * NUDGE + + // simple 2-line clamp to avoid spilling into chart + const clampLines = (s: string, max = 18) => { + const parts = s.includes(" - ") ? s.split(" - ") : [s] + if (parts.length >= 2) return parts.slice(0, 2) + const words = s.split(/\s+/) + let line1 = "" + let line2 = "" + for (const w of words) { + if ((line1 + " " + w).trim().length <= max) + line1 = (line1 + " " + w).trim() + else if ((line2 + " " + w).trim().length <= max) + line2 = (line2 + " " + w).trim() + else { + line2 = (line2 || w).slice(0, max - 1) + "…" + break + } + } + return line2 ? [line1, line2] : [line1] + } + + const lines = clampLines(label, 18) + + return ( + + + {lines.map((ln, i) => ( + + {ln} + + ))} + + + ) + }} + /> + + { + try { + const d = payload?.payload as MetricData | undefined + if (!d) return [val, "Score"] + // val is normalized percentage for the active series + const pct = typeof val === "number" ? val : Number(val) + // Reconstruct raw from normalized and axis max (for numeric) + const rawFromPct = (pctNum: number) => + (pctNum / 100) * (d?.maxScore ?? 0) + + const color = + typeof payload?.color === "string" ? payload.color : "#0F172A" + const styledName = ( + {String(name)} + ) + + if (d.type === "binary") { + const valueLabel = `${pct.toFixed(2)}% / 100%` + return [ + {valueLabel}, + styledName, + ] + } + + // Numeric: format latency/costs specially when subject hints it + const raw = rawFromPct(pct) + const valueColor = {color, fontWeight: 600} + if (String(d?.subject).toLowerCase().includes("duration")) { + return [ + + {`${formatLatency(raw)} / ${formatLatency(d?.maxScore)}`} + , + styledName, + ] + } + if (String(d?.subject).toLowerCase().includes("cost")) { + return [ + + {`${formatCurrency(raw)} / ${formatCurrency(d?.maxScore)}`} + , + styledName, + ] + } + return [ + {`${format3Sig(raw)} / ${format3Sig( + d?.maxScore, + )}`}, + styledName, + ] + } catch (error) { + return [String(val), String(name)] + } + }} + /> + {series.map((s: SeriesMeta, i: number) => ( + + ))} + + +
+ ) +} + +export default memo(EvaluatorMetricsSpiderChart) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetircsSpiderChart/types.ts b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetircsSpiderChart/types.ts new file mode 100644 index 0000000000..6a485ee9bb --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetircsSpiderChart/types.ts @@ -0,0 +1,31 @@ +export interface MetricData { + subject: string + // Normalized values used for plotting (0-100) per series. + // Base series uses `value`; additional series use keys like `value-2`, `value-3`, ... + value?: number + [key: string]: any + // Raw value and axis-specific max for tooltip/labels (base series) + rawValue: number + maxScore: number + type?: "binary" | "numeric" +} + +export interface SeriesMeta { + key: string // e.g. "value", "value-2", ... + color: string + name?: string +} + +export interface EvaluatorMetricsSpiderChartProps { + className?: string + metrics: { + name: string + // Base value; additional series are passed via dynamic props (e.g., value-2) + value?: number + [key: string]: any + maxScore: number + type: "binary" | "numeric" + }[] + maxScore?: number + series?: SeriesMeta[] +} diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/BarChart.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/BarChart.tsx new file mode 100644 index 0000000000..50ae382b24 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/BarChart.tsx @@ -0,0 +1,286 @@ +import {memo, useMemo} from "react" + +import { + Bar, + CartesianGrid, + Cell, + BarChart as RechartsBarChart, + ResponsiveContainer, + Tooltip, + TooltipProps, + XAxis, + YAxis, +} from "recharts" + +type ChartDatum = Record + +interface BarChartProps { + data: readonly ChartDatum[] + xKey: string + yKey: string + /** optional key in data row that carries the color (e.g. 'color') */ + colorKey?: string + + /** Axis / chart tuning */ + yDomain?: [number | "auto" | "dataMin", number | "auto" | "dataMax"] + xAxisProps?: Partial> + yAxisProps?: Partial> + cartesianGridProps?: Partial> + chartProps?: Partial> + containerProps?: Partial> + + /** Bar sizing & spacing */ + barSize?: number // if omitted, width is auto-calculated from gaps + barGap?: number | string // e.g. 16 or '30%' + barCategoryGap?: number | string // e.g. 24 or '30%' + + /** Tooltip label for Y value. Pass falsy to hide Tooltip. */ + tooltipLabel?: string + tooltipFormatter?: (value: number, row: ChartDatum) => string + + /** Per-bar overrides */ + getCellProps?: (row: ChartDatum, index: number) => Partial> + + /** Direct pass-through to */ + barProps?: Partial> + + className?: string +} + +const BarChart = ({ + data, + xKey, + yKey, + colorKey, + yDomain = ["auto", "auto"], + xAxisProps, + yAxisProps, + cartesianGridProps, + chartProps, + containerProps, + // Use percentage-based gaps by default for consistent spacing across datasets + barSize, + barGap = "10%", + barCategoryGap = "30%", + tooltipLabel = "Value", + tooltipFormatter, + getCellProps, + barProps, + className, +}: BarChartProps) => { + const chartBarSize = !barSize ? undefined : barSize + const yAxisWidth = typeof yAxisProps?.width === "number" ? yAxisProps.width : 58 + const { + interval: xAxisInterval, + height: xAxisHeight, + tickWidth: xAxisTickWidthProp, + ...restXAxisProps + } = xAxisProps ?? {} + + const labelBasedTickWidth = useMemo(() => { + const longestLabelLength = data.reduce((max, row) => { + const rawLabel = row?.[xKey] + + if (typeof rawLabel === "string" || typeof rawLabel === "number") { + return Math.max(max, String(rawLabel).length) + } + + return max + }, 0) + + // Invert the relationship: longer labels get smaller width, shorter labels get more width + const maxPossibleWidth = 100 + const minPossibleWidth = 60 + const baseWidth = Math.max(1, longestLabelLength) // Ensure we don't divide by zero + const invertedWidth = (1 / baseWidth) * 1000 // Scale factor to get reasonable numbers + + return Math.min(maxPossibleWidth, Math.max(minPossibleWidth, invertedWidth)) + }, [data, xKey]) + + const xAxisTickWidth = xAxisTickWidthProp ?? labelBasedTickWidth + + return ( + + + ( + +
+ {payload?.value} +
+
+ )} + height={xAxisHeight ?? 24} + {...restXAxisProps} + /> + + + + {tooltipLabel ? ( + ) => { + if (!active || !payload?.length) return null + + const rows = payload.filter((p) => p?.value != null) + if (!rows.length) return null + + return ( +
+ {/*
+ {label} +
*/} + {rows.map((entry, idx) => { + const rawRow = entry?.payload as ChartDatum + const barColor = + (colorKey && typeof rawRow?.[colorKey] === "string" + ? (rawRow[colorKey] as string) + : undefined) || + entry?.color || + "#3B82F6" + const entryLabel = (() => { + const rawLabel = rawRow?.[xKey] + if ( + typeof rawLabel === "string" || + typeof rawLabel === "number" + ) + return String(rawLabel) + + return entry?.name || tooltipLabel + })() + const formattedValue = + typeof entry?.value === "number" + ? (tooltipFormatter?.(entry.value, rawRow) ?? + String(entry.value)) + : String(entry?.value ?? "") + return ( +
+
+ + + {entryLabel} + +
+ + {formattedValue} + +
+ ) + })} +
+ ) + }} + /> + ) : null} + + + {data.map((row, i) => { + const fill = + colorKey && typeof row[colorKey] === "string" + ? (row[colorKey] as string) + : undefined + return ( + + ) + })} + +
+
+ ) +} + +export default memo(BarChart) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/EvaluatorMetricsChartSkeleton.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/EvaluatorMetricsChartSkeleton.tsx new file mode 100644 index 0000000000..833b8e674e --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/EvaluatorMetricsChartSkeleton.tsx @@ -0,0 +1,20 @@ +import {memo} from "react" + +import {Skeleton} from "antd" +import clsx from "clsx" + +const EvaluatorMetricsChartSkeleton = ({className}: {className?: string}) => { + return ( +
+
+ + +
+
+ +
+
+ ) +} + +export default memo(EvaluatorMetricsChartSkeleton) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/HistogramChart.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/HistogramChart.tsx new file mode 100644 index 0000000000..e8046c92da --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/assets/HistogramChart.tsx @@ -0,0 +1,149 @@ +import {memo} from "react" + +import { + BarChart as RechartsBarChart, + Bar, + XAxis, + YAxis, + ResponsiveContainer, + Tooltip, + CartesianGrid, + Cell, +} from "recharts" + +type ChartDatum = Record + +interface HistogramChartProps { + data: readonly ChartDatum[] + xKey: string + yKey: string + /** optional key in data row that carries the color (e.g. 'color') */ + colorKey?: string + + /** Axis / chart tuning */ + yDomain?: [number | "auto" | "dataMin", number | "auto" | "dataMax"] + xAxisProps?: Partial> + yAxisProps?: Partial> + cartesianGridProps?: Partial> + chartProps?: Partial> + containerProps?: Partial> + + /** Bar sizing & spacing */ + barSize?: number // if omitted, width is auto-calculated from gaps + barGap?: number | string // e.g. 16 or '30%' + barCategoryGap?: number | string // e.g. 24 or '30%' + + /** Tooltip label for Y value. Pass falsy to hide Tooltip. */ + tooltipLabel?: string + + /** Per-bar overrides */ + getCellProps?: (row: ChartDatum, index: number) => Partial> + + /** Direct pass-through to */ + barProps?: Partial> + + className?: string +} + +const HistogramChart = ({ + data, + xKey, + yKey, + colorKey, + yDomain = ["auto", "auto"], + xAxisProps, + yAxisProps, + cartesianGridProps, + chartProps, + containerProps, + // Use percentage-based gaps by default for consistent spacing across datasets + barSize, + barGap = "10%", + barCategoryGap = "30%", + tooltipLabel = "Value", + getCellProps, + barProps, + className, +}: HistogramChartProps) => { + const chartBarSize = !barSize ? undefined : barSize + const yAxisWidth = typeof yAxisProps?.width === "number" ? yAxisProps.width : 48 + + return ( + + + + + + + {tooltipLabel ? ( + [v as number, tooltipLabel]} + cursor={false} + contentStyle={{ + backgroundColor: "white", + border: "1px solid #d9d9d9", + borderRadius: "4px", + padding: "4px 8px", + }} + /> + ) : null} + + + {data.map((row, i) => { + const fill = + colorKey && typeof row[colorKey] === "string" + ? (row[colorKey] as string) + : undefined + return ( + + ) + })} + + + + ) +} + +export default memo(HistogramChart) diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/index.tsx new file mode 100644 index 0000000000..ca5cc0ef57 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/components/EvaluatorMetricsChart/index.tsx @@ -0,0 +1,299 @@ +import {useCallback, useMemo, useState} from "react" + +import {Card, Radio, Typography} from "antd" +import clsx from "clsx" + +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" + +import {EVAL_BG_COLOR} from "../../assets/utils" + +import BarChart from "./assets/BarChart" +import HistogramChart from "./assets/HistogramChart" + +/* ---------------- helpers ---------------- */ + +const format3Sig = (n: number) => { + if (!Number.isFinite(n)) return String(n) + const abs = Math.abs(n) + if (abs !== 0 && (abs < 0.001 || abs >= 1000)) return n.toExponential(2) + const s = n.toPrecision(3) + return String(Number(s)) +} + +interface BooleanMetric { + rank: {value: boolean; count: number}[] + count: number + unique: boolean[] + frequency: {value: boolean; count: number}[] +} + +/** Boolean metric → two-bars histogram */ +export function toBooleanHistogramRows( + metric: BooleanMetric, + opts?: {trueLabel?: string; falseLabel?: string; trueColor?: string; falseColor?: string}, +) { + const source = metric.frequency?.length ? metric.frequency : metric.rank + const map = new Map(source.map((f) => [f.value, f.count])) + const t = map.get(true) ?? 0 + const f = map.get(false) ?? 0 + return [ + {x: opts?.trueLabel ?? "true", y: t, color: opts?.trueColor ?? "#22c55e"}, + {x: opts?.falseLabel ?? "false", y: f, color: opts?.falseColor ?? "#ef4444"}, + ] as const +} + +interface EvaluatorMetric { + count: number + sum: number + mean: number + min: number + max: number + range: number + distribution: {value: number; count: number}[] + percentiles: Record + iqrs: Record + binSize: number +} + +/** + * Numeric metric → XY rows from distribution, ignoring binSize. + * X is just the formatted starting value; Y is count. + * This gives a categorical X axis that still preserves the shape. + */ +export function toXYRowsFromDistributionNoBin( + metric: EvaluatorMetric, + opts?: {color?: string; digits?: number}, +) { + const rows = [...(metric.distribution ?? [])] + .sort((a, b) => a.value - b.value) + .map((d) => ({ + x: format3Sig(opts?.digits != null ? Number(d.value.toFixed(opts.digits)) : d.value), + y: d.count, + color: opts?.color ?? "rgba(145, 202, 255, 0.7)", + })) + + return rows +} + +/** Fallback: if no distribution is present, plot a single bar at the mean (x label = value) */ +export function toSingleMeanRow(metric: EvaluatorMetric, opts?: {color?: string; digits?: number}) { + const y = typeof metric.mean === "number" ? metric.mean : 0 + const x = format3Sig(opts?.digits != null ? Number(y.toFixed(opts.digits)) : y) + return [{x, y, color: opts?.color ?? "rgba(145, 202, 255, 0.7)"}] as const +} + +const items = ["average", "histogram", "total"] + +/* ---------------- page component ---------------- */ + +const EvaluatorMetricsChart = ({ + className, + name, + metricKey, + metric, + evaluator, + isCompare, + averageRows, + summaryRows, +}: { + className?: string + name: string + metricKey?: string + metric: Record + evaluator?: EvaluatorDto + isCompare?: boolean + averageRows?: readonly {x: string; y: number; color?: string}[] + summaryRows?: readonly {x: string; y: number; color?: string}[] +}) => { + const [selectedItem, setSelectedItem] = useState(items[0]) + const isBooleanMetric = !!metric?.unique?.length + const hasDistribution = Array.isArray(metric?.distribution) && metric.distribution.length > 0 + const isNumeric = typeof metric?.mean === "number" || hasDistribution + + // Big summary number + const chartSummeryValue = useMemo(() => { + if (isBooleanMetric) { + const trueEntry = metric?.frequency?.find((f: any) => f?.value === true) + const total = metric?.count ?? 0 + const pct = total ? ((trueEntry?.count ?? 0) / total) * 100 : 0 + return `${pct.toFixed(2)}%` + } + if (typeof metric?.mean === "number") return format3Sig(metric.mean) + return "" + }, [metric, isBooleanMetric]) + + // Summary for compare mode: one value per evaluation with +/- delta vs base + const compareSummaries = useMemo(() => { + // Use only evaluations that actually have this evaluator's metric (averageRows already filtered) + if (!isCompare || !averageRows?.length) + return [] as {value: string; delta?: string; color: string}[] + + const base = averageRows?.[0]?.y ?? 0 + const isPct = isBooleanMetric + return averageRows.map((r, i) => { + const color = (r as any)?.color || (EVAL_BG_COLOR as any)[i + 1] || "#3B82F6" + const valNum = Number(r.y || 0) + const value = isPct ? `${valNum.toFixed(2)}%` : format3Sig(valNum) + if (i === 0) return {value, delta: "-", color} + // percent difference vs base (avoid divide by zero) + const deltaPct = base ? ((valNum - base) / Math.abs(base)) * 100 : 0 + const sign = deltaPct > 0 ? "+" : "" + const delta = `${sign}${deltaPct.toFixed(0)}%` + return {value, delta, color} + }) + }, [isCompare, averageRows, isBooleanMetric]) + + // Shape data: + // - Boolean: two bars true/false + // - Numeric: distribution → (x = formatted start value, y = count) + // - Fallback numeric: single bar at mean (x = value, y = mean) + const chartData = useMemo(() => { + if (isBooleanMetric) { + return toBooleanHistogramRows(metric as BooleanMetric, { + trueLabel: "true", + falseLabel: "false", + trueColor: "rgba(145, 202, 255, 0.7)", + falseColor: "rgba(145, 202, 255, 0.7)", + }) + } + if (hasDistribution) { + return toXYRowsFromDistributionNoBin(metric as EvaluatorMetric, { + color: "rgba(145, 202, 255, 0.7)", + digits: 3, + }) + } + if (isNumeric) { + return toSingleMeanRow(metric as EvaluatorMetric, { + color: "rgba(145, 202, 255, 0.7)", + digits: 3, + }) + } + return [] + }, [metric, isBooleanMetric, hasDistribution, isNumeric]) + console.log("chartData", chartData) + + const showHistogram = !isCompare || selectedItem === "histogram" + const showAverageBars = isCompare && selectedItem === "average" + + const formatYAxisTick = useCallback( + (value: number) => { + if (typeof value !== "number" || Number.isNaN(value)) return "" + + const formatToThreeDecimals = (num: number) => { + if (num === 0) return "0" + const abs = Math.abs(num) + if (abs < 0.001) return num.toExponential(2) + return Number(num.toFixed(3)).toString() + } + + if (isBooleanMetric) { + return `${formatToThreeDecimals(value)}%` + } + + return formatToThreeDecimals(value) + }, + [isBooleanMetric], + ) + + return ( + +
+ + {evaluator?.name} + + + {name} + +
+
+ } + className={clsx("rounded !p-0 overflow-hidden", className)} + classNames={{title: "!py-0 !px-4", header: "!p-0", body: "!p-0"}} + > +
+ {isCompare ? ( +
+ {compareSummaries.map((s, idx) => ( +
+ + {s.value} + + + {s.delta} + +
+ ))} +
+ ) : ( +
+ + {chartSummeryValue} + +
+ )} + +
+
+ + {showHistogram ? "Frequency" : "Avg score"} + +
+ {showHistogram ? ( + + ) : ( + formatYAxisTick(value)} + yAxisProps={{tickFormatter: formatYAxisTick}} + barCategoryGap={(averageRows?.length ?? 0) < 4 ? "30%" : "10%"} + barProps={{radius: [8, 8, 0, 0]}} + /> + )} +
+
+ + {name} + +
+
+ + ) +} + +export default EvaluatorMetricsChart diff --git a/web/ee/src/components/EvalRunDetails/AutoEvalRun/index.tsx b/web/ee/src/components/EvalRunDetails/AutoEvalRun/index.tsx new file mode 100644 index 0000000000..b645cd2ac5 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/AutoEvalRun/index.tsx @@ -0,0 +1,51 @@ +import clsx from "clsx" +import deepEqual from "fast-deep-equal" +import {useAtomValue} from "jotai" +import {selectAtom} from "jotai/utils" +import dynamic from "next/dynamic" + +import {runViewTypeAtom} from "../state/urlState" + +import AutoEvalRunSkeleton from "./assets/AutoEvalRunSkeleton" +import {AutoEvalRunDetailsProps} from "./assets/types" +import EvalRunHeader from "./components/EvalRunHeader" + +const EvalRunOverviewViewer = dynamic(() => import("./components/EvalRunOverviewViewer"), { + ssr: false, +}) +const EvalRunPromptConfigViewer = dynamic(() => import("./components/EvalRunPromptConfigViewer"), { + ssr: false, +}) +const EvalRunTestCaseViewer = dynamic(() => import("./components/EvalRunTestCaseViewer"), { + ssr: false, +}) + +const viewTypeAtom = selectAtom(runViewTypeAtom, (v) => v, deepEqual) +const AutoEvalRunDetails = ({name, description, id, isLoading}: AutoEvalRunDetailsProps) => { + const viewType = useAtomValue(viewTypeAtom) + + if (isLoading) { + return + } + + return ( +
+ + + {viewType === "overview" ? ( + + ) : viewType === "test-cases" ? ( + + ) : viewType === "prompt" ? ( + + ) : null} +
+ ) +} + +export default AutoEvalRunDetails diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/annotationUtils.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/annotationUtils.ts new file mode 100644 index 0000000000..cacaffc627 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/annotationUtils.ts @@ -0,0 +1,383 @@ +import axios from "@/oss/lib/api/assets/axiosConfig" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {uuidToSpanId, uuidToTraceId} from "@/oss/lib/hooks/useAnnotations/assets/helpers" +import {AnnotationDto} from "@/oss/lib/hooks/useAnnotations/types" +import { + evaluationEvaluatorsFamily, + evaluationRunStateFamily, + evalAtomStore, + scenarioStepFamily, + revalidateScenarioForRun, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {triggerMetricsFetch} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import {IAnnotationStep, IStepResponse} from "@/oss/lib/hooks/useEvaluationRunScenarioSteps/types" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {EvaluationStatus} from "@/oss/lib/Types" +import {getJWT} from "@/oss/services/api" +import {updateScenarioStatusRemote} from "@/oss/services/evaluations/workerUtils" +import {createScenarioMetrics} from "@/oss/services/runMetrics/api" +import {getProjectValues} from "@/oss/state/project" + +import {setOptimisticStepData} from "./optimisticUtils" +import {collectStepsAndMetrics} from "./stepsMetricsUtils" +/** + * Retrieve the scenario object (if present) for the given id. + * Updated for multi-run support with runId parameter. + */ +export const getScenario = (scenarioId: string, runId: string) => { + // Use run-scoped atoms for multi-run support + return ( + evalAtomStore() + .get(evaluationRunStateFamily(runId)) + ?.scenarios?.find((s: any) => s.id === scenarioId) || null + ) +} + +/** + * Retrieve the evaluators associated with the current evaluation run. + * Updated for multi-run support with runId parameter. + */ +export const getEvaluators = (runId: string) => { + return evalAtomStore().get(evaluationEvaluatorsFamily(runId)) +} + +/** + * Lazily load step data for a scenario via the jotai family. + * Updated for multi-run support with runId parameter. + */ +export const getStepData = async (scenarioId: string, runId?: string) => { + if (runId) { + // Use run-scoped atoms for multi-run support + return await evalAtomStore().get(scenarioStepFamily({scenarioId, runId})) + } +} + +/** + * Utility that checks the `requiredMetrics` object returned by payload generation. + * If any metric is missing it will call the provided formatter and returns `false`. Otherwise returns `true`. + */ +export const validateRequiredMetrics = ( + requiredMetrics: Record, + formatErrorMessages: (requiredMetrics: Record) => void, +): boolean => { + const hasMissing = Object.keys(requiredMetrics || {}).length > 0 + if (hasMissing) { + formatErrorMessages(requiredMetrics) + } + return !hasMissing +} + +// ---------------------------------- +// Backend synchronisation utilities +// ---------------------------------- + +interface PushStepsAndMetricsParams { + patchStepsFull: any[] + stepsToCreate?: any[] + metricEntries: {scenarioId: string; data: Record}[] + projectId: string + runId: string +} + +export const pushStepsAndMetrics = async ({ + patchStepsFull, + stepsToCreate = [], + metricEntries, + projectId, + runId, +}: PushStepsAndMetricsParams) => { + // Normalize payloads to results schema + const normalizePatch = (items: any[]) => + items.map((it) => { + const out: Record = { + id: it.id, + status: it.status, + trace_id: it.trace_id ?? it.traceId, + span_id: it.span_id ?? it.spanId, + references: it.references, + } + const stepKey = it.step_key ?? it.stepKey + if (stepKey) out.step_key = stepKey + return out + }) + + const normalizeCreate = (items: any[]) => + items.map((it) => { + const out: Record = { + status: it.status, + step_key: it.step_key ?? it.stepKey ?? it.key, + trace_id: it.trace_id ?? it.traceId, + span_id: it.span_id ?? it.spanId, + scenario_id: it.scenario_id ?? it.scenarioId, + run_id: it.run_id ?? it.runId, + references: it.references, + } + const testcaseId = it.testcase_id ?? it.testcaseId + if (testcaseId) out.testcase_id = testcaseId + return out + }) + + if (patchStepsFull.length) { + await axios.patch(`/preview/evaluations/results/?project_id=${projectId}`, { + results: normalizePatch(patchStepsFull), + }) + } + if (stepsToCreate.length) { + await axios.post(`/preview/evaluations/results/?project_id=${projectId}`, { + results: normalizeCreate(stepsToCreate), + }) + } + if (metricEntries.length) { + const jwt = await getJWT() + if (jwt) { + await createScenarioMetrics(getAgentaApiUrl(), jwt, runId, metricEntries, projectId) + } + } +} + +/** + * Triggers revalidation for a single scenario and cleans up optimistic overrides once fresh data arrives. + */ +/** + * Partitions Promise.allSettled results into successful responses and builds evaluator status map + */ +export const partitionAnnotationResults = ( + annotationResults: PromiseSettledResult[], + payload: any[], +): {annotationResponses: any[]; evaluatorStatuses: Record} => { + const fulfilled = annotationResults.filter( + (r): r is PromiseFulfilledResult => r.status === "fulfilled", + ) + const annotationResponses = fulfilled.map((f) => f.value) + const evaluatorStatuses: Record = {} + annotationResults.forEach((result, idx) => { + const slug = payload[idx]?.annotation?.references?.evaluator?.slug + if (!slug) return + evaluatorStatuses[slug] = + result.status === "fulfilled" ? EvaluationStatus.SUCCESS : EvaluationStatus.FAILURE + }) + return {annotationResponses, evaluatorStatuses} +} + +/** + * Returns true if metrics are missing and the caller should abort. + */ +export const abortIfMissingMetrics = ( + requiredMetrics: Record | undefined, + formatErrorMessages: (metrics: any) => void, +): boolean => { + if (requiredMetrics && Object.keys(requiredMetrics).length > 0) { + formatErrorMessages(requiredMetrics) + return true + } + return false +} + +/** + * Handles backend sync and scenario status updates after annotation succeeds + */ +export const startOptimisticAnnotation = async ( + scenarioId: string, + step: IAnnotationStep, + apiUrl: string, + jwt: string, + projectId: string, + runId?: string, +) => { + setOptimisticStepData( + scenarioId, + [ + { + ...structuredClone(step), + status: "annotating", + }, + ], + runId, + ) + updateScenarioStatusRemote(apiUrl, jwt, scenarioId, EvaluationStatus.RUNNING, projectId, runId) +} + +/** + * Build common annotation context (evaluators, trace ids, testset ids, etc.) + */ +export const buildAnnotationContext = async ({ + scenarioId, + stepKey, + runId, +}: { + scenarioId: string + stepKey: string + runId: string +}) => { + const evaluators = getEvaluators(runId) + const testsets = evalAtomStore().get(evaluationRunStateFamily(runId))?.enrichedRun?.testsets + const stepData = await getStepData(scenarioId, runId) + const jwt = await getJWT() + const {projectId} = getProjectValues() + + const invocationStep = stepData?.invocationSteps?.find((s: any) => s.stepKey === stepKey) + if (!invocationStep) return null + + const traceTree = (invocationStep as any)?.trace + if (!traceTree) return null + + const node = traceTree.nodes?.[0] + if (!node) return null + + const traceSpanIds = { + spanId: uuidToSpanId(node.node.id) as string, + traceId: uuidToTraceId(node.root.id) as string, + } + + const testcaseId = invocationStep.testcaseId + const testsetId = testsets?.find((s: any) => s.data?.testcase_ids?.includes(testcaseId))?.id + + return { + evaluators, + jwt, + projectId, + stepData, + traceSpanIds, + testsetId, + testcaseId, + invocationStep, + traceTree, + apiUrl: getAgentaApiUrl(), + } +} + +export const processAnnotationError = async ( + scenarioId: string, + err: unknown, + annotationSteps: IAnnotationStep[], + apiUrl: string, + jwt: string, + projectId: string, + runId: string, + setErrorMessages: (msgs: string[]) => void, +) => { + setErrorMessages([(err as Error).message]) + setOptimisticStepData( + scenarioId, + annotationSteps.map((st) => ({ + ...structuredClone(st), + status: EvaluationStatus.ERROR, + })), + ) + // await updateScenarioStatus(scenario, finalStatus) + updateScenarioStatusRemote(apiUrl, jwt, scenarioId, EvaluationStatus.ERROR, projectId, runId) +} + +export const finalizeAnnotationSuccess = async ({ + annotationSteps, + mode, + annotationResponses, + evaluatorStatuses, + stepData, + stepKey, + scenarioId, + runId, + projectId, + scenario, + jwt, + apiUrl, + evaluators, + setErrorMessages, +}: { + annotationSteps: IAnnotationStep[] + mode: "create" | "update" + annotationResponses: any[] + evaluatorStatuses: Record + stepData: any + stepKey: string + scenarioId: string + runId: string + projectId: string + jwt: string + apiUrl: string + scenario: any + evaluators: EvaluatorDto[] + setErrorMessages: (val: any[]) => void +}) => { + if (!annotationResponses.length) return + + const {stepsToCreate, patchStepsFull, metricEntries} = collectStepsAndMetrics({ + mode, + annotationResponses, + stepData, + stepKey, + evaluatorStatuses, + scenarioId, + runId, + evaluators, + }) + + await pushStepsAndMetrics({ + patchStepsFull, + stepsToCreate, + metricEntries, + projectId, + runId, + }) + + await updateScenarioStatusRemote( + apiUrl, + jwt, + scenarioId, + EvaluationStatus.SUCCESS, + projectId, + runId, + ) + await triggerScenarioRevalidation( + runId, + scenarioId, + annotationSteps.map((st) => ({ + ...structuredClone(st), + status: "revalidating", + })), + ) + + // Trigger metrics refresh when scenario completes (success or failure) + if (runId) { + triggerMetricsFetch(runId) + } + + // Note: Metrics will be automatically refreshed by store-level subscription + console.log(`[finalizeAnnotationSuccess] Annotation finalized for runId: ${runId}`) + + setErrorMessages([]) +} + +export const triggerScenarioRevalidation = async ( + runId: string, + scenarioId: string, + updatedSteps?: IStepResponse[], +) => { + try { + await revalidateScenarioForRun(runId, scenarioId, evalAtomStore(), updatedSteps) + } catch (err) { + console.error("Failed to revalidate scenario", err) + } +} + +/** Return all annotationSteps that match any item in the payload */ +export const findAnnotationStepsFromPayload = ( + annotationSteps: IAnnotationStep[] = [], + payload: {annotation: AnnotationDto}[], +) => { + if (!annotationSteps.length || !payload.length) return [] + + return annotationSteps.filter((step) => + payload.some(({annotation}) => { + const evaluatorSlug = annotation.references?.evaluator?.slug + const linkKeys = annotation.links ? Object.keys(annotation.links) : [] + if (!evaluatorSlug || !linkKeys.length) return false + + // backend guarantees first (and usually only) link key is the invocation key + const invocationKey = linkKeys[0] // e.g. "default-2cd951533447" + const expectedStepKey = `${invocationKey}.${evaluatorSlug}` + + return step.stepKey === expectedStepKey + }), + ) +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/helpers.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/helpers.ts new file mode 100644 index 0000000000..03fc0cfeda --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/helpers.ts @@ -0,0 +1,252 @@ +import { + generateAnnotationPayloadData, + generateNewAnnotationPayloadData, +} from "@agenta/oss/src/components/pages/observability/drawer/AnnotateDrawer/assets/transforms" + +import {AnnotationDto} from "@/oss/lib/hooks/useAnnotations/types" +import {createAnnotation, updateAnnotation} from "@/oss/services/annotations/api" + +import { + getScenario, + buildAnnotationContext, + partitionAnnotationResults, + abortIfMissingMetrics, + finalizeAnnotationSuccess, + startOptimisticAnnotation, + processAnnotationError, + findAnnotationStepsFromPayload, +} from "./annotationUtils" + +export const handleAnnotate = async ({ + runId, + scenarioId, + updatedMetrics, + formatErrorMessages, + setErrorMessages, + projectId, + stepKey, +}: { + runId: string + scenarioId: string + updatedMetrics: Record + formatErrorMessages: (requiredMetrics: Record) => void + setErrorMessages: (errorMessages: string[]) => void + projectId: string + stepKey: string +}) => { + console.log("handleAnnotate") + const ctx = await buildAnnotationContext({scenarioId, stepKey, runId}) + if (!ctx) return + const {evaluators, stepData, traceSpanIds, testsetId, testcaseId, traceTree, jwt, apiUrl} = ctx + + if (!traceTree) { + if (process.env.NODE_ENV !== "production") { + console.debug("No trace found on invocation step", scenarioId) + } + return + } + + const node = traceTree.nodes?.[0] + + if (!node) { + if (process.env.NODE_ENV !== "production") { + console.debug("No trace node found for scenario", scenarioId) + } + return + } + + const params = { + updatedMetrics, + selectedEvaluators: evaluators.map((e) => e.slug), + evaluators, + traceSpanIds, + testsetId, + testcaseId, + } + + const {payload, requiredMetrics} = generateNewAnnotationPayloadData({ + ...params, + invocationStepKey: stepKey, + testsetId, + testcaseId, + }) + + if (abortIfMissingMetrics(requiredMetrics, formatErrorMessages)) return + if (!payload.length) return + + const annotationSteps = findAnnotationStepsFromPayload(stepData?.annotationSteps, payload) + + if (!annotationSteps.length) { + console.error("No annotation steps matched payload", {scenarioId, payload}) + throw new Error("Annotation step(s) not found") + } + + try { + // optimistic update for each matched step + annotationSteps.forEach((st) => { + startOptimisticAnnotation(scenarioId, st, apiUrl, jwt, projectId, runId) + }) + + const annotationResults = await Promise.allSettled( + payload.map((evaluatorPayload) => createAnnotation(evaluatorPayload)), + ) + const {annotationResponses, evaluatorStatuses} = partitionAnnotationResults( + annotationResults, + payload, + ) + + await finalizeAnnotationSuccess({ + mode: "create", + annotationResponses, + evaluatorStatuses, + stepData, + stepKey, + scenarioId, + runId, + projectId, + scenario: getScenario(scenarioId, runId), + setErrorMessages, + annotationSteps, + jwt, + apiUrl, + evaluators, + }) + } catch (err) { + await processAnnotationError( + scenarioId, + err, + annotationSteps, + apiUrl, + jwt || "", + projectId, + runId, + setErrorMessages, + ) + } +} + +export const handleUpdateAnnotate = async ({ + runId, + scenarioId, + updatedMetrics, + formatErrorMessages, + setErrorMessages, + projectId, + stepKey, +}: { + runId: string + scenarioId: string + updatedMetrics: Record + formatErrorMessages: (requiredMetrics: Record) => void + setErrorMessages: (errorMessages: string[]) => void + projectId: string + stepKey: string +}) => { + console.log("handleUpdateAnnotate") + const ctx = await buildAnnotationContext({scenarioId, stepKey, runId}) + if (!ctx) return + const {evaluators, stepData, jwt, apiUrl} = ctx + + const allAnnotations = stepData?.annotationSteps + ?.map((s) => s.annotation) + .filter(Boolean) as AnnotationDto[] + + // Only use the new canonical payload generator + const params = { + updatedMetrics, + selectedEvaluators: evaluators.map((e) => e.slug), + evaluators, + annotations: allAnnotations, + } + const {payload, requiredMetrics} = generateAnnotationPayloadData({ + ...params, + invocationStepKey: stepKey, + }) + + if (abortIfMissingMetrics(requiredMetrics, formatErrorMessages)) return + if (!payload.length) return + + const scenario = getScenario(scenarioId, runId) + const annotationSteps = findAnnotationStepsFromPayload( + stepData?.annotationSteps, + payload + .map((p) => { + const annotation = allAnnotations.find( + (a) => a.span_id === p.span_id && a.trace_id === p.trace_id, + ) + return { + annotation, + } + }) + .filter(Boolean) as {annotation: AnnotationDto}[], + ) + + if (!annotationSteps.length) { + console.error("No annotation steps matched payload", {scenarioId, payload}) + throw new Error("Annotation step(s) not found") + } + + try { + // 1. enabling annotating state + annotationSteps.forEach((st) => { + startOptimisticAnnotation(scenarioId, st, apiUrl, jwt, projectId, runId) + }) + + // 2. updating annotations + const annotationResults = await Promise.allSettled( + payload.map((annotation) => { + const {trace_id, span_id, ...rest} = annotation + return updateAnnotation({ + payload: rest, + traceId: trace_id || "", + spanId: span_id || "", + }) + }), + ) + const {annotationResponses, evaluatorStatuses} = partitionAnnotationResults( + annotationResults, + payload, + ) + + // 3. Optimistic update: mark as revalidating + await finalizeAnnotationSuccess({ + mode: "update", + annotationResponses, + evaluatorStatuses, + stepData, + stepKey, + scenarioId, + runId, + projectId, + scenario, + setErrorMessages, + annotationSteps, + jwt, + apiUrl, + evaluators, + }) + setErrorMessages([]) + } catch (err) { + await processAnnotationError( + scenarioId, + err, + annotationSteps, + apiUrl, + jwt || "", + projectId, + runId, + setErrorMessages, + ) + } +} + +export const statusColorMap: Record = { + pending: "text-[#758391]", + incomplete: "text-[#758391]", + running: "text-[#758391]", + done: "text-green-600", + success: "text-green-600", + failed: "text-red-500", + error: "text-red-500", + cancelled: "text-yellow-500", +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/optimisticUtils.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/optimisticUtils.ts new file mode 100644 index 0000000000..361a3db1bc --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/optimisticUtils.ts @@ -0,0 +1,41 @@ +// Import run-scoped version for multi-run support +import { + scenarioStepLocalFamily, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {IStepResponse} from "@/oss/lib/hooks/useEvaluationRunScenarioSteps/types" + +/** + * Merge partial step data into the optimistic cache so components can render + * interim worker results immediately while awaiting server revalidation. + */ +export const setOptimisticStepData = async ( + scenarioId: string, + updatedSteps: IStepResponse[], + runId?: string, +) => { + // Write into per-scenario atom to avoid cloning the entire cache map + // Skip if no runId provided since run-scoped atoms require it + if (!runId) { + console.warn("[setOptimisticStepData] No runId provided, skipping optimistic update") + return + } + + evalAtomStore().set(scenarioStepLocalFamily({runId, scenarioId}), (draft: any) => { + if (!draft) return + + updatedSteps.forEach((updatedStep) => { + const targetStep = + draft.invocationSteps?.find((s: any) => s.stepKey === updatedStep.stepKey) || + draft.inputSteps?.find((s: any) => s.stepKey === updatedStep.stepKey) || + draft.annotationSteps?.find((s: any) => s.stepKey === updatedStep.stepKey) + + if (!targetStep) return + + Object.entries(updatedStep).forEach(([k, v]) => { + // @ts-ignore – dynamic merge + targetStep[k] = v as any + }) + }) + }) +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/runnableSelectors.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/runnableSelectors.ts new file mode 100644 index 0000000000..b2a6dc5c6c --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/runnableSelectors.ts @@ -0,0 +1,64 @@ +import deepEqual from "fast-deep-equal" +import {atom} from "jotai" +import {selectAtom, loadable, atomFamily} from "jotai/utils" + +import { + scenariosFamily, + scenarioStatusAtomFamily, + scenarioStepFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +/** + * IDs of scenarios that are currently runnable (i.e. have invocation parameters + * and are not in a final/running UI state). + */ +// 1. Combine the needed state into a single base atom +// helper shallow array equality +const shallowArrayEqual = (a: T[], b: T[]) => + a.length === b.length && a.every((v, i) => v === b[i]) + +// A scenario is considered runnable when: +// 1. Its overall status is not in a terminal / running state, AND +// 2. Its step data has been fetched (Loadable state === "hasData"), AND +// 3. At least one invocationStep still contains `invocationParameters` (i.e. not yet executed) +// Per-scenario memoised check – avoids re-running heavy logic for all 1000 scenarios +export const scenarioIsRunnableFamily = atomFamily( + (params: {scenarioId: string; runId: string}) => + atom((get) => { + const {status} = get(scenarioStatusAtomFamily(params)) + if (["running", "done", "success", "revalidating"].includes(status)) return false + const loadableStep = get(loadable(scenarioStepFamily(params))) + if (loadableStep.state !== "hasData") return false + const invSteps: any[] = loadableStep.data?.invocationSteps ?? [] + return invSteps.some((st) => !!st.invocationParameters) + }), + deepEqual, +) + +export const runnableScenarioIdsFamily = atomFamily((runId: string) => { + return atom((get) => { + const scenarios = get(scenariosFamily(runId)) + return scenarios + .filter((scenario: any) => + get(scenarioIsRunnableFamily({scenarioId: scenario.id, runId})), + ) + .map((s: any) => s.id) + }) +}, deepEqual) + +/* memoised view that won’t re-emit if the array is the same */ +export const runnableScenarioIdsMemoFamily = atomFamily((runId: string) => { + return selectAtom(runnableScenarioIdsFamily(runId), (ids) => ids, shallowArrayEqual) +}, deepEqual) + +// Boolean flag: true if at least one scenario is runnable. Uses early exit to avoid building arrays +export const hasRunnableScenarioFamily = atomFamily((runId: string) => { + return atom((get) => { + const scenarios = get(scenariosFamily(runId)) + for (const scenario of scenarios) { + if (get(scenarioIsRunnableFamily({scenarioId: (scenario as any).id, runId}))) + return true + } + return false + }) +}, deepEqual) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/stepsMetricsUtils.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/stepsMetricsUtils.ts new file mode 100644 index 0000000000..1859d83d92 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/stepsMetricsUtils.ts @@ -0,0 +1,180 @@ +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {EvaluationStatus} from "@/oss/lib/Types" +import {computeRunMetrics} from "@/oss/services/runMetrics/api" + +export interface StepsAndMetricsResult { + stepsToCreate: any[] + patchStepsFull: any[] + metricEntries: {scenarioId: string; data: Record}[] +} + +interface CollectParams { + mode: "create" | "update" + annotationResponses: any[] + stepData: any + stepKey: string + evaluatorStatuses?: Record + scenarioId: string + runId: string + evaluators: EvaluatorDto[] +} + +/** + * Consolidated logic used by both handleAnnotate (create) and handleUpdateAnnotate (update) + * to build arrays for step PATCH/POST and metric creation. + */ +export const collectStepsAndMetrics = ({ + mode, + annotationResponses, + stepData, + stepKey, + evaluatorStatuses = {}, + scenarioId, + runId, + evaluators, +}: CollectParams): StepsAndMetricsResult => { + const patchStepsFull: any[] = [] + const stepsToCreate: any[] = [] + const nestedMetrics: Record> = {} + + // Filter annotation steps belonging to the selected invocation step + const stepAnnotationSteps = (stepData.annotationSteps || []).filter((ann: any) => + (ann.stepKey ?? "").startsWith(`${stepKey}.`), + ) + + if (mode === "create") { + // Track existing keys to avoid duplicates + const existingStepKeys = new Set(stepAnnotationSteps.map((s: any) => s.stepKey)) + + annotationResponses.forEach((resp: any) => { + const ann = resp?.data?.annotation + if (!ann) return + const slug = ann.references?.evaluator?.slug + const evaluatorKey = `${stepKey}.${slug}` + const status = evaluatorStatuses[slug] || EvaluationStatus.SUCCESS + + const evaluator = evaluators.find((e) => e.slug === slug) + if (!evaluator) return + + const metricSchema = evaluator?.data.service.format.properties.outputs.properties + // Add to creation list if not already existing + if (!existingStepKeys.has(evaluatorKey)) { + stepsToCreate.push({ + status, + step_key: evaluatorKey, + span_id: ann.span_id, + trace_id: ann.trace_id, + scenario_id: scenarioId, + run_id: runId, + }) + } + + // Collect metric outputs into nested structure keyed by invocation+evaluator + const outputs = ann.data?.outputs || {} + const fullKey = slug ? `${stepKey}.${slug}` : stepKey + const computed = computeRunMetrics([{data: outputs}]) + + if (!nestedMetrics[fullKey]) nestedMetrics[fullKey] = {} + Object.entries(computed).forEach(([k, v]) => { + const stat = structuredClone(v) + const schema = metricSchema[k] + if (schema?.type === "boolean") { + stat.value = stat.unique?.[0] + } else if (schema?.type === "array") { + stat.value = stat.unique + } else if (schema?.type === "string") { + stat.value = stat.unique + } else if ("anyOf" in schema) { + stat.value = stat.unique.length > 1 ? stat.unique : stat.unique[0] + } + // else if (schema?.type === "number") { + // stat.value = stat.mean + // } + if ("distribution" in stat) delete stat.distribution + if ("percentiles" in stat) delete stat.percentiles + if ("iqrs" in stat) delete stat.iqrs + if ("frequency" in stat) delete stat.frequency + if ("rank" in stat) delete stat.rank + if ("unique" in stat) delete stat.unique + if ("binSize" in stat) delete stat.binSize + + nestedMetrics[fullKey][k] = stat + }) + }) + + // Build patch list by aligning responses to existing steps + stepAnnotationSteps.forEach((ann: any) => { + const linkedResponse = annotationResponses.find((r) => { + const annKey = `${stepKey}.${r?.data?.annotation?.references?.evaluator?.slug}` + return annKey === ann.stepKey + }) + if (linkedResponse) { + const status = + evaluatorStatuses[ann.stepKey.split(".")[1]] || EvaluationStatus.SUCCESS + patchStepsFull.push({ + ...ann, + status, + trace_id: linkedResponse.data.annotation.trace_id, + span_id: linkedResponse.data.annotation.span_id, + }) + } else { + patchStepsFull.push(ann) + } + }) + } else { + // UPDATE flow: only patch existing steps, no creations + stepAnnotationSteps.forEach((ann: any) => { + const linkedResponse = annotationResponses.find( + (r) => + r?.data?.annotation?.span_id === ann.annotation?.span_id && + r?.data?.annotation?.trace_id === ann.annotation?.trace_id, + ) + if (!linkedResponse) return + + const slug = ann.stepKey.split(".")[1] + const evaluator = evaluators.find((e) => e.slug === slug) + if (!evaluator) return + + const metricSchema = evaluator?.data.service.format.properties.outputs.properties + + patchStepsFull.push({ + ...ann, + trace_id: linkedResponse?.data?.annotation?.trace_id, + span_id: linkedResponse?.data?.annotation?.span_id, + }) + + const outputs = linkedResponse?.data?.annotation?.data?.outputs || {} + const computed = computeRunMetrics([{data: outputs}]) + + const fullKey = `${stepKey}.${slug}` + if (!nestedMetrics[fullKey]) nestedMetrics[fullKey] = {} + Object.entries(computed).forEach(([k, v]) => { + const stat = structuredClone(v) + if (metricSchema?.[k]?.type === "boolean") { + stat.value = v.unique?.[0] + } else if (metricSchema?.[k]?.type === "array") { + stat.value = stat.unique + } else if (metricSchema?.[k]?.type === "string") { + stat.value = stat.unique + } else if ("anyOf" in metricSchema[k]) { + stat.value = stat.unique?.length > 1 ? stat.unique : stat.unique[0] + } + + if ("distribution" in stat) delete stat.distribution + if ("percentiles" in stat) delete stat.percentiles + if ("iqrs" in stat) delete stat.iqrs + if ("frequency" in stat) delete stat.frequency + if ("rank" in stat) delete stat.rank + if ("unique" in stat) delete stat.unique + nestedMetrics[fullKey][k] = stat + }) + }) + } + + const metricEntries: {scenarioId: string; data: Record}[] = [] + if (Object.keys(nestedMetrics).length > 0) { + metricEntries.push({scenarioId, data: nestedMetrics}) + } + + return {stepsToCreate, patchStepsFull, metricEntries} +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/types.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/types.ts new file mode 100644 index 0000000000..a48a4840bd --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/assets/types.ts @@ -0,0 +1,6 @@ +export interface EvalRunProps { + id: string + name: string + description?: string + runId?: string +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/AnnotateScenarioButton/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/AnnotateScenarioButton/index.tsx new file mode 100644 index 0000000000..2eeb2064c2 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/AnnotateScenarioButton/index.tsx @@ -0,0 +1,96 @@ +import {useState, useCallback, memo} from "react" + +import {Button} from "antd" +import {useAtomValue} from "jotai" + +import {AnnotationDto} from "@/oss/lib/hooks/useAnnotations/types" +import {evalAtomStore} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {scenarioUiFlagsFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/progress" +import {getProjectValues} from "@/oss/state/project" + +import {buildAnnotationContext} from "../../assets/annotationUtils" +import {handleAnnotate, handleUpdateAnnotate} from "../../assets/helpers" + +import {AnnotateScenarioButtonProps} from "./types" + +const AnnotateScenarioButton = ({ + runId, + scenarioId, + stepKey, + updatedMetrics, + formatErrorMessages, + setErrorMessages, + disabled = false, + label = "Annotate", + isAnnotated = false, + onAnnotate: propsOnAnnotate, + className, +}: AnnotateScenarioButtonProps) => { + const [annotating, setAnnotating] = useState(false) + const store = evalAtomStore() + const uiFlags = useAtomValue(scenarioUiFlagsFamily({scenarioId, runId}), {store}) + const isLoading = annotating || uiFlags.isAnnotating || uiFlags.isRevalidating + + const onAnnotate = useCallback(async () => { + try { + setAnnotating(true) + + const ctx = await buildAnnotationContext({scenarioId, stepKey, runId}) + if (!ctx) return + const {evaluators, stepData} = ctx + const annotations = stepData?.annotationSteps + ?.map((s) => s.annotation) + .filter(Boolean) as AnnotationDto[] + + const annEvalSlugs = annotations + .map((a) => a.references?.evaluator?.slug) + .filter(Boolean) as string[] + const selectedEval = evaluators + .map((e) => e.slug) + .filter((evaluator) => !annEvalSlugs.includes(evaluator)) + + if (selectedEval.length > 0) { + await handleAnnotate({ + runId, + scenarioId, + updatedMetrics, + formatErrorMessages, + setErrorMessages, + projectId: getProjectValues().projectId, + stepKey, + }) + } + + if (annotations.length > 0) { + await handleUpdateAnnotate({ + runId, + scenarioId, + updatedMetrics, + formatErrorMessages, + setErrorMessages, + projectId: getProjectValues().projectId, + stepKey, + }) + } + } catch (error) { + console.error("Failed to annotate scenario", error) + } finally { + propsOnAnnotate?.() + setAnnotating(false) + } + }, [runId, scenarioId, stepKey, updatedMetrics, formatErrorMessages, setErrorMessages]) + + return ( + + ) +} + +export default memo(AnnotateScenarioButton) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/AnnotateScenarioButton/types.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/AnnotateScenarioButton/types.ts new file mode 100644 index 0000000000..f42a66b324 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/AnnotateScenarioButton/types.ts @@ -0,0 +1,14 @@ +export interface AnnotateScenarioButtonProps { + runId: string + scenarioId: string + stepKey: string + updatedMetrics: Record + disabled?: boolean + label?: string + className?: string + isAnnotated?: boolean // check if annotations are already present + + formatErrorMessages: (requiredMetrics: Record) => void + setErrorMessages: (errorMessages: string[]) => void + onAnnotate?: () => void +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalResultsView/EvaluatorMetricsCard.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalResultsView/EvaluatorMetricsCard.tsx new file mode 100644 index 0000000000..6e8b81c93b --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalResultsView/EvaluatorMetricsCard.tsx @@ -0,0 +1,81 @@ +import {memo, useCallback, useMemo} from "react" + +import {Card, Typography} from "antd" +import deepEqual from "fast-deep-equal" +import {useAtomValue} from "jotai" +import {selectAtom} from "jotai/utils" + +import {MetricDetailsPopoverWrapper} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover" +import {useRunId} from "@/oss/contexts/RunIdContext" +import { + evalAtomStore, + evaluationRunStateFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +interface EvaluatorMetricsCardProps { + runId: string + evaluatorSlug: string +} + +/** + * Displays all metric definitions for a single evaluator with popovers. + * Uses jotai selectAtom so only this card re-renders when its evaluator object changes. + */ +const EvaluatorMetricsCard = ({runId, evaluatorSlug}: EvaluatorMetricsCardProps) => { + // Use proper runId fallback logic: prop takes priority over context + const contextRunId = useRunId() + const effectiveRunId = runId || contextRunId + const store = evalAtomStore() + + // Create a selector to extract the specific evaluator from the evaluation run state + const evaluatorSelector = useCallback( + (state: any) => { + const evaluators = state?.enrichedRun?.evaluators + if (!evaluators) return null + + // Handle both array and object formats + if (Array.isArray(evaluators)) { + return evaluators.find((ev: any) => ev.slug === evaluatorSlug) + } else { + return Object.values(evaluators).find((ev: any) => ev.slug === evaluatorSlug) + } + }, + [evaluatorSlug], + ) + + const evaluatorAtom = useMemo( + () => selectAtom(evaluationRunStateFamily(effectiveRunId), evaluatorSelector, deepEqual), + [effectiveRunId, evaluatorSelector], + ) + + const evaluator = useAtomValue(evaluatorAtom, {store}) + + if (!evaluator) return null + + const metricEntries = Object.entries(evaluator.metrics || {}) + + return ( + + {evaluator.name} +
+ {metricEntries.map(([metricKey, def]) => ( +
+ {metricKey} + +
+ ))} +
+
+ ) +} + +export default memo(EvaluatorMetricsCard) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalResultsView/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalResultsView/index.tsx new file mode 100644 index 0000000000..37185f8eda --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalResultsView/index.tsx @@ -0,0 +1,39 @@ +import {memo, useCallback, useMemo} from "react" + +import deepEqual from "fast-deep-equal" +import {useAtomValue} from "jotai" +import {selectAtom} from "jotai/utils" + +import {evaluationEvaluatorsFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import EvaluatorMetricsCard from "./EvaluatorMetricsCard" + +/** + * Displays run-level evaluation results grouped by evaluator. + * Uses selectAtom to subscribe only to the evaluator *list shape* (slug array) so the + * parent component re-renders only when evaluators are added/removed – any metric changes + * are handled inside each card. + */ +const EvalResultsView = ({runId}: {runId: string}) => { + const slugSelector = useCallback( + (list: any[] | undefined): string[] => + (list || []).map((ev) => ev.slug || ev.id || ev.name), + [], + ) + + const slugsAtom = useMemo( + () => selectAtom(evaluationEvaluatorsFamily(runId), slugSelector, deepEqual), + [runId], + ) + const evaluatorSlugs = useAtomValue(slugsAtom) + + return ( +
+ {evaluatorSlugs.map((slug) => ( + + ))} +
+ ) +} + +export default memo(EvalResultsView) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunBatchActions.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunBatchActions.tsx new file mode 100644 index 0000000000..50ba4ccf21 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunBatchActions.tsx @@ -0,0 +1,238 @@ +import {memo, useCallback, useState} from "react" + +import RunButton from "@agenta/oss/src/components/Playground/assets/RunButton" +import {useAtomValue} from "jotai" +import {loadable} from "jotai/utils" + +// agenta hooks & utils +import {useRunId} from "@/oss/contexts/RunIdContext" +import {convertToStringOrJson} from "@/oss/lib/helpers/utils" +import {useEvalScenarioQueue} from "@/oss/lib/hooks/useEvalScenarioQueue" +import { + scenarioStepFamily, + scenariosFamily, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {scenarioMetricsMapFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" + +import SaveDataButton from "../../components/SaveDataModal/assets/SaveDataButton" +import {hasRunnableScenarioFamily} from "../assets/runnableSelectors" + +import InstructionButton from "./Modals/InstructionModal/assets/InstructionButton" + +const EMPTY_ROWS: any[] = [] + +/** + * This component renders a bar of buttons above the scenario table. + * It includes a button to run all scenarios, a button to export results, + * a button to save the test set, a button to refresh the page, and a button + * to open the instruction modal. + * + * @returns A JSX element containing a bar of buttons. + */ +// derived atom: keeps only the length (count) of runnable scenarios to minimise re-renders + +const EvalRunBatchActions = ({name}: {name: string}) => { + const [rows, setRows] = useState(EMPTY_ROWS) + const runId = useRunId() + const store = evalAtomStore() + + const {enqueueScenario} = useEvalScenarioQueue({concurrency: 5, runId}) + + // Lightweight subscription: only track the count of runnable scenarios - use global store + const hasRunnable = useAtomValue(hasRunnableScenarioFamily(runId), {store}) + const isRunAllDisabled = !hasRunnable + + const handleRunAll = useCallback(async () => { + if (!runId) return + + try { + const store = evalAtomStore() + + // Get all scenarios for this run (same as single run approach) + const scenarios = store.get(scenariosFamily(runId)) + console.log(`[EvalRunBatchActions] Found ${scenarios.length} total scenarios`) + + if (scenarios.length === 0) { + console.warn("[EvalRunBatchActions] No scenarios found") + return + } + + let enqueuedCount = 0 + + // For each scenario, get its step data using the same approach as RunEvalScenarioButton + for (const scenario of scenarios) { + const scenarioId = scenario.id + + try { + // Use the same loadable approach as RunEvalScenarioButton + const stepLoadableAtom = loadable(scenarioStepFamily({scenarioId, runId})) + const stepLoadable = store.get(stepLoadableAtom) + + if (stepLoadable.state !== "hasData" || !stepLoadable.data) { + console.log( + `[EvalRunBatchActions] Scenario ${scenarioId} - step data not ready (state: ${stepLoadable.state})`, + ) + continue + } + + const invocationSteps = stepLoadable.data.invocationSteps || [] + console.log( + `[EvalRunBatchActions] Scenario ${scenarioId} has ${invocationSteps.length} invocation steps`, + ) + + // Find the first step with invocation parameters (same logic as RunEvalScenarioButton) + const targetStep = invocationSteps.find((s: any) => s.invocationParameters) + + if (targetStep && targetStep.invocationParameters) { + // Check if step is not already running or successful + const isRunning = invocationSteps.some((s: any) => s.status === "running") + const isSuccess = (targetStep as any).status === "success" + + if (!isRunning && !isSuccess) { + console.log( + `[EvalRunBatchActions] Enqueuing scenario ${scenarioId}, step ${targetStep.stepKey}`, + ) + enqueueScenario(scenarioId, targetStep.stepKey) + enqueuedCount++ + } else { + console.log( + `[EvalRunBatchActions] Skipping scenario ${scenarioId} - already running or successful`, + ) + } + } else { + console.log( + `[EvalRunBatchActions] Skipping scenario ${scenarioId} - no invocation parameters`, + ) + } + } catch (error) { + console.error( + `[EvalRunBatchActions] Error processing scenario ${scenarioId}:`, + error, + ) + } + } + + console.log( + `[EvalRunBatchActions] Run all completed, enqueued ${enqueuedCount} scenarios`, + ) + + // Note: Metrics will be automatically fetched by store-level subscription + if (enqueuedCount > 0) { + console.log( + `[EvalRunBatchActions] Enqueued ${enqueuedCount} scenarios for runId: ${runId}`, + ) + } + } catch (error) { + console.error("[EvalRunBatchActions] Error in handleRunAll:", error) + } + }, [runId, enqueueScenario]) + + const csvDataFormat = useCallback(async () => { + if (!runId) return [] + + // 1. Gather the scenario IDs present in the current evaluation (sync) + const store = evalAtomStore() + const scenarios = store.get(scenariosFamily(runId)) + const ids = scenarios.map((s: any) => s.id) + + // 2. Resolve (possibly async) scenario step data for each id + const [scenarioMetricsMap, ...allScenarios] = await Promise.all([ + store.get(scenarioMetricsMapFamily(runId)), + ...ids.map((id) => store.get(scenarioStepFamily({runId, scenarioId: id}))), + ]) + + // 3. Build the CSV-friendly records + const data = allScenarios.map((scenario) => { + if (!scenario) return {} + const sid = scenario.steps?.[0]?.scenarioId + + const primaryInput = scenario.inputSteps?.find((s: any) => s.inputs) || {} + const {inputs = {}, groundTruth = {}, status: inputStatus} = primaryInput as any + + const record: Record = {} + + // Add inputs + Object.entries(inputs).forEach(([k, v]) => { + record[k] = convertToStringOrJson(v) + }) + + // Add ground truths + Object.entries(groundTruth).forEach(([k, v]) => { + record[k] = convertToStringOrJson(v) + }) + + // Add annotation metrics/notes per evaluator slug + scenario.annotationSteps?.forEach((annStep: any) => { + const evaluatorSlug = (annStep.stepKey as string)?.split(".")[1] + if (!evaluatorSlug) return + + // 1. summarize metrics from scenarioMetricsMap for this scenario by slug prefix + const summarized: Record = {} + // const sid = + // scenario.scenarioId || (scenario as any).scenario_id || (scenario as any).id + const scenarioMetrics = scenarioMetricsMap?.[String(sid)] || {} + Object.entries(scenarioMetrics).forEach(([fullKey, stats]) => { + if (fullKey.startsWith(`${evaluatorSlug}.`)) { + const metricKey = fullKey.slice(evaluatorSlug.length + 1) + summarized[metricKey] = stats + } + }) + + if (Object.keys(summarized).length) { + record[evaluatorSlug] = convertToStringOrJson({...summarized}) + } + }) + + // Extract model output from the first invocation step that contains a trace + const invWithTrace = scenario.invocationSteps?.find((inv: any) => inv.trace) + const traceObj = invWithTrace?.trace + let traceOutput: any + if (Array.isArray(traceObj?.nodes)) { + traceOutput = traceObj.nodes[0]?.data?.outputs + } else if (Array.isArray(traceObj?.trees)) { + traceOutput = traceObj.trees[0]?.nodes?.[0]?.data?.outputs + } + + if (traceOutput) { + record.output = convertToStringOrJson(traceOutput) + } + + record.status = inputStatus ?? "unknown" + return record + }) + + return data + }, [runId]) + + const onClickSaveData = useCallback(async () => { + const data = await csvDataFormat() + setRows(data) + }, [csvDataFormat]) + + return ( +
+ + + + + + + +
+ ) +} + +export default memo(EvalRunBatchActions) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunName/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunName/index.tsx new file mode 100644 index 0000000000..96bdcb6d14 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunName/index.tsx @@ -0,0 +1,78 @@ +import {memo} from "react" + +import {Typography} from "antd" +import {useAtomValue} from "jotai" + +import {urlStateAtom} from "../../../state/urlState" +import {EvalRunProps} from "../../assets/types" +import RenameEvalButton from "../Modals/RenameEvalModal/assets/RenameEvalButton" + +const EvalRunName = (props: EvalRunProps) => { + const {id, name, description, runId} = props || {} + const urlState = useAtomValue(urlStateAtom) + + // Check if we're in comparison mode + const isComparisonMode = Boolean(urlState.compare && urlState.compare.length > 0) + + if (isComparisonMode) { + return ( +
+
+
+ + Evaluation Run Comparison + +
+
+ {description && ( + + {description} + + )} +
+ ) + } + + return ( +
+
+
+ + {name} + + +
+
+ {description && ( + + {description} + + )} +
+ ) +} + +export default memo(EvalRunName) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenario/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenario/index.tsx new file mode 100644 index 0000000000..b1783d41bb --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenario/index.tsx @@ -0,0 +1,53 @@ +import {memo} from "react" + +import clsx from "clsx" +import {useAtomValue} from "jotai" + +import {runViewTypeAtom} from "../../../state/urlState" +import EvalRunScenarioCard from "../EvalRunScenarioCard" +import ScenarioAnnotationPanel from "../ScenarioAnnotationPanel" + +import {EvalRunScenarioProps} from "./types" + +const EvalRunScenario = ({scenarioId, runId, className}: EvalRunScenarioProps) => { + const viewType = useAtomValue(runViewTypeAtom) + + return ( +
_.ant-card]:grow": viewType !== "focus", + }, + ])} + > +
+ {viewType !== "focus" ? ( + + ) : null} + +
+
+ ) +} + +export default memo(EvalRunScenario) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenario/types.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenario/types.ts new file mode 100644 index 0000000000..26a7b69ed4 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenario/types.ts @@ -0,0 +1,5 @@ +export interface EvalRunScenarioProps { + scenarioId: string + runId: string + className?: string +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/EvalRunScenarioCardBody.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/EvalRunScenarioCardBody.tsx new file mode 100644 index 0000000000..a56efc709c --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/EvalRunScenarioCardBody.tsx @@ -0,0 +1,151 @@ +import {FC, memo, useCallback, useMemo} from "react" + +import {Typography} from "antd" +import {atom, useAtomValue} from "jotai" +import {selectAtom} from "jotai/utils" + +import { + loadableScenarioStepFamily, + bulkStepsCacheFamily, + getCurrentRunId, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import {renderSkeleton} from "./assets/utils" +import InvocationRun from "./InvocationRun" + +interface EvalRunScenarioCardBodyProps { + scenarioId: string + runId?: string +} + +const EvalRunScenarioCardBody: FC = ({scenarioId, runId}) => { + const store = evalAtomStore() + + // Get effective runId - use provided runId or fallback to current run context + const effectiveRunId = useMemo(() => { + if (runId) return runId + try { + return getCurrentRunId() + } catch (error) { + return "" + } + }, [runId]) + + /* --- atoms & data --- */ + // Unified data access that prioritizes bulk cache over individual scenario atoms + // This ensures we get data from whichever source is available + const invocationSteps = useAtomValue( + useMemo( + () => + atom((get) => { + // First try bulk cache (populated by worker) + const bulkCache = get(bulkStepsCacheFamily(effectiveRunId)) + const bulkData = bulkCache?.get(scenarioId) + if ( + bulkCache && + bulkData?.state === "hasData" && + bulkData.data?.invocationSteps + ) { + return bulkData.data.invocationSteps as any[] + } + + // Fallback to individual scenario atom + const loadable = get( + loadableScenarioStepFamily({scenarioId, runId: effectiveRunId}), + ) + if (loadable.state === "hasData" && loadable.data?.invocationSteps) { + return loadable.data.invocationSteps as any[] + } + + return [] + }), + [scenarioId, effectiveRunId], + ), + {store}, + ) + + // Use the same atom for load state as we use for data to ensure consistency + // This prevents blocking UI when we have optimistically updated data + const loadState = useAtomValue( + useMemo( + () => + selectAtom(loadableScenarioStepFamily({scenarioId, runId: effectiveRunId}), (l) => { + return l.state + }), + [scenarioId, effectiveRunId], + ), + {store}, + ) + + /* --- render content --- */ + const renderRuns = useCallback(() => { + if (!invocationSteps.length) return null + + return invocationSteps.map((invStep: any) => ( + + )) + }, [scenarioId, invocationSteps, effectiveRunId]) + + /* --- loading / error states --- */ + // Determine if we truly have no cached data for this scenario yet + const hasCachedSteps = useAtomValue( + useMemo( + () => + selectAtom( + loadableScenarioStepFamily({scenarioId, runId: effectiveRunId}), + (l) => l.state === "hasData" && l.data !== undefined, + ), + [scenarioId, effectiveRunId], + ), + {store}, + ) + + // Check scenario status to determine if we're in execution/revalidation state + const scenarioStatus = useAtomValue( + useMemo( + () => + selectAtom(loadableScenarioStepFamily({scenarioId, runId: effectiveRunId}), (l) => { + if (l.state !== "hasData" || !l.data) return null + const invSteps = l.data.invocationSteps || [] + const annSteps = l.data.annotationSteps || [] + const inputSteps = l.data.inputSteps || [] + + // Check if any step is running or revalidating + const isRunning = [...invSteps, ...annSteps, ...inputSteps].some( + (s: any) => s.status === "running" || s.status === "revalidating", + ) + + return isRunning ? "active" : "idle" + }), + [scenarioId, effectiveRunId], + ), + {store}, + ) + + // Only show loading skeleton when we're actually fetching data from server AND have no cached data + // Don't show loading during scenario execution ("running") or revalidation ("revalidating") + const isInitialLoading = + loadState === "loading" && + !hasCachedSteps && + invocationSteps.length === 0 && + scenarioStatus !== "active" + + if (isInitialLoading) { + return renderSkeleton() + } + if (loadState === "hasError") { + return Failed to load scenario data. + } + + if (!invocationSteps.length) return null + + return
{renderRuns()}
+} + +export default memo(EvalRunScenarioCardBody) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationInputs.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationInputs.tsx new file mode 100644 index 0000000000..cfc9d6333e --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationInputs.tsx @@ -0,0 +1,110 @@ +import {memo} from "react" + +import JSON5 from "json5" + +import TextControl from "@/oss/components/Playground/Components/PlaygroundVariantPropertyControl/assets/TextControl" +import SharedEditor from "@/oss/components/Playground/Components/SharedEditor" +import useEvalRunScenarioData from "@/oss/lib/hooks/useEvaluationRunData/useEvalRunScenarioData" + +import {renderChatMessages} from "../../../assets/renderChatMessages" + +interface InvocationInputsProps { + scenarioId: string + testcaseId: string | undefined + runId?: string +} + +const InvocationInputs = ({scenarioId, testcaseId, runId}: InvocationInputsProps) => { + const data = useEvalRunScenarioData(scenarioId, runId) + // Prefer the inputStep directly enriched with `inputs` field (added during bulk/enrichment) + const inputStep = + data?.inputSteps?.find((s) => s.testcaseId === testcaseId) ?? + data?.steps?.find((s) => s.testcaseId === testcaseId && s.inputs) + const inputs = inputStep?.inputs ?? {} + const groundTruth = (inputStep as any)?.groundTruth ?? {} + + // Merge inputs and groundTruth, giving preference to explicit inputs if duplicate keys + const displayInputs = {...groundTruth, ...inputs} + + if (!displayInputs || Object.keys(displayInputs).length === 0) return null + + // Separate inputs into primitives, JSON objects/arrays, and chat messages + const primitiveEntries: [string, string][] = [] + const jsonEntries: [string, any][] = [] + const chatEntries: [string, string][] = [] + + Object.entries(displayInputs).forEach(([k, _v]) => { + // If already an object/array, treat as JSON directly + if (_v && typeof _v === "object") { + jsonEntries.push([k, _v]) + return + } + // Strings may encode JSON or chat messages + if (typeof _v === "string") { + try { + const parsed = JSON5.parse(_v) + if ( + parsed && + Array.isArray(parsed) && + parsed.every( + (m: any) => m && typeof m === "object" && "role" in m && "content" in m, + ) + ) { + chatEntries.push([k, _v]) + } else if (parsed && typeof parsed === "object") { + jsonEntries.push([k, parsed]) + } else { + primitiveEntries.push([k, _v]) + } + } catch { + primitiveEntries.push([k, _v]) + } + return + } + // Fallback to primitive string rendering + primitiveEntries.push([k, String(_v)]) + }) + + const renderPrimitive = ([k, v]: [string, string]) => ( +
+ {}} + disabled + state="readOnly" + className="!text-xs" + /> +
+ ) + + // Render complex chat message inputs using shared util + const renderComplex = ([k, v]: [string, string]) => + renderChatMessages({keyPrefix: k, rawJson: v, view: "single"}) + + const renderJson = ([k, obj]: [string, any]) => ( +
+ +
+ ) + + return ( +
+ {/* Render primitives first */} + {primitiveEntries.map(renderPrimitive)} + {/* Then structured JSON objects/arrays */} + {jsonEntries.map(renderJson)} + {/* Then complex chat/message inputs */} + {chatEntries.flatMap(renderComplex)} +
+ ) +} + +export default memo(InvocationInputs) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationResponse.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationResponse.tsx new file mode 100644 index 0000000000..76520ea9ee --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationResponse.tsx @@ -0,0 +1,151 @@ +import {memo} from "react" + +import {Typography} from "antd" +import JSON5 from "json5" + +import GenerationResultUtils from "@/oss/components/Playground/Components/PlaygroundGenerations/assets/GenerationResultUtils" +import SimpleDropdownSelect from "@/oss/components/Playground/Components/PlaygroundVariantPropertyControl/assets/SimpleDropdownSelect" +import SharedEditor from "@/oss/components/Playground/Components/SharedEditor" +import {useInvocationResult} from "@/oss/lib/hooks/useInvocationResult" + +import RunEvalScenarioButton from "../RunEvalScenarioButton" + +import {InvocationResponseProps} from "./types" + +const InvocationResponse = ({scenarioId, stepKey, runId}: InvocationResponseProps) => { + const {status, trace, value, messageNodes} = useInvocationResult({scenarioId, stepKey, runId}) + const editorKey = trace?.trace_id ?? trace?.id ?? `${scenarioId}-${stepKey}-${runId}` + + return ( +
+
+ + Model Response + + +
+ + {messageNodes ? ( + messageNodes + ) : typeof value === "object" && value && "role" in value && "content" in value ? ( + + {}} + disabled + /> +
+ } + initialValue={(value as any).content} + editorClassName="!text-xs" + disabled + error={!!trace?.exception} + /> + ) : typeof value === "string" ? ( + (() => { + try { + const parsed = JSON5.parse(value) + if (parsed && typeof parsed === "object") { + const pretty = JSON.stringify(parsed, null, 2) + return ( + {}} + initialValue={pretty} + editorType="border" + placeholder="Click the 'Run' icon to get variant output" + disabled + editorClassName="!text-xs" + editorProps={{enableResize: true, codeOnly: true}} + error={!!trace?.exception} + /> + ) + } + + return ( + {}} + initialValue={value} + editorType="border" + placeholder="Click the 'Run' icon to get variant output" + disabled + editorClassName="!text-xs" + editorProps={{enableResize: true}} + error={!!trace?.exception} + /> + ) + } catch { + return ( + {}} + initialValue={value} + editorType="border" + placeholder="Click the 'Run' icon to get variant output" + disabled + editorClassName="!text-xs" + editorProps={{enableResize: true}} + error={!!trace?.exception} + /> + ) + } + })() + ) : typeof value === "object" ? ( + {}} + initialValue={(() => { + try { + return JSON.stringify(value, null, 2) + } catch { + return String(value) + } + })()} + editorType="border" + placeholder="Click the 'Run' icon to get variant output" + disabled + editorClassName="!text-xs" + editorProps={{enableResize: true, codeOnly: true}} + error={!!trace?.exception} + /> + ) : ( + {}} + initialValue={status?.error ? String(status.error) : (value ?? status?.result)} + editorType="border" + placeholder="Click the 'Run' icon to get variant output" + disabled + editorClassName="!text-xs" + editorProps={{enableResize: true}} + error={!!trace?.exception} + /> + )} + {trace ? ( + + ) : null} + + ) +} + +export default memo(InvocationResponse) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationRun.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationRun.tsx new file mode 100644 index 0000000000..d04cee9eaf --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/InvocationRun.tsx @@ -0,0 +1,20 @@ +import {memo} from "react" + +import InvocationInputs from "./InvocationInputs" +import InvocationResponse from "./InvocationResponse" +import {InvocationRunProps} from "./types" + +const InvocationRun = ({invStep, scenarioId, runId}: InvocationRunProps) => { + return ( +
+ + +
+ ) +} + +export default memo(InvocationRun) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/assets/KeyValue.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/assets/KeyValue.tsx new file mode 100644 index 0000000000..9b13e9e97c --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/assets/KeyValue.tsx @@ -0,0 +1,59 @@ +import {memo} from "react" + +import {Typography} from "antd" + +import {KeyValueProps} from "../types" + +const KeyValue = ({label, value, ...rest}: KeyValueProps) => { + const renderVal = () => { + if (value == null || value === "") { + return N/A + } + if (typeof value === "object") { + const entries = Object.entries(value as Record) + if (entries.length > 1) { + return ( +
    + {entries.map(([k, v]) => { + if (process.env.NODE_ENV !== "production") { + console.debug("k - v", k, v) + } + return ( +
  • + + {k}: + + + {typeof v === "object" ? JSON.stringify(v) : String(v)} + +
  • + ) + })} +
+ ) + } + const singleVal = entries[0][1] + return typeof singleVal === "object" ? JSON.stringify(singleVal) : String(singleVal) + } + return String(value) + } + + return ( + <> +
+ + {label}: + + + {renderVal()} + +
+ + ) +} + +export default memo(KeyValue) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/assets/utils.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/assets/utils.tsx new file mode 100644 index 0000000000..35bd0cce84 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/assets/utils.tsx @@ -0,0 +1,9 @@ +import {Skeleton} from "antd" + +export function renderSkeleton() { + return ( +
+ +
+ ) +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/index.tsx new file mode 100644 index 0000000000..281821b189 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/index.tsx @@ -0,0 +1,73 @@ +import {memo, useMemo} from "react" + +import {Card} from "antd" +import deepEqual from "fast-deep-equal" +import {useAtomValue} from "jotai" +import {selectAtom} from "jotai/utils" + +import { + evaluationRunStateFamily, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {EvaluationRunState} from "@/oss/lib/hooks/useEvaluationRunData/types" + +import EvalRunScenarioCardTitle from "../EvalRunScenarioCardTitle" +import RunEvalScenarioButton from "../RunEvalScenarioButton" + +import EvalRunScenarioCardBody from "./EvalRunScenarioCardBody" +import {EvalRunScenarioCardProps} from "./types" + +/** + * Component that renders a card view for a specific evaluation run scenario. + * Depending on the `viewType`, it can display the scenario in a card format + * or a full-width format. Utilizes data from Jotai atoms to display scenario + * details, including loading state and error handling. + * + * @param {string} scenarioId - The unique identifier for the scenario to be displayed. + * @param {ViewType} [viewType="list"] - Determines the layout of the scenario display, + * either as a "list" (card format) or "single" (full-width). + */ +const EvalRunScenarioCard = ({scenarioId, runId, viewType = "list"}: EvalRunScenarioCardProps) => { + const store = evalAtomStore() + + /* scenario index for card title */ + // Read from the same global store that writes are going to + const scenarioIndex = useAtomValue( + useMemo( + () => + selectAtom( + evaluationRunStateFamily(runId), // Use run-scoped atom with runId + (state: EvaluationRunState) => + state.scenarios?.find((s) => s.id === scenarioId)?.scenarioIndex, + deepEqual, + ), + [scenarioId, runId], // Include runId in dependencies + ), + {store}, + ) + + if (scenarioIndex === undefined) return null + + return viewType === "list" ? ( + + } + style={{width: 400}} + className="self-stretch" + actions={[]} + > + + + ) : ( +
+ +
+ ) +} + +export default memo(EvalRunScenarioCard) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/types.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/types.ts new file mode 100644 index 0000000000..559d89c211 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCard/types.ts @@ -0,0 +1,29 @@ +import {ComponentProps} from "react" + +import {Typography} from "antd" + +export type ViewType = "list" | "focus" + +export interface EvalRunScenarioCardProps { + scenarioId: string + runId: string + viewType?: ViewType +} + +export interface KeyValueProps { + label: string + value: any + type?: ComponentProps["type"] +} + +export interface InvocationResponseProps { + scenarioId: string + stepKey: string + runId?: string +} + +export interface InvocationRunProps { + invStep: any + scenarioId: string + runId?: string +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCardTitle/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCardTitle/index.tsx new file mode 100644 index 0000000000..ff9c9df14e --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCardTitle/index.tsx @@ -0,0 +1,22 @@ +import {memo} from "react" + +import {Typography} from "antd" + +import EvalRunScenarioStatusTag from "../../../components/EvalRunScenarioStatusTag" + +import {EvalRunScenarioCardTitleProps} from "./types" + +const EvalRunScenarioCardTitle = ({ + scenarioIndex, + scenarioId, + runId, +}: EvalRunScenarioCardTitleProps) => { + return ( +
+ Test Case #{scenarioIndex} + +
+ ) +} + +export default memo(EvalRunScenarioCardTitle) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCardTitle/types.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCardTitle/types.ts new file mode 100644 index 0000000000..de91d7ca09 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCardTitle/types.ts @@ -0,0 +1,5 @@ +export interface EvalRunScenarioCardTitleProps { + scenarioIndex: number + scenarioId: string + runId: string +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCards/EvalRunScenarioCards.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCards/EvalRunScenarioCards.tsx new file mode 100644 index 0000000000..9b845ee94a --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCards/EvalRunScenarioCards.tsx @@ -0,0 +1,75 @@ +import {memo, RefObject, useRef} from "react" + +import {Typography} from "antd" +import clsx from "clsx" +import {useAtomValue} from "jotai" +import {FixedSizeList as List} from "react-window" +import {useResizeObserver} from "usehooks-ts" + +import { + displayedScenarioIdsFamily, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import EvalRunScenario from "../EvalRunScenario" +import ScenarioLoadingIndicator from "../ScenarioLoadingIndicator/ScenarioLoadingIndicator" + +import {ITEM_GAP, ITEM_SIZE, ITEM_WIDTH} from "./assets/constants" + +/** + * Horizontal scroll list of `EvalRunScenario` cards with a shared loading indicator. + * Extracted clean version after refactor. No duplicated legacy code. + */ +const EvalRunScenarioCards = ({runId}: {runId: string}) => { + const store = evalAtomStore() + const scenarioIds = useAtomValue(displayedScenarioIdsFamily(runId), {store}) || [] + + const containerRef = useRef(null) + const {width = 0, height = 0} = useResizeObserver({ + ref: containerRef as RefObject, + box: "border-box", + }) + + return ( +
+
+ + All Scenarios + + +
+ +
+ {width > 0 && height > 0 && ( + scenarioIds[index]} + > + {({index, style}) => ( +
+ +
+ )} +
+ )} +
+
+ ) +} + +export default memo(EvalRunScenarioCards) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCards/assets/constants.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCards/assets/constants.ts new file mode 100644 index 0000000000..de7b942a52 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioCards/assets/constants.ts @@ -0,0 +1,3 @@ +export const ITEM_WIDTH = 400 +export const ITEM_GAP = 16 +export const ITEM_SIZE = ITEM_WIDTH + ITEM_GAP diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioFilters.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioFilters.tsx new file mode 100644 index 0000000000..843cbe16ec --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/EvalRunScenarioFilters.tsx @@ -0,0 +1,48 @@ +import {memo, useCallback} from "react" + +import {Segmented} from "antd" +import {useSetAtom, useAtomValue} from "jotai" + +import {useRunId} from "@/oss/contexts/RunIdContext" +import { + evalAtomStore, + totalCountFamily, + evalScenarioFilterAtom, + pendingCountFamily, + unannotatedCountFamily, + failedCountFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +const EvalRunScenarioFilters = () => { + const runId = useRunId() + const store = evalAtomStore() + + // Read from the same global store that writes are going to + const setFilterAtom = useSetAtom(evalScenarioFilterAtom, {store}) + const filter = useAtomValue(evalScenarioFilterAtom, {store}) + const totalCount = useAtomValue(totalCountFamily(runId), {store}) + const pendingCount = useAtomValue(pendingCountFamily(runId), {store}) + const unannotatedCount = useAtomValue(unannotatedCountFamily(runId), {store}) + const failedCount = useAtomValue(failedCountFamily(runId), {store}) + + const handleChange = useCallback((val: string) => { + setFilterAtom(val as any) + }, []) + + return ( + + ) +} + +export default memo(EvalRunScenarioFilters) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/InstructionModal/assets/InstructionButton.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/InstructionModal/assets/InstructionButton.tsx new file mode 100644 index 0000000000..6b7698d0dd --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/InstructionModal/assets/InstructionButton.tsx @@ -0,0 +1,51 @@ +import {cloneElement, isValidElement, useState} from "react" + +import {Question} from "@phosphor-icons/react" +import dynamic from "next/dynamic" + +import EnhancedButton from "@/oss/components/Playground/assets/EnhancedButton" + +const InstructionModal = dynamic(() => import("../index"), {ssr: false}) + +const InstructionButton = ({ + icon = true, + children, + label, + ...props +}: { + icon?: boolean + children?: React.ReactNode + label?: string +}) => { + const [isModalOpen, setIsModalOpen] = useState(false) + + return ( + <> + {isValidElement(children) ? ( + cloneElement( + children as React.ReactElement<{ + onClick: () => void + }>, + { + onClick: () => { + setIsModalOpen(true) + }, + }, + ) + ) : ( + } + onClick={() => setIsModalOpen(true)} + tooltipProps={icon && !label ? {title: "Instructions"} : {}} + label={label} + {...props} + /> + )} + + setIsModalOpen(false)} /> + + ) +} + +export default InstructionButton diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/InstructionModal/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/InstructionModal/index.tsx new file mode 100644 index 0000000000..8fb398eca4 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/InstructionModal/index.tsx @@ -0,0 +1,37 @@ +import {Play} from "@phosphor-icons/react" +import {Modal} from "antd" +import {useRouter} from "next/router" + +import {InstructionModalProps} from "../types" + +const InstructionModal = ({...props}: InstructionModalProps) => { + const router = useRouter() + const isAbTesting = router.pathname.includes("a_b_testing") + + return ( + +
    +
  1. + Use the buttons Next and Prev or the arrow keys{" "} + {`Left (<)`} and {`Right (>)`} to navigate between + scenarios. +
  2. +
  3. + Click the Run button or press{" "} + {`Meta+Enter (⌘+↵)`} or {`Ctrl+Enter`} to run the + scenario. +
  4. + {isAbTesting && ( +
  5. + Vote by either clicking the evaluation buttons at the right sidebar + or pressing the key a for 1st Variant, b for 2nd + Variant and x if both are bad. +
  6. + )} +
  7. Annotate the scenario
  8. +
+
+ ) +} + +export default InstructionModal diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/assets/RenameEvalButton.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/assets/RenameEvalButton.tsx new file mode 100644 index 0000000000..2db2a20630 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/assets/RenameEvalButton.tsx @@ -0,0 +1,60 @@ +import {cloneElement, isValidElement, memo, useState} from "react" + +import {EditOutlined} from "@ant-design/icons" +import dynamic from "next/dynamic" + +import EnhancedButton from "@/oss/components/Playground/assets/EnhancedButton" + +import {RenameEvalButtonProps} from "../../types" + +const RenameEvalModal = dynamic(() => import(".."), {ssr: false}) + +const RenameEvalButton = ({ + id, + name, + description, + runId, + icon = true, + children, + label, + ...props +}: RenameEvalButtonProps) => { + const [isModalOpen, setIsModalOpen] = useState(false) + + return ( + <> + {isValidElement(children) ? ( + cloneElement( + children as React.ReactElement<{ + onClick: () => void + }>, + { + onClick: () => { + setIsModalOpen(true) + }, + }, + ) + ) : ( + } + onClick={() => setIsModalOpen(true)} + tooltipProps={icon && !label ? {title: "Rename the eval run"} : {}} + label={label} + {...props} + /> + )} + + setIsModalOpen(false)} + /> + + ) +} + +export default memo(RenameEvalButton) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/assets/RenameEvalModalContent.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/assets/RenameEvalModalContent.tsx new file mode 100644 index 0000000000..2b7a9b5eac --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/assets/RenameEvalModalContent.tsx @@ -0,0 +1,35 @@ +import {Input, Typography} from "antd" + +import {RenameEvalModalContentProps} from "../../types" + +const RenameEvalModalContent = ({ + loading, + error, + editName, + setEditName, + editDescription, + setEditDescription, +}: RenameEvalModalContentProps) => { + return ( +
+ setEditName(e.target.value)} + maxLength={100} + placeholder="Run name" + disabled={loading} + /> + setEditDescription(e.target.value)} + rows={3} + maxLength={500} + placeholder="Description (optional)" + disabled={loading} + /> + {error && {error}} +
+ ) +} + +export default RenameEvalModalContent diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/index.tsx new file mode 100644 index 0000000000..c4c07ba0ed --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/RenameEvalModal/index.tsx @@ -0,0 +1,91 @@ +import {useCallback, useMemo, useState} from "react" + +import {message} from "antd" +import {useSWRConfig} from "swr" + +import EnhancedModal from "@/oss/components/EnhancedUIs/Modal" +import {useRunId} from "@/oss/contexts/RunIdContext" +import axios from "@/oss/lib/api/assets/axiosConfig" +import { + evalAtomStore, + evaluationRunStateFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import {RenameEvalModalProps} from "../types" + +import RenameEvalModalContent from "./assets/RenameEvalModalContent" + +const RenameEvalModal = ({id, name, description, runId, ...props}: RenameEvalModalProps) => { + const {mutate} = useSWRConfig() + const contextRunId = useRunId() // Get runId from context + const effectiveRunId = runId || contextRunId // Use prop runId if available, otherwise context + const [editName, setEditName] = useState(name) + const [editDescription, setEditDescription] = useState(description || "") + const [loading, setLoading] = useState(false) + const [error, setError] = useState(null) + + const onAfterClose = useCallback(() => { + setEditName(name) + setEditDescription(description || "") + setError(null) + props.afterClose?.() + }, [name, description]) + + const handleSave = useCallback(async () => { + setLoading(true) + setError(null) + + // Use run-scoped atom with effectiveRunId (from prop or context) + const state = evalAtomStore().get(evaluationRunStateFamily(effectiveRunId)) + + try { + await axios.patch(`/preview/evaluations/runs/${id}`, { + run: { + ...state.rawRun, + id, + name: editName, + description: editDescription, + }, + }) + await mutate( + (key: string) => key.includes("/preview/evaluations/runs/") || key.includes(id), + undefined, + true, + ) + + message.success("Evaluation run updated") + props.onCancel?.({} as any) + } catch (err: any) { + setError(err?.message || "Failed to update run") + } finally { + setLoading(false) + } + }, [id, editName, editDescription, mutate, runId]) + + const isDisabled = useMemo(() => { + return editName.trim() === name.trim() && editDescription.trim() === description?.trim() + }, [editName, editDescription, name, description]) + + return ( + + + + ) +} + +export default RenameEvalModal diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/types.d.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/types.d.ts new file mode 100644 index 0000000000..a453b392e1 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/Modals/types.d.ts @@ -0,0 +1,31 @@ +import {Dispatch, SetStateAction, ReactNode} from "react" + +import {ModalProps, ButtonProps} from "antd" + +export interface InstructionModalProps extends ModalProps {} + +export interface RenameEvalModalProps extends ModalProps { + id: string + name: string + description?: string + runId?: string +} + +export interface RenameEvalModalContentProps { + loading?: boolean + error: string | null + editName: string + setEditName: Dispatch> + editDescription: string + setEditDescription: Dispatch> +} + +export interface RenameEvalButtonProps extends ButtonProps { + id: string + name: string + description?: string + runId?: string + icon?: boolean + children?: ReactNode + label?: string +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/RunEvalScenarioButton/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/RunEvalScenarioButton/index.tsx new file mode 100644 index 0000000000..476242a1ba --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/RunEvalScenarioButton/index.tsx @@ -0,0 +1,107 @@ +import {memo, useMemo, useCallback} from "react" + +import RunButton from "@agenta/oss/src/components/Playground/assets/RunButton" +import {Tooltip} from "antd" +import {useAtomValue} from "jotai" +import {loadable} from "jotai/utils" + +// Use EE run-scoped versions for multi-run support +import {useEvalScenarioQueue} from "@/oss/lib/hooks/useEvalScenarioQueue" +import { + getCurrentRunId, + scenarioStepFamily, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import {RunEvalScenarioButtonProps} from "./types" + +const RunEvalScenarioButton = memo( + ({scenarioId, stepKey, label = "Run Scenario", runId}: RunEvalScenarioButtonProps) => { + const store = evalAtomStore() + + // Use effective runId with fallback using useMemo + const effectiveRunId = useMemo(() => { + if (runId) return runId + try { + return getCurrentRunId() + } catch (error) { + console.warn("[RunEvalScenarioButton] No run ID available:", error) + return "" + } + }, [runId]) + + const {enqueueScenario} = useEvalScenarioQueue({concurrency: 5, runId: effectiveRunId}) + + // Derive invocationParameters via scenario step loadable (run-scoped) - use global store + const stepLoadable = useAtomValue( + loadable(scenarioStepFamily({scenarioId, runId: effectiveRunId})), + {store}, + ) + + // derive running flag directly from run-scoped scenario step data + const isRunning = useMemo(() => { + if (stepLoadable.state !== "hasData" || !stepLoadable.data) return false + const data = stepLoadable.data + return ( + data?.invocationSteps?.some((s: any) => s.status === "running") || + data?.annotationSteps?.some((s: any) => s.status === "running") || + data?.inputSteps?.some((s: any) => s.status === "running") + ) + }, [stepLoadable]) + + // Extract invocation steps (if any) + const invocationSteps = + stepLoadable.state === "hasData" ? stepLoadable.data?.invocationSteps || [] : [] + + // Determine target step + const targetStep = stepKey + ? invocationSteps.find((s) => s.stepKey === stepKey) + : invocationSteps.find((s) => s.invocationParameters) + + const autoStepKey = targetStep?.stepKey + const invocationParameters = targetStep?.invocationParameters + const invocationStepStatus = targetStep?.status + + const handleClick = useCallback(() => { + if (invocationParameters) { + enqueueScenario(scenarioId, autoStepKey) + } + }, [enqueueScenario, scenarioId, autoStepKey, invocationParameters]) + + const button = useMemo( + () => ( + + ), + [handleClick, isRunning, invocationStepStatus, invocationParameters, label], + ) + + return ( +
+ {invocationParameters ? ( + + {JSON.stringify(invocationParameters, null, 2)} +
+ } + > + {button} + + ) : ( + button + )} +
+ ) + }, +) + +export default RunEvalScenarioButton diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/RunEvalScenarioButton/types.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/RunEvalScenarioButton/types.ts new file mode 100644 index 0000000000..a523a32155 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/RunEvalScenarioButton/types.ts @@ -0,0 +1,6 @@ +export interface RunEvalScenarioButtonProps { + scenarioId: string + label?: string + stepKey?: string + runId?: string // Optional for multi-run support +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioAnnotationPanel/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioAnnotationPanel/index.tsx new file mode 100644 index 0000000000..dc428e3c5a --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioAnnotationPanel/index.tsx @@ -0,0 +1,302 @@ +import {FC, memo, useCallback, useMemo, useRef, useState} from "react" + +import {Card, Typography} from "antd" +import clsx from "clsx" +import deepEqual from "fast-deep-equal" +import {useAtomValue} from "jotai" +import {selectAtom, loadable} from "jotai/utils" +import dynamic from "next/dynamic" + +import { + getInitialMetricsFromAnnotations, + getInitialSelectedEvalMetrics, +} from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/transforms" +import {UpdatedMetricsType} from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/types" +import {isAnnotationCreatedByCurrentUser} from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/utils" +import {AnnotationDto} from "@/oss/lib/hooks/useAnnotations/types" +import { + getCurrentRunId, + scenarioStepFamily, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {evaluationRunStateFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {UseEvaluationRunScenarioStepsFetcherResult} from "@/oss/lib/hooks/useEvaluationRunScenarioSteps/types" + +import AnnotateScenarioButton from "../AnnotateScenarioButton" +import RunEvalScenarioButton from "../RunEvalScenarioButton" + +import {ScenarioAnnotationPanelProps} from "./types" + +const Annotate = dynamic( + () => + import( + "@agenta/oss/src/components/pages/observability/drawer/AnnotateDrawer/assets/Annotate" + ), + {ssr: false}, +) + +const EmptyArray: any[] = [] + +const ScenarioAnnotationPanelAnnotation = memo( + ({ + onAnnotate, + runId, + scenarioId, + buttonClassName, + invStep, + annotationsByStep, + evaluators, + }: ScenarioAnnotationPanelProps) => { + const [errorMessages, setErrorMessages] = useState(EmptyArray as string[]) + + // TODO: move this to a shared utils file + const formatErrorMessages = useCallback((requiredMetrics: Record) => { + const errorMessages: string[] = [] + + for (const [key, data] of Object.entries(requiredMetrics || {})) { + errorMessages.push( + `Value ${data?.value === "" ? "empty string" : data?.value} is not assignable to type ${data?.type} in ${key}`, + ) + } + setErrorMessages(errorMessages) + }, []) + + const [updatedMetrics, setUpdatedMetrics] = useState({}) + + // helper to compute per-step annotation & evaluator lists + const buildAnnotateData = useCallback( + (stepKey: string) => { + const _steps = annotationsByStep?.[stepKey] || [] + const _annotations = _steps + .map((s) => s.annotation) + .filter(Boolean) as AnnotationDto[] + const annotationEvaluatorSlugs = _annotations + .map((annotation) => annotation?.references?.evaluator?.slug) + .filter(Boolean) + + return { + annotations: _annotations, + evaluatorSlugs: + evaluators + ?.map((e) => e.slug) + .filter((slug) => !annotationEvaluatorSlugs.includes(slug)) || [], + evaluators: + evaluators?.filter((e) => !annotationEvaluatorSlugs.includes(e.slug)) || [], + } + }, + [annotationsByStep, evaluators], + ) + + const {_annotations, isAnnotated, isCreatedByCurrentUser, selectedEvaluators} = + useMemo(() => { + const annotateData = buildAnnotateData(invStep.stepKey) + + const _annotations = annotateData.annotations + const selectedEvaluators = annotateData.evaluatorSlugs + + const isAnnotated = _annotations.length > 0 + const isCreatedByCurrentUser = _annotations.length + ? _annotations.some((ann) => isAnnotationCreatedByCurrentUser(ann)) + : true + + return { + isAnnotated, + isCreatedByCurrentUser, + selectedEvaluators, + _annotations, + } + }, [invStep.stepKey, buildAnnotateData, evaluators]) + + const isChangedMetricData = useMemo(() => { + const annotateData = buildAnnotateData(invStep.stepKey) + + const initialAnnotationMetrics = getInitialMetricsFromAnnotations({ + annotations: annotateData.annotations, + evaluators, + }) + const annotationSlugs = annotateData.annotations + .map((ann) => ann.references?.evaluator?.slug) + .filter(Boolean) + + // Filter updatedMetrics to only include user existing annotations + const filteredUpdatedMetrics = Object.fromEntries( + Object.entries(updatedMetrics).filter(([slug]) => annotationSlugs.includes(slug)), + ) + + if ( + Object.keys(filteredUpdatedMetrics).length === 0 && + filteredUpdatedMetrics.constructor === Object + ) { + return true + } + return deepEqual(filteredUpdatedMetrics, initialAnnotationMetrics) + }, [updatedMetrics, evaluators, invStep.stepKey]) + + const isChangedSelectedEvalMetrics = useMemo(() => { + const annotateData = buildAnnotateData(invStep.stepKey) + const selectedEvaluators = annotateData.evaluatorSlugs + + const initialSelectedEvalMetrics = getInitialSelectedEvalMetrics({ + evaluators: annotateData.evaluators, + selectedEvaluators, + }) + + const filteredUpdatedMetrics = Object.fromEntries( + Object.entries(updatedMetrics).filter(([slug]) => + selectedEvaluators.includes(slug), + ), + ) + + if ( + Object.keys(filteredUpdatedMetrics).length === 0 && + filteredUpdatedMetrics.constructor === Object + ) { + return true + } + + return deepEqual(filteredUpdatedMetrics, initialSelectedEvalMetrics) + }, [updatedMetrics, updatedMetrics, evaluators, invStep.stepKey]) + + return ( +
+ + +
+ ) + }, +) + +const ScenarioAnnotationPanel: FC = ({ + runId, + scenarioId, + className, + classNames, + buttonClassName, + onAnnotate, +}) => { + const store = evalAtomStore() + + // Use effective runId with fallback using useMemo + const effectiveRunId = useMemo(() => { + if (runId) return runId + try { + return getCurrentRunId() + } catch (error) { + console.warn("[ScenarioAnnotationPanel] No run ID available:", error) + return "" + } + }, [runId]) + + // Get evaluators from run-scoped state instead of global atom + const evaluatorsSelector = useCallback((state: any) => { + return state?.enrichedRun?.evaluators ? Object.values(state.enrichedRun.evaluators) : [] + }, []) + + const evaluatorsAtom = useMemo( + () => selectAtom(evaluationRunStateFamily(effectiveRunId), evaluatorsSelector, deepEqual), + [effectiveRunId, evaluatorsSelector], + ) + const evaluators = useAtomValue(evaluatorsAtom, {store}) + + // Loadable step data for this scenario (always eager) - now run-scoped + // Read from the same global store that writes are going to + const stepDataLoadable = useAtomValue( + loadable(scenarioStepFamily({scenarioId, runId: effectiveRunId})), + {store}, + ) + + // Preserve last known data so we can still show tool-tips / fields while revalidating + const prevDataRef = useRef(undefined) + + let stepData: UseEvaluationRunScenarioStepsFetcherResult | undefined = undefined + if (stepDataLoadable.state === "hasData") { + stepData = stepDataLoadable.data + prevDataRef.current = stepDataLoadable.data + } else if (stepDataLoadable.state === "loading") { + stepData = prevDataRef.current + } + + // Memoize field slices for best performance (multi-step) + const _invocationSteps = useMemo(() => stepData?.invocationSteps ?? [], [stepData]) + // Build annotations per step key + const annotationsByStep = useMemo(() => { + if (!stepData) return {} + + type AnnStep = (typeof stepData.steps)[number] + const map: Record = {} + if (!stepData?.steps || !_invocationSteps.length) return map + + // Pre-compute all annotation steps once (annotation step = has invocation key prefix) + const allAnnSteps = (stepData.steps || []).filter((s) => + _invocationSteps.some((invStep) => (s.stepKey ?? "").startsWith(`${invStep.stepKey}.`)), + ) + _invocationSteps.forEach(({stepKey}) => { + const anns = allAnnSteps.filter((s) => (s.stepKey ?? "").startsWith(`${stepKey}.`)) + map[stepKey] = anns + }) + return map + }, [stepData?.steps, _invocationSteps]) + + const hasAnyTrace = useMemo(() => _invocationSteps.some((s) => s.traceId), [_invocationSteps]) + + return ( + +
+ {_invocationSteps.map((invStep) => { + return ( + + ) + })} +
+ {!hasAnyTrace ? ( +
+ To annotate, please generate output + +
+ ) : null} +
+ ) +} + +export default ScenarioAnnotationPanel diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioAnnotationPanel/types.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioAnnotationPanel/types.ts new file mode 100644 index 0000000000..510a28fad2 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioAnnotationPanel/types.ts @@ -0,0 +1,16 @@ +import {CardProps} from "antd" + +import {IStepResponse} from "@/oss/lib/hooks/useEvaluationRunScenarioSteps/types" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" + +export interface ScenarioAnnotationPanelProps { + runId: string + scenarioId: string + className?: string + classNames?: CardProps["classNames"] + buttonClassName?: string + invStep?: IStepResponse + annotationsByStep?: Record + evaluators?: EvaluatorDto[] + onAnnotate?: () => void +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioLoadingIndicator/ScenarioLoadingIndicator.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioLoadingIndicator/ScenarioLoadingIndicator.tsx new file mode 100644 index 0000000000..395d7e4fee --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioLoadingIndicator/ScenarioLoadingIndicator.tsx @@ -0,0 +1,23 @@ +import {memo} from "react" + +import {Progress} from "antd" +import {useAtomValue} from "jotai" + +import {scenarioStepProgressFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import {conicColors} from "./assets/constants" + +const ScenarioLoadingIndicator = ({runId}: {runId: string}) => { + const scenarioStepProgress = useAtomValue(scenarioStepProgressFamily(runId)) + + return scenarioStepProgress.loadingStep === "scenario-steps" ? ( + + ) : null +} + +export default memo(ScenarioLoadingIndicator) diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioLoadingIndicator/assets/constants.ts b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioLoadingIndicator/assets/constants.ts new file mode 100644 index 0000000000..bc4a530f64 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/ScenarioLoadingIndicator/assets/constants.ts @@ -0,0 +1,7 @@ +import type {ProgressProps} from "antd" + +export const conicColors: ProgressProps["strokeColor"] = { + "0%": "#87d068", + "50%": "#ffe58f", + "100%": "#ffccc7", +} diff --git a/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/SingleScenarioViewer/index.tsx b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/SingleScenarioViewer/index.tsx new file mode 100644 index 0000000000..bcf68d8f6f --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/HumanEvalRun/components/SingleScenarioViewer/index.tsx @@ -0,0 +1,130 @@ +import {memo, useEffect} from "react" + +import {Button, Space, Typography} from "antd" +import clsx from "clsx" +import {useAtom, useAtomValue} from "jotai" +import {loadable} from "jotai/utils" +import {useRouter} from "next/router" + +import {useRunId} from "@/oss/contexts/RunIdContext" +import { + displayedScenarioIdsFamily, + scenariosFamily, + evalAtomStore, + scenarioStepProgressFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import EvalRunScenarioNavigator from "../../../components/EvalRunScenarioNavigator" +import {urlStateAtom} from "../../../state/urlState" +import EvalRunScenarioCard from "../EvalRunScenarioCard" +import ScenarioAnnotationPanel from "../ScenarioAnnotationPanel" +import ScenarioLoadingIndicator from "../ScenarioLoadingIndicator/ScenarioLoadingIndicator" + +import {SingleScenarioViewerProps} from "./types" + +const SingleScenarioViewer = ({runId}: SingleScenarioViewerProps) => { + // Use run-scoped atoms with the provided runId + const effectiveRunId = useRunId() || runId + const store = evalAtomStore() + + // Read from the same global store that writes are going to + const scenariosLoadable = useAtomValue(loadable(scenariosFamily(effectiveRunId)), {store}) + const scenarioIdsFromFamily = useAtomValue(displayedScenarioIdsFamily(effectiveRunId), {store}) + + // Fallback: if displayedScenarioIdsFamily is empty but scenariosLoadable has data, use that + const scenarioIds = + scenarioIdsFromFamily?.length > 0 + ? scenarioIdsFromFamily + : scenariosLoadable.state === "hasData" + ? scenariosLoadable.data?.map((s) => s.id) || [] + : [] + const scenarioStepProgress = useAtomValue(scenarioStepProgressFamily(effectiveRunId), {store}) + + // Access URL state atom + const router = useRouter() + const [urlState, setUrlState] = useAtom(urlStateAtom) + + // Prefer URL query first, then atom, then fallback + const activeId = + (router.query.scenarioId as string | undefined) ?? urlState.scenarioId ?? scenarioIds[0] + + // Ensure URL/atom always reference a scenario visible in current list + // Ensure URL/atom correctness + useEffect(() => { + if (scenarioIds.length === 0) return + + const currentScenarioId = + (router.query.scenarioId as string | undefined) ?? urlState.scenarioId + + if (!currentScenarioId || !scenarioIds.includes(currentScenarioId)) { + // Default to the first scenario for this run when no valid selection/deep-link. + setUrlState((draft) => { + draft.scenarioId = scenarioIds[0] + }) + return + } + }, [scenarioIds, router.query.scenarioId, urlState.scenarioId, setUrlState]) + + if (scenariosLoadable.state !== "hasData") { + const step = scenarioStepProgress.loadingStep as string | undefined + if (step === "eval-run" || step === "scenarios") { + return ( + + + )} +
+ setSearchTerm(e.target.value)} + /> + +
+ +
{menu}
+
+ )} + {...selectProps} + > + {_scenarios.map((scenario) => { + const {id, scenarioIndex} = scenario as any + + // non-hook read; never suspends + const loadableStatus = evalAtomStore().get( + loadable(scenarioStatusFamily({scenarioId: id, runId: effectiveRunId})), + ) + const scenStatus = + loadableStatus.state === "hasData" + ? loadableStatus.data + : {status: "pending", label: "Pending"} + + const colorClass = statusColorMap[scenStatus.status] + const labelIndex = scenarioIndex ?? scenarioIds.indexOf(id) + 1 + + return ( + +
+ Scenario {labelIndex} + {scenStatus.status} +
+
+ ) + })} + + + {activeId && showStatus ? ( + + ) : null} +
+ + {!showOnlySelect && ( + + )} + + ) +} + +export default memo(EvalRunScenarioNavigator) diff --git a/web/ee/src/components/EvalRunDetails/components/EvalRunScenarioStatusTag/assets/index.tsx b/web/ee/src/components/EvalRunDetails/components/EvalRunScenarioStatusTag/assets/index.tsx new file mode 100644 index 0000000000..2b87a19431 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/EvalRunScenarioStatusTag/assets/index.tsx @@ -0,0 +1,32 @@ +export const STATUS_COLOR: Record = { + success: "success", + done: "success", + failure: "error", + failed: "error", + EVALUATION_FAILED: "error", + EVALUATION_FINISHED_WITH_ERRORS: "warning", + cancelled: "warning", + EVALUATION_AGGREGATION_FAILED: "warning", + pending: "default", + EVALUATION_INITIALIZED: "default", + running: "blue", + incomplete: "blue", + EVALUATION_STARTED: "blue", + revalidating: "purple", +} + +export const STATUS_COLOR_TEXT: Record = { + success: "text-green-600", + done: "text-green-600", + failure: "text-red-500", + failed: "text-red-500", + EVALUATION_FAILED: "text-red-500", + EVALUATION_FINISHED_WITH_ERRORS: "text-orange-500", + cancelled: "text-yellow-500", + EVALUATION_AGGREGATION_FAILED: "text-orange-500", + pending: "text-gray-400", + EVALUATION_INITIALIZED: "text-gray-400", + running: "text-blue-500", + EVALUATION_STARTED: "text-blue-500", + revalidating: "text-purple-500", +} diff --git a/web/ee/src/components/EvalRunDetails/components/EvalRunScenarioStatusTag/index.tsx b/web/ee/src/components/EvalRunDetails/components/EvalRunScenarioStatusTag/index.tsx new file mode 100644 index 0000000000..16037b2e8e --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/EvalRunScenarioStatusTag/index.tsx @@ -0,0 +1,67 @@ +import {memo, useMemo} from "react" + +import {Tag} from "antd" +import clsx from "clsx" +import {useAtomValue} from "jotai" +import {loadable} from "jotai/utils" + +import {getStatusLabel} from "@/oss/lib/constants/statusLabels" +import { + scenarioStatusFamily, + evalAtomStore, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import {STATUS_COLOR, STATUS_COLOR_TEXT} from "./assets" +/** + * Component to display the status of an evaluation scenario as a Tag. + * + * Retrieves the optimistic scenario overrides for the given scenarioId, + * and uses them to show transient UI-only states like "annotating" or + * "revalidating" if the backend has not yet been updated. + * + * @param scenarioId The ID of the scenario to display the status for. + * @returns A Tag component displaying the status of the scenario. + */ +interface EvalRunScenarioStatusTagProps { + scenarioId: string + runId: string + className?: string + showAsTag?: boolean +} + +const EvalRunScenarioStatusTag = ({ + scenarioId, + runId, + className, + showAsTag = true, +}: EvalRunScenarioStatusTagProps) => { + const store = evalAtomStore() + + /** + * Loadable atom wrapping scenarioStatusFamily, which provides the most + * up-to-date status for the given scenarioId. This can be either a status + * that is being optimistically updated, or the latest status update from + * the backend. + * + * @type {import("jotai/utils").Loadable} + */ + const statusLoadable = useAtomValue( + useMemo(() => loadable(scenarioStatusFamily({scenarioId, runId})), [scenarioId, runId]), + {store}, + ) + const scenarioStatus = statusLoadable.state === "hasData" ? statusLoadable.data : undefined + const status = (scenarioStatus?.status as string) || "pending" + const label = getStatusLabel(status) + + return showAsTag ? ( + + {label} + + ) : ( + + {label} + + ) +} + +export default memo(EvalRunScenarioStatusTag) diff --git a/web/ee/src/components/EvalRunDetails/components/EvalRunScenariosViewSelector/assets/constants.ts b/web/ee/src/components/EvalRunDetails/components/EvalRunScenariosViewSelector/assets/constants.ts new file mode 100644 index 0000000000..b548d80e60 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/EvalRunScenariosViewSelector/assets/constants.ts @@ -0,0 +1,20 @@ +// Feature flag to toggle prototype card (list) view +export const ENABLE_CARD_VIEW = process.env.NEXT_PUBLIC_ENABLE_EVAL_CARD_VIEW === "true" + +export const VIEW_HUMAN_OPTIONS = (() => { + const base = [ + {label: "Focus view", value: "focus"}, + {label: "Table view", value: "table"}, + {label: "Results view", value: "results"}, + ] + if (ENABLE_CARD_VIEW) { + base.splice(1, 0, {label: "Card view", value: "list"}) + } + return base +})() + +export const VIEW_AUTO_OPTIONS = [ + {label: "Overview", value: "overview"}, + {label: "Test cases", value: "test-cases"}, + {label: "Prompt configuration", value: "prompt"}, +] diff --git a/web/ee/src/components/EvalRunDetails/components/EvalRunScenariosViewSelector/index.tsx b/web/ee/src/components/EvalRunDetails/components/EvalRunScenariosViewSelector/index.tsx new file mode 100644 index 0000000000..d3f13ffab0 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/EvalRunScenariosViewSelector/index.tsx @@ -0,0 +1,51 @@ +import {memo, useTransition} from "react" + +import {Radio} from "antd" +import {useAtomValue, useSetAtom} from "jotai" + +import {evalAtomStore} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import {evalTypeAtom} from "../../state/evalType" +import {runViewTypeAtom, urlStateAtom} from "../../state/urlState" + +import {ENABLE_CARD_VIEW, VIEW_HUMAN_OPTIONS, VIEW_AUTO_OPTIONS} from "./assets/constants" + +const EvalRunScenariosViewSelector = () => { + const store = evalAtomStore() + const evalType = useAtomValue(evalTypeAtom) + // Read from the same global store that writes are going to + const viewType = useAtomValue(runViewTypeAtom, {store}) + const [_isPending, startTransition] = useTransition() + + const setUrlState = useSetAtom(urlStateAtom, {store}) + + // Sync local atom from urlStateAtom changes + return ( +
+ { + const v = e.target.value as "focus" | "list" | "table" + startTransition(() => { + setUrlState((draft) => { + draft.view = v + }) + }) + }} + defaultValue={"focus"} + value={ENABLE_CARD_VIEW ? viewType : viewType === "list" ? "focus" : viewType} + > + {(evalType === "human" ? VIEW_HUMAN_OPTIONS : VIEW_AUTO_OPTIONS).map((option) => ( + + {option.label} + + ))} + +
+ ) +} + +export default memo(EvalRunScenariosViewSelector) diff --git a/web/ee/src/components/EvalRunDetails/components/SaveDataModal/assets/SaveDataButton.tsx b/web/ee/src/components/EvalRunDetails/components/SaveDataModal/assets/SaveDataButton.tsx new file mode 100644 index 0000000000..d7b20242eb --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/SaveDataModal/assets/SaveDataButton.tsx @@ -0,0 +1,65 @@ +import {cloneElement, isValidElement, memo, MouseEvent, useState} from "react" + +import {ArrowSquareOut, Database} from "@phosphor-icons/react" +import dynamic from "next/dynamic" + +import EnhancedButton from "@/oss/components/Playground/assets/EnhancedButton" + +import {SaveDataButtonProps} from "./types" + +const SaveDataModal = dynamic(() => import(".."), {ssr: false}) + +const SaveDataButton = ({ + name, + rows, + exportDataset = false, + icon = true, + children, + label, + onClick, + ...props +}: SaveDataButtonProps) => { + const [isModalOpen, setIsModalOpen] = useState(false) + + return ( + <> + {isValidElement(children) ? ( + cloneElement( + children as React.ReactElement<{ + onClick: (e: MouseEvent) => void + }>, + { + onClick: (e) => { + onClick?.(e) + setIsModalOpen(true) + }, + }, + ) + ) : ( + : ) + } + onClick={async (e) => { + await onClick?.(e) + setIsModalOpen(true) + }} + label={label} + {...props} + /> + )} + + setIsModalOpen(false)} + /> + + ) +} + +export default memo(SaveDataButton) diff --git a/web/ee/src/components/EvalRunDetails/components/SaveDataModal/assets/SaveDataModalContent.tsx b/web/ee/src/components/EvalRunDetails/components/SaveDataModal/assets/SaveDataModalContent.tsx new file mode 100644 index 0000000000..2e1070caf3 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/SaveDataModal/assets/SaveDataModalContent.tsx @@ -0,0 +1,82 @@ +import {useMemo} from "react" + +import {Input, Select, Typography} from "antd" + +import EnhancedTable from "@/oss/components/EnhancedUIs/Table" +import useFocusInput from "@/oss/hooks/useFocusInput" + +import {SaveDataModalContentProps} from "./types" + +const SaveDataModalContent = ({ + rows, + rowKeys, + exportDataset, + name, + setName, + isOpen, + selectedColumns, + setSelectedColumns, +}: SaveDataModalContentProps) => { + const {inputRef} = useFocusInput({isOpen}) + + const columns = useMemo(() => { + if (selectedColumns.length === 0) { + return [{title: "-", dataIndex: "-"}] + } + return selectedColumns.map((key) => ({ + title: key, + dataIndex: key, + width: 150, + ellipsis: true, + })) + }, [selectedColumns]) + + const options = useMemo(() => { + return rowKeys.map((key) => ({label: key, value: key})) + }, [rowKeys]) + + return ( +
+
+ + {exportDataset ? "File name" : "Test set name"} + + setName(e.target.value)} + value={name} + /> +
+ +
+ Columns + + {availableRuns.map((run) => ( + +
+ {run.name} +
+ + {run.status} + + + {run.createdAt} + +
+
+
+ {run.id.slice(0, 8)}... +
+
+ ))} + +
+ + {selectedRuns.length > 0 && ( +
+
Comparison Preview:
+
+
+ Base + + {currentRun?.name || `Current Run`} + +
+ {selectedRuns.map((runId) => { + const run = availableRuns.find( + (r: AvailableRun) => r.id === runId, + ) + return ( +
+ Compare + {run?.name} +
+ ) + })} +
+
+ )} + +
+ + 💡 Tip: Use comparison mode to analyze performance differences between + runs + +
+
+ + + ) +} + +export default memo(ComparisonModeToggle) diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedAnnotationValueCell.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedAnnotationValueCell.tsx new file mode 100644 index 0000000000..1497de3cfe --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedAnnotationValueCell.tsx @@ -0,0 +1,106 @@ +import {memo, useMemo} from "react" + +import deepEqual from "fast-deep-equal" +import {useAtomValue} from "jotai" +import {atomFamily, selectAtom} from "jotai/utils" + +import LabelValuePill from "@/oss/components/ui/LabelValuePill" +import {loadableScenarioStepFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedScenarios" + +import {CellWrapper} from "../CellComponents" + +import {CollapsedAnnotationValueCellProps} from "./types" + +function buildCollapsedValues(data: any, keys: string[]) { + const annotations: any[] = [] + + if (Array.isArray(data?.annotationSteps) && data.annotationSteps.length) { + annotations.push(...data.annotationSteps.map((st: any) => st.annotation).filter(Boolean)) + } + if (data?.annotations?.length) { + annotations.push(...data.annotations) + } + if (data?.annotation) { + annotations.push(data.annotation) + } + + // Deduplicate by span_id+trace_id to avoid duplicates if same ann appears in multiple arrays + const unique = new Map() + annotations.forEach((ann) => { + if (!ann) return + const key = `${ann.trace_id || ""}_${ann.span_id || Math.random()}` + if (!unique.has(key)) unique.set(key, ann) + }) + + const out: Record = {} + keys.forEach((fieldPath) => { + for (const ann of unique.values()) { + let val = fieldPath + .split(".") + .reduce((acc: any, k: string) => (acc ? acc[k] : undefined), ann) + + if (val === undefined && fieldPath.startsWith("data.outputs.")) { + const suffix = fieldPath.slice("data.outputs.".length) + val = ann?.data?.outputs?.metrics?.[suffix] ?? ann?.data?.outputs?.extra?.[suffix] + } + if (val !== undefined) { + out[fieldPath] = val + break // stop at first found value + } + } + }) + return out +} + +export const collapsedAnnotationValuesFamily = atomFamily( + ({scenarioId, runId, keys}: {scenarioId: string; runId: string; keys: string[]}) => + selectAtom( + loadableScenarioStepFamily({scenarioId, runId}), + (loadableData) => + buildCollapsedValues( + loadableData.state === "hasData" ? loadableData.data : undefined, + keys, + ), + deepEqual, + ), +) + +const CollapsedAnnotationValueCell = memo( + ({scenarioId, runId, childrenDefs}) => { + const keyPaths = useMemo( + () => childrenDefs.map((c) => c.path || c.dataIndex || c.key) as string[], + [childrenDefs], + ) + const familyParam = useMemo( + () => ({scenarioId, runId, keys: keyPaths}), + [scenarioId, runId, keyPaths], + ) + + const out = useAtomValue(collapsedAnnotationValuesFamily(familyParam)) + + if (!Object.keys(out).length) { + return ( + + + + ) + } + + return ( + +
+ {Object.entries(out).map(([name, val]) => ( + + ))} +
+
+ ) + }, +) + +export default CollapsedAnnotationValueCell diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedMetricValueCell.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedMetricValueCell.tsx new file mode 100644 index 0000000000..e99deda613 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedMetricValueCell.tsx @@ -0,0 +1,308 @@ +import {memo, useMemo, type ReactNode} from "react" + +import {useAtomValue} from "jotai" + +import {formatColumnTitle} from "@/oss/components/Filters/EditColumns/assets/helper" +import {formatMetricValue} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils" +import LabelValuePill from "@/oss/components/ui/LabelValuePill" +import { + SchemaMetricType, + canonicalizeMetricKey, + extractPrimitive, + getMetricValueWithAliases, + summarizeMetric, +} from "@/oss/lib/metricUtils" + +import {scenarioMetricSelectorFamily} from "../../../../../../lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import {TableColumn} from "../../types" +import {CellWrapper} from "../CellComponents" + +export interface CollapsedMetricValueCellProps { + scenarioId: string + evaluatorSlug?: string + runId: string + childrenDefs?: TableColumn[] +} + +interface PillEntry { + label: string + value: string +} + +const includesBooleanType = (metricType?: SchemaMetricType): boolean => { + if (!metricType) return false + return Array.isArray(metricType) ? metricType.includes("boolean") : metricType === "boolean" +} + +const flattenColumns = (columns?: TableColumn[]): TableColumn[] => { + if (!columns?.length) return [] + const queue = [...columns] + const leaves: TableColumn[] = [] + + while (queue.length) { + const column = queue.shift() + if (!column) continue + if (column.children && column.children.length) { + queue.push(...column.children) + } else { + leaves.push(column) + } + } + + return leaves +} + +const toBooleanString = (value: unknown): string | undefined => { + if (typeof value === "boolean") return value ? "true" : "false" + if (typeof value === "number") { + if (value === 1) return "true" + if (value === 0) return "false" + } + if (typeof value === "string") { + const trimmed = value.trim().toLowerCase() + if (trimmed === "true" || trimmed === "false") return trimmed + if (trimmed === "1") return "true" + if (trimmed === "0") return "false" + } + return undefined +} + +const extractBooleanFromStats = (value: any): string | undefined => { + if (!value || typeof value !== "object") return undefined + const candidates: unknown[] = [] + if (Array.isArray(value.rank) && value.rank.length) { + candidates.push(value.rank[0]?.value) + } + if (Array.isArray(value.frequency) && value.frequency.length) { + candidates.push(value.frequency[0]?.value) + } + if ("mean" in value) candidates.push(value.mean) + if ("sum" in value) candidates.push(value.sum) + if ("value" in value) candidates.push((value as any).value) + + for (const candidate of candidates) { + const boolString = toBooleanString(candidate) + if (boolString) return boolString + } + return undefined +} + +const resolveBooleanDisplay = ({ + summarized, + rawValue, + metricType, +}: { + summarized: unknown + rawValue: unknown + metricType?: SchemaMetricType +}): string | undefined => { + const preferBoolean = includesBooleanType(metricType) + if (preferBoolean) { + const summaryBool = toBooleanString(summarized) + if (summaryBool) return summaryBool + const rawBool = + toBooleanString(rawValue) || + (typeof rawValue === "object" ? extractBooleanFromStats(rawValue) : undefined) + if (rawBool) return rawBool + } else { + const rawBool = + toBooleanString(rawValue) || + (typeof rawValue === "object" ? extractBooleanFromStats(rawValue) : undefined) + if (rawBool) return rawBool + const summaryBool = toBooleanString(summarized) + if (summaryBool) return summaryBool + } + return undefined +} + +const summariseMetricValue = (value: unknown, metricType?: SchemaMetricType) => { + if (value === null || value === undefined) return undefined + + if (typeof value === "object" && !Array.isArray(value)) { + const summary = summarizeMetric(value as any, metricType) + if (summary !== undefined) return summary + + const primitive = extractPrimitive(value) + if (primitive !== undefined) return primitive + } + + return value +} + +const buildCandidateKeys = (column: TableColumn, evaluatorSlug?: string): string[] => { + const keys = new Set() + const addKey = (key?: string) => { + if (!key) return + if (!keys.has(key)) keys.add(key) + const canonical = canonicalizeMetricKey(key) + if (canonical !== key && !keys.has(canonical)) { + keys.add(canonical) + } + } + + addKey(column.path) + addKey(column.fallbackPath) + if (typeof column.key === "string") addKey(column.key) + + if (column.path?.includes(".")) { + const tail = column.path.split(".").pop() + if (tail) addKey(tail) + } + + if (evaluatorSlug) { + const ensurePrefixed = (raw?: string) => { + if (!raw) return + if (raw.startsWith(`${evaluatorSlug}.`)) { + addKey(raw) + } else { + addKey(`${evaluatorSlug}.${raw}`) + } + } + + ensurePrefixed(column.path) + ensurePrefixed(column.fallbackPath) + if (typeof column.key === "string") ensurePrefixed(column.key) + } + + return Array.from(keys).filter(Boolean) +} + +const buildLabel = (column: TableColumn) => { + const raw = + (typeof column.title === "string" && column.title.trim()) || + (typeof column.name === "string" && column.name.trim()) || + column.path?.split(".").pop() || + (typeof column.key === "string" ? column.key : "") || + "" + + const base = raw || "Metric" + return /\s/.test(base) || base.includes("(") ? base : formatColumnTitle(base) +} + +const buildCollapsedPills = ({ + rowMetrics, + childrenDefs, + evaluatorSlug, +}: { + rowMetrics: Record + childrenDefs?: TableColumn[] + evaluatorSlug?: string +}): PillEntry[] => { + if (!rowMetrics || typeof rowMetrics !== "object") return [] + + const leaves = flattenColumns(childrenDefs) + if (!leaves.length) return [] + + const seenLabels = new Set() + const result: PillEntry[] = [] + + leaves.forEach((column) => { + const candidateKeys = buildCandidateKeys(column, evaluatorSlug) + let rawValue: unknown + let resolvedKey: string | undefined + + for (const key of candidateKeys) { + if (!key) continue + if (rowMetrics[key] !== undefined) { + rawValue = rowMetrics[key] + resolvedKey = key + break + } + const alias = getMetricValueWithAliases(rowMetrics, key) + if (alias !== undefined) { + rawValue = alias + resolvedKey = key + break + } + } + + if (rawValue === undefined) return + + const summarized = summariseMetricValue(rawValue, column.metricType) + if (summarized === undefined || summarized === null) return + + const canonicalKey = canonicalizeMetricKey(resolvedKey ?? column.path ?? column.key ?? "") + const label = buildLabel(column) + if (!label.trim() || seenLabels.has(label)) return + const booleanDisplay = resolveBooleanDisplay({ + summarized, + rawValue, + metricType: column.metricType, + }) + + const value = + booleanDisplay ?? + (typeof summarized === "number" + ? formatMetricValue(canonicalKey, summarized) + : String(summarized)) + + seenLabels.add(label) + result.push({label, value}) + }) + + return result +} + +interface BaseCellProps extends CollapsedMetricValueCellProps { + emptyState: ReactNode +} + +const BaseCollapsedMetricValueCell = ({ + scenarioId, + evaluatorSlug, + runId, + childrenDefs, + emptyState, +}: BaseCellProps) => { + const rowMetrics = useAtomValue(scenarioMetricSelectorFamily({runId, scenarioId})) || {} + + const pillEntries = useMemo( + () => + buildCollapsedPills({ + rowMetrics, + childrenDefs, + evaluatorSlug, + }), + [rowMetrics, childrenDefs, evaluatorSlug], + ) + + if (!pillEntries.length) { + return ( + + {typeof emptyState === "string" ? ( + {emptyState} + ) : ( + emptyState + )} + + ) + } + + return ( + +
+ {pillEntries.map(({label, value}) => ( + + ))} +
+
+ ) +} + +const CollapsedMetricValueCell = memo((props) => ( + +)) + +export const AutoEvalCollapsedMetricValueCell = memo((props) => ( + } + /> +)) + +export default CollapsedMetricValueCell diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedMetricsCell.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedMetricsCell.tsx new file mode 100644 index 0000000000..ceb5dd2df9 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/CollapsedMetricsCell.tsx @@ -0,0 +1,34 @@ +import {memo} from "react" + +import {useAtomValue} from "jotai" + +import {scenarioMetricsMapFamily} from "../../../../../../lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import {CellWrapper} from "../CellComponents" + +export interface CollapsedMetricsCellProps { + scenarioId: string + evaluatorSlug?: string // undefined → include all evaluators +} + +const CollapsedMetricsCell = memo(({scenarioId, evaluatorSlug}) => { + const rowMetrics = useAtomValue(scenarioMetricsMapFamily(scenarioId)) || {} + + const filtered: Record = {} + Object.entries(rowMetrics).forEach(([k, v]) => { + if (!evaluatorSlug) { + filtered[k] = v + } else if (k.startsWith(`${evaluatorSlug}.`)) { + filtered[k.slice(evaluatorSlug.length + 1)] = v + } + }) + + return ( + +
+                {Object.keys(filtered).length ? JSON.stringify(filtered) : ""}
+            
+
+ ) +}) + +export default CollapsedMetricsCell diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/MetricCell.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/MetricCell.tsx new file mode 100644 index 0000000000..4328802974 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/MetricCell.tsx @@ -0,0 +1,322 @@ +import {type ReactNode, memo, useMemo} from "react" + +import {Tag, Tooltip} from "antd" +import clsx from "clsx" +import {useAtomValue} from "jotai" + +import {urlStateAtom} from "@/oss/components/EvalRunDetails/state/urlState" +import MetricDetailsPopover from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover" // adjust path if necessary +import {formatMetricValue} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils" // same util used elsewhere +import {Expandable} from "@/oss/components/Tables/ExpandableCell" +import {useRunId} from "@/oss/contexts/RunIdContext" +import {getStatusLabel} from "@/oss/lib/constants/statusLabels" +import { + evalAtomStore, + loadableScenarioStepFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {runScopedMetricDataFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import {EvaluationStatus} from "@/oss/lib/Types" + +import {STATUS_COLOR_TEXT} from "../../../EvalRunScenarioStatusTag/assets" +import {CellWrapper} from "../CellComponents" // CellWrapper is default export? need to check. + +import {AnnotationValueCellProps, MetricCellProps, MetricValueCellProps} from "./types" + +/* + * MetricCell – common renderer for metric columns (scenario-level or evaluator-level). + * Props: + * - metricKey: base metric name (without evaluator slug) + * - fullKey: full metric path as used in maps (e.g. "evaluator.slug.score") + * - value: value for current scenario row + * - distInfo: pre-computed distribution / stats for popover (optional) + * - metricType: primitive type from evaluator schema ("number", "boolean", "array", etc.) + */ + +const MetricCell = memo( + ({ + hidePrimitiveTable = true, + scenarioId, + metricKey, + fullKey, + value, + distInfo, + metricType, + isComparisonMode, + }) => { + if (value === undefined || value === null) { + if (isComparisonMode) { + return ( + +
+ + ) + } + return null + } + + if (typeof value === "object" && Object.keys(value || {}).length === 0) { + if (isComparisonMode) { + return ( + +
+ + ) + } + return null + } + + const frequency = value?.frequency || value?.freq + + if (frequency && frequency?.length > 0) { + const mostFrequent = frequency.reduce((max, current) => + current.count > max.count ? current : max, + ).value + value = mostFrequent + } + + // Non-numeric arrays rendered as Tag list + let formatted: ReactNode = formatMetricValue(metricKey, value) + + if (metricType === "boolean" && Array.isArray(value as any)) { + const trueEntry = (distInfo as any).frequency.find((f: any) => f.value === true) + const total = (distInfo as any).count ?? 0 + if (total) { + return ( +
+
+
+
+ true + false +
+
+
+
+
+
+
+
+
+
+ {(((trueEntry?.count ?? 0) / total) * 100).toFixed(2)} +
+
+ ) + } + } + + if (metricType === "array" || Array.isArray(value)) { + const values = Array.isArray(value) ? value : [value] + // const Component = metricType === "string" ? "span" : Tag + formatted = + metricType === "string" ? ( +
+ {values.map((it: any) => ( +
  • + {String(it)} +
  • + ))} +
    + ) : ( +
    + {values.map((it: any) => ( + + {String(it)} + + ))} +
    + ) + } else if (typeof value === "object") { + // Extract primitive when wrapped in an object (e.g. { score, value, ... }) + if ("score" in value) value = (value as any).score + else { + const prim = Object.values(value || {}).find( + (v) => typeof v === "number" || typeof v === "string", + ) + value = prim !== undefined ? prim : JSON.stringify(value) + } + } + + // Boolean metrics – show raw value + if (metricType === "boolean") { + formatted = String(value) + } + + // Wrap in popover when distInfo present + if (distInfo && metricType !== "string") { + return ( + + + + {formatted} + + + + ) + } + + return ( + + + {formatted} + + + ) + }, +) + +// --- Wrapper cell that fetches the value from atoms ---------------------- + +const failureRunTypes = [EvaluationStatus.FAILED, EvaluationStatus.FAILURE, EvaluationStatus.ERROR] + +export const MetricValueCell = memo( + ({scenarioId, metricKey, fallbackKey, fullKey, metricType, evalType, runId}) => { + const param = useMemo( + () => ({runId, scenarioId, metricKey}), + [runId, scenarioId, metricKey], + ) + + const fallbackParam = useMemo( + () => + fallbackKey && fallbackKey !== metricKey + ? ({runId, scenarioId, metricKey: fallbackKey} as const) + : param, + [fallbackKey, metricKey, param, runId, scenarioId], + ) + + const store = evalAtomStore() + + const urlState = useAtomValue(urlStateAtom) + const isComparisonMode = Boolean(urlState.compare && urlState.compare.length > 0) + + let value, distInfo + const result = useAtomValue(runScopedMetricDataFamily(param as any), {store}) + const fallbackResult = useAtomValue(runScopedMetricDataFamily(fallbackParam as any), { + store, + }) + + value = result.value + distInfo = result.distInfo + + if ((value === undefined || value === null) && fallbackResult) { + value = fallbackResult.value + distInfo = distInfo ?? fallbackResult.distInfo + } + const loadable = useAtomValue(loadableScenarioStepFamily({scenarioId, runId})) + + // TODO: remove this from here and create a function or something to also use in somewhere else + // Last minute implementation for eval-checkpoint + const errorStep = useMemo(() => { + if (evalType !== "auto") return null + if (loadable.state === "loading") return null + const [evalSlug, key] = metricKey.split(".") + if (!key) return null // if does not have key that means it's not an evaluator metric + const _step = loadable.data?.steps?.find((s) => s.stepKey === evalSlug) + + if (!_step) { + const invocationStep = loadable.data?.invocationSteps?.find( + (s) => s.scenarioId === scenarioId, + ) + + if (failureRunTypes.includes(invocationStep?.status)) { + return { + status: invocationStep?.status, + error: invocationStep?.error?.stacktrace || invocationStep?.error?.message, + } + } + return null + } + + if (failureRunTypes.includes(_step?.status)) { + return { + status: _step?.status, + error: _step?.error?.stacktrace || _step?.error?.message, + } + } + + return null + }, [loadable]) + + // TODO: create a separate component for error + if (errorStep?.status || errorStep?.error) { + return ( + + + {getStatusLabel(errorStep?.status)} + + + ) + } + + return ( + + ) + }, +) + +// --- Annotation value cell ----------------------------------------------- + +export const AnnotationValueCell = memo( + ({ + scenarioId, + stepKey, + name, + fieldPath, + metricKey, + metricType, + fullKey, + distInfo: propsDistInfo, + }) => { + const stepSlug = stepKey?.includes(".") ? stepKey.split(".")[1] : undefined + const param = useMemo( + () => ({scenarioId, stepSlug, metricKey: metricKey || ""}), + [scenarioId, stepSlug, metricKey], + ) + const {value: metricVal, distInfo} = useAtomValue(metricDataFamily(param)) + + return ( + + ) + }, +) + +export default MetricCell diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/types.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/types.ts new file mode 100644 index 0000000000..677dc9e6a1 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/MetricCell/types.ts @@ -0,0 +1,41 @@ +import {BasicStats, SchemaMetricType} from "@/oss/lib/metricUtils" + +import {TableColumn} from "../types" + +export interface MetricCellProps { + scenarioId: string + metricKey: string + fullKey?: string + value: any + distInfo?: Record | Promise> + metricType?: SchemaMetricType + isComparisonMode?: boolean +} + +export interface MetricValueCellProps { + scenarioId: string + metricKey: string + fallbackKey?: string + fullKey?: string + distInfo?: Record | Promise> + metricType?: SchemaMetricType + evalType?: "auto" | "human" + runId?: string +} + +export interface AnnotationValueCellProps { + scenarioId: string + fieldPath: string // e.g. "data.outputs.isGood" + metricKey: string + fullKey?: string + distInfo?: Record | Promise> + metricType?: SchemaMetricType + stepKey?: string + name?: string +} + +export interface CollapsedAnnotationValueCellProps { + scenarioId: string + childrenDefs: TableColumn[] + runId?: string +} diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/StatusCell.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/StatusCell.tsx new file mode 100644 index 0000000000..c93ed1d64f --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/StatusCell.tsx @@ -0,0 +1,30 @@ +import {memo} from "react" + +import {useRunId} from "@/oss/contexts/RunIdContext" + +import EvalRunScenarioStatusTag from "../../EvalRunScenarioStatusTag" + +import {CellWrapper} from "./CellComponents" + +interface Props { + scenarioId: string + result?: string + runId?: string +} + +/** + * Lightweight status cell for Scenario rows. + * Displays coloured status tag and optional result snippet. + */ +const StatusCell = ({scenarioId, runId: propRunId}: Props) => { + const contextRunId = useRunId() + const effectiveRunId = propRunId || contextRunId + + return ( + + + + ) +} + +export default memo(StatusCell) diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/VirtualizedScenarioTableAnnotateDrawer.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/VirtualizedScenarioTableAnnotateDrawer.tsx new file mode 100644 index 0000000000..cb7c668dd4 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/VirtualizedScenarioTableAnnotateDrawer.tsx @@ -0,0 +1,81 @@ +import {memo, useCallback} from "react" + +import {DrawerProps} from "antd" +import clsx from "clsx" +import {useAtomValue} from "jotai" + +import EnhancedDrawer from "@/oss/components/EnhancedUIs/Drawer" +import {virtualScenarioTableAnnotateDrawerAtom} from "@/oss/lib/atoms/virtualTable" +import {evalAtomStore} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" + +import ScenarioAnnotationPanel from "../../../HumanEvalRun/components/ScenarioAnnotationPanel" + +interface VirtualizedScenarioTableAnnotateDrawerProps extends DrawerProps { + runId?: string +} +const VirtualizedScenarioTableAnnotateDrawer = ({ + runId: propRunId, + ...props +}: VirtualizedScenarioTableAnnotateDrawerProps) => { + const store = evalAtomStore() + + // Annotate drawer state (global, per-run) + const annotateDrawer = useAtomValue(virtualScenarioTableAnnotateDrawerAtom, {store}) + const setAnnotateDrawer = store.set + + const scenarioId = annotateDrawer.scenarioId + // Use runId from atom state if available, fallback to prop + const runId = annotateDrawer.runId || propRunId + + const closeDrawer = useCallback(() => { + setAnnotateDrawer( + virtualScenarioTableAnnotateDrawerAtom, + // @ts-ignore + (prev) => { + return { + ...prev, + open: false, + } + }, + ) + }, []) + + return ( + +
    +
    + {scenarioId && runId && ( + + )} +
    +
    +
    + ) +} + +export default memo(VirtualizedScenarioTableAnnotateDrawer) diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/constants.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/constants.ts new file mode 100644 index 0000000000..d72dfefd9a --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/constants.ts @@ -0,0 +1,92 @@ +// Centralized column widths for easy reuse +export const COLUMN_WIDTHS = { + input: 400, + groundTruth: 460, + response: 400, + action: 140, + metric: 100, + padding: 100, +} as const + +// Table layout constants +export const TABLE_LAYOUT = { + rowHeight: 54, // approximate height of one table row (px) +} as const + +export const SKELETON_ROW_COUNT = 5 + +export const GeneralHumanEvalMetricColumns = [ + { + name: "totalCost", + kind: "metric", + path: "totalCost", + stepKey: "metric", + metricType: "number", + }, + { + name: "Total Duration", + kind: "metric", + path: "duration.total", + stepKey: "metric", + metricType: "number", + }, + { + name: "totalTokens", + kind: "metric", + path: "totalTokens", + stepKey: "metric", + metricType: "number", + }, + { + name: "promptTokens", + kind: "metric", + path: "promptTokens", + stepKey: "metric", + metricType: "number", + }, + { + name: "completionTokens", + kind: "metric", + path: "completionTokens", + stepKey: "metric", + metricType: "number", + }, + { + name: "errors", + kind: "metric", + path: "errors", + stepKey: "metric", + metricType: "number", + }, +] + +export const GeneralAutoEvalMetricColumns = [ + { + name: "Cost (Total)", + kind: "metric", + path: "attributes.ag.metrics.costs.cumulative.total", + stepKey: "metric", + metricType: "number", + }, + { + name: "Duration (Total)", + kind: "metric", + path: "attributes.ag.metrics.duration.cumulative", + stepKey: "metric", + metricType: "number", + }, + { + name: "Total tokens", + kind: "metric", + path: "attributes.ag.metrics.tokens.cumulative.total", + stepKey: "metric", + metricType: "number", + }, + { + name: "errors", + kind: "metric", + path: "attributes.ag.metrics.errors.cumulative", + stepKey: "metric", + metricType: "number", + }, +] diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/dataSourceBuilder.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/dataSourceBuilder.ts new file mode 100644 index 0000000000..ecb9ea6949 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/dataSourceBuilder.ts @@ -0,0 +1,394 @@ +import groupBy from "lodash/groupBy" + +import {formatColumnTitle} from "@/oss/components/Filters/EditColumns/assets/helper" +import {evalTypeAtom} from "@/oss/components/EvalRunDetails/state/evalType" +import { + evalAtomStore, + evaluationEvaluatorsFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import type { + ColumnDef, + RunIndex, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {BasicStats, canonicalizeMetricKey} from "@/oss/lib/metricUtils" +import {buildSkeletonRows} from "@/oss/lib/tableUtils" + +import {TableRow} from "../types" + +import {GeneralAutoEvalMetricColumns, GeneralHumanEvalMetricColumns} from "./constants" + +const AUTO_INVOCATION_METRIC_SUFFIXES = GeneralAutoEvalMetricColumns.map((col) => col.path) +const AUTO_INVOCATION_METRIC_CANONICAL_SET = new Set( + AUTO_INVOCATION_METRIC_SUFFIXES.map((path) => canonicalizeMetricKey(path)), +) + +const matchesGeneralInvocationMetric = (path?: string): boolean => { + if (!path) return false + if (AUTO_INVOCATION_METRIC_SUFFIXES.some((suffix) => path.endsWith(suffix))) { + return true + } + const segments = path.split(".") + for (let i = 0; i < segments.length; i += 1) { + const candidate = segments.slice(i).join(".") + if (AUTO_INVOCATION_METRIC_CANONICAL_SET.has(canonicalizeMetricKey(candidate))) { + return true + } + } + return AUTO_INVOCATION_METRIC_CANONICAL_SET.has(canonicalizeMetricKey(path)) +} + +/** + * Build the data source (rows) for the virtualised scenario table. + * This logic was previously inline inside the table component; moving it here means + * the component can stay tidy while we have a single canonical place that knows: + * • which scenarios belong to the run + * • what their execution / annotation status is + * • how to present skeleton rows while data is still loading + */ + +export function buildScenarioTableRows({ + scenarioIds, + allScenariosLoaded, + skeletonCount = 20, + runId, +}: { + scenarioIds: string[] + allScenariosLoaded: boolean + skeletonCount?: number + runId: string +}): TableRow[] { + if (!allScenariosLoaded) { + // Render placeholder skeleton rows (fixed count) so the table height is stable + return buildSkeletonRows(skeletonCount).map((r, idx) => ({ + ...r, + scenarioIndex: idx + 1, + })) + } + + return scenarioIds.map((id, idx) => { + return { + key: id, + scenarioIndex: idx + 1, + runId, + } + }) +} + +/** + * Build raw ColumnDef list for scenario table. + */ +export function buildScenarioTableData({ + runIndex, + metricsFromEvaluators, + metrics, + runId, + evaluators, +}: { + runIndex: RunIndex | null | undefined + metricsFromEvaluators: Record> + metrics: Record + runId: string + evaluators: EvaluatorDto[] +}): (ColumnDef & {values?: Record})[] { + const baseColumnDefs: ColumnDef[] = runIndex ? Object.values(runIndex.columnsByStep).flat() : [] + const evalType = evalAtomStore().get(evalTypeAtom) + + // Augment columns with per-scenario values (currently only for input columns) + let columnsInput = baseColumnDefs + .filter((col) => col.kind !== "annotation") + .filter((col) => col.name !== "testcase_dedup_id") + + // Further group metrics by evaluator when evaluators info present + const evaluatorMetricGroups: any[] = [] + + // Evaluator Metric Columns + if (metricsFromEvaluators && evalType === "human") { + const annotationData = baseColumnDefs.filter((def) => def.kind === "annotation") + const groupedAnnotationData = groupBy(annotationData, (data) => { + return data.name.split(".")[0] + }) + + for (const [k, v] of Object.entries(groupedAnnotationData)) { + const evaluator = evaluators?.find((e) => e.slug === k) + evaluatorMetricGroups.push({ + title: evaluator?.name || k, + key: `metrics_${k}`, + children: v.map((data) => { + const [evaluatorSlug, metricName] = data.name.split(".") + const formattedMetricName = formatColumnTitle( + metricName || data.name.replace(`${evaluatorSlug}.`, ""), + ) + return { + ...data, + name: metricName || data.name, + title: formattedMetricName, + kind: "metric", + path: data.name, + stepKey: "metric", + metricType: metricsFromEvaluators[evaluatorSlug]?.find( + (x) => metricName in x, + )?.[metricName]?.metricType, + } + }), + }) + } + } + + if (metricsFromEvaluators && evalType === "auto") { + const annotationData = baseColumnDefs.filter((def) => def.kind === "annotation") + const groupedAnnotationData = groupBy(annotationData, (data) => { + return data.name.split(".")[0] + }) + + for (const metricKey of Object.keys(metricsFromEvaluators)) { + const evaluator = evaluators?.find((e) => e.slug === metricKey) + + // Build children from base run annotations when available, otherwise from metrics map + let children = Object.entries(groupedAnnotationData) + .flatMap(([k, v]) => { + return v.map((data) => { + // Prefer strict match on slug in data.path when present, else stepKey + const pathPrefix = `${metricKey}.` + const belongsToEvaluator = + (data.path && data.path.startsWith(pathPrefix)) || + data.stepKey === metricKey + if (belongsToEvaluator) { + const metric = metrics?.[`${metricKey}.${data.name}`] + const isMean = metric?.mean !== undefined + const legacyPath = `${metricKey}.${data.name}` + const fullPath = data.path ? `${metricKey}.${data.path}` : legacyPath + + if ( + matchesGeneralInvocationMetric(fullPath) || + matchesGeneralInvocationMetric(legacyPath) + ) { + return undefined + } + + const formattedName = formatColumnTitle(data.name) + return { + ...data, + name: data.name, + key: `${metricKey}.${data.name}`, + title: `${formattedName} ${isMean ? "(mean)" : ""}`.trim(), + kind: "metric", + path: fullPath, + fallbackPath: legacyPath, + stepKey: "metric", + metricType: metricsFromEvaluators[metricKey]?.find( + (x) => data.name in x, + )?.[data.name]?.metricType, + } + } + return undefined + }) + }) + .filter(Boolean) as any[] + + // If no base annotations matched (evaluator only exists in comparison runs), + // fall back to constructing children from metricsFromEvaluators + if (!children.length) { + const metricDefs = metricsFromEvaluators[metricKey] || [] + const seen = new Set() + children = metricDefs + .map((def: any) => { + const metricName = Object.keys(def || {})[0] + if (!metricName || seen.has(metricName)) return undefined + seen.add(metricName) + const fullPath = `${metricKey}.${metricName}` + if ( + matchesGeneralInvocationMetric(fullPath) || + matchesGeneralInvocationMetric(metricName) + ) { + return undefined + } + const formattedName = formatColumnTitle(metricName) + return { + name: metricName, + key: `${metricKey}.${metricName}`, + title: formattedName, + kind: "metric" as const, + path: `${metricKey}.${metricName}`, + fallbackPath: `${metricKey}.${metricName}`, + stepKey: "metric", + metricType: def?.[metricName]?.metricType, + } + }) + .filter(Boolean) as any[] + } + + evaluatorMetricGroups.push({ + title: evaluator?.name || metricKey, + key: `metrics_${metricKey}_evaluators`, + children, + }) + } + } + + const genericMetricsGroup = { + title: "Metrics", + key: "__metrics_group__", + children: + evalType === "auto" ? GeneralAutoEvalMetricColumns : GeneralHumanEvalMetricColumns, + } + + let metaStart: ColumnDef[] = [ + {name: "#", kind: "meta" as any, path: "scenarioIndex", stepKey: "meta"}, + ] + + const metaEnd: ColumnDef[] = [ + {name: "Action", kind: "meta" as any, path: "action", stepKey: "meta"}, + ] + + const columnsCore = [...columnsInput, ...evaluatorMetricGroups] + if (genericMetricsGroup) columnsCore.push(genericMetricsGroup as any) + const columns = [...metaStart, ...columnsCore, ...metaEnd] + + return columns +} + +/** + * Build columns for comparison mode showing multiple runs side-by-side + */ +export function buildComparisonTableColumns({ + baseRunId, + comparisonRunIds, + baseRunIndex, + comparisonRunIndexes, + metricsFromEvaluators, +}: { + baseRunId: string + comparisonRunIds: string[] + baseRunIndex: RunIndex | null | undefined + comparisonRunIndexes: Record + metricsFromEvaluators: Record> +}): (ColumnDef & {values?: Record})[] { + if (!baseRunIndex) return [] + + const allRunIds = [baseRunId, ...comparisonRunIds] + const evalType = evalAtomStore().get(evalTypeAtom) + + // Start with meta columns + const metaColumns: ColumnDef[] = [ + {name: "#", kind: "meta" as any, path: "scenarioIndex", stepKey: "meta"}, + ] + + // Get base column definitions (inputs, outputs, etc.) + const baseColumnDefs: ColumnDef[] = Object.values(baseRunIndex.columnsByStep).flat() + const inputOutputColumns = baseColumnDefs + .filter((col) => col.kind !== "annotation" && col.kind !== "metric") + .filter((col) => col.name !== "testcase_dedup_id") + + // For comparison mode, we want to show inputs once, then outputs/metrics for each run + const inputColumns = inputOutputColumns.filter((col) => col.stepKey === "input") + + // Create run-specific output columns + const runSpecificColumns: any[] = [] + + allRunIds.forEach((runId, index) => { + const isBase = index === 0 + const runLabel = isBase ? "Base" : `Run ${index}` + const runShort = runId.slice(0, 8) + + // Output columns for this run + const outputColumns = inputOutputColumns + .filter((col) => col.stepKey === "output") + .map((col) => ({ + ...col, + name: `${col.name} (${runLabel})`, + title: `${col.name} (${runShort})`, + runId, + isComparison: !isBase, + })) + + // Metric columns for this run + if (metricsFromEvaluators && evalType !== "auto") { + const annotationData = baseColumnDefs.filter((def) => def.kind === "annotation") + const groupedAnnotationData = groupBy(annotationData, (data) => { + return data.name.split(".")[0] + }) + + for (const [evaluatorSlug, annotations] of Object.entries(groupedAnnotationData)) { + const metricGroup = { + title: `${evaluatorSlug} (${runLabel})`, + key: `metrics_${evaluatorSlug}_${runId}`, + runId, + isComparison: !isBase, + children: annotations.map((data) => { + const [, metricName] = data.name.split(".") + return { + ...data, + name: metricName, + title: `${metricName} (${runShort})`, + kind: "metric", + path: data.name, + stepKey: "metric", + runId, + isComparison: !isBase, + metricType: metricsFromEvaluators[evaluatorSlug]?.find( + (x) => metricName in x, + )?.[metricName]?.metricType, + } + }), + } + runSpecificColumns.push(metricGroup) + } + } + + runSpecificColumns.push(...outputColumns) + }) + + const actionColumns: ColumnDef[] = [ + {name: "Action", kind: "meta" as any, path: "action", stepKey: "meta"}, + ] + + return [...metaColumns, ...inputColumns, ...runSpecificColumns, ...actionColumns] +} + +/** + * Build rows for comparison mode with data from multiple runs + */ +export function buildComparisonTableRows({ + scenarioIds, + baseRunId, + comparisonRunIds, + allScenariosLoaded, + skeletonCount = 20, +}: { + scenarioIds: string[] + baseRunId: string + comparisonRunIds: string[] + allScenariosLoaded: boolean + skeletonCount?: number +}): TableRow[] { + if (!allScenariosLoaded) { + return buildSkeletonRows(skeletonCount).map((r, idx) => ({ + ...r, + scenarioIndex: idx + 1, + })) + } + + return scenarioIds.map((scenarioId, idx) => { + const row: TableRow = { + key: scenarioId, + scenarioIndex: idx + 1, + scenarioId, + baseRunId, + comparisonRunIds, + } + + // Add run-specific data placeholders + // The actual data will be populated by the table cells using atoms + const allRunIds = [baseRunId, ...comparisonRunIds] + allRunIds.forEach((runId) => { + row[`${runId}_data`] = { + runId, + scenarioId, + // Cell components will use atoms to get actual data + } + }) + + return row + }) +} diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/flatDataSourceBuilder.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/flatDataSourceBuilder.ts new file mode 100644 index 0000000000..a9bd65fcb0 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/flatDataSourceBuilder.ts @@ -0,0 +1,8 @@ +// ---------------- Helpers ------------------ +export const titleCase = (str: string) => + String(str || "") + .replace(/([a-z0-9])([A-Z])/g, "$1 $2") + .replace(/_/g, " ") + .replace(/\s+/g, " ") + .trim() + .replace(/^[a-z]|\s[a-z]/g, (m) => m.toUpperCase()) diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/types.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/types.ts new file mode 100644 index 0000000000..916a7e4cf8 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/types.ts @@ -0,0 +1,18 @@ +import {SchemaMetricType} from "@/oss/lib/metricUtils" + +export interface BaseColumn { + name: string + title: string + key: string + kind: string + path: string + fallbackPath?: string + stepKey: string + stepKeyByRunId?: Record + metricType: SchemaMetricType + children?: TableColumn[] +} + +export interface TableColumn extends BaseColumn { + children?: TableColumn[] +} diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/utils.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/utils.tsx new file mode 100644 index 0000000000..de9d89d523 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/assets/utils.tsx @@ -0,0 +1,453 @@ +import {EnhancedColumnType} from "@/oss/components/EnhancedUIs/Table/types" +import {evalTypeAtom} from "@/oss/components/EvalRunDetails/state/evalType" +import {Expandable} from "@/oss/components/Tables/ExpandableCell" +import {evalAtomStore} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {getMetricConfig} from "@/oss/lib/metrics/utils" +import {buildMetricSorter} from "@/oss/lib/metricSorter" +import {extractPrimitive, isSortableMetricType} from "@/oss/lib/metricUtils" + +import { + runMetricsStatsCacheFamily, + runScopedMetricDataFamily, + scenarioMetricValueFamily, +} from "../../../../../lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import type {TableRow} from "../types" + +import ActionCell from "./ActionCell" +import {CellWrapper, InputCell, InvocationResultCell, SkeletonCell} from "./CellComponents" +import {COLUMN_WIDTHS} from "./constants" +import {titleCase} from "./flatDataSourceBuilder" +import CollapsedAnnotationValueCell from "./MetricCell/CollapsedAnnotationValueCell" +import CollapsedMetricValueCell, { + AutoEvalCollapsedMetricValueCell, +} from "./MetricCell/CollapsedMetricValueCell" +import {AnnotationValueCell, MetricValueCell} from "./MetricCell/MetricCell" +import {BaseColumn, TableColumn} from "./types" + +// Helper to compare metric/annotation primitives across scenarios +function scenarioMetricPrimitive(recordKey: string, column: any, runId: string) { + const st = evalAtomStore() + let raw: any = column.values?.[recordKey] + if (raw === undefined) { + const metricKey = column.path || column.key || column.name || "" + const fallbackKey = column.fallbackPath + if (column.kind === "metric") { + const stepSlug = + column.stepKey && column.stepKey.includes(".") + ? column.stepKey.split(".")[1] + : undefined + raw = st.get( + scenarioMetricValueFamily({ + runId, + scenarioId: recordKey, + metricKey, + stepSlug, + }) as any, + ) + if ((raw === undefined || raw === null) && fallbackKey && fallbackKey !== metricKey) { + raw = st.get( + scenarioMetricValueFamily({ + runId, + scenarioId: recordKey, + metricKey: fallbackKey, + stepSlug, + }) as any, + ) + } + } else { + const stepSlug = + column.stepKey && column.stepKey.includes(".") + ? column.stepKey.split(".")[1] + : undefined + raw = st.get( + runScopedMetricDataFamily({ + scenarioId: recordKey, + stepSlug, + metricKey, + runId, + }) as any, + )?.value + if ((raw === undefined || raw === null) && fallbackKey && fallbackKey !== metricKey) { + raw = st.get( + runScopedMetricDataFamily({ + scenarioId: recordKey, + stepSlug, + metricKey: fallbackKey, + runId, + }) as any, + )?.value + } + } + } + return extractPrimitive(raw) +} + +function scenarioMetricSorter(column: any, runId: string) { + return buildMetricSorter((row) => + scenarioMetricPrimitive(row.key as string, column, runId), + ) +} + +/** + * Transforms a list of scenario metrics into a map of scenarioId -> metrics, merging + * nested metrics under `outputs` into the same level. + * + * @param {{scenarioMetrics: any[]}} props - The props object containing the metrics. + * @returns {Record>} - A map of scenarioId -> metrics. + */ +export const getScenarioMetricsMap = ({scenarioMetrics}: {scenarioMetrics: any[]}) => { + const map: Record> = {} + const _metrics = scenarioMetrics || [] + + _metrics.forEach((m: any) => { + const sid = m.scenarioId + if (!sid) return + + // Clone the data object to avoid accidental mutations + const data: Record = + m && typeof m === "object" && m.data && typeof m.data === "object" ? {...m.data} : {} + + // If metrics are nested under `outputs`, merge them into the same level + if (data.outputs && typeof data.outputs === "object") { + Object.assign(data, data.outputs) + delete data.outputs + } + + if (!map[sid]) map[sid] = {} + Object.assign(map[sid], data) + }) + + return map +} + +// ---------------- Column adapter ------------------ +const generateColumnTitle = (col: BaseColumn) => { + if (col.kind === "metric") { + if (typeof col.title === "string" && col.title.trim().length > 0) { + return col.title + } + if (typeof col.path === "string") { + return getMetricConfig(col.path).label + } + } + if (col.kind === "invocation") return titleCase(col.name) + if (col.kind === "annotation") return titleCase(col.name) + return titleCase(col.title ?? col.name) +} +const generateColumnWidth = (col: BaseColumn) => { + if (col.kind === "meta") return 80 + if (col.kind === "input") return COLUMN_WIDTHS.input + if (col.kind === "metric") return COLUMN_WIDTHS.metric + if (col.kind === "annotation") return COLUMN_WIDTHS.metric + if (col.kind === "invocation") return COLUMN_WIDTHS.response + return 20 +} +const orderRank = (def: EnhancedColumnType): number => { + if (def.key === "#") return 0 + if (def.key === "inputs_group") return 1 + if (def.key === "outputs") return 2 + if (def.key === "Status") return 3 + if (def.key === "annotation" || def.key?.includes("metrics")) return 4 + if (def.key?.includes("evaluators")) return 5 + if (def.key === "__metrics_group__") return 6 + return 7 +} + +export function buildAntdColumns( + cols: TableColumn[], + runId: string, + expendedRows: Record, +): EnhancedColumnType[] { + const resolveStepKeyForRun = (column: TableColumn, targetRunId: string) => { + return column.stepKeyByRunId?.[targetRunId] ?? column.stepKey + } + const distMap = runId ? evalAtomStore().get(runMetricsStatsCacheFamily(runId)) : {} + const evalType = evalAtomStore().get(evalTypeAtom) + + // Count how many input columns we have + const inputColumns = cols.filter((col) => col.kind === "input") + + return cols + .map((c: TableColumn): EnhancedColumnType | null => { + const editLabel = generateColumnTitle(c) + const common = { + metricType: c.metricType ?? c.kind, + title: editLabel, + key: c.key ?? c.name, + minWidth: generateColumnWidth(c), + width: generateColumnWidth(c), + __editLabel: editLabel, + } + const sortable = + (c.kind === "metric" || c.kind === "annotation") && + isSortableMetricType(c.metricType) + + const sorter = sortable ? scenarioMetricSorter(c, runId) : undefined + + if (c.children) { + // drop empty wrapper groups + if ((!c.title && !c.name) || c.kind === "metrics_group") { + return { + ...common, + __editLabel: editLabel, + children: buildAntdColumns(c.children, runId, expendedRows), + } as EnhancedColumnType + } + if (c.key === "__metrics_group__" || c.key?.startsWith("metrics_")) { + return { + title: ( + + {c.key === "__metrics_group__" ? "Metrics" : (c.title ?? "")} + + ), + dataIndex: c.key, + collapsible: true, + key: c.key, + __editLabel: editLabel, + renderAggregatedData: ({record}) => { + const hasAnnotation = + Array.isArray(c.children) && + c.children.some((ch: any) => ch.kind === "annotation") + const evaluatorSlug = + c.key === "__metrics_group__" + ? undefined + : c.name || + c.key.replace(/^metrics_/, "").replace(/_evaluators/, "") + const scenarioId = (record as any).scenarioId || (record as any).key + if (hasAnnotation) { + return ( + + ) + } + return evalType === "auto" ? ( + + ) : ( + + ) + }, + children: buildAntdColumns(c.children, runId, expendedRows), + } + } + + return { + ...common, + __editLabel: editLabel, + title: titleCase(c.title ?? c.name), + key: c.key ?? c.name, + children: buildAntdColumns(c.children, runId, expendedRows), + } as EnhancedColumnType + } + + if (c.kind === "meta") { + switch (c.path) { + case "scenarioIndex": + return { + ...common, + fixed: "left", + width: 50, + minWidth: 50, + onCell: (record) => { + const showBorder = + expendedRows?.[record.key] || + (record?.isComparison && !record.isLastRow) + return { + className: showBorder + ? "!border-b-0 !p-0" + : record?.children?.length || record?.isComparison + ? "!p-0" + : "", + } + }, + render: (_: any, record: TableRow) => ( + {record.scenarioIndex} + ), + } + case "action": + if (evalType === "auto") return null + return { + ...common, + fixed: "right", + width: 120, + minWidth: 120, + render: (_: any, record: TableRow) => { + // Use runId from record data instead of function parameter + const effectiveRunId = (record as any).runId || runId + return ( + + ) + }, + } + default: + return {...common, dataIndex: c.path} + } + } + + if (c.kind === "input") { + const isFirstInput = inputColumns.length > 0 && inputColumns[0] === c + if (!isFirstInput) return null + + return { + title: ( + Inputs + ), + dataIndex: "inputs_group", + key: "inputs_group", + align: "left", + collapsible: true, + addNotAvailableCell: false, + onCell: (record) => { + const showBorder = + expendedRows?.[record.key] || + (record?.isComparison && !record.isLastRow) + return { + className: showBorder ? "!border-b-0 !bg-white" : "!bg-white", + } + }, + renderAggregatedData: ({record, isCollapsed}) => { + if (record.isComparison) return null + return ( +
    + + {inputColumns.map((inputCol) => ( +
    + + {titleCase(inputCol.name!)}: + {" "} + +
    + ))} +
    +
    + ) + }, + children: inputColumns.map((inputCol, idx) => ({ + title: titleCase(inputCol.name!), + key: `${inputCol.name}-input-${idx}`, + addNotAvailableCell: false, + onCell: (record) => { + const showBorder = + expendedRows?.[record.key] || + (record?.isComparison && !record.isLastRow) + return { + className: showBorder ? "!border-b-0 !bg-white" : "!bg-white", + } + }, + render: (_: any, record: TableRow) => { + if (record.isComparison) return "" + return ( + + ) + }, + })), + } + } + + return { + ...common, + sorter, + render: (_unused: any, record: TableRow) => { + // Use runId from record data instead of function parameter + const effectiveRunId = (record as any).runId || runId + // if (record.isSkeleton) return + switch (c.kind) { + case "input": { + const inputStepKey = resolveStepKeyForRun(c, effectiveRunId) + return ( + + ) + } + case "invocation": { + const invocationStepKey = resolveStepKeyForRun(c, effectiveRunId) + return ( + + ) + } + case "annotation": { + const annotationStepKey = resolveStepKeyForRun(c, effectiveRunId) + return ( + + ) + } + case "metric": + return ( + + ) + default: + return record.isSkeleton ? ( + + ) : ( + (c.values?.[record.scenarioId || record.key] ?? "") + ) + } + }, + } + }) + .filter(Boolean) + .sort((a, b) => { + if (!a || !b) return 0 + const r = orderRank(a) - orderRank(b) + if (r !== 0) return r + const aName = "title" in a && a.title ? String(a.title) : a.key + const bName = "title" in b && b.title ? String(b.title) : b.key + return aName?.localeCompare(bName) + }) as EnhancedColumnType[] +} diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useExpandableComparisonDataSource.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useExpandableComparisonDataSource.tsx new file mode 100644 index 0000000000..2615b89538 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useExpandableComparisonDataSource.tsx @@ -0,0 +1,387 @@ +import {useMemo} from "react" + +import deepEqual from "fast-deep-equal" +import {atom, useAtomValue} from "jotai" +import {atomFamily} from "jotai/utils" + +import { + evalAtomStore, + evaluationRunStateFamily, + runIndexFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {filterColumns} from "@/oss/components/Filters/EditColumns/assets/helper" +import type {RunIndex} from "@/oss/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex" + +import { + displayedScenarioIdsFamily, + scenarioStepsFamily, +} from "../../../../../lib/hooks/useEvaluationRunData/assets/atoms/runScopedScenarios" +import {buildScenarioTableData} from "../assets/dataSourceBuilder" +import {buildAntdColumns} from "../assets/utils" +import {expendedRowAtom} from "../ComparisonScenarioTable" +import type {TableColumn} from "../assets/types" +import {editColumnsFamily} from "./useTableDataSource" + +export interface GroupedScenario { + key: string + scenarioId: string + testcaseId: string + inputs: any + outputs: any + runId: string + comparedScenarios: { + id: string + inputSteps: string + inputs: any + outputs: any + runId: string + }[] +} + +interface UseExpandableComparisonDataSourceProps { + baseRunId: string + comparisonRunIds: string[] +} + +const testcaseForScenarios = atomFamily((runId: string) => + atom((get) => { + const scenarioSteps = get(scenarioStepsFamily(runId)) + const allScenarioIds = Object.keys(scenarioSteps) + const allSteps = allScenarioIds.reduce((acc, scenarioId) => { + const scenarioTestcaseIds = scenarioSteps[scenarioId]?.data?.inputSteps?.map( + (s) => s?.testcaseId, + ) + acc[scenarioId] = scenarioTestcaseIds + return acc + }, {}) + return allSteps + }), +) +export const comparisonRunsStepsAtom = atomFamily((runIds: string[]) => + atom((get) => { + const steps = runIds.reduce((acc, runId) => { + const scenarioSteps = get(scenarioStepsFamily(runId)) + + const allStepIds = Object.keys(scenarioSteps) + const allSteps = allStepIds.map((stepId) => ({ + id: stepId, + ...scenarioSteps[stepId], + })) + const allStepsData = allSteps.reduce((acc, step) => { + if (step.state === "hasData") { + acc[step.id] = step?.data?.inputSteps?.map((s) => s?.testcaseId) + } + return acc + }, {}) + + acc[runId] = allStepsData + return acc + }, {}) + return steps + }), +) + +export const comparisonRunIndexesAtom = atomFamily( + (runIds: string[]) => + atom((get) => + runIds.reduce>((acc, runId) => { + acc[runId] = get(runIndexFamily(runId)) + return acc + }, {}), + ), + deepEqual, +) + +const comparisonRunsEvaluatorsAtom = atomFamily((runIds: string[]) => + atom((get) => { + const evaluators = new Set() + runIds.forEach((runId) => { + const evals = get(evaluationRunStateFamily(runId)) + const enrichRun = evals?.enrichedRun + if (enrichRun) { + enrichRun.evaluators?.forEach((e) => evaluators.add(e)) + } + }) + + return Array.from(evaluators) + }), +) + +const metricsFromEvaluatorsFamily = atomFamily( + (runIds: string[]) => + atom((get) => { + // Build a map of evaluatorSlug -> unique metrics + const result: Record = {} + const seenMetricBySlug: Record> = {} + + runIds.forEach((runId) => { + const state = get(evaluationRunStateFamily(runId)) + const evaluators = state?.enrichedRun?.evaluators + ? Object.values(state.enrichedRun.evaluators) + : [] + + evaluators.forEach((ev: any) => { + const slug = ev?.slug + if (!slug) return + + if (!seenMetricBySlug[slug]) { + seenMetricBySlug[slug] = new Set() + } + + if (ev?.metrics && typeof ev.metrics === "object") { + Object.entries(ev.metrics).forEach( + ([metricName, metricInfo]: [string, any]) => { + if (seenMetricBySlug[slug].has(metricName)) return + seenMetricBySlug[slug].add(metricName) + + if (!result[slug]) result[slug] = [] + result[slug].push({ + [metricName]: { + metricType: metricInfo?.type || "unknown", + }, + evaluatorSlug: slug, + }) + }, + ) + } + }) + }) + + return result + }), + deepEqual, +) + +const useExpandableComparisonDataSource = ({ + baseRunId, + comparisonRunIds, +}: UseExpandableComparisonDataSourceProps) => { + const store = evalAtomStore() + // const fetchMultipleRuns = useSetAtom(multiRunDataFetcherAtom) + + const comparisonRunsSteps = useAtomValue(comparisonRunsStepsAtom(comparisonRunIds), {store}) + const baseTestcases = useAtomValue(testcaseForScenarios(baseRunId), {store}) + const comparisonRunIndexes = useAtomValue(comparisonRunIndexesAtom(comparisonRunIds), {store}) + + const comparisonRunsEvaluators = useAtomValue(comparisonRunsEvaluatorsAtom(comparisonRunIds), { + store, + }) + + const metricsFromEvaluators = useAtomValue( + metricsFromEvaluatorsFamily([baseRunId, ...comparisonRunIds]), + {store}, + ) + + // Match scenarios by content rather than IDs + const matchedScenarios = useMemo(() => { + const matches: Record = {} + + // For each base scenario, find matching scenarios in comparison runs + Object.entries(baseTestcases as Record).forEach( + ([baseScenarioId, baseSteps]) => { + const baseTestcaseData = baseSteps?.[0] + if (!baseTestcaseData) return + + const comparedScenarios: any[] = [] + + // Search through all comparison runs + Object.entries(comparisonRunsSteps as Record).forEach( + ([compRunId, compScenarios]) => { + Object.entries(compScenarios as Record).forEach( + ([compScenarioId, compSteps]) => { + const compTestcaseData = compSteps?.[0] + if (!compTestcaseData) return + + const inputsMatch = baseTestcaseData === compTestcaseData + + if (inputsMatch) { + // Derive compareIndex for this run from state or fallback to order in comparisonRunIds + const compState = store.get(evaluationRunStateFamily(compRunId)) + const compareIndex = + compState?.compareIndex ?? + (comparisonRunIds.includes(compRunId) + ? comparisonRunIds.indexOf(compRunId) + 2 + : undefined) + comparedScenarios.push({ + matchedTestcaseId: compTestcaseData, + runId: compRunId, + scenarioId: compScenarioId, + compareIndex, + }) + } + }, + ) + }, + ) + + matches[baseScenarioId] = comparedScenarios + }, + ) + + return matches + }, [baseTestcases, comparisonRunsSteps, comparisonRunIds.join(",")]) + + // Build columns using EXACT same approach as regular table (useTableDataSource) + const runIndex = useAtomValue(runIndexFamily(baseRunId), {store}) + const evaluationRunState = useAtomValue(evaluationRunStateFamily(baseRunId), {store}) + const expendedRows = useAtomValue(expendedRowAtom) + const evaluators = evaluationRunState?.enrichedRun?.evaluators || [] + const baseEvaluators = Array.isArray(evaluators) ? evaluators : Object.values(evaluators) + const allEvaluators = useMemo(() => { + const bySlug = new Map() + ;[...comparisonRunsEvaluators, ...baseEvaluators].forEach((ev: any) => { + if (ev?.slug && !bySlug.has(ev.slug)) bySlug.set(ev.slug, ev) + }) + return Array.from(bySlug.values()) + }, [comparisonRunsEvaluators, baseEvaluators]) + + const rawColumns = useMemo( + () => + buildScenarioTableData({ + runIndex, + metricsFromEvaluators, + runId: baseRunId, + evaluators: allEvaluators, + }), + [runIndex, metricsFromEvaluators, allEvaluators, expendedRows], + ) + + const columnsWithRunSpecificSteps = useMemo(() => { + if (!rawColumns) return [] as TableColumn[] + + const allRunIndexes: Record = { + [baseRunId]: runIndex, + ...(comparisonRunIndexes || {}), + } + + const cache = new Map() + + const getColumnsForRun = (runId: string) => { + if (cache.has(runId)) return cache.get(runId)! + const idx = allRunIndexes[runId] + const cols = idx ? Object.values(idx.columnsByStep || {}).flat() : [] + cache.set(runId, cols) + return cols + } + + const matchStepKey = (runId: string, column: any): string | undefined => { + if (runId === baseRunId && column.stepKey) return column.stepKey + const candidates = getColumnsForRun(runId) + const match = candidates.find((candidate) => { + if (candidate.kind !== column.kind) return false + if (column.path && candidate.path) { + return candidate.path === column.path + } + if (column.name && candidate.name) { + return candidate.name === column.name + } + return false + }) + return match?.stepKey + } + + const attach = (columns: any[]): any[] => + columns.map((column) => { + const children = column.children ? attach(column.children) : undefined + const shouldAttachStepKey = + column.kind === "input" || + column.kind === "invocation" || + column.kind === "annotation" + + if (!shouldAttachStepKey) { + return children ? {...column, children} : column + } + + const stepKeyByRunId = Object.keys(allRunIndexes).reduce< + Record + >((acc, runId) => { + const mapped = matchStepKey(runId, column) + if (mapped) acc[runId] = mapped + return acc + }, {}) + + if (column.stepKey && !stepKeyByRunId[baseRunId]) { + stepKeyByRunId[baseRunId] = column.stepKey + } + + if (!Object.keys(stepKeyByRunId).length) { + return children ? {...column, children} : column + } + + const enriched = { + ...column, + stepKeyByRunId, + } + if (children) enriched.children = children + return enriched + }) + + return attach(rawColumns as any[]) as TableColumn[] + }, [rawColumns, baseRunId, runIndex, comparisonRunIndexes]) + + // Build Ant Design columns using the same function as regular table + const baseAntColumns = useMemo( + () => + buildAntdColumns(columnsWithRunSpecificSteps as TableColumn[], baseRunId, expendedRows), + [columnsWithRunSpecificSteps, baseRunId, expendedRows], + ) + + const hiddenColumns = useAtomValue(editColumnsFamily(baseRunId), {store}) + + const antColumns = useMemo( + () => filterColumns(baseAntColumns, hiddenColumns), + [baseAntColumns, hiddenColumns], + ) + + // For backward compatibility, also provide basic columns + const columns = baseAntColumns + + // No longer need expandedRowRender - using children approach instead + const expandedRowRender = undefined + + const loading = false + + // Build rows with actual scenario data - use the SAME approach as regular table + const scenarioIds = useAtomValue(displayedScenarioIdsFamily(baseRunId), {store}) || [] + + const rows = useMemo(() => { + const builtRows = scenarioIds.map((scenarioId, idx) => { + // Get matched comparison scenarios for this base scenario + const comparedScenarios = matchedScenarios[scenarioId] || [] + + // Create base row structure + const baseRow = { + key: scenarioId, + scenarioIndex: idx + 1, + runId: baseRunId, // This row represents the base run + compareIndex: 1, + // Add children for comparison scenarios + children: comparedScenarios.map((compScenario, compIdx) => ({ + key: `${scenarioId}-comp-${compScenario.runId}-${compIdx}`, + scenarioIndex: idx + 1, // Same scenario index as parent + runId: compScenario.runId, // Use comparison run ID + scenarioId: compScenario.scenarioId, // Use comparison scenario ID + isComparison: true, // Flag to identify comparison rows + isLastRow: compIdx === comparedScenarios.length - 1, + compareIndex: compScenario.compareIndex, + })), + } + + return baseRow + }) + + return builtRows + }, [scenarioIds, matchedScenarios, baseRunId]) + + return { + antColumns, + columns, + rawColumns: baseAntColumns, + rows, + expandedRowRender, + loading, + totalColumnWidth: 0, // TODO: Calculate if needed + } +} + +export default useExpandableComparisonDataSource diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useScrollToScenario.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useScrollToScenario.ts new file mode 100644 index 0000000000..f901299ad8 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useScrollToScenario.ts @@ -0,0 +1,88 @@ +import {RefObject, useEffect, useMemo, useRef} from "react" + +import {useRouter} from "next/router" + +import {TableRow} from "../types" + +type TableRowWithChildren = TableRow & { + scenarioId?: string + children?: TableRowWithChildren[] +} + +const useScrollToScenario = ({ + dataSource, + expandedRowKeys = [], +}: { + dataSource: TableRowWithChildren[] + expandedRowKeys?: string[] +}) => { + const router = useRouter() + const tableContainerRef = useRef(null) + const tableInstance = useRef(null) + + const selectedScenarioId = router.query.scrollTo as string + + const flattenedRowKeys = useMemo(() => { + const keys: string[] = [] + const expandedSet = new Set((expandedRowKeys || []).map((key) => String(key))) + + const traverse = (rows: TableRowWithChildren[] = []) => { + rows.forEach((row) => { + const rowKey = (row?.key ?? row?.scenarioId) as string | undefined + if (!rowKey) return + + keys.push(rowKey) + + const isExpanded = expandedSet.has(rowKey) + if (!isExpanded) { + return + } + + if (Array.isArray(row.children) && row.children.length > 0) { + traverse(row.children) + } + }) + } + + traverse(dataSource) + + return keys + }, [dataSource, expandedRowKeys]) + + // Scroll to the specified row when user selects a scenario in auto eval + useEffect(() => { + if (!router.isReady) return + if (!tableInstance.current || !selectedScenarioId) return + // Get the row index from the flattened dataSource including expanded rows + const rowIndex = flattenedRowKeys.findIndex((key) => key === selectedScenarioId) + if (rowIndex === -1) return + // Use Ant Design's scrollTo method for virtualized tables when available + if (typeof tableInstance.current?.scrollTo === "function") { + tableInstance.current.scrollTo({ + index: rowIndex, + behavior: "smooth", + }) + } + + const rowElement = tableContainerRef.current?.querySelector( + `[data-row-key="${selectedScenarioId}"]`, + ) as HTMLElement | null + + // Fallback to native DOM scrolling when virtualization instance is unavailable + if (typeof tableInstance.current?.scrollTo !== "function") { + rowElement?.scrollIntoView({behavior: "smooth", block: "center"}) + } + + // Add highlight effect + if (rowElement) { + rowElement.classList.add("highlight-row") + setTimeout(() => { + rowElement.classList.remove("highlight-row") + }, 2000) + } + }, [selectedScenarioId, flattenedRowKeys, router.isReady]) + + return {tableContainerRef, tableInstance} +} + +export default useScrollToScenario diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useTableDataSource.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useTableDataSource.ts new file mode 100644 index 0000000000..95f0a6f3ab --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/hooks/useTableDataSource.ts @@ -0,0 +1,156 @@ +import {useMemo} from "react" + +import deepEqual from "fast-deep-equal" +import {atom, useAtom, useAtomValue} from "jotai" +import {atomFamily, selectAtom} from "jotai/utils" +import groupBy from "lodash/groupBy" + +import {filterColumns} from "@/oss/components/Filters/EditColumns/assets/helper" +import {useRunId} from "@/oss/contexts/RunIdContext" +import {ColumnDef} from "@/oss/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex" + +import { + evaluationRunStateFamily, + runIndexFamily, +} from "../../../../../lib/hooks/useEvaluationRunData/assets/atoms/runScopedAtoms" +// import {scenarioMetricsMapFamily} from "../../../../../lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics" +import { + displayedScenarioIdsFamily, + loadableScenarioStepFamily, +} from "../../../../../lib/hooks/useEvaluationRunData/assets/atoms/runScopedScenarios" +import {evalAtomStore} from "../../../../../lib/hooks/useEvaluationRunData/assets/atoms/store" +import {buildScenarioTableData, buildScenarioTableRows} from "../assets/dataSourceBuilder" +import {buildAntdColumns} from "../assets/utils" + +const EMPTY_SCENARIOS: any[] = [] + +export const editColumnsFamily = atomFamily((runId: string) => atom([]), deepEqual) + +// Convert to atom family for run-scoped access +export const allScenariosLoadedFamily = atomFamily( + (runId: string) => + atom( + (get) => + (get(evaluationRunStateFamily(runId)).scenarios || EMPTY_SCENARIOS).map( + (s: any) => s.id, + )?.length > 0, + ), + deepEqual, +) + +// Run-scoped metrics from evaluators atom family +export const metricsFromEvaluatorsFamily = atomFamily( + (runId: string) => + selectAtom( + evaluationRunStateFamily(runId), + (state) => { + const evs = state?.enrichedRun?.evaluators + ? Object.values(state.enrichedRun.evaluators) + : [] + if (!evs || !Array.isArray(evs)) { + return {} + } + return groupBy( + evs.reduce((acc: any[], ev: any) => { + return [ + ...acc, + ...Object.entries(ev.metrics || {}).map( + ([metricName, metricInfo]: [string, any]) => { + return { + [metricName]: { + metricType: metricInfo.type, + }, + evaluatorSlug: ev.slug, + } + }, + ), + ] + }, []), + (def: any) => { + return def.evaluatorSlug + }, + ) + }, + deepEqual, + ), + deepEqual, +) + +const useTableDataSource = () => { + const runId = useRunId() + const store = evalAtomStore() + + // states + const [editColumns, setEditColumns] = useAtom(editColumnsFamily(runId), {store}) + + // Read from the same global store that writes are going to + const scenarioIds = useAtomValue(displayedScenarioIdsFamily(runId), {store}) || EMPTY_SCENARIOS + const allScenariosLoaded = useAtomValue(allScenariosLoadedFamily(runId), {store}) + + // const metricDistributions = useAtomValue(runMetricsStatsAtom) + const runIndex = useAtomValue(runIndexFamily(runId)) + const metricsFromEvaluators = + useAtomValue(metricsFromEvaluatorsFamily(runId)) || EMPTY_SCENARIOS + // temporary implementation to implement loading state for auto eval + const loadable = useAtomValue(loadableScenarioStepFamily({runId, scenarioId: scenarioIds?.[0]})) + const evaluationRunState = useAtomValue(evaluationRunStateFamily(runId), {store}) + const evaluators = evaluationRunState?.enrichedRun?.evaluators || [] + + const isLoadingSteps = useMemo( + () => loadable.state === "loading" || !allScenariosLoaded, + [loadable, allScenariosLoaded], + ) + + const rows = useMemo(() => { + return buildScenarioTableRows({ + scenarioIds, + allScenariosLoaded, + runId, + }) + }, [scenarioIds, allScenariosLoaded]) + + // New alternative data source built via shared helper + const builtColumns: ColumnDef[] = useMemo( + () => + buildScenarioTableData({ + runIndex, + runId, + metricsFromEvaluators, + evaluators, + }), + [runIndex, runId, metricsFromEvaluators, evaluators], + ) + + // Build Ant Design columns and make them resizable + const antColumns = useMemo(() => { + return buildAntdColumns(builtColumns, runId, {}) + }, [builtColumns, runId]) + + const visibleColumns = useMemo( + () => filterColumns(antColumns, editColumns), + [antColumns, editColumns], + ) + + const totalColumnWidth = useMemo(() => { + const calc = (cols: any[]): number => + cols.reduce((sum, col) => { + if (col?.children && col?.children.length) { + return sum + calc(col?.children) + } + return sum + (col?.width ?? col?.minWidth ?? 100) + }, 0) + return calc(antColumns) + }, [antColumns]) + + return { + rawColumns: antColumns, + antColumns: visibleColumns, + rows, + totalColumnWidth, + isLoadingSteps, + editColumns, + setEditColumns, + } +} + +export default useTableDataSource diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/index.tsx b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/index.tsx new file mode 100644 index 0000000000..19db92e1fa --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/index.tsx @@ -0,0 +1,23 @@ +import {memo} from "react" + +import {useAtomValue} from "jotai" +import dynamic from "next/dynamic" + +import {urlStateAtom} from "../../state/urlState" + +import ScenarioTable from "./ScenarioTable" + +const ComparisonTable = dynamic(() => import("./ComparisonScenarioTable"), {ssr: false}) + +const VirtualizedScenarioTable = () => { + const urlState = useAtomValue(urlStateAtom) + const isComparisonMode = Boolean(urlState.compare && urlState.compare.length > 0) + + if (isComparisonMode) { + return + } + + return +} + +export default VirtualizedScenarioTable diff --git a/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/types.ts b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/types.ts new file mode 100644 index 0000000000..a1273dccd6 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/components/VirtualizedScenarioTable/types.ts @@ -0,0 +1,19 @@ +import {ColumnsType} from "antd/es/table" + +export interface TableRow { + key: string // scenarioId + scenarioIndex: number + status?: string + result?: string + baseRunId?: string + /** + * For skeleton rows shown while data is loading. + */ + isSkeleton?: boolean +} + +export interface VirtualizedScenarioTableProps { + columns?: ColumnsType + dataSource?: TableRow[] + totalColumnWidth?: number +} diff --git a/web/ee/src/components/EvalRunDetails/index.tsx b/web/ee/src/components/EvalRunDetails/index.tsx new file mode 100644 index 0000000000..d4957b4747 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/index.tsx @@ -0,0 +1,310 @@ +import {memo, useCallback, useEffect, useMemo} from "react" + +import {Spin, Typography} from "antd" +import clsx from "clsx" +import deepEqual from "fast-deep-equal" +import {createStore, getDefaultStore, Provider, useAtomValue, useSetAtom} from "jotai" +import {selectAtom} from "jotai/utils" +import {useRouter} from "next/router" + +import EvalRunDetails from "@/oss/components/EvalRunDetails/HumanEvalRun" +import ErrorState from "@/oss/components/ErrorState" +import SingleModelEvaluationTable from "@/oss/components/EvaluationTable/SingleModelEvaluationTable" +import {RunIdProvider} from "@/oss/contexts/RunIdContext" +import {useAppId} from "@/oss/hooks/useAppId" +import {appendBreadcrumbAtom, breadcrumbAtom, setBreadcrumbsAtom} from "@/oss/lib/atoms/breadcrumb" +import {isUuid} from "@/oss/lib/helpers/utils" +import useEvaluationRunData from "@/oss/lib/hooks/useEvaluationRunData" +import { + evalAtomStore, + evaluationRunStateFamily, + initializeRun, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {_EvaluationScenario, Evaluation} from "@/oss/lib/Types" +import {abortAll} from "@/oss/lib/utils/abortControllers" + +import EvaluationScenarios from "../pages/evaluations/evaluationScenarios/EvaluationScenarios" + +import AutoEvalRunDetails from "./AutoEvalRun" +import {ComparisonDataFetcher} from "./components/ComparisonDataFetcher" +import {evalTypeAtom, setEvalTypeAtom} from "./state/evalType" +import {runViewTypeAtom} from "./state/urlState" +import UrlSync from "./UrlSync" + +const EvaluationPageData = memo( + ({children, runId}: {children?: React.ReactNode; runId?: string}) => { + const router = useRouter() + + // Abort any in-flight data requests when navigating away + useEffect(() => { + if (runId) { + initializeRun(runId) + } + }, [runId]) + + // Abort any in-flight data requests when navigating away + useEffect(() => { + return () => { + abortAll() + } + }, [router.pathname]) + + useEvaluationRunData(runId || null, true, runId) + return runId ? children : null + }, +) + +const LegacyEvaluationPage = ({id: evaluationTableId}: {id: string}) => { + const evalType = useAtomValue(evalTypeAtom) + + const {legacyEvaluationSWR, legacyScenariosSWR} = useEvaluationRunData( + evaluationTableId || null, + true, + ) + + if (legacyEvaluationSWR.isLoading || legacyScenariosSWR.isLoading) { + return ( +
    +
    + + + Loading... + +
    +
    + ) + } + + const data = legacyEvaluationSWR.data + + return data ? ( + evalType === "auto" ? ( + + ) : evalType === "human" ? ( + + ) : null + ) : null +} + +const PreviewEvaluationPage = memo( + ({ + evalType, + name, + description, + id, + }: { + evalType: "auto" | "human" + name: string + description: string + id: string + }) => { + return evalType === "auto" ? ( + + ) : ( + + ) + }, +) + +const LoadingState = ({ + evalType, + name, + description, + id, +}: { + evalType: "auto" | "human" + name: string + description: string + id: string +}) => { + return evalType === "auto" ? ( + + ) : ( +
    +
    + + + Loading... + +
    +
    + ) +} + +const EvaluationPage = memo(({evalType, runId}: {evalType: "auto" | "human"; runId: string}) => { + const rootStore = getDefaultStore() + const breadcrumbs = useAtomValue(breadcrumbAtom, {store: rootStore}) + const appendBreadcrumb = useSetAtom(appendBreadcrumbAtom, {store: rootStore}) + const setEvalType = useSetAtom(setEvalTypeAtom) + const appId = useAppId() + + const {isPreview, name, description, id} = useAtomValue( + selectAtom( + evaluationRunStateFamily(runId!), + useCallback((v) => { + return { + isPreview: v.isPreview, + name: v.enrichedRun?.name, + description: v.enrichedRun?.description, + id: v.enrichedRun?.id, + } + }, []), + deepEqual, + ), + ) + + useEffect(() => { + setEvalType(evalType) + }, [evalType]) + + useEffect(() => { + // Try loaded name first; fallback to name in URL (when present as /results/:id/:name). + const base = (typeof window !== "undefined" ? window.location.pathname : "") || "" + const segs = base.split("/").filter(Boolean) + const resultsIdx = segs.findIndex((s) => s === "results") + const urlName = + resultsIdx !== -1 && segs[resultsIdx + 2] && !isUuid(segs[resultsIdx + 2]) + ? segs[resultsIdx + 2] + : undefined + + const label = name || urlName + if (!id || !label) return + + const existing = (breadcrumbs && (breadcrumbs["eval-detail"] as any)) || null + const currentLabel: string | undefined = existing?.label + if (currentLabel === label) return + + appendBreadcrumb({ + "eval-detail": { + label, + value: id as string, + }, + }) + }, [appendBreadcrumb, breadcrumbs, id, name]) + + useEffect(() => { + const base = (typeof window !== "undefined" ? window.location.pathname : "") || "" + const segs = base.split("/").filter(Boolean) + const desiredLabel = evalType === "auto" ? "auto evaluation" : "human annotation" + + const appsIdx = segs.findIndex((s) => s === "apps") + if (appsIdx !== -1) { + const appId = segs[appsIdx + 1] + if (!appId) return + const evaluationsHref = `/${segs.slice(0, appsIdx + 2).join("/")}/evaluations` + + const current = (rootStore.get(breadcrumbAtom) as any) || {} + const appPage = current["appPage"] as any + const needsHref = !appPage || !appPage.href || !appPage.href.endsWith("/evaluations") + const needsLabel = !appPage || appPage.label !== desiredLabel + if (!needsHref && !needsLabel) return + + rootStore.set(appendBreadcrumbAtom, { + appPage: { + ...(appPage || {}), + label: desiredLabel, + href: evaluationsHref, + }, + }) + return + } + + const evaluationsIdx = segs.findIndex((s) => s === "evaluations") + if (evaluationsIdx === -1) return + const evaluationsHref = `/${segs.slice(0, evaluationsIdx + 1).join("/")}` + + const current = (rootStore.get(breadcrumbAtom) as any) || {} + const projectPage = current["projectPage"] as any + const needsHref = !projectPage || projectPage.href !== evaluationsHref + const needsLabel = !projectPage || projectPage.label !== desiredLabel + if (!needsHref && !needsLabel) return + + rootStore.set(appendBreadcrumbAtom, { + projectPage: { + ...(projectPage || {}), + label: desiredLabel, + href: evaluationsHref, + }, + }) + }, [rootStore, appendBreadcrumb, evalType]) + + // Clean up eval-detail crumb when leaving the page to avoid stale breadcrumbs + useEffect(() => { + return () => { + const current = (rootStore.get(breadcrumbAtom) as any) || {} + if (current["eval-detail"]) { + const {"eval-detail": _omit, ...rest} = current + rootStore.set(setBreadcrumbsAtom, rest) + } + } + }, [rootStore]) + + const hasPreviewData = Boolean(id) + + if (isPreview && !hasPreviewData) { + return ( + router.reload()} + /> + ) + } + + return ( +
    + {/** TODO: improve the component state specially AutoEvalRunDetails */} + {isPreview === undefined ? ( + + ) : isPreview && id ? ( + <> + + + + ) : ( + + )} +
    + ) +}) + +const EvalRunDetailsPage = memo(({evalType}: {evalType: "auto" | "human"}) => { + const router = useRouter() + const runId = router.query.evaluation_id ? router.query.evaluation_id.toString() : "" + return ( + + + + + + ) +}) + +export default memo(EvalRunDetailsPage) diff --git a/web/ee/src/components/EvalRunDetails/state/evalType.ts b/web/ee/src/components/EvalRunDetails/state/evalType.ts new file mode 100644 index 0000000000..bef8a97447 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/state/evalType.ts @@ -0,0 +1,10 @@ +import {atom} from "jotai" + +// This atom is used to store the evaluation type (auto or human) for the current evaluation run. +// It is used to determine which evaluation page to render. +export const evalTypeAtom = atom<"auto" | "human" | null>(null) + +// This atom is used to set the evaluation type (auto or human) for the current evaluation run. +export const setEvalTypeAtom = atom(null, (get, set, update: "auto" | "human" | null) => { + set(evalTypeAtom, update) +}) diff --git a/web/ee/src/components/EvalRunDetails/state/focusScenarioAtom.ts b/web/ee/src/components/EvalRunDetails/state/focusScenarioAtom.ts new file mode 100644 index 0000000000..3940d1afea --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/state/focusScenarioAtom.ts @@ -0,0 +1,89 @@ +import {atom} from "jotai" +import {atomWithImmer} from "jotai-immer" + +export interface FocusTarget { + focusRunId: string | null + focusScenarioId: string | null +} + +interface FocusDrawerState extends FocusTarget { + open: boolean + isClosing: boolean +} + +export const initialFocusDrawerState: FocusDrawerState = { + open: false, + isClosing: false, + focusRunId: null, + focusScenarioId: null, +} + +export const focusDrawerAtom = atomWithImmer(initialFocusDrawerState) + +export const focusScenarioAtom = atom((get) => { + const {focusRunId, focusScenarioId} = get(focusDrawerAtom) + if (!focusScenarioId) return null + return {focusRunId, focusScenarioId} +}) + +export const isFocusDrawerOpenAtom = atom((get) => get(focusDrawerAtom).open) + +export const focusDrawerTargetAtom = atom((get) => { + const {focusRunId, focusScenarioId} = get(focusDrawerAtom) + return {focusRunId, focusScenarioId} +}) + +export const setFocusDrawerTargetAtom = atom(null, (_get, set, target: FocusTarget) => { + set(focusDrawerAtom, (draft) => { + if ( + draft.focusRunId === target.focusRunId && + draft.focusScenarioId === target.focusScenarioId + ) { + return + } + draft.focusRunId = target.focusRunId + draft.focusScenarioId = target.focusScenarioId + }) +}) + +export const openFocusDrawerAtom = atom(null, (_get, set, target: FocusTarget) => { + set(focusDrawerAtom, (draft) => { + const sameTarget = + draft.focusRunId === target.focusRunId && + draft.focusScenarioId === target.focusScenarioId && + draft.open + draft.open = true + draft.isClosing = false + if (!sameTarget) { + draft.focusRunId = target.focusRunId + draft.focusScenarioId = target.focusScenarioId + } + }) +}) + +export const closeFocusDrawerAtom = atom(null, (_get, set) => { + set(focusDrawerAtom, (draft) => { + if (!draft.open && !draft.focusScenarioId && !draft.focusRunId) { + return + } + draft.open = false + draft.isClosing = true + }) +}) + +export const resetFocusDrawerAtom = atom(null, (_get, set) => { + set(focusDrawerAtom, () => ({...initialFocusDrawerState})) +}) + +export const applyFocusDrawerStateAtom = atom( + null, + (_get, set, payload: Partial) => { + set(focusDrawerAtom, (draft) => { + const next = {...draft, ...payload} + draft.open = Boolean(next.open) + draft.isClosing = Boolean(next.isClosing) + draft.focusRunId = next.focusRunId ?? null + draft.focusScenarioId = next.focusScenarioId ?? null + }) + }, +) diff --git a/web/ee/src/components/EvalRunDetails/state/urlState.ts b/web/ee/src/components/EvalRunDetails/state/urlState.ts new file mode 100644 index 0000000000..a4d4e88a92 --- /dev/null +++ b/web/ee/src/components/EvalRunDetails/state/urlState.ts @@ -0,0 +1,36 @@ +import {atom} from "jotai" +import {atomWithImmer} from "jotai-immer" + +import {evalTypeAtom} from "../state/evalType" + +export interface EvalRunUrlState { + view?: "list" | "table" | "focus" + scenarioId?: string + compare?: string[] // Array of run IDs to compare against the base run +} + +// Holds the subset of query params we care about for EvalRunDetails page +export const urlStateAtom = atomWithImmer({}) + +type HumanEvalViewTypes = "focus" | "list" | "table" | "results" +type AutoEvalViewTypes = "overview" | "test-cases" | "prompt" + +// Derived UI atom: maps the URL state and eval type to a concrete view +export const runViewTypeAtom = atom((get) => { + const evalType = get(evalTypeAtom) + const view = get(urlStateAtom).view + + const humanViews: HumanEvalViewTypes[] = ["focus", "list", "table", "results"] + // Put "test-cases" first so it becomes the default for auto evaluations + const autoViews: AutoEvalViewTypes[] = ["test-cases", "overview", "prompt"] + + if (evalType === "auto") { + // default and validation for auto eval + const v = (view as AutoEvalViewTypes | undefined) ?? autoViews[0] + return autoViews.includes(v) ? v : undefined + } + + // default and validation for human eval + const v = (view as HumanEvalViewTypes | undefined) ?? humanViews[0] + return humanViews.includes(v) ? v : "focus" +}) diff --git a/web/ee/src/components/EvaluationTable/ABTestingEvaluationTable.tsx b/web/ee/src/components/EvaluationTable/ABTestingEvaluationTable.tsx new file mode 100644 index 0000000000..ce5467f572 --- /dev/null +++ b/web/ee/src/components/EvaluationTable/ABTestingEvaluationTable.tsx @@ -0,0 +1,823 @@ +// @ts-nocheck +import {useState, useEffect, useCallback, useMemo, useRef} from "react" + +import SecondaryButton from "@agenta/oss/src/components/SecondaryButton/SecondaryButton" +import {Button, Card, Col, Input, Radio, Row, Space, Statistic, Table, message} from "antd" +import type {ColumnType} from "antd/es/table" +import {getDefaultStore, useAtomValue} from "jotai" +import debounce from "lodash/debounce" +import {useRouter} from "next/router" + +import {useQueryParam} from "@/oss/hooks/useQuery" +import {EvaluationFlow} from "@/oss/lib/enums" +import {exportABTestingEvaluationData} from "@/oss/lib/helpers/evaluate" +import {isBaseResponse, isFuncResponse} from "@/oss/lib/helpers/playgroundResp" +import {testsetRowToChatMessages} from "@/oss/lib/helpers/testset" +import { + EvaluationTypeLabels, + batchExecute, + camelToSnake, + getStringOrJson, +} from "@/oss/lib/helpers/utils" +import {variantNameWithRev} from "@/oss/lib/helpers/variantHelper" +import useStatelessVariants from "@/oss/lib/hooks/useStatelessVariants" +import {getAllMetadata} from "@/oss/lib/hooks/useStatelessVariants/state" +import {extractInputKeysFromSchema} from "@/oss/lib/shared/variant/inputHelpers" +import {getRequestSchema} from "@/oss/lib/shared/variant/openapiUtils" +import {derivePromptsFromSpec} from "@/oss/lib/shared/variant/transformer/transformer" +import {transformToRequestBody} from "@/oss/lib/shared/variant/transformer/transformToRequestBody" +import type {BaseResponse, EvaluationScenario, KeyValuePair, Variant} from "@/oss/lib/Types" +import {callVariant} from "@/oss/services/api" +import {updateEvaluationScenario, updateEvaluation} from "@/oss/services/human-evaluations/api" +import {useEvaluationResults} from "@/oss/services/human-evaluations/hooks/useEvaluationResults" +import {customPropertiesByRevisionAtomFamily} from "@/oss/state/newPlayground/core/customProperties" +import { + stablePromptVariablesAtomFamily, + transformedPromptsAtomFamily, +} from "@/oss/state/newPlayground/core/prompts" +import {variantFlagsAtomFamily} from "@/oss/state/newPlayground/core/variantFlags" +import {appUriInfoAtom, appSchemaAtom} from "@/oss/state/variant/atoms/fetcher" + +import EvaluationCardView from "../Evaluations/EvaluationCardView" +import {VARIANT_COLORS} from "../Evaluations/EvaluationCardView/assets/styles" +import EvaluationVotePanel from "../Evaluations/EvaluationCardView/EvaluationVotePanel" +import VariantAlphabet from "../Evaluations/EvaluationCardView/VariantAlphabet" + +import {useABTestingEvaluationTableStyles} from "./assets/styles" +import ParamsFormWithRun from "./components/ParamsFormWithRun" +import type {ABTestingEvaluationTableProps, ABTestingEvaluationTableRow} from "./types" + +// Note: Avoid Typography.Title to prevent EllipsisMeasure/ResizeObserver loops + +/** + * + * @param evaluation - Evaluation object + * @param evaluationScenarios - Evaluation rows + * @param columnsCount - Number of variants to compare face to face (per default 2) + * @returns + */ +const ABTestingEvaluationTable: React.FC = ({ + evaluation, + evaluationScenarios, + isLoading, +}) => { + const classes = useABTestingEvaluationTableStyles() + const router = useRouter() + const appId = router.query.app_id as string + const uriObject = useAtomValue(appUriInfoAtom) + const store = getDefaultStore() + const evalVariants = [...evaluation.variants] + + const {variants: data, isLoading: isVariantsLoading} = useStatelessVariants() + + // // Select the correct variant revisions for this evaluation + const variantData = useMemo(() => { + const allVariantData = data || [] + if (!allVariantData.length) return [] + + return evaluation.variants.map((evVariant, idx) => { + const revisionId = evaluation.variant_revision_ids?.[idx] + const revisionNumber = evaluation.revisions?.[idx] + + // 1. Try to find by exact revision id + let selected = allVariantData.find((v) => v.id === revisionId) + + // 2. Try by variantId & revision number + if (!selected && revisionNumber !== undefined) { + selected = allVariantData.find( + (v) => v.variantId === evVariant.variantId && v.revision === revisionNumber, + ) + } + + // 3. Fallback – latest revision for that variant + if (!selected) { + selected = allVariantData.find( + (v) => v.variantId === evVariant.variantId && v.isLatestRevision, + ) + } + + return selected || evVariant + }) + }, [data, evaluation.variants, evaluation.variant_revision_ids, evaluation.revisions]) + + const [rows, setRows] = useState([]) + const [, setEvaluationStatus] = useState(evaluation.status) + const [viewMode, setViewMode] = useQueryParam("viewMode", "card") + const {data: evaluationResults, mutate} = useEvaluationResults({ + evaluationId: evaluation.id, + onSuccess: () => { + updateEvaluation(evaluation.id, {status: EvaluationFlow.EVALUATION_FINISHED}) + }, + onError: (err) => { + console.error("Failed to fetch results:", err) + }, + }) + + const {numOfRows, flagVotes, positiveVotes, appVariant1Votes, appVariant2Votes} = + useMemo(() => { + const votesData = evaluationResults?.votes_data || {} + const variantsVotesData = votesData.variants_votes_data || {} + + const [variant1, variant2] = evaluation.variants || [] + + return { + numOfRows: votesData.nb_of_rows || 0, + flagVotes: votesData.flag_votes?.number_of_votes || 0, + positiveVotes: votesData.positive_votes?.number_of_votes || 0, + appVariant1Votes: variantsVotesData?.[variant1?.variantId]?.number_of_votes || 0, + appVariant2Votes: variantsVotesData?.[variant2?.variantId]?.number_of_votes || 0, + } + }, [evaluationResults, evaluation.variants]) + + const depouncedUpdateEvaluationScenario = useCallback( + debounce((data: Partial, scenarioId) => { + updateEvaluationScenarioData(scenarioId, data) + }, 800), + [], + ) + + useEffect(() => { + if (evaluationScenarios) { + setRows(() => { + const obj = [...evaluationScenarios] + const spec = store.get(appSchemaAtom) as any + const routePath = uriObject?.routePath + + obj.forEach((item, rowIndex) => { + // Map outputs into row shape for table columns + item.outputs.forEach((op) => (item[op.variant_id] = op.variant_output)) + + try { + // Build a stable input name set from variants (schema for custom, stable prompts otherwise) + const names = new Set() + ;(variantData || []).forEach((v: any) => { + const rid = v?.id + if (!rid) return + const flags = store.get( + variantFlagsAtomFamily({revisionId: rid}), + ) as any + if (flags?.isCustom && spec) { + extractInputKeysFromSchema(spec as any, routePath).forEach((k) => + names.add(k), + ) + } else { + const vars = store.get( + stablePromptVariablesAtomFamily(rid), + ) as string[] + ;(vars || []).forEach((k) => names.add(k)) + } + }) + + const chatCol = evaluation?.testset?.testsetChatColumn || "" + const reserved = new Set(["correct_answer", chatCol]) + const testRow = evaluation?.testset?.csvdata?.[rowIndex] || {} + + const existing = new Set( + (Array.isArray(item.inputs) ? item.inputs : []) + .map((ip: any) => ip?.input_name) + .filter(Boolean), + ) + + const nextInputs = Array.isArray(item.inputs) ? [...item.inputs] : [] + Array.from(names) + .filter((k) => typeof k === "string" && k && !reserved.has(k)) + .forEach((k) => { + if (!existing.has(k)) { + nextInputs.push({ + input_name: k, + input_value: (testRow as any)?.[k] ?? "", + }) + } + }) + item.inputs = nextInputs + } catch { + // best-effort prepopulation only + } + }) + + return obj + }) + } + }, [evaluationScenarios, variantData, uriObject?.routePath, evaluation?.testset?.csvdata]) + + const handleInputChange = useCallback( + (e: React.ChangeEvent, id: string, inputIndex: number) => { + setRows((oldRows) => { + const rowIndex = oldRows.findIndex((row) => row.id === id) + const newRows = [...oldRows] + if (newRows[rowIndex] && newRows[rowIndex].inputs?.[inputIndex]) { + newRows[rowIndex].inputs[inputIndex].input_value = e.target.value + } + return newRows + }) + }, + [], + ) + + const setRowValue = useCallback( + (rowIndex: number, columnKey: keyof ABTestingEvaluationTableRow, value: any) => { + setRows((oldRows) => { + const newRows = [...oldRows] + newRows[rowIndex][columnKey] = value as never + return newRows + }) + }, + [], + ) + + // Upsert a single input value into a row by scenario id + const upsertRowInput = useCallback((rowId: string, name: string, value: any) => { + setRows((old) => { + const idx = old.findIndex((r) => r.id === rowId) + if (idx === -1) return old + const next = [...old] + const row = {...next[idx]} + const inputs = Array.isArray(row.inputs) ? [...row.inputs] : [] + const pos = inputs.findIndex((ip) => ip.input_name === name) + if (pos === -1) { + inputs.push({input_name: name, input_value: value}) + } else if (inputs[pos]?.input_value !== value) { + inputs[pos] = {...inputs[pos], input_value: value} + } + row.inputs = inputs + next[idx] = row as any + return next + }) + }, []) + + const updateEvaluationScenarioData = useCallback( + async (id: string, data: Partial, showNotification = true) => { + await updateEvaluationScenario( + evaluation.id, + id, + Object.keys(data).reduce( + (acc, key) => ({ + ...acc, + [camelToSnake(key)]: data[key as keyof EvaluationScenario], + }), + {}, + ), + evaluation.evaluationType, + ) + .then(() => { + setRows((prev) => { + const next = [...prev] + const idx = next.findIndex((r) => r.id === id) + if (idx >= 0) { + Object.keys(data).forEach((key) => { + // @ts-ignore + next[idx][key] = data[key as keyof EvaluationScenario] as any + }) + } + return next + }) + if (showNotification) message.success("Evaluation Updated!") + }) + .catch(console.error) + }, + [evaluation.evaluationType, evaluation.id], + ) + + const handleVoteClick = useCallback( + async (id: string, vote: string) => { + const rowIndex = rows.findIndex((row) => row.id === id) + const evaluation_scenario_id = rows[rowIndex]?.id + + if (evaluation_scenario_id) { + setRowValue(rowIndex, "vote", "loading") + const data = { + vote: vote, + outputs: evalVariants.map((v: Variant) => ({ + variant_id: v.variantId, + variant_output: rows[rowIndex][v.variantId], + })), + inputs: rows[rowIndex].inputs, + } + await updateEvaluationScenarioData(evaluation_scenario_id, data) + await mutate() + } + }, + [rows, setRowValue, updateEvaluationScenarioData, evalVariants], + ) + + // Keep stable refs to callback handlers to avoid re-creating table columns + // Initialize with no-ops to avoid TDZ when functions are declared below + const runEvaluationRef = useRef< + (id: string, count?: number, showNotification?: boolean) => void + >(() => {}) + const handleInputChangeRef = useRef< + (e: React.ChangeEvent, id: string, inputIndex: number) => void + >(() => {}) + const handleVoteClickRef = useRef<(id: string, vote: string) => void>(() => {}) + // // Note: assign .current values after handlers are defined (see below) + + const runEvaluation = useCallback( + async (id: string, count = 1, showNotification = true) => { + const _variantData = variantData + const rowIndex = rows.findIndex((row) => row.id === id) + const testRow = evaluation?.testset?.csvdata?.[rowIndex] || {} + + // Derive request schema once + const spec = store.get(appSchemaAtom) as any + const routePath = uriObject?.routePath + const requestSchema: any = spec ? getRequestSchema(spec as any, {routePath}) : undefined + const hasMessagesProp = Boolean(requestSchema?.properties?.messages) + + const outputs = rows[rowIndex].outputs.reduce( + (acc, op) => ({...acc, [op.variant_id]: op.variant_output}), + {}, + ) + + await Promise.all( + evalVariants.map(async (variant: Variant, idx: number) => { + setRowValue(rowIndex, variant.variantId, "loading...") + + const isChatTestset = !!evaluation?.testset?.testsetChatColumn + + const rawMessages = isChatTestset + ? testsetRowToChatMessages(evaluation.testset.csvdata[rowIndex], false) + : [] + + const sanitizedMessages = rawMessages.map((msg) => { + if (!Array.isArray(msg.content)) return msg + return { + ...msg, + content: msg.content.filter((part) => { + return part.type !== "image_url" || part.image_url.url.trim() !== "" + }), + } + }) + + try { + // Build stable optional parameters using atom-based prompts (stable params) + const revisionId = _variantData?.[idx]?.id as string | undefined + const flags = revisionId + ? (store.get(variantFlagsAtomFamily({revisionId})) as any) + : undefined + const isCustom = Boolean(flags?.isCustom) + // Determine effective input keys per variant + const schemaKeys = spec + ? extractInputKeysFromSchema(spec as any, routePath) + : [] + const stableFromParams: string[] = (() => { + try { + const params = (_variantData[idx] as any)?.parameters + const ag = params?.ag_config ?? params ?? {} + const s = new Set() + Object.values(ag || {}).forEach((cfg: any) => { + const arr = cfg?.input_keys + if (Array.isArray(arr)) { + arr.forEach((k) => { + if (typeof k === "string" && k) s.add(k) + }) + } + }) + return Array.from(s) + } catch { + return [] + } + })() + + console.log("stableFromParams", stableFromParams) + // Also include stable variables derived from saved prompts (handles cases where input_keys are not explicitly listed) + const stableFromPrompts: string[] = revisionId + ? (store.get(stablePromptVariablesAtomFamily(revisionId)) as string[]) + : [] + const effectiveKeys = isCustom + ? schemaKeys + : Array.from( + new Set([ + ...(stableFromParams || []), + ...(stableFromPrompts || []), + ]), + ).filter((k) => typeof k === "string" && k && k !== "chat") + + // Build input params strictly from effective keys using testcase (with row overrides) + let inputParamsDict: Record = {} + if (Array.isArray(effectiveKeys) && effectiveKeys.length > 0) { + effectiveKeys.forEach((key) => { + const fromRowInput = rows[rowIndex]?.inputs?.find( + (ip) => ip.input_name === key, + )?.input_value + const fromTestcase = (testRow as any)?.[key] + if (fromRowInput !== undefined) inputParamsDict[key] = fromRowInput + else if (fromTestcase !== undefined) + inputParamsDict[key] = fromTestcase + }) + } else { + // Fallback: preserve previous behavior if keys unavailable + inputParamsDict = rows[rowIndex].inputs.reduce( + (acc: Record, item) => { + acc[item.input_name] = item.input_value + return acc + }, + {}, + ) + } + // Fallback: if chat testset, hydrate from test row keys as needed + if (isChatTestset) { + const testRow = evaluation?.testset?.csvdata?.[rowIndex] || {} + const reserved = new Set([ + "correct_answer", + evaluation?.testset?.testsetChatColumn || "", + ]) + Object.keys(testRow) + .filter((k) => !reserved.has(k)) + .forEach((k) => { + if (!(k in inputParamsDict)) + inputParamsDict[k] = (testRow as any)[k] + }) + } + + const stableOptional = revisionId + ? store.get( + transformedPromptsAtomFamily({ + revisionId, + useStableParams: true, + }), + ) + : undefined + + const optionalParameters = + stableOptional || + (_variantData[idx]?.parameters + ? transformToRequestBody({ + variant: _variantData[idx], + allMetadata: getAllMetadata(), + prompts: + spec && _variantData[idx] + ? derivePromptsFromSpec( + _variantData[idx] as any, + spec as any, + uriObject?.routePath, + ) || [] + : [], + // Keep request shape aligned with OpenAPI schema + isChat: hasMessagesProp, + isCustom, + customProperties: undefined, + }) + : (_variantData[idx]?.promptOptParams as any)) + // For new arch, variable inputs must live under requestBody.inputs + // Mark them as non-"input" so callVariant places them under "inputs" + const synthesizedParamDef = Object.keys(inputParamsDict).map((name) => ({ + name, + input: false, + })) as any + + const result = await callVariant( + inputParamsDict, + synthesizedParamDef, + optionalParameters, + appId || "", + _variantData[idx].baseId || "", + sanitizedMessages, + undefined, + true, + !!_variantData[idx]._parentVariant, // isNewVariant (new arch if parent exists) + isCustom, + uriObject, + _variantData[idx].variantId, + ) + + let res: BaseResponse | undefined + + if (typeof result === "string") { + res = {version: "2.0", data: result} as BaseResponse + } else if (isFuncResponse(result)) { + res = {version: "2.0", data: result.message} as BaseResponse + } else if (isBaseResponse(result)) { + res = result as BaseResponse + } else if (result.data) { + res = {version: "2.0", data: result.data} as BaseResponse + } else { + res = {version: "2.0", data: ""} as BaseResponse + } + + const _result = getStringOrJson(res.data) + + setRowValue(rowIndex, variant.variantId, _result) + ;(outputs as KeyValuePair)[variant.variantId] = _result + setRowValue( + rowIndex, + "evaluationFlow", + EvaluationFlow.COMPARISON_RUN_STARTED, + ) + if (idx === evalVariants.length - 1) { + if (count === 1 || count === rowIndex) { + setEvaluationStatus(EvaluationFlow.EVALUATION_FINISHED) + } + } + + updateEvaluationScenarioData( + id, + { + outputs: Object.keys(outputs).map((key) => ({ + variant_id: key, + variant_output: outputs[key as keyof typeof outputs], + })), + inputs: rows[rowIndex].inputs, + }, + showNotification, + ) + } catch (err) { + console.error("Error running evaluation:", err) + setEvaluationStatus(EvaluationFlow.EVALUATION_FAILED) + setRowValue( + rowIndex, + variant.variantId, + err?.response?.data?.detail?.message || "Failed to run evaluation!", + ) + } + }), + ) + }, + [ + variantData, + rows, + evalVariants, + updateEvaluationScenarioData, + setRowValue, + appId, + evaluation.testset.csvdata, + ], + ) + + // Now that handlers are declared, update stable refs + useEffect(() => { + runEvaluationRef.current = runEvaluation + handleInputChangeRef.current = handleInputChange + handleVoteClickRef.current = handleVoteClick + }, [runEvaluation, handleInputChange, handleVoteClick]) + + const runAllEvaluations = useCallback(async () => { + setEvaluationStatus(EvaluationFlow.EVALUATION_STARTED) + batchExecute(rows.map((row) => () => runEvaluation(row.id!, rows.length - 1, false))) + .then(() => { + setEvaluationStatus(EvaluationFlow.EVALUATION_FINISHED) + mutate() + message.success("Evaluations Updated!") + }) + .catch((err) => console.error("An error occurred:", err)) + }, [runEvaluation, rows]) + + const dynamicColumns: ColumnType[] = useMemo( + () => + evalVariants.map((variant: Variant, ix) => { + const columnKey = variant.variantId + + return { + title: ( +
    + Variant: + + + {evalVariants + ? variantNameWithRev({ + variant_name: variant.variantName, + revision: evaluation.revisions[ix], + }) + : ""} + +
    + ), + dataIndex: columnKey, + key: columnKey, + width: "20%", + render: (text: any, record: ABTestingEvaluationTableRow) => { + const value = + text || + record?.[columnKey] || + record.outputs?.find((o: any) => o.variant_id === columnKey) + ?.variant_output || + "" + return ( +
    + {value} +
    + ) + }, + } + }), + [evalVariants, evaluation.revisions], + ) + + const columns = useMemo(() => { + return [ + { + key: "1", + title: ( +
    +
    + Inputs (Test set: + {evaluation.testset.name} + ) +
    +
    + ), + width: 300, + dataIndex: "inputs", + render: (_: any, record: ABTestingEvaluationTableRow, rowIndex: number) => { + return ( + runEvaluationRef.current(record.id!)} + onParamChange={(name, value) => upsertRowInput(record.id!, name, value)} + variantData={variantData} + isLoading={isVariantsLoading} + /> + ) + }, + }, + { + title: "Expected Output", + dataIndex: "expectedOutput", + key: "expectedOutput", + width: "25%", + render: (text: any, record: any, rowIndex: number) => { + const correctAnswer = + record.correctAnswer || evaluation.testset.csvdata[rowIndex].correct_answer + + return ( + <> + + depouncedUpdateEvaluationScenario( + { + correctAnswer: e.target.value, + }, + record.id, + ) + } + key={record.id} + /> + + ) + }, + }, + ...dynamicColumns, + { + title: "Score", + dataIndex: "score", + key: "score", + render: (text: any, record: any, rowIndex: number) => { + return ( + <> + { + handleVoteClickRef.current(record.id, vote)} + loading={record.vote === "loading"} + vertical + key={record.id} + outputs={record.outputs} + /> + } + + ) + }, + }, + { + title: "Additional Note", + dataIndex: "additionalNote", + key: "additionalNote", + render: (text: any, record: any, rowIndex: number) => { + return ( + <> + + depouncedUpdateEvaluationScenario( + {note: e.target.value}, + record.id, + ) + } + key={record.id} + /> + + ) + }, + }, + ] + }, [ + isVariantsLoading, + evaluation.testset.name, + classes.inputTestContainer, + classes.inputTest, + dynamicColumns, + evalVariants, + ]) + + return ( +
    +

    {EvaluationTypeLabels.human_a_b_testing}

    +
    + + + + + + exportABTestingEvaluationData( + evaluation, + evaluationScenarios, + rows, + ) + } + disabled={false} + > + Export Results + + + + + + + + + + + + + + + + + + + + + + + +
    + +
    + setViewMode(e.target.value)} + value={viewMode} + optionType="button" + /> +
    + + {viewMode === "tabular" ? ( + record.id!} + /> + ) : ( + handleVoteClick(id, vote as string)} + onInputChange={handleInputChange} + updateEvaluationScenarioData={updateEvaluationScenarioData} + evaluation={evaluation} + variantData={variantData} + isLoading={isLoading || isVariantsLoading} + /> + )} + + ) +} + +export default ABTestingEvaluationTable diff --git a/web/ee/src/components/EvaluationTable/SingleModelEvaluationTable.tsx b/web/ee/src/components/EvaluationTable/SingleModelEvaluationTable.tsx new file mode 100644 index 0000000000..8273f6aa2e --- /dev/null +++ b/web/ee/src/components/EvaluationTable/SingleModelEvaluationTable.tsx @@ -0,0 +1,752 @@ +// @ts-nocheck +import {useCallback, useEffect, useState, useMemo} from "react" + +import { + Button, + Card, + Col, + Input, + Radio, + Row, + Space, + Statistic, + Table, + Typography, + message, +} from "antd" +import type {ColumnType} from "antd/es/table" +import {getDefaultStore, useAtomValue} from "jotai" +import debounce from "lodash/debounce" +import {useRouter} from "next/router" + +import SecondaryButton from "@/oss/components/SecondaryButton/SecondaryButton" +import {useQueryParam} from "@/oss/hooks/useQuery" +import {EvaluationFlow} from "@/oss/lib/enums" +import {exportSingleModelEvaluationData} from "@/oss/lib/helpers/evaluate" +import {isBaseResponse, isFuncResponse} from "@/oss/lib/helpers/playgroundResp" +import {testsetRowToChatMessages} from "@/oss/lib/helpers/testset" +import { + EvaluationTypeLabels, + batchExecute, + camelToSnake, + getStringOrJson, +} from "@/oss/lib/helpers/utils" +import {variantNameWithRev} from "@/oss/lib/helpers/variantHelper" +import useStatelessVariants from "@/oss/lib/hooks/useStatelessVariants" +import {getAllMetadata} from "@/oss/lib/hooks/useStatelessVariants/state" +import {extractInputKeysFromSchema} from "@/oss/lib/shared/variant/inputHelpers" +import {getRequestSchema} from "@/oss/lib/shared/variant/openapiUtils" +import {derivePromptsFromSpec} from "@/oss/lib/shared/variant/transformer/transformer" +import {transformToRequestBody} from "@/oss/lib/shared/variant/transformer/transformToRequestBody" +import type {BaseResponse, EvaluationScenario, KeyValuePair, Variant} from "@/oss/lib/Types" +import {callVariant} from "@/oss/services/api" +import {updateEvaluation, updateEvaluationScenario} from "@/oss/services/human-evaluations/api" +import {customPropertiesByRevisionAtomFamily} from "@/oss/state/newPlayground/core/customProperties" +import { + stablePromptVariablesAtomFamily, + transformedPromptsAtomFamily, +} from "@/oss/state/newPlayground/core/prompts" +import {variantFlagsAtomFamily} from "@/oss/state/newPlayground/core/variantFlags" +import {appUriInfoAtom, appSchemaAtom} from "@/oss/state/variant/atoms/fetcher" + +import EvaluationCardView from "../Evaluations/EvaluationCardView" +import EvaluationVotePanel from "../Evaluations/EvaluationCardView/EvaluationVotePanel" +import SaveTestsetModal from "../SaveTestsetModal/SaveTestsetModal" + +import {useSingleModelEvaluationTableStyles} from "./assets/styles" +import ParamsFormWithRun from "./components/ParamsFormWithRun" +import type {EvaluationTableProps, SingleModelEvaluationRow} from "./types" + +const {Title} = Typography + +/** + * + * @param evaluation - Evaluation object + * @param evaluationScenarios - Evaluation rows + * @param columnsCount - Number of variants to compare face to face (per default 2) + * @returns + */ +const SingleModelEvaluationTable: React.FC = ({ + evaluation, + evaluationScenarios, + isLoading, +}) => { + const classes = useSingleModelEvaluationTableStyles() + const router = useRouter() + const appId = router.query.app_id as string + const uriObject = useAtomValue(appUriInfoAtom) + const store = getDefaultStore() + const variants = evaluation.variants + + const {variants: data, isLoading: isVariantsLoading} = useStatelessVariants() + + // Select the correct variant revisions for this evaluation + const variantData = useMemo(() => { + const allVariantData = data || [] + if (!allVariantData.length) return [] + + return evaluation.variants.map((evVariant, idx) => { + const revisionId = evaluation.variant_revision_ids?.[idx] + const revisionNumber = evaluation.revisions?.[idx] + + // 1. Try to find by exact revision id + let selected = allVariantData.find((v) => v.id === revisionId) + + // 2. Try by variantId & revision number + if (!selected && revisionNumber !== undefined) { + selected = allVariantData.find( + (v) => v.variantId === evVariant.variantId && v.revision === revisionNumber, + ) + } + + // 3. Fallback – latest revision for that variant + if (!selected) { + selected = allVariantData.find( + (v) => v.variantId === evVariant.variantId && v.isLatestRevision, + ) + } + + return selected || evVariant + }) + }, [data, evaluation.variants, evaluation.variant_revision_ids, evaluation.revisions]) + + const [rows, setRows] = useState([]) + const [evaluationStatus, setEvaluationStatus] = useState(evaluation.status) + const [viewMode, setViewMode] = useQueryParam("viewMode", "card") + const [accuracy, setAccuracy] = useState(0) + const [isTestsetModalOpen, setIsTestsetModalOpen] = useState(false) + + const depouncedUpdateEvaluationScenario = useCallback( + debounce((data: Partial, scenarioId) => { + updateEvaluationScenarioData(scenarioId, data) + }, 800), + [rows], + ) + + useEffect(() => { + if (evaluationScenarios) { + const obj = [...evaluationScenarios] + const spec = store.get(appSchemaAtom) as any + const routePath = uriObject?.routePath + + obj.forEach((item, rowIndex) => { + // Map outputs into row shape for table columns + item.outputs.forEach((op) => (item[op.variant_id] = op.variant_output)) + + try { + const names = new Set() + ;(variantData || []).forEach((v: any) => { + const rid = v?.id + if (!rid) return + const flags = store.get(variantFlagsAtomFamily({revisionId: rid})) as any + if (flags?.isCustom && spec) { + extractInputKeysFromSchema(spec as any, routePath).forEach((k) => + names.add(k), + ) + } else { + const vars = store.get(stablePromptVariablesAtomFamily(rid)) as string[] + ;(vars || []).forEach((k) => names.add(k)) + } + }) + + const chatCol = evaluation?.testset?.testsetChatColumn || "" + const reserved = new Set(["correct_answer", chatCol]) + const testRow = evaluation?.testset?.csvdata?.[rowIndex] || {} + + const existing = new Set( + (Array.isArray(item.inputs) ? item.inputs : []) + .map((ip: any) => ip?.input_name) + .filter(Boolean), + ) + + const nextInputs = Array.isArray(item.inputs) ? [...item.inputs] : [] + Array.from(names) + .filter((k) => typeof k === "string" && k && !reserved.has(k)) + .forEach((k) => { + if (!existing.has(k)) { + nextInputs.push({ + input_name: k, + input_value: (testRow as any)?.[k] ?? "", + }) + } + }) + item.inputs = nextInputs + } catch { + // best-effort only + } + }) + + setRows(obj) + } + }, [evaluationScenarios, variantData]) + + useEffect(() => { + const filtered = rows.filter((row) => typeof row.score === "number" && !isNaN(row.score)) + + if (filtered.length > 0) { + const avg = filtered.reduce((acc, val) => acc + Number(val.score), 0) / filtered.length + setAccuracy(avg) + } else { + setAccuracy(0) + } + }, [rows]) + + useEffect(() => { + if (evaluationStatus === EvaluationFlow.EVALUATION_FINISHED) { + updateEvaluation(evaluation.id, {status: EvaluationFlow.EVALUATION_FINISHED}).catch( + (err) => console.error("Failed to fetch results:", err), + ) + } + }, [evaluationStatus, evaluation.id]) + + const handleInputChange = ( + e: React.ChangeEvent, + id: string, + inputIndex: number, + ) => { + const rowIndex = rows.findIndex((row) => row.id === id) + const newRows = [...rows] + newRows[rowIndex].inputs[inputIndex].input_value = e.target.value + setRows(newRows) + } + + const handleScoreChange = (id: string, score: number) => { + const rowIndex = rows.findIndex((row) => row.id === id) + const evaluation_scenario_id = rows[rowIndex].id + + if (evaluation_scenario_id) { + setRowValue(rowIndex, "score", "loading") + const data = { + score: score ?? "", + outputs: variants.map((v: Variant) => ({ + variant_id: v.variantId, + variant_output: rows[rowIndex][v.variantId], + })), + inputs: rows[rowIndex].inputs, + } + + updateEvaluationScenarioData(evaluation_scenario_id, data) + } + } + + const depouncedHandleScoreChange = useCallback( + debounce((...args: Parameters) => { + handleScoreChange(...args) + }, 800), + [handleScoreChange], + ) + + const updateEvaluationScenarioData = async ( + id: string, + data: Partial, + showNotification = true, + ) => { + await updateEvaluationScenario( + evaluation.id, + id, + Object.keys(data).reduce( + (acc, key) => ({ + ...acc, + [camelToSnake(key)]: data[key as keyof EvaluationScenario], + }), + {}, + ), + evaluation.evaluationType, + ) + .then(() => { + Object.keys(data).forEach((key) => { + setRowValue( + rows.findIndex((item) => item.id === id), + key, + data[key as keyof EvaluationScenario], + ) + }) + if (showNotification) message.success("Evaluation Updated!") + }) + .catch(console.error) + } + + const runAllEvaluations = async () => { + setEvaluationStatus(EvaluationFlow.EVALUATION_STARTED) + batchExecute(rows.map((row) => () => runEvaluation(row.id!, rows.length - 1, false))) + .then(() => { + setEvaluationStatus(EvaluationFlow.EVALUATION_FINISHED) + message.success("Evaluations Updated!") + }) + .catch((err) => console.error("An error occurred:", err)) + } + + const runEvaluation = async (id: string, count = 1, showNotification = true) => { + const rowIndex = rows.findIndex((row) => row.id === id) + // Build input params from stable effective keys: schema keys for custom; stable prompt variables/parameters for non-custom + const testRow = evaluation?.testset?.csvdata?.[rowIndex] || {} + const spec = store.get(appSchemaAtom) as any + const routePath = uriObject?.routePath + const requestSchema: any = spec ? getRequestSchema(spec as any, {routePath}) : undefined + const hasMessagesProp = Boolean(requestSchema?.properties?.messages) + + const effectiveKeysForVariant = (idx: number): string[] => { + const v = variantData?.[idx] as any + const rid = v?.id + const flags = rid ? (store.get(variantFlagsAtomFamily({revisionId: rid})) as any) : null + const isCustom = Boolean(flags?.isCustom) + if (isCustom) { + return spec ? extractInputKeysFromSchema(spec as any, routePath) : [] + } + // Union of saved parameters input_keys and stable prompt variables + const fromParams: string[] = (() => { + try { + const params = v?.parameters + const ag = params?.ag_config ?? params ?? {} + const s = new Set() + Object.values(ag || {}).forEach((cfg: any) => { + const arr = cfg?.input_keys + if (Array.isArray(arr)) + arr.forEach((k) => typeof k === "string" && s.add(k)) + }) + return Array.from(s) + } catch { + return [] + } + })() + const fromPrompts: string[] = rid + ? (store.get(stablePromptVariablesAtomFamily(rid)) as string[]) || [] + : [] + return Array.from(new Set([...(fromParams || []), ...(fromPrompts || [])])).filter( + (k) => k && k !== (evaluation?.testset?.testsetChatColumn || ""), + ) + } + + let inputParamsDict: Record = {} + const keys = effectiveKeysForVariant(0) // single model uses one variant for inputs shape + if (Array.isArray(keys) && keys.length > 0) { + keys.forEach((key) => { + const fromScenario = rows[rowIndex]?.inputs?.find( + (ip) => ip.input_name === key, + )?.input_value + const fromTestcase = (testRow as any)?.[key] + if (fromScenario !== undefined) inputParamsDict[key] = fromScenario + else if (fromTestcase !== undefined) inputParamsDict[key] = fromTestcase + }) + } else { + // Fallback to backend-provided inputs + inputParamsDict = rows[rowIndex].inputs.reduce((acc: Record, item) => { + acc[item.input_name] = item.input_value + return acc + }, {}) + } + + const outputs = rows[rowIndex].outputs.reduce( + (acc, op) => ({...acc, [op.variant_id]: op.variant_output}), + {}, + ) + await Promise.all( + variants.map(async (variant: Variant, idx: number) => { + setRowValue(rowIndex, variant.variantId, "loading...") + + const isChatTestset = !!evaluation?.testset?.testsetChatColumn + const rawMessages = isChatTestset + ? testsetRowToChatMessages(evaluation.testset.csvdata[rowIndex], false) + : [] + + const sanitizedMessages = rawMessages.map((msg) => { + if (!Array.isArray(msg.content)) return msg + return { + ...msg, + content: msg.content.filter((part) => { + return part.type !== "image_url" || part.image_url.url.trim() !== "" + }), + } + }) + + try { + const revisionId = variantData?.[idx]?.id as string | undefined + const flags = revisionId + ? (store.get(variantFlagsAtomFamily({revisionId})) as any) + : undefined + const isCustom = Boolean(flags?.isCustom) + // Recompute effective keys for this variant index + const vKeys = effectiveKeysForVariant(idx) + if (Array.isArray(vKeys) && vKeys.length > 0) { + vKeys.forEach((key) => { + if (!(key in inputParamsDict)) { + const v = (testRow as any)?.[key] + if (v !== undefined) inputParamsDict[key] = v + } + }) + } + if (isChatTestset) { + const testRow = evaluation?.testset?.csvdata?.[rowIndex] || {} + const reserved = new Set([ + "correct_answer", + evaluation?.testset?.testsetChatColumn || "", + ]) + Object.keys(testRow) + .filter((k) => !reserved.has(k)) + .forEach((k) => { + if (!(k in inputParamsDict)) + inputParamsDict[k] = (testRow as any)[k] + }) + } + + // Prefer stable transformed parameters (saved revision + schema) + const stableOptional = revisionId + ? store.get( + transformedPromptsAtomFamily({ + revisionId, + useStableParams: true, + }), + ) + : undefined + const optionalParameters = + stableOptional || + (variantData[idx]?.parameters + ? transformToRequestBody({ + variant: variantData[idx], + allMetadata: getAllMetadata(), + prompts: + spec && variantData[idx] + ? derivePromptsFromSpec( + variantData[idx] as any, + spec as any, + uriObject?.routePath, + ) || [] + : [], + // Keep request shape aligned with OpenAPI schema + isChat: hasMessagesProp, + isCustom, + // stableOptional already includes custom props; fallback path keeps schema-aligned custom props + customProperties: undefined, + }) + : (variantData[idx]?.promptOptParams as any)) + + // For new arch, variable inputs must live under requestBody.inputs + // Mark them as non-"input" so callVariant places them under "inputs" + const synthesizedParamDef = Object.keys(inputParamsDict).map((name) => ({ + name, + input: false, + })) as any + + const result = await callVariant( + inputParamsDict, + synthesizedParamDef, + optionalParameters, + appId || "", + variantData[idx].baseId || "", + sanitizedMessages, + undefined, + true, + !!variantData[idx]._parentVariant, // isNewVariant + isCustom, + uriObject, + variantData[idx].variantId, + ) + + let res: BaseResponse | undefined + + if (typeof result === "string") { + res = {version: "2.0", data: result} as BaseResponse + } else if (isFuncResponse(result)) { + res = {version: "2.0", data: result.message} as BaseResponse + } else if (isBaseResponse(result)) { + res = result as BaseResponse + } else if (result.data) { + res = {version: "2.0", data: result.data} as BaseResponse + } else { + res = {version: "2.0", data: ""} as BaseResponse + console.error("Unknown response type:", result) + } + + const _result = getStringOrJson(res.data) + + setRowValue(rowIndex, variant.variantId, _result) + ;(outputs as KeyValuePair)[variant.variantId] = _result + setRowValue(rowIndex, "evaluationFlow", EvaluationFlow.COMPARISON_RUN_STARTED) + if (idx === variants.length - 1) { + if (count === 1 || count === rowIndex) { + setEvaluationStatus(EvaluationFlow.EVALUATION_FINISHED) + } + } + + updateEvaluationScenarioData( + id, + { + outputs: Object.keys(outputs).map((key) => ({ + variant_id: key, + variant_output: outputs[key as keyof typeof outputs], + })), + inputs: rows[rowIndex].inputs, + }, + showNotification, + ) + } catch (err) { + console.error("Error running evaluation:", err) + setEvaluationStatus(EvaluationFlow.EVALUATION_FAILED) + setRowValue( + rowIndex, + variant.variantId, + err?.response?.data?.detail?.message || "Failed to run evaluation!", + ) + } + }), + ) + } + + const setRowValue = ( + rowIndex: number, + columnKey: keyof SingleModelEvaluationRow, + value: any, + ) => { + const newRows = [...rows] + newRows[rowIndex][columnKey] = value as never + setRows(newRows) + } + + const dynamicColumns: ColumnType[] = variants.map( + (variant: Variant) => { + const columnKey = variant.variantId + + return { + title: ( +
    + App Variant: + + {variants + ? variantNameWithRev({ + variant_name: variant.variantName, + revision: evaluation.revisions[0], + }) + : ""} + +
    + ), + dataIndex: columnKey, + key: columnKey, + width: "25%", + render: (text: any, record: SingleModelEvaluationRow, rowIndex: number) => { + let outputValue = text + if (!outputValue && record.outputs && record.outputs.length > 0) { + outputValue = record.outputs.find( + (output: any) => output.variant_id === columnKey, + )?.variant_output + } + return ( +
    + {outputValue} +
    + ) + }, + } + }, + ) + + const columns = [ + { + key: "1", + title: ( +
    +
    + Inputs (Test set: + {evaluation.testset.name} + ) +
    +
    + ), + width: 300, + dataIndex: "inputs", + render: (_: any, record: SingleModelEvaluationRow, rowIndex: number) => { + return ( + runEvaluation(record.id!)} + onParamChange={(name, value) => + handleInputChange( + {target: {value}} as any, + record.id, + record?.inputs.findIndex((ip) => ip.input_name === name), + ) + } + variantData={variantData} + isLoading={isVariantsLoading} + /> + ) + }, + }, + { + title: "Expected Output", + dataIndex: "expectedOutput", + key: "expectedOutput", + width: "25%", + render: (text: any, record: any, rowIndex: number) => { + const correctAnswer = + record.correctAnswer || evaluation.testset.csvdata[rowIndex].correct_answer + + return ( + <> + + depouncedUpdateEvaluationScenario( + { + correctAnswer: e.target.value, + }, + record.id, + ) + } + key={record.id} + /> + + ) + }, + }, + ...dynamicColumns, + { + title: "Score", + dataIndex: "score", + key: "score", + render: (text: any, record: any, rowIndex: number) => { + return ( + <> + { + + depouncedHandleScoreChange(record.id, val[0].score as number) + } + loading={record.score === "loading"} + showVariantName={false} + key={record.id} + outputs={record.outputs} + /> + } + + ) + }, + }, + { + title: "Additional Note", + dataIndex: "additionalNote", + key: "additionalNote", + render: (text: any, record: any, rowIndex: number) => { + return ( + <> + + depouncedUpdateEvaluationScenario({note: e.target.value}, record.id) + } + key={record.id} + /> + + ) + }, + }, + ] + + return ( +
    + {EvaluationTypeLabels.single_model_test} +
    + +
    + + + + exportSingleModelEvaluationData( + evaluation, + evaluationScenarios, + rows, + ) + } + disabled={false} + > + Export Results + + + + + + + + + + + + + + + +
    + setViewMode(e.target.value)} + value={viewMode} + optionType="button" + /> +
    + + setIsTestsetModalOpen(false)} + onSuccess={(testsetName: string) => { + message.success(`Row added to the "${testsetName}" test set!`) + setIsTestsetModalOpen(false) + }} + rows={rows} + evaluation={evaluation} + /> + + {viewMode === "tabular" ? ( +
    record.id!} + /> + ) : ( + depouncedHandleScoreChange(id, score as number)} + onInputChange={handleInputChange} + updateEvaluationScenarioData={updateEvaluationScenarioData} + evaluation={evaluation} + variantData={variantData} + isLoading={isLoading || isVariantsLoading} + /> + )} + + ) +} + +export default SingleModelEvaluationTable diff --git a/web/ee/src/components/EvaluationTable/assets/styles.ts b/web/ee/src/components/EvaluationTable/assets/styles.ts new file mode 100644 index 0000000000..ba1413d743 --- /dev/null +++ b/web/ee/src/components/EvaluationTable/assets/styles.ts @@ -0,0 +1,140 @@ +import {createUseStyles} from "react-jss" + +export const useSingleModelEvaluationTableStyles = createUseStyles({ + appVariant: { + backgroundColor: "rgb(201 255 216)", + color: "rgb(0 0 0)", + padding: 4, + borderRadius: 5, + }, + inputTestContainer: { + display: "flex", + justifyContent: "space-between", + }, + inputTest: { + backgroundColor: "rgb(201 255 216)", + color: "rgb(0 0 0)", + padding: 4, + borderRadius: 5, + }, + inputTestBtn: { + width: "100%", + display: "flex", + justifyContent: "flex-end", + "& button": { + marginLeft: 10, + }, + marginTop: "0.75rem", + }, + recordInput: { + marginBottom: 10, + }, + card: { + marginBottom: 20, + }, + statCorrect: { + "& .ant-statistic-content-value": { + color: "#3f8600", + }, + }, + statWrong: { + "& .ant-statistic-content-value": { + color: "#cf1322", + }, + }, + viewModeRow: { + display: "flex", + justifyContent: "flex-end", + margin: "1rem 0", + position: "sticky", + top: 36, + zIndex: 1, + }, + sideBar: { + marginTop: "1rem", + display: "flex", + flexDirection: "column", + gap: "2rem", + border: "1px solid #d9d9d9", + borderRadius: 6, + padding: "1rem", + alignSelf: "flex-start", + "&>h4.ant-typography": { + margin: 0, + }, + flex: 0.35, + minWidth: 240, + maxWidth: 500, + }, +}) + +export const useABTestingEvaluationTableStyles = createUseStyles({ + appVariant: { + padding: 4, + borderRadius: 5, + }, + inputTestContainer: { + display: "flex", + justifyContent: "space-between", + }, + inputTest: { + backgroundColor: "rgb(201 255 216)", + color: "rgb(0 0 0)", + padding: 4, + borderRadius: 5, + }, + inputTestBtn: { + width: "100%", + display: "flex", + justifyContent: "flex-end", + "& button": { + marginLeft: 10, + }, + marginTop: "0.75rem", + }, + recordInput: { + marginBottom: 10, + }, + card: { + marginBottom: 20, + }, + statCorrect: { + "& .ant-statistic-content-value": { + color: "#3f8600", + }, + }, + stat: { + "& .ant-statistic-content-value": { + color: "#1677ff", + }, + }, + statWrong: { + "& .ant-statistic-content-value": { + color: "#cf1322", + }, + }, + viewModeRow: { + display: "flex", + justifyContent: "flex-end", + margin: "1rem 0", + position: "sticky", + top: 36, + zIndex: 1, + }, + sideBar: { + marginTop: "1rem", + display: "flex", + flexDirection: "column", + gap: "2rem", + border: "1px solid #d9d9d9", + borderRadius: 6, + padding: "1rem", + alignSelf: "flex-start", + "&>h4.ant-typography": { + margin: 0, + }, + flex: 0.35, + minWidth: 240, + maxWidth: 500, + }, +}) diff --git a/web/ee/src/components/EvaluationTable/components/ParamsFormWithRun.tsx b/web/ee/src/components/EvaluationTable/components/ParamsFormWithRun.tsx new file mode 100644 index 0000000000..9857286493 --- /dev/null +++ b/web/ee/src/components/EvaluationTable/components/ParamsFormWithRun.tsx @@ -0,0 +1,148 @@ +// @ts-nocheck +import {useMemo} from "react" + +import {CaretRightOutlined} from "@ant-design/icons" +import {Button, Form} from "antd" +import {atom, useAtomValue} from "jotai" + +import ParamsForm from "@/oss/components/ParamsForm" +import {useLegacyVariants} from "@/oss/lib/hooks/useLegacyVariant" +import type {Evaluation} from "@/oss/lib/Types" +import {inputParamsAtomFamily} from "@/oss/state/newPlayground/core/inputParams" +import {stablePromptVariablesAtomFamily} from "@/oss/state/newPlayground/core/prompts" +import {variantFlagsAtomFamily} from "@/oss/state/newPlayground/core/variantFlags" +import {appUriInfoAtom} from "@/oss/state/variant/atoms/fetcher" + +import {useSingleModelEvaluationTableStyles} from "../assets/styles" +import type {SingleModelEvaluationRow} from "../types" + +/** + * + * @param evaluation - Evaluation object + * @param evaluationScenarios - Evaluation rows + * @param columnsCount - Number of variants to compare face to face (per default 2) + * @returns + */ +const ParamsFormWithRun = ({ + evaluation, + record, + rowIndex, + onRun, + onParamChange, + variantData = [], + isLoading, +}: { + record: SingleModelEvaluationRow + rowIndex: number + evaluation: Evaluation + onRun: () => void + onParamChange: (name: string, value: any) => void + variantData: ReturnType + isLoading: boolean +}) => { + const classes = useSingleModelEvaluationTableStyles() + const [form] = Form.useForm() + const selectedVariant = variantData?.[0] + const routePath = useAtomValue(appUriInfoAtom)?.routePath + const hasRevision = Boolean(selectedVariant && (selectedVariant as any).id) + // Memoize the atom-family selector only when we have a proper revision and route + const inputParamsSelector = useMemo( + () => + (hasRevision && routePath + ? inputParamsAtomFamily({variant: selectedVariant as any, routePath}) + : atom([])) as any, + [hasRevision ? (selectedVariant as any).id : undefined, routePath], + ) + const baseInputParams = useAtomValue(inputParamsSelector) as any[] + // Stable variables derived from saved prompts (spec + saved parameters; no live mutations) + const stableVariableNames = useAtomValue( + selectedVariant?.id + ? (stablePromptVariablesAtomFamily((selectedVariant as any).id) as any) + : atom([]), + ) as string[] + const flags = useAtomValue( + selectedVariant?.id + ? (variantFlagsAtomFamily({revisionId: (selectedVariant as any).id}) as any) + : atom({}), + ) as any + + // Build input params similar to EvaluationCardView with robust fallbacks + const testsetRow = evaluation?.testset?.csvdata?.[rowIndex] || {} + const chatCol = evaluation?.testset?.testsetChatColumn + const reservedKeys = new Set(["correct_answer", chatCol || ""]) as Set + + const derivedInputParams = useMemo((): any[] => { + const haveSchema = Array.isArray(baseInputParams) && baseInputParams.length > 0 + let source: any[] + if (haveSchema) { + source = baseInputParams + } else if (Array.isArray(record?.inputs) && record.inputs.length > 0) { + source = record.inputs + .filter((ip: any) => (chatCol ? ip.input_name !== chatCol : true)) + .map((ip: any) => ({name: ip.input_name, type: "string"})) + } else { + source = Object.keys(testsetRow) + .filter((k) => !reservedKeys.has(k)) + .map((k) => ({name: k, type: "string"})) + } + // Filter to stable variables only for non-custom apps + if ( + !flags?.isCustom && + Array.isArray(stableVariableNames) && + stableVariableNames.length > 0 + ) { + const allow = new Set( + stableVariableNames.filter((name) => (chatCol ? name !== chatCol : true)), + ) + source = (source || []).filter((p: any) => allow.has(p?.name)) + } + + return (source || []).map((item: any) => ({ + ...item, + value: + record?.inputs?.find((ip: any) => ip.input_name === item.name)?.input_value ?? + (testsetRow as any)?.[item.name] ?? + "", + })) + }, [baseInputParams, record?.inputs, testsetRow, chatCol, stableVariableNames, flags?.isCustom]) + + return isLoading ? null : ( +
    +
    + {evaluation.testset.testsetChatColumn && ( +
    + {evaluation.testset.csvdata[rowIndex][ + evaluation.testset.testsetChatColumn + ] || " - "} +
    + )} + {derivedInputParams && derivedInputParams.length > 0 ? ( + { + // Ensure local row inputs are updated before invoking run + Object.entries(values || {}).forEach(([k, v]) => + onParamChange(k as string, v), + ) + onRun() + }} + key={`${record.id}-${rowIndex}`} + form={form} + /> + ) : null} +
    +
    + +
    +
    + ) +} + +export default ParamsFormWithRun diff --git a/web/ee/src/components/EvaluationTable/types.d.ts b/web/ee/src/components/EvaluationTable/types.d.ts new file mode 100644 index 0000000000..18006be8d6 --- /dev/null +++ b/web/ee/src/components/EvaluationTable/types.d.ts @@ -0,0 +1,21 @@ +import {EvaluationFlow} from "@/oss/lib/enums" +import {Evaluation, EvaluationScenario} from "@/oss/lib/Types" + +export interface EvaluationTableProps { + evaluation: Evaluation + evaluationScenarios: SingleModelEvaluationRow[] + isLoading: boolean +} + +export type SingleModelEvaluationRow = EvaluationScenario & { + evaluationFlow: EvaluationFlow +} & Record + +export interface ABTestingEvaluationTableProps extends EvaluationTableProps { + evaluationScenarios: ABTestingEvaluationTableRow[] + columnsCount: number +} + +export type ABTestingEvaluationTableRow = EvaluationScenario & { + evaluationFlow: EvaluationFlow +} & Record diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationCard.tsx b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationCard.tsx new file mode 100644 index 0000000000..4b825fc75f --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationCard.tsx @@ -0,0 +1,78 @@ +import {memo} from "react" + +import {createUseStyles} from "react-jss" + +import type {ABTestingEvaluationTableRow} from "@/oss/components/EvaluationTable/types" +import type {Evaluation, Variant} from "@/oss/lib/Types" + +import EvaluationChatResponse from "./EvaluationChatResponse" +import EvaluationVariantCard from "./EvaluationVariantCard" + +const useStyles = createUseStyles({ + root: { + display: "flex", + gap: "1rem", + flexWrap: "wrap", + }, +}) + +interface Props { + evaluationScenario: ABTestingEvaluationTableRow + variants: Variant[] + isChat?: boolean + showVariantName?: boolean + evaluation: Evaluation +} + +const EvaluationCard: React.FC = ({ + evaluationScenario, + variants, + isChat, + showVariantName = true, + evaluation, +}) => { + const classes = useStyles() + + return ( +
    + {variants.map((variant, ix) => + isChat ? ( + item.variant_id) + ?.variant_output || + "" + } + index={ix} + showVariantName={showVariantName} + evaluation={evaluation} + /> + ) : ( + item.variant_id) + ?.variant_output || + "" + } + index={ix} + showVariantName={showVariantName} + evaluation={evaluation} + //random image from unsplash + // outputImg={`https://fps.cdnpk.net/images/home/subhome-ai.webp?w=649&h=649`} + /> + ), + )} +
    + ) +} + +export default memo(EvaluationCard) diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationChatResponse.tsx b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationChatResponse.tsx new file mode 100644 index 0000000000..8e66303bd6 --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationChatResponse.tsx @@ -0,0 +1,69 @@ +import {memo, useMemo} from "react" + +import {Space, Typography} from "antd" +import {createUseStyles} from "react-jss" +import {v4 as uuidv4} from "uuid" + +import ChatInputs from "@/oss/components/ChatInputs/ChatInputs" +import {safeParse} from "@/oss/lib/helpers/utils" +import {ChatRole, Evaluation, Variant} from "@/oss/lib/Types" + +import {VARIANT_COLORS} from "./assets/styles" +import VariantAlphabet from "./VariantAlphabet" + +const useStyles = createUseStyles({ + title: { + fontSize: 20, + textAlign: "center", + }, +}) + +interface Props { + variant: Variant + outputText?: string + index?: number + showVariantName?: boolean + evaluation: Evaluation +} + +const EvaluationChatResponse: React.FC = ({ + variant, + outputText, + index = 0, + showVariantName = true, + evaluation, +}) => { + const classes = useStyles() + const color = VARIANT_COLORS[index] + const parsedOutput = safeParse(outputText || "", null) + const messageContent = + parsedOutput && typeof parsedOutput === "object" && "content" in parsedOutput + ? parsedOutput.content + : outputText || "" + + const chatValue = useMemo( + () => [{role: ChatRole.Assistant, content: messageContent, id: uuidv4()}], + [messageContent], + ) + + return ( + + {showVariantName && ( + + + + {variant.variantName}{" "} + {evaluation.revisions[index] && ( + + v{evaluation.revisions[index]} + + )} + + + )} + + + ) +} + +export default memo(EvaluationChatResponse) diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationInputs.tsx b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationInputs.tsx new file mode 100644 index 0000000000..21784abfde --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationInputs.tsx @@ -0,0 +1,50 @@ +import {Input, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {EvaluationScenario} from "@/oss/lib/Types" + +const useStyles = createUseStyles({ + root: { + display: "flex", + gap: "1rem", + flexDirection: "column", + }, + inputRow: { + display: "flex", + flexDirection: "column", + gap: "0.25rem", + "& .ant-typography": { + textTransform: "capitalize", + }, + "& textarea": { + width: "100%", + }, + }, +}) + +interface Props { + evaluationScenario: EvaluationScenario + onInputChange: Function +} + +const EvaluationInputs: React.FC = ({evaluationScenario, onInputChange}) => { + const classes = useStyles() + + return ( +
    + {evaluationScenario.inputs.map((ip, ix) => ( +
    + {ip.input_name}: + onInputChange(e, evaluationScenario.id, ix)} + /> +
    + ))} +
    + ) +} + +export default EvaluationInputs diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationVariantCard.tsx b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationVariantCard.tsx new file mode 100644 index 0000000000..58d487656c --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationVariantCard.tsx @@ -0,0 +1,105 @@ +import {Typography} from "antd" +import Image from "next/image" +import {createUseStyles} from "react-jss" + +import {useAppTheme} from "@/oss/components/Layout/ThemeContextProvider" +import {Evaluation, Variant, StyleProps} from "@/oss/lib/Types" + +import {VARIANT_COLORS} from "./assets/styles" + +const useStyles = createUseStyles({ + root: ({themeMode}: StyleProps) => ({ + flex: 1, + display: "flex", + flexDirection: "column", + alignItems: "center", + gap: "0.75rem", + border: `1px solid ${themeMode === "dark" ? "#424242" : "#d9d9d9"}`, + padding: "0.75rem", + paddingTop: "1.25rem", + borderRadius: 6, + "& img": { + maxHeight: 300, + width: "100%", + objectFit: "contain", + borderRadius: "inherit", + }, + position: "relative", + }), + title: { + fontSize: 20, + textAlign: "center", + }, + output: { + whiteSpace: "pre-line", + position: "relative", + maxHeight: 300, + overflow: "auto", + }, + variantType: { + position: "absolute", + top: 10, + left: 10, + borderRadius: "50%", + border: `1.5px solid`, + width: 32, + aspectRatio: "1/1", + display: "grid", + placeItems: "center", + + "& .ant-typography": { + fontSize: 18, + }, + }, +}) + +interface Props { + variant: Variant + outputText?: string + outputImg?: string + index?: number + showVariantName?: boolean + evaluation: Evaluation +} + +const EvaluationVariantCard: React.FC = ({ + variant, + outputText, + outputImg, + index = 0, + showVariantName = true, + evaluation, +}) => { + const {appTheme} = useAppTheme() + const classes = useStyles({themeMode: appTheme} as StyleProps) + const color = VARIANT_COLORS[index] + + return ( +
    + {showVariantName && ( + <> + {" "} +
    + + {String.fromCharCode(65 + index)} + +
    + + {variant.variantName}{" "} + {evaluation.revisions[index] && ( + + v{evaluation.revisions[index]} + + )} + {" "} + + )} + {outputImg && } + + {outputText || Click the "Run" icon to get variant output} + +
    + ) +} + +export default EvaluationVariantCard diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationVotePanel.tsx b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationVotePanel.tsx new file mode 100644 index 0000000000..8f309454d7 --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/EvaluationVotePanel.tsx @@ -0,0 +1,405 @@ +import {StarFilled} from "@ant-design/icons" +import {Button, ConfigProvider, InputNumber, Rate, Spin, Typography, theme} from "antd" +import {createUseStyles} from "react-jss" + +import {Variant} from "@/oss/lib/Types" + +import {VARIANT_COLORS} from "./assets/styles" + +const useStyles = createUseStyles({ + root: { + display: "flex", + justifyContent: "center", + width: "100%", + }, + btnRow: { + display: "flex", + gap: "0.5rem", + }, + gradeRoot: { + display: "flex", + alignItems: "center", + gap: "1.5rem", + }, + variantName: { + display: "inline-block", + marginBottom: "0.25rem", + }, + btnsDividerHorizontal: { + height: 30, + borderRight: "1.2px solid", + alignSelf: "center", + margin: "0 4px", + }, + btnsDividerVertical: { + width: 120, + borderBottom: "1.2px solid", + alignSelf: "center", + margin: "4px 0", + }, +}) + +interface CommonProps { + onChange: (value: T) => void + value?: T + vertical?: boolean +} + +type BinaryVoteProps = CommonProps + +const BinaryVote: React.FC = ({onChange, value, vertical}) => { + const classes = useStyles() + + const getOnClick = (isGood: boolean) => () => { + onChange(isGood) + } + + return ( +
    + + +
    + ) +} + +type ComparisonVoteProps = { + variants: Variant[] + outputs: any +} & CommonProps + +const ComparisonVote: React.FC = ({ + variants, + onChange, + value, + vertical, + outputs, +}) => { + const classes = useStyles() + const {token} = theme.useToken() + const badId = "0" + const goodId = "1" + + const getOnClick = (variantId: string) => () => { + onChange(variantId) + } + + return ( +
    + {variants.map((variant, ix) => ( + + + + ))} +
    + + + + +
    + ) +} + +type GradingVoteProps = { + variants: Variant[] + maxGrade?: number +} & CommonProps< + { + grade: number | null + variantId: string + }[] +> + +const GradingVote: React.FC = ({ + variants, + onChange, + value = [], + maxGrade = 5, + vertical, +}) => { + const classes = useStyles() + + const getOnClick = (variantId: string, grade: number) => () => { + onChange( + variants.map((variant) => ({ + variantId: variant.variantId, + grade: variant.variantId === variantId ? grade : null, + })), + ) + } + + return ( +
    + {variants.map((variant, ix) => ( +
    + + {variant.variantName} + +
    + {Array.from({length: maxGrade}, (_, i) => i + 1).map((grade) => ( + + ))} +
    +
    + ))} +
    + ) +} + +type NumericScoreVoteProps = { + variants: Variant[] + min?: number + max?: number + showVariantName?: boolean + outputs: any +} & CommonProps< + { + score: number | null + variantId: string + }[] +> + +const NumericScoreVote: React.FC = ({ + variants, + onChange, + value = [], + min = 0, + max = 100, + vertical, + showVariantName = true, + outputs, +}) => { + const classes = useStyles() + + const _onChange = (variantId: string, score: number | null) => { + onChange( + variants.map((variant) => ({ + variantId: variant.variantId, + score: variant.variantId === variantId ? score : null, + })), + ) + } + + return ( +
    + {variants.map((variant, ix) => ( +
    + {showVariantName && ( + + {variant.variantName} + + )} +
    + item.variantId === variant.variantId)?.score ?? + undefined + } + min={min} + max={max} + onChange={(score) => _onChange(variant.variantId, score)} + disabled={!outputs?.length} + /> + / {max} +
    +
    + ))} +
    + ) +} + +type RatingVoteProps = NumericScoreVoteProps + +const RatingVote: React.FC = ({ + variants, + onChange, + value = [], + vertical, + showVariantName = true, + outputs, +}) => { + const classes = useStyles() + + const _onChange = (variantId: string, score: number | null) => { + onChange( + variants.map((variant) => ({ + variantId: variant.variantId, + score: variant.variantId === variantId ? score : null, + })), + ) + } + + return ( +
    + {variants.map((variant, ix) => { + const score = value.find((item) => item.variantId === variant.variantId)?.score + const finalValue = typeof score !== "number" ? null : score / 25 + 1 + + return ( +
    + {showVariantName && ( + + {variant.variantName} + + )} +
    + { + const rateColors: Record = { + 1: "#D61010", + 2: "#FFA940", + 3: "#FADB14", + 4: "#BAE637", + 5: "#73D13D", + } + + return ( + index + ? rateColors[value] || "#d9d9d9" + : "#d9d9d9", + }} + /> + ) + }} + onChange={(score) => { + const finalScore = (score - 1) * 25 + _onChange(variant.variantId, finalScore) + }} + disabled={!outputs?.length} + /> +
    +
    + ) + })} +
    + ) +} + +type Props = + | ({ + type: "binary" + } & BinaryVoteProps) + | ({ + type: "comparison" + } & ComparisonVoteProps) + | ({ + type: "grading" + } & GradingVoteProps) + | ({ + type: "numeric" + } & NumericScoreVoteProps) + | ({ + type: "rating" + } & RatingVoteProps) + +const EvaluationVotePanel: React.FC = ({type, loading, ...props}) => { + const classes = useStyles() + + return ( +
    + + {type === "binary" ? ( + + ) : type === "comparison" ? ( + + ) : type === "grading" ? ( + + ) : type === "rating" ? ( + + ) : ( + + )} + +
    + ) +} + +export default EvaluationVotePanel diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/VariantAlphabet.tsx b/web/ee/src/components/Evaluations/EvaluationCardView/VariantAlphabet.tsx new file mode 100644 index 0000000000..da0e948a42 --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/VariantAlphabet.tsx @@ -0,0 +1,44 @@ +import {Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {VARIANT_COLORS} from "./assets/styles" + +interface StyleProps { + color: string + width: number +} + +const useStyles = createUseStyles({ + variantType: { + borderRadius: "50%", + border: `1.5px solid`, + borderColor: ({color}: StyleProps) => color, + width: ({width}: StyleProps) => width, + aspectRatio: "1/1", + display: "inline-flex", + justifyContent: "center", + alignItems: "center", + "& .ant-typography": { + fontSize: ({width}: StyleProps) => width / 1.75, + color: ({color}: StyleProps) => color, + }, + }, +}) + +interface Props { + index: number + width?: number +} + +const VariantAlphabet: React.FC = ({index, width = 28}) => { + const color = VARIANT_COLORS[index] + const classes = useStyles({width, color} as StyleProps) + + return ( +
    + {String.fromCharCode(65 + index)} +
    + ) +} + +export default VariantAlphabet diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/assets/styles.ts b/web/ee/src/components/Evaluations/EvaluationCardView/assets/styles.ts new file mode 100644 index 0000000000..eb466a544e --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/assets/styles.ts @@ -0,0 +1,108 @@ +import {createUseStyles} from "react-jss" + +export const VARIANT_COLORS = [ + "#297F87", // "#722ed1", + "#F6D167", //"#13c2c2", + "#4caf50", +] + +export const useStyles = createUseStyles({ + root: { + display: "flex", + gap: "1rem", + outline: "none", + }, + evaluation: { + flex: 1, + display: "flex", + flexDirection: "column", + padding: "1rem", + "& .ant-divider": { + margin: "2rem 0 1.5rem 0", + }, + "& h5.ant-typography": { + margin: 0, + marginBottom: "1rem", + }, + gap: "1rem", + }, + heading: { + width: "100%", + display: "flex", + justifyContent: "space-between", + alignItems: "center", + gap: "0.75rem", + "& .ant-typography": { + margin: 0, + fontWeight: 400, + }, + }, + headingDivider: { + position: "relative", + }, + helpIcon: { + position: "absolute", + right: 0, + top: 42, + fontSize: 16, + }, + instructions: { + paddingInlineStart: 0, + "& code": { + backgroundColor: "rgba(0, 0, 0, 0.05)", + padding: "0.1rem 0.3rem", + borderRadius: 3, + }, + "& li": { + marginBottom: "0.5rem", + }, + }, + note: { + marginTop: "1.25rem", + marginBottom: "-1rem", + whiteSpace: "pre-line", + display: "flex", + alignItems: "flex-start", + + "& .anticon": { + marginTop: 4, + }, + }, + chatInputsCon: { + marginTop: "0.5rem", + }, + correctAnswerCon: { + marginBottom: "0.5rem", + }, + toolBar: { + display: "flex", + alignItems: "center", + gap: "0.5rem", + justifyContent: "flex-end", + "& .anticon": { + fontSize: 18, + cursor: "pointer", + }, + }, + sideBar: { + marginTop: "1rem", + display: "flex", + flexDirection: "column", + gap: "2rem", + border: "1px solid #d9d9d9", + borderRadius: 6, + padding: "1rem", + alignSelf: "flex-start", + "&>h4.ant-typography": { + margin: 0, + }, + flex: 0.35, + minWidth: 240, + maxWidth: 500, + }, + centeredItem: { + display: "grid", + placeItems: "center", + width: "100%", + }, +}) diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/index.tsx b/web/ee/src/components/Evaluations/EvaluationCardView/index.tsx new file mode 100644 index 0000000000..96aef26f30 --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/index.tsx @@ -0,0 +1,504 @@ +// @ts-nocheck +import {useCallback, useEffect, useMemo, useRef} from "react" + +import { + LeftOutlined, + LoadingOutlined, + PlayCircleOutlined, + QuestionCircleOutlined, + RightOutlined, +} from "@ant-design/icons" +import {Button, Empty, Form, Input, Result, Space, Tooltip, Typography, theme} from "antd" +import {atom, useAtomValue} from "jotai" +import debounce from "lodash/debounce" +import {useLocalStorage} from "usehooks-ts" + +import AlertPopup from "@/oss/components/AlertPopup/AlertPopup" +import ParamsForm from "@/oss/components/ParamsForm" +import {useQueryParam} from "@/oss/hooks/useQuery" +import {EvaluationType} from "@/oss/lib/enums" +import {testsetRowToChatMessages} from "@/oss/lib/helpers/testset" +import useStatelessVariants from "@/oss/lib/hooks/useStatelessVariants" +import type {ChatMessage, EvaluationScenario} from "@/oss/lib/Types" +import {inputParamsAtomFamily} from "@/oss/state/newPlayground/core/inputParams" +import {stablePromptVariablesAtomFamily} from "@/oss/state/newPlayground/core/prompts" +import {variantFlagsAtomFamily} from "@/oss/state/newPlayground/core/variantFlags" +import {appUriInfoAtom} from "@/oss/state/variant/atoms/fetcher" + +import {useStyles} from "./assets/styles" +import EvaluationCard from "./EvaluationCard" +import EvaluationVotePanel from "./EvaluationVotePanel" +import type {EvaluationCardViewProps} from "./types" + +const EvaluationCardView: React.FC = ({ + variants, + evaluationScenarios, + onRun, + onVote, + onInputChange, + updateEvaluationScenarioData, + evaluation, + variantData = [], + isLoading, +}) => { + const classes = useStyles() + const {token} = theme.useToken() + const [evaluationsState, setEvaluationsState] = useLocalStorage< + Record + >("evaluationsState", {}) + + const [scenarioId, setScenarioId] = useQueryParam( + "evaluationScenario", + evaluationsState[evaluation.id]?.lastVisitedScenario || evaluationScenarios[0]?.id || "", + ) + const [instructionsShown, setInstructionsShown] = useLocalStorage( + "evalInstructionsShown", + false, + ) + const {scenario, scenarioIndex} = useMemo(() => { + const scenarioIndex = evaluationScenarios.findIndex( + (scenario) => scenario.id === scenarioId, + ) + return {scenario: evaluationScenarios[scenarioIndex], scenarioIndex} + }, [scenarioId, evaluationScenarios]) + + useEffect(() => { + setEvaluationsState((prevEvaluationsState) => ({ + ...prevEvaluationsState, + [evaluation.id]: { + ...(prevEvaluationsState[evaluation.id] || {}), + lastVisitedScenario: scenarioId, + }, + })) + }, [scenarioId]) + + const rootRef = useRef(null) + const opened = useRef(false) + const callbacks = useRef({ + onVote, + onRun, + onInputChange, + }) + const isChat = !!evaluation.testset.testsetChatColumn + const testsetRow = evaluation.testset.csvdata[scenarioIndex] + const isAbTesting = evaluation.evaluationType === EvaluationType.human_a_b_testing + const [form] = Form.useForm() + const {_variants: _allStatelessVariants} = useStatelessVariants() as any + + const loadPrevious = () => { + if (scenarioIndex === 0) return + setScenarioId(evaluationScenarios[scenarioIndex - 1].id) + } + + const loadNext = () => { + if (scenarioIndex === evaluationScenarios.length - 1) return + setScenarioId(evaluationScenarios[scenarioIndex + 1].id) + } + + const showInstructions = useCallback(() => { + if (opened.current) return + + opened.current = true + AlertPopup({ + title: "Instructions", + type: "info", + message: ( +
      +
    1. + Use the buttons Next and Prev or the arrow keys{" "} + {`Left (<)`} and {`Right (>)`} to navigate between + evaluations. +
    2. +
    3. + Click the Run{" "} + button on + right or press {`Enter (↵)`} key to generate the variants' + outputs. +
    4. + {isAbTesting && ( +
    5. + Vote by either clicking the evaluation buttons at the right + sidebar or pressing the key a for 1st Variant,{" "} + b for 2nd Variant and x if both are bad. +
    6. + )} +
    7. + Add a note to an evaluation from the Additional Notes input section{" "} + in the right sidebar. +
    8. +
    + ), + okText: Ok, + cancelText: null, + width: 500, + onCancel: () => (opened.current = false), + onOk: () => (opened.current = false), + }) + }, []) + + const depouncedUpdateEvaluationScenario = useCallback( + debounce((data: Partial) => { + updateEvaluationScenarioData(scenarioId, data) + }, 800), + [scenarioId], + ) + + const onChatChange = (chat: ChatMessage[]) => { + const stringified = JSON.stringify(chat) + testsetRow[evaluation.testset.testsetChatColumn] = stringified + + depouncedUpdateEvaluationScenario({ + inputs: [ + {input_name: "chat", input_value: stringified}, + ...scenario.inputs.filter( + (ip: {input_name: string; input_value: string}) => ip.input_name !== "chat", + ), + ], + [evaluation.testset.testsetChatColumn]: stringified, + }) + } + + //hack to always get the latest callbacks using ref + useEffect(() => { + callbacks.current = {onVote, onRun, onInputChange} + }, [onVote, onRun, onInputChange]) + + // focus the root element on mount + useEffect(() => { + if (rootRef.current) { + rootRef.current.focus() + } + }, []) + + useEffect(() => { + if (!instructionsShown) { + showInstructions() + setInstructionsShown(true) + } + }, [instructionsShown]) + + useEffect(() => { + if (typeof window === "undefined") return () => {} + + const listener = (e: KeyboardEvent) => { + if (document.activeElement !== rootRef.current) return + if (e.key === "ArrowLeft") loadPrevious() + else if (e.key === "ArrowRight") loadNext() + else if (e.key === "Enter") callbacks.current.onRun(scenarioId) + + if (isAbTesting) { + if (e.key === "a") callbacks.current.onVote(scenarioId, variants[0].variantId) + else if (e.key === "b") callbacks.current.onVote(scenarioId, variants[1].variantId) + else if (e.key === "x") callbacks.current.onVote(scenarioId, "0") + } + } + + document.addEventListener("keydown", listener) + return () => document.removeEventListener("keydown", listener) + }, [scenarioIndex]) + + useEffect(() => { + if (scenario) { + const chatStr = scenario?.inputs.find( + (ip: {input_name: string; input_value: string}) => ip.input_name === "chat", + )?.input_value + if (chatStr) testsetRow[evaluation.testset.testsetChatColumn] = chatStr + } + }, [scenario]) + + const correctAnswer = useMemo(() => { + if (scenario?.correctAnswer) return scenario.correctAnswer + const res = testsetRow?.correct_answer + return res || "" + }, [testsetRow?.correct_answer, scenario?.correctAnswer]) + + const chat = useMemo(() => { + const fromInput = scenario?.inputs.find( + (ip: {input_name: string; input_value: string}) => ip.input_name === "chat", + )?.input_value + if (!isChat) return [] + + return testsetRowToChatMessages( + fromInput + ? {chat: fromInput, correct_answer: testsetRow?.correct_answer} + : testsetRow || {}, + false, + ) + }, [scenarioIndex]) + + const routePath = useAtomValue(appUriInfoAtom)?.routePath + const selectedRevisionId = (variantData?.[0] as any)?.id as string | undefined + const hasRevision = Boolean(variantData?.[0] && selectedRevisionId) + const inputParamsSelector = useMemo( + () => + (hasRevision && routePath + ? inputParamsAtomFamily({variant: variantData[0] as any, routePath}) + : atom([])) as any, + [hasRevision ? (variantData?.[0] as any)?.id : undefined, routePath], + ) + const baseInputParams = useAtomValue(inputParamsSelector) as any[] + // // Stable variables derived from saved prompts (spec + saved parameters; no live mutations) + const variableNames = useAtomValue( + hasRevision ? (stablePromptVariablesAtomFamily(selectedRevisionId!) as any) : atom([]), + ) as string[] + // Avoid creating new atoms during render to prevent infinite update loops + const emptyObjAtom = useMemo(() => atom({}), []) + const stableFlagsParam = useMemo( + () => (selectedRevisionId ? {revisionId: selectedRevisionId} : undefined), + [selectedRevisionId], + ) + const flags = useAtomValue( + hasRevision && stableFlagsParam + ? (variantFlagsAtomFamily(stableFlagsParam) as any) + : (emptyObjAtom as any), + ) as any + + const derivedInputParams = useMemo(() => { + const haveSchemaParams = Array.isArray(baseInputParams) && baseInputParams.length > 0 + + // Determine candidate field names + let sourceParams: any[] = [] + if (haveSchemaParams) { + sourceParams = baseInputParams + } else if (Array.isArray(scenario?.inputs) && scenario.inputs.length > 0) { + sourceParams = scenario.inputs + .filter((ip: any) => (isChat ? ip.input_name !== "chat" : true)) + .map((ip: any) => ({name: ip.input_name, type: "string"})) + } else { + const reserved = new Set([ + "correct_answer", + evaluation?.testset?.testsetChatColumn || "", + ]) + const row = testsetRow || {} + sourceParams = Object.keys(row) + .filter((k) => !reserved.has(k)) + .map((k) => ({name: k, type: "string"})) + } + // Display only stable inputs: filter to stable variable names for non-custom apps + // For chat apps, exclude the reserved "chat" key (handled separately below). + if (!flags?.isCustom && Array.isArray(variableNames) && variableNames.length > 0) { + const allow = new Set(variableNames.filter((name) => (isChat ? name !== "chat" : true))) + sourceParams = (sourceParams || []).filter((p: any) => allow.has(p?.name)) + } + + const withValues = (sourceParams || []).map((item: any) => { + const fromScenario = scenario?.inputs.find( + (ip: {input_name: string; input_value: string}) => ip.input_name === item.name, + )?.input_value + const fromRow = (testsetRow as any)?.[item.name] + return { + ...item, + value: fromScenario ?? fromRow ?? "", + } + }) + + if (isChat) { + return [...withValues, {name: "chat", type: "string", value: chat}] + } + return withValues + }, [ + baseInputParams, + scenario?.inputs, + isChat, + chat, + evaluation?.testset?.testsetChatColumn, + testsetRow, + variableNames, + flags?.isCustom, + ]) + + const handleRun = useCallback(async () => { + try { + // Persist current derived inputs into scenario if missing, so runner sees them + const nextInputs = (derivedInputParams || []) + .filter((p: any) => p.name !== "chat") + .map((p: any) => ({input_name: p.name, input_value: p.value ?? ""})) + + if (Array.isArray(nextInputs) && nextInputs.length > 0) { + await updateEvaluationScenarioData(scenarioId, {inputs: nextInputs}) + } + } catch (e) { + console.warn("[EvaluationCardView] failed to persist inputs before run", e) + } + onRun(scenarioId) + }, [derivedInputParams, scenarioId, onRun, updateEvaluationScenarioData]) + + return ( +
    + {isLoading ? ( + } /> + ) : scenario ? ( + <> +
    +
    + +

    + Evaluation: {scenarioIndex + 1}/{evaluationScenarios.length} +

    + +
    + +
    + Inputs + {derivedInputParams.length > 0 || isChat ? ( + { + if (isChat && name === "chat") return onChatChange(value) + const idx = + scenario?.inputs?.findIndex( + (ip: any) => ip.input_name === name, + ) ?? -1 + if (idx === -1) { + // If the input key does not exist yet (cold load fallback), persist it + const nextInputs = [ + {input_name: name, input_value: value}, + ...((scenario?.inputs || []).filter( + (ip: any) => ip.input_name !== name, + ) as any[]), + ] + updateEvaluationScenarioData(scenarioId, { + inputs: nextInputs as any, + }) + } else { + onInputChange({target: {value}} as any, scenarioId, idx) + } + }} + inputParams={derivedInputParams} + key={`${scenarioId}-${(variantData?.[0] as any)?.id || ""}`} + useChatDefaultValue + form={form} + onFinish={handleRun} + imageSize="large" + /> + ) : null} +
    + +
    + + + + + onRun(scenarioId) : form.submit} + /> + +
    + +
    +
    + {!isAbTesting ? ( + + Model Response + + ) : ( + + Outputs + + )} +
    + + +
    +
    + +
    +

    Submit your feedback

    + {scenario.outputs.length > 0 && + scenario.outputs.every((item) => !!item.variant_output) && ( + + + {isAbTesting + ? "Which response is better?" + : "Rate the response"} + + {isAbTesting ? ( + onVote(scenarioId, vote)} + loading={scenario.vote === "loading"} + vertical + key={scenarioId} + outputs={scenario.outputs} + /> + ) : ( + onVote(scenarioId, val[0].score)} + loading={scenario.score === "loading"} + showVariantName={false} + key={scenarioId} + outputs={scenario.outputs} + /> + )} + + )} + + + Expected Answer + + depouncedUpdateEvaluationScenario({ + correctAnswer: e.target.value, + }) + } + key={scenarioId} + /> + + + + Additional Notes + + depouncedUpdateEvaluationScenario({note: e.target.value}) + } + key={scenarioId} + /> + +
    + + ) : ( + + )} +
    + ) +} + +export default EvaluationCardView diff --git a/web/ee/src/components/Evaluations/EvaluationCardView/types.d.ts b/web/ee/src/components/Evaluations/EvaluationCardView/types.d.ts new file mode 100644 index 0000000000..e0ad386f0b --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationCardView/types.d.ts @@ -0,0 +1,15 @@ +import type {ABTestingEvaluationTableRow} from "@/oss/components/EvaluationTable/ABTestingEvaluationTable" +import {useLegacyVariants} from "@/oss/lib/hooks/useLegacyVariant" +import type {Evaluation, EvaluationScenario, Variant} from "@/oss/lib/Types" + +export interface EvaluationCardViewProps { + variants: Variant[] + evaluationScenarios: ABTestingEvaluationTableRow[] + onRun: (id: string) => void + onVote: (id: string, vote: string | number | null) => void + onInputChange: Function + updateEvaluationScenarioData: (id: string, data: Partial) => void + evaluation: Evaluation + variantData: ReturnType + isLoading: boolean +} diff --git a/web/ee/src/components/Evaluations/EvaluationErrorModal.tsx b/web/ee/src/components/Evaluations/EvaluationErrorModal.tsx new file mode 100644 index 0000000000..11f0fa2e1a --- /dev/null +++ b/web/ee/src/components/Evaluations/EvaluationErrorModal.tsx @@ -0,0 +1,48 @@ +import {Modal, Button} from "antd" +import {createUseStyles} from "react-jss" + +const useStyles = createUseStyles({ + container: { + display: "flex", + justifyContent: "flex-end", + gap: 10, + }, +}) + +interface Props { + isModalOpen: boolean + handleNavigate: () => void + message: string + btnText: string + onClose: () => void +} + +const EvaluationErrorModal: React.FC = ({ + isModalOpen, + handleNavigate, + message, + btnText, + onClose, +}) => { + const classes = useStyles() + const handleCloseModal = () => onClose() + + const handleCTAClick = () => { + handleNavigate() + handleCloseModal() + } + + return ( + +

    {message}

    +
    + + +
    +
    + ) +} + +export default EvaluationErrorModal diff --git a/web/ee/src/components/Evaluations/HumanEvaluationResult.tsx b/web/ee/src/components/Evaluations/HumanEvaluationResult.tsx new file mode 100644 index 0000000000..e69de29bb2 diff --git a/web/ee/src/components/Evaluations/ShareEvaluationModal.tsx b/web/ee/src/components/Evaluations/ShareEvaluationModal.tsx new file mode 100644 index 0000000000..4f6a1f647d --- /dev/null +++ b/web/ee/src/components/Evaluations/ShareEvaluationModal.tsx @@ -0,0 +1,61 @@ +import qs from "querystring" + +import {Input, Modal, ModalProps, Typography} from "antd" +import {useRouter} from "next/router" +import {createUseStyles} from "react-jss" + +import CopyButton from "@/oss/components/CopyButton/CopyButton" +import {EvaluationType} from "@/oss/lib/enums" +import {useOrgData} from "@/oss/state/org" + +const useStyles = createUseStyles({ + row: { + marginTop: "1rem", + display: "flex", + alignItems: "center", + gap: "0.5rem", + }, + input: { + pointerEvents: "none", + color: "rgba(0, 0, 0, 0.45)", + flex: 1, + }, +}) + +interface Props { + variantIds: string[] + testsetId: string + evaluationType: EvaluationType +} + +const ShareEvaluationModal: React.FC = ({...props}) => { + const classes = useStyles() + const {selectedOrg} = useOrgData() + const router = useRouter() + const appId = router.query.app_id as string + + const parmas = qs.stringify({ + type: props.evaluationType, + testset: props.testsetId, + variants: props.variantIds, + app: appId, + org: selectedOrg?.id, + }) + const link = `${window.location.origin}/evaluations/share?${parmas}` + + return ( + + + You can invite members of your organization to collaborate on this evaluation by + sharing the link below. + + +
    + + +
    +
    + ) +} + +export default ShareEvaluationModal diff --git a/web/ee/src/components/HumanEvaluationModal/HumanEvaluationModal.tsx b/web/ee/src/components/HumanEvaluationModal/HumanEvaluationModal.tsx new file mode 100644 index 0000000000..71af87eca4 --- /dev/null +++ b/web/ee/src/components/HumanEvaluationModal/HumanEvaluationModal.tsx @@ -0,0 +1,420 @@ +// @ts-nocheck +import {useEffect, useMemo, useState} from "react" + +import VariantDetailsWithStatus from "@agenta/oss/src/components/VariantDetailsWithStatus" +import {CaretDown, Play} from "@phosphor-icons/react" +import {Button, Col, Dropdown, MenuProps, Modal, Row, Spin, message} from "antd" +import {getDefaultStore} from "jotai" +import isEqual from "lodash/isEqual" +import dynamic from "next/dynamic" +import {useRouter} from "next/router" + +import EvaluationErrorModal from "@/oss/components/Evaluations/EvaluationErrorModal" +import {useAppTheme} from "@/oss/components/Layout/ThemeContextProvider" +import useURL from "@/oss/hooks/useURL" +import {PERMISSION_ERR_MSG} from "@/oss/lib/api/assets/axiosConfig" +import {EvaluationType} from "@/oss/lib/enums" +import {getErrorMessage} from "@/oss/lib/helpers/errorHandler" +import {isDemo} from "@/oss/lib/helpers/utils" +import {getAllVariantParameters, groupVariantsByParent} from "@/oss/lib/helpers/variantHelper" +import useStatelessVariants from "@/oss/lib/hooks/useStatelessVariants" +import type {GenericObject, Parameter, StyleProps, Variant} from "@/oss/lib/Types" +import {createNewEvaluation} from "@/oss/services/human-evaluations/api" +// import {currentAppAtom} from "@/oss/state/app" +import {promptVariablesAtomFamily} from "@/oss/state/newPlayground/core/prompts" +import {useTestsetsData} from "@/oss/state/testset" + +import {useStyles} from "./assets/styles" +import type {HumanEvaluationModalProps} from "./types" + +const ShareEvaluationModal = dynamic( + () => import("@/oss/components/Evaluations/ShareEvaluationModal"), + {ssr: false}, +) + +const store = getDefaultStore() + +const HumanEvaluationModal = ({ + isEvalModalOpen, + setIsEvalModalOpen, + evaluationType, +}: HumanEvaluationModalProps) => { + const router = useRouter() + const {appURL} = useURL() + const {appTheme} = useAppTheme() + const [isError, setIsError] = useState(false) + const classes = useStyles({themeMode: appTheme} as StyleProps) + const {projectURL} = useURL() + const [selectedTestset, setSelectedTestset] = useState<{ + _id?: string + name: string + }>({name: "Select a Test set"}) + const [testsetsList, setTestsetsList] = useState([]) + + const [selectedVariants, setSelectedVariants] = useState( + new Array(1).fill({variantName: "Select a variant"}), + ) + + const [_selectedCustomEvaluationID, _setSelectedCustomEvaluationID] = useState("") + + const appId = router.query.app_id?.toString() || "" + + const {testsets, isError: isTestsetsLoadingError} = useTestsetsData() + + const [variantsInputs, setVariantsInputs] = useState>({}) + + const [error, setError] = useState({message: "", btnText: "", endpoint: ""}) + + const [shareModalOpen, setShareModalOpen] = useState(false) + + const { + variants: data, + isLoading: areAppVariantsLoading, + specMap, + uriMap, + } = useStatelessVariants() + + const variants = useMemo(() => groupVariantsByParent(data || [], true), [data]) + + useEffect(() => { + if (variants.length > 0) { + const fetchAndSetSchema = async () => { + try { + let results: { + variantName: string + inputs: string[] + }[] + // Prefer deriving inputs from OpenAPI schema exposed by useStatelessVariants + results = variants.map((_variant) => { + const variant = _variant.revisions.sort( + (a, b) => b.updatedAtTimestamp - a.updatedAtTimestamp, + )[0] + const vId = variant.variantId || variant.id + const inputs = store.get(promptVariablesAtomFamily(vId)) + return { + variantName: variant.variantName, + inputs, + } + }) + + // Fallback: if some variants have no inputs from schema, try server-side parameters API + if (results.some((r) => (r.inputs || []).length === 0)) { + const promises = variants.map((variant) => + getAllVariantParameters(appId, variant).then((data) => ({ + variantName: variant.variantName, + inputs: + data?.inputs.map((inputParam: Parameter) => inputParam.name) || + [], + })), + ) + const fallback = await Promise.all(promises) + // Merge fallback only where empty + const map = Object.fromEntries( + fallback.map((f) => [f.variantName, f.inputs]), + ) as Record + results = results.map((r) => ({ + variantName: r.variantName, + inputs: + r.inputs && r.inputs.length > 0 + ? r.inputs + : map[r.variantName] || [], + })) + } + + // Reduce the results into the desired newVariantsInputs object structure + const newVariantsInputs: Record = results.reduce( + (acc: GenericObject, result) => { + acc[result.variantName] = result.inputs + return acc + }, + {}, + ) + + setVariantsInputs(newVariantsInputs) + } catch (e: any) { + setIsError("Failed to fetch some variants parameters. Error: " + e?.message) + } + } + + fetchAndSetSchema() + } + }, [appId, variants]) + + useEffect(() => { + if (!isTestsetsLoadingError && testsets) { + setTestsetsList((prev) => { + if (isEqual(prev, testsets)) { + return prev + } + + return testsets + }) + } + }, [testsets, isTestsetsLoadingError]) + + const onTestsetSelect = (selectedTestsetIndexInTestsetsList: number) => { + setSelectedTestset(testsetsList[selectedTestsetIndexInTestsetsList]) + } + + const getTestsetDropdownMenu = (): MenuProps => { + const items: MenuProps["items"] = testsetsList.map((testset, index) => { + return { + label: ( + <> +
    {testset.name}
    + + ), + key: `${testset.name}-${testset._id}`, + } + }) + + const menuProps: MenuProps = { + items, + onClick: ({key}) => { + const index = items.findIndex((item) => item?.key === key) + onTestsetSelect(index) + }, + } + + return menuProps + } + + const handleAppVariantsMenuClick = + (dropdownIndex: number) => + ({key}: {key: string}) => { + const data = { + variants: [ + selectedVariants[dropdownIndex]?.variantName, + selectedVariants[dropdownIndex]?.variantName, + ], + } + + data.variants[dropdownIndex] = key + const _selectedVariant = variants.find((variant) => variant.variantName === key) + const selectedVariant = (_selectedVariant?.revisions || []).sort( + (a, b) => b.updatedAtTimestamp - a.updatedAtTimestamp, + )[0] + if (!selectedVariant) { + console.error("Error: No variant found") + } + + setSelectedVariants((prevState) => { + const newState = [...prevState] + newState[dropdownIndex] = selectedVariant + return newState + }) + } + + const getVariantsDropdownMenu = (index: number): MenuProps => { + const selectedVariantsNames = selectedVariants.map( + (revision) => revision.__parentVariant?.variantName, + ) + + const items = variants.reduce((filteredVariants, variant, idx) => { + const label = variant.variantName + + if (!selectedVariantsNames.includes(label)) { + filteredVariants.push({ + label: ( + <> +
    + + + # + { + ( + variant.variantId || + variant.id || + variant.variant_id + ).split("-")[0] + } + +
    + + ), + key: label, + }) + } + + return filteredVariants + }, []) + + const menuProps: MenuProps = { + items, + onClick: handleAppVariantsMenuClick(index), + } + + return menuProps + } + + const onStartEvaluation = async () => { + const selectedVariant = selectedVariants[0] + // 1. We check all data is provided + if (selectedTestset === undefined || selectedTestset.name === "Select a Test set") { + message.error("Please select a Testset") + return + } else if (selectedVariant?.variantName === "Select a variant") { + message.error("Please select a variant") + return + } else if ( + evaluationType === EvaluationType.human_a_b_testing && + selectedVariants[1]?.variantName === "Select a variant" + ) { + message.error("Please select a second variant") + return + } + + const inputs = store.get( + promptVariablesAtomFamily(selectedVariant.variantId || selectedVariant.id), + ) + + // 2. We create a new app evaluation + const evaluationTableId = await createNewEvaluation({ + variant_ids: selectedVariants.map((variant) => variant.variantId || variant.id), + appId, + inputs, + evaluationType: EvaluationType[evaluationType as keyof typeof EvaluationType], + evaluationTypeSettings: {}, + llmAppPromptTemplate: "", + selectedCustomEvaluationID: _selectedCustomEvaluationID, + testsetId: selectedTestset._id!, + }).catch((err) => { + if (err.message !== PERMISSION_ERR_MSG) { + setError({ + message: getErrorMessage(err), + btnText: "Go to Test sets", + endpoint: `${projectURL}/testsets`, + }) + } + }) + + if (!evaluationTableId) { + return + } + + // 3 We set the variants + // setVariants(selectedVariants) + + if (evaluationType === EvaluationType.human_a_b_testing) { + router.push(`${appURL}/evaluations/human_a_b_testing/${evaluationTableId}`) + } else if (evaluationType === EvaluationType.single_model_test) { + router.push(`${appURL}/evaluations/single_model_test/${evaluationTableId}`) + } + } + + return ( + <> + { + setIsEvalModalOpen(false) + + setSelectedTestset({name: "Select a Test set"}) + setSelectedVariants(new Array(1).fill({variantName: "Select a variant"})) + }} + title="New Evaluation" + footer={null} + > + + {typeof isError === "string" ? ( +
    {isError}
    + ) : ( +
    +
    +

    Which testset you want to use?

    + + + +
    + +
    +

    Which variants would you like to evaluate

    + {Array.from({ + length: evaluationType === "human_a_b_testing" ? 2 : 1, + }).map((_, index) => ( + + + + ))} +
    + + + + {evaluationType === EvaluationType.human_a_b_testing && + isDemo() && ( +
    + + + )} + + + + + + )} + + + + setError({message: "", btnText: "", endpoint: ""})} + handleNavigate={() => router.push(error.endpoint)} + message={error.message} + btnText={error.btnText} + /> + + setShareModalOpen(false)} + destroyOnHidden + variantIds={selectedVariants.map((v) => v.variantId)} + testsetId={selectedTestset._id} + evaluationType={EvaluationType.human_a_b_testing} + /> + + ) +} + +export default HumanEvaluationModal diff --git a/web/ee/src/components/HumanEvaluationModal/assets/styles.ts b/web/ee/src/components/HumanEvaluationModal/assets/styles.ts new file mode 100644 index 0000000000..5fbdda1955 --- /dev/null +++ b/web/ee/src/components/HumanEvaluationModal/assets/styles.ts @@ -0,0 +1,105 @@ +import {createUseStyles} from "react-jss" + +import type {JSSTheme, StyleProps} from "@/oss/lib/Types" + +export const useStyles = createUseStyles((theme: JSSTheme) => ({ + evaluationContainer: { + border: "1px solid lightgrey", + padding: "20px", + borderRadius: "14px", + marginBottom: 50, + }, + evaluationImg: ({themeMode}: StyleProps) => ({ + width: 24, + height: 24, + marginRight: "8px", + filter: themeMode === "dark" ? "invert(1)" : "none", + }), + createCustomEvalBtn: { + color: "#fff !important", + backgroundColor: "#0fbf0f", + marginRight: "20px", + borderColor: "#0fbf0f !important", + }, + evaluationType: { + display: "flex", + alignItems: "center", + }, + dropdownStyles: { + display: "flex", + justifyContent: "space-between", + alignItems: "center", + width: "100%", + }, + dropdownBtn: { + marginRight: 10, + width: "100%", + }, + optionSelected: { + border: "1px solid #1668dc", + "& .ant-select-selection-item": { + color: "#1668dc !important", + }, + }, + radioGroup: { + width: "100%", + "& .ant-radio-button-wrapper": { + marginBottom: "0.5rem", + borderRadius: theme.borderRadius, + borderLeft: `1px solid ${theme.colorBorder}`, + "&::before": { + display: "none", + }, + }, + "& .ant-radio-button-wrapper-checked ": { + borderLeft: `1px solid ${theme.colorPrimary}`, + }, + }, + radioBtn: { + display: "block", + marginBottom: "10px", + }, + selectGroup: { + width: "100%", + display: "block", + "& .ant-select-selector": { + borderRadius: 0, + }, + "& .ant-select-selection-item": { + marginLeft: 34, + }, + }, + customCodeSelectContainer: { + position: "relative", + }, + customCodeIcon: { + position: "absolute", + left: 16, + top: 4.5, + pointerEvents: "none", + }, + thresholdStyles: { + paddingLeft: 10, + paddingRight: 10, + }, + variantDropdown: { + marginRight: 10, + width: "100%", + }, + newCodeEval: { + display: "flex", + alignItems: "center", + gap: 8, + color: "#1668dc", + }, + newCodeEvalList: { + display: "flex", + alignItems: "center", + justifyContent: "space-between", + }, + dropdownItemLabels: { + fontSize: theme.fontSizeSM, + lineHeight: theme.lineHeightSM, + color: theme.colorTextDescription, + }, +})) diff --git a/web/ee/src/components/HumanEvaluationModal/types.d.ts b/web/ee/src/components/HumanEvaluationModal/types.d.ts new file mode 100644 index 0000000000..baf1a2b734 --- /dev/null +++ b/web/ee/src/components/HumanEvaluationModal/types.d.ts @@ -0,0 +1,5 @@ +export interface HumanEvaluationModalProps { + isEvalModalOpen: boolean + setIsEvalModalOpen: React.Dispatch> + evaluationType: "single_model_test" | "human_a_b_testing" +} diff --git a/web/ee/src/components/HumanEvaluations/AbTestingEvaluation.tsx b/web/ee/src/components/HumanEvaluations/AbTestingEvaluation.tsx new file mode 100644 index 0000000000..e85bd50861 --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/AbTestingEvaluation.tsx @@ -0,0 +1,551 @@ +import {type Key, useEffect, useState} from "react" + +import VariantDetailsWithStatus from "@agenta/oss/src/components/VariantDetailsWithStatus" +import {MoreOutlined, PlusOutlined} from "@ant-design/icons" +import {Database, Export, GearSix, Note, Plus, Rocket, Trash} from "@phosphor-icons/react" +import {Button, Dropdown, message, Space, Spin, Statistic, Table, Typography} from "antd" +import {ColumnsType} from "antd/es/table" +import {useAtomValue} from "jotai" +import Link from "next/link" +import {useRouter} from "next/router" +import {createUseStyles} from "react-jss" + +import DeleteEvaluationModal from "@/oss/components/DeleteEvaluationModal/DeleteEvaluationModal" +import HumanEvaluationModal from "@/oss/components/HumanEvaluationModal/HumanEvaluationModal" +import useURL from "@/oss/hooks/useURL" +import {EvaluationType} from "@/oss/lib/enums" +import {formatDate24} from "@/oss/lib/helpers/dateTimeHelper" +import {getVotesPercentage} from "@/oss/lib/helpers/evaluate" +import {convertToCsv, downloadCsv} from "@/oss/lib/helpers/fileManipulations" +import {buildRevisionsQueryParam} from "@/oss/lib/helpers/url" +import {variantNameWithRev} from "@/oss/lib/helpers/variantHelper" +import {abTestingEvaluationTransformer} from "@/oss/lib/transformers" +import {HumanEvaluationListTableDataType, JSSTheme} from "@/oss/lib/Types" +import { + deleteEvaluations, + fetchAllLoadEvaluations, + fetchEvaluationResults, +} from "@/oss/services/human-evaluations/api" +import {getAppValues, selectedAppIdAtom} from "@/oss/state/app" +import {projectIdAtom} from "@/oss/state/project" + +const {Title} = Typography + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + display: "flex", + flexDirection: "column", + gap: theme.paddingXS, + "& > div h1.ant-typography": { + fontSize: theme.fontSize, + }, + }, + statFlag: { + lineHeight: theme.lineHeight, + "& .ant-statistic-content-value": { + fontSize: theme.fontSize, + color: theme.colorError, + }, + "& .ant-statistic-content-suffix": { + fontSize: theme.fontSize, + color: theme.colorError, + }, + }, + stat: { + lineHeight: theme.lineHeight, + "& .ant-statistic-content-value": { + fontSize: theme.fontSize, + color: theme.colorPrimary, + }, + "& .ant-statistic-content-suffix": { + fontSize: theme.fontSize, + color: theme.colorPrimary, + }, + }, + statGood: { + lineHeight: theme.lineHeight, + "& .ant-statistic-content-value": { + fontSize: theme.fontSize, + color: theme.colorSuccess, + }, + "& .ant-statistic-content-suffix": { + fontSize: theme.fontSize, + color: theme.colorSuccess, + }, + }, + button: { + display: "flex", + alignItems: "center", + }, +})) + +const AbTestingEvaluation = ({viewType}: {viewType: "evaluation" | "overview"}) => { + const classes = useStyles() + const router = useRouter() + const {appURL, projectURL} = useURL() + const projectId = useAtomValue(projectIdAtom) + const appId = useAtomValue(selectedAppIdAtom) + + const [evaluationsList, setEvaluationsList] = useState([]) + const [fetchingEvaluations, setFetchingEvaluations] = useState(false) + const [isEvalModalOpen, setIsEvalModalOpen] = useState(false) + const [selectedEvalRecord, setSelectedEvalRecord] = useState() + const [isDeleteEvalModalOpen, setIsDeleteEvalModalOpen] = useState(false) + const [isDeleteMultipleEvalModalOpen, setIsDeleteMultipleEvalModalOpen] = useState(false) + const [selectedRowKeys, setSelectedRowKeys] = useState([]) + + useEffect(() => { + if (!appId || !projectId) return + + const fetchEvaluations = async () => { + try { + setFetchingEvaluations(true) + const evals = await fetchAllLoadEvaluations(appId, projectId) + + const fetchPromises = evals.map(async (item: any) => { + return fetchEvaluationResults(item.id) + .then((results) => { + if (item.evaluation_type === EvaluationType.human_a_b_testing) { + if (Object.keys(results.votes_data).length > 0) { + return abTestingEvaluationTransformer({item, results}) + } + } + }) + .catch((err) => console.error(err)) + }) + + const results = (await Promise.all(fetchPromises)) + .filter((evaluation) => evaluation !== undefined) + .sort( + (a, b) => + new Date(b.createdAt || 0).getTime() - + new Date(a.createdAt || 0).getTime(), + ) + + setEvaluationsList(viewType === "overview" ? results.slice(0, 5) : results) + } catch (error) { + console.error(error) + } finally { + setFetchingEvaluations(false) + } + } + + fetchEvaluations() + }, [appId, projectId]) + + const handleNavigation = (variantRevisionId: string) => { + router.push({ + pathname: `${appURL}/playground`, + query: { + revisions: buildRevisionsQueryParam([variantRevisionId]), + }, + }) + } + + const rowSelection = { + onChange: (selectedRowKeys: Key[]) => { + setSelectedRowKeys(selectedRowKeys) + }, + } + + const handleDeleteMultipleEvaluations = async () => { + const evaluationsIds = selectedRowKeys.map((key) => key.toString()) + try { + setFetchingEvaluations(true) + await deleteEvaluations(evaluationsIds) + setEvaluationsList((prevEvaluationsList) => + prevEvaluationsList.filter( + (evaluation) => !evaluationsIds.includes(evaluation.key), + ), + ) + setSelectedRowKeys([]) + message.success("Evaluations Deleted") + } catch (error) { + console.error(error) + } finally { + setFetchingEvaluations(false) + } + } + + const handleDeleteEvaluation = async (record: HumanEvaluationListTableDataType) => { + try { + setFetchingEvaluations(true) + await deleteEvaluations([record.key]) + setEvaluationsList((prevEvaluationsList) => + prevEvaluationsList.filter((evaluation) => ![record.key].includes(evaluation.key)), + ) + message.success("Evaluation Deleted") + } catch (error) { + console.error(error) + } finally { + setFetchingEvaluations(false) + } + } + + const columns: ColumnsType = [ + { + title: "Variant 1", + dataIndex: "variantNames", + key: "variant1", + onHeaderCell: () => ({ + style: {minWidth: 160}, + }), + render: (value, record) => { + return ( + + ) + }, + }, + { + title: "Variant 2", + dataIndex: "variantNames", + key: "variant2", + onHeaderCell: () => ({ + style: {minWidth: 160}, + }), + render: (value, record) => { + return ( + + ) + }, + }, + { + title: "Test set", + dataIndex: "testsetName", + key: "testsetName", + onHeaderCell: () => ({ + style: {minWidth: 160}, + }), + render: (_, record: HumanEvaluationListTableDataType, index: number) => { + return {record.testset.name} + }, + }, + { + title: "Results", + key: "results", + onHeaderCell: () => ({ + style: {minWidth: 240}, + }), + render: (_, record: HumanEvaluationListTableDataType) => { + const stat1 = getVotesPercentage(record, 0) + const stat2 = getVotesPercentage(record, 1) + + return ( +
    + + | + +
    + ) + }, + }, + { + title: "Both are good", + dataIndex: "positive", + key: "positive", + onHeaderCell: () => ({ + style: {minWidth: 160}, + }), + render: (_, record: HumanEvaluationListTableDataType) => { + const percentage = record.votesData.positive_votes.percentage + return ( + + + + ) + }, + }, + { + title: "Flag", + dataIndex: "flag", + key: "flag", + onHeaderCell: () => ({ + style: {minWidth: 160}, + }), + render: (value: any, record: HumanEvaluationListTableDataType) => { + const percentage = record.votesData.flag_votes.percentage + return ( + + + + ) + }, + }, + ] + + columns.push( + ...([ + { + title: "Created on", + dataIndex: "createdAt", + key: "createdAt", + onHeaderCell: () => ({ + style: {minWidth: 160}, + }), + }, + { + title: , + key: "key", + width: 56, + fixed: "right", + align: "center", + render: (_: any, record: HumanEvaluationListTableDataType) => { + return ( + , + onClick: (e) => { + e.domEvent.stopPropagation() + router.push( + `${appURL}/evaluations/human_a_b_testing/${record.key}`, + ) + }, + }, + { + key: "variant1", + label: "View variant 1", + icon: , + onClick: (e) => { + e.domEvent.stopPropagation() + handleNavigation(record.variant_revision_ids[0]) + }, + }, + { + key: "variant2", + label: "View variant 2", + icon: , + onClick: (e) => { + e.domEvent.stopPropagation() + handleNavigation(record.variant_revision_ids[1]) + }, + }, + { + key: "view_testset", + label: "View test set", + icon: , + onClick: (e) => { + e.domEvent.stopPropagation() + router.push( + `${projectURL}/testsets/${record.testset._id}`, + ) + }, + }, + {type: "divider"}, + { + key: "delete_eval", + label: "Delete", + icon: , + danger: true, + onClick: (e) => { + e.domEvent.stopPropagation() + setSelectedEvalRecord(record) + setIsDeleteEvalModalOpen(true) + }, + }, + ], + }} + > + + + + + + ) : ( +
    + + + + + + +
    + )} + + +
    ({ + style: {cursor: "pointer"}, + onClick: () => + router.push(`${appURL}/evaluations/human_a_b_testing/${record.key}`), + })} + /> + + + + + {selectedEvalRecord && ( + setIsDeleteEvalModalOpen(false)} + onOk={async () => { + await handleDeleteEvaluation(selectedEvalRecord) + setIsDeleteEvalModalOpen(false) + }} + evaluationType={"a/b testing evaluation"} + /> + )} + + {isDeleteMultipleEvalModalOpen && ( + setIsDeleteMultipleEvalModalOpen(false)} + onOk={async () => { + await handleDeleteMultipleEvaluations() + setIsDeleteMultipleEvalModalOpen(false) + }} + evaluationType={"a/b testing evaluation"} + /> + )} + + ) +} + +export default AbTestingEvaluation diff --git a/web/ee/src/components/HumanEvaluations/SingleModelEvaluation.tsx b/web/ee/src/components/HumanEvaluations/SingleModelEvaluation.tsx new file mode 100644 index 0000000000..cdc664f9b2 --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/SingleModelEvaluation.tsx @@ -0,0 +1,228 @@ +import {memo, useMemo, useCallback, useState, type Key} from "react" + +import {ColumnsType} from "antd/es/table" +import clsx from "clsx" +import {useRouter} from "next/router" + +import EnhancedTable from "@/oss/components/EnhancedUIs/Table" +import {useAppId} from "@/oss/hooks/useAppId" +import useURL from "@/oss/hooks/useURL" +import {EvaluationType} from "@/oss/lib/enums" +import {buildRevisionsQueryParam} from "@/oss/lib/helpers/url" +import useEvaluations from "@/oss/lib/hooks/useEvaluations" +import useRunMetricsMap from "@/oss/lib/hooks/useRunMetricsMap" +import {useAppsData} from "@/oss/state/app" + +import SingleModelEvaluationHeader from "./assets/SingleModelEvaluationHeader" +import {useStyles} from "./assets/styles" +import {getColumns} from "./assets/utils" +import {EvaluationRow} from "./types" +import { + buildAppScopedUrl, + buildEvaluationNavigationUrl, + extractEvaluationAppId, +} from "../pages/evaluations/utils" + +interface SingleModelEvaluationProps { + viewType: "evaluation" | "overview" + scope?: "app" | "project" +} + +const SingleModelEvaluation = ({viewType, scope = "app"}: SingleModelEvaluationProps) => { + const classes = useStyles() + const router = useRouter() + const {appURL, projectURL, baseAppURL} = useURL() + const routeAppId = useAppId() + const activeAppId = scope === "app" ? routeAppId || undefined : undefined + const {apps: availableApps = []} = useAppsData() + + const [selectedEvalRecord, setSelectedEvalRecord] = useState() + const [isDeleteEvalModalOpen, setIsDeleteEvalModalOpen] = useState(false) + const [selectedRowKeys, setSelectedRowKeys] = useState([]) + + const {mergedEvaluations, isLoadingPreview, isLoadingLegacy} = useEvaluations({ + withPreview: true, + types: [EvaluationType.single_model_test], + evalType: "human", + appId: activeAppId, + }) + + const runIds = useMemo( + () => mergedEvaluations.map((e) => ("id" in e ? e.id : e.key)), + [mergedEvaluations], + ) + const evaluatorSlugs = useMemo(() => { + const evaSlugs = new Set() + mergedEvaluations.forEach((e) => { + const key = e?.data.steps?.find((step) => step.type === "annotation")?.key + if (key) evaSlugs.add(key) + }) + return evaSlugs + }, [mergedEvaluations]) + + const {data: runMetricsMap} = useRunMetricsMap(runIds, evaluatorSlugs) + + const knownAppIds = useMemo(() => { + return new Set( + (availableApps as Array<{app_id?: string}>) + .map((app) => app?.app_id) + .filter(Boolean) as string[], + ) + }, [availableApps]) + + const resolveAppId = useCallback( + (record: EvaluationRow): string | undefined => { + const candidate = extractEvaluationAppId(record) || activeAppId + if (!candidate) return undefined + if (scope === "project" && !knownAppIds.has(candidate)) return undefined + return candidate + }, + [activeAppId, knownAppIds, scope], + ) + + const isRecordNavigable = useCallback( + (record: EvaluationRow): boolean => { + const evaluationId = "id" in record ? record.id : record.key + const recordAppId = resolveAppId(record) + return Boolean(evaluationId && recordAppId) + }, + [resolveAppId], + ) + + const rowSelection = useMemo(() => { + return { + onChange: (selectedRowKeys: Key[]) => { + setSelectedRowKeys(selectedRowKeys) + }, + getCheckboxProps: (record: EvaluationRow) => ({ + disabled: !isRecordNavigable(record), + }), + } + }, [isRecordNavigable]) + + const handleNavigation = useCallback( + ({revisionId, appId: recordAppId}: {revisionId: string; appId?: string}) => { + const targetAppId = recordAppId || activeAppId + if (!targetAppId) return + + router.push({ + pathname: buildAppScopedUrl(baseAppURL, targetAppId, "/playground"), + query: { + revisions: buildRevisionsQueryParam([revisionId]), + }, + }) + }, + [router, baseAppURL, activeAppId], + ) + + const columns: ColumnsType = useMemo(() => { + return getColumns({ + evaluations: mergedEvaluations, + onVariantNavigation: handleNavigation, + evalType: "human", + setSelectedEvalRecord, + setIsDeleteEvalModalOpen, + runMetricsMap, + scope, + baseAppURL, + extractAppId: extractEvaluationAppId, + projectURL, + resolveAppId, + }) + }, [ + mergedEvaluations, + handleNavigation, + setSelectedEvalRecord, + setIsDeleteEvalModalOpen, + runMetricsMap, + scope, + baseAppURL, + projectURL, + resolveAppId, + ]) + + const dataSource = useMemo(() => { + return viewType === "overview" ? mergedEvaluations.slice(0, 5) : mergedEvaluations + }, [viewType, mergedEvaluations]) + + return ( +
    + + +
    + { + return record.id || record.key + }} + className={clsx("ph-no-capture", "grow min-h-0", "eval-runs-table")} + showHorizontalScrollBar={true} + columns={columns} + dataSource={dataSource} + virtualized + loading={isLoadingPreview || isLoadingLegacy} + uniqueKey="human-annotation" + onRow={(record) => { + const evaluationId = "id" in record ? record.id : record.key + const recordAppId = resolveAppId(record) + const isNavigable = isRecordNavigable(record) + + return { + className: isNavigable ? undefined : "cursor-not-allowed opacity-60", + style: {cursor: isNavigable ? "pointer" : "not-allowed"}, + onClick: () => { + if (!isNavigable || !recordAppId || !evaluationId) return + + const pathname = buildEvaluationNavigationUrl({ + scope, + baseAppURL, + projectURL, + appId: recordAppId, + path: `/evaluations/single_model_test/${evaluationId}`, + }) + + if (scope === "project") { + router.push({ + pathname, + query: recordAppId ? {app_id: recordAppId} : undefined, + }) + } else { + router.push(pathname) + } + }, + } + }} + /> +
    +
    + ) +} + +export default memo(SingleModelEvaluation) diff --git a/web/ee/src/components/HumanEvaluations/assets/EvaluationStatusCell.tsx b/web/ee/src/components/HumanEvaluations/assets/EvaluationStatusCell.tsx new file mode 100644 index 0000000000..dc4b2afb3c --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/EvaluationStatusCell.tsx @@ -0,0 +1,147 @@ +import {memo, useEffect, useMemo, useRef} from "react" + +import {Tag, theme} from "antd" +import {useAtom, useAtomValue} from "jotai" +import {mutate} from "swr" + +import {EvaluationType} from "@/oss/lib/enums" +import useEvaluationRunScenarios, { + getEvaluationRunScenariosKey, +} from "@/oss/lib/hooks/useEvaluationRunScenarios" +import useEvaluations from "@/oss/lib/hooks/useEvaluations" +import {resourceStatusQueryFamily} from "@/oss/lib/hooks/usePreviewRunningEvaluations" +import {tempEvaluationAtom} from "@/oss/lib/hooks/usePreviewRunningEvaluations/states/runningEvalAtom" +import {EvaluationStatus} from "@/oss/lib/Types" + +import {statusMapper} from "../../pages/evaluations/cellRenderers/cellRenderers" + +import {extractEvaluationStatus} from "./utils" + +const EvaluationStatusCell = ({ + runId, + status, + evalType, +}: { + runId: string + status?: EvaluationStatus + evalType?: "auto" | "human" +}) => { + const swrData = useEvaluationRunScenarios(runId, undefined, { + syncAtom: false, + revalidateOnMount: true, + }) + const {token} = theme.useToken() + const {refetch} = useEvaluations({ + withPreview: true, + types: + evalType === "auto" + ? [EvaluationType.automatic, EvaluationType.auto_exact_match] + : [EvaluationType.human, EvaluationType.single_model_test], + evalType, + }) + const runningEvaluations = useAtomValue( + resourceStatusQueryFamily(evalType === "auto" ? runId : ""), + ) + const [tempEvaluation, setTempEvaluation] = useAtom(tempEvaluationAtom) + const handledCompletionRef = useRef>(new Set()) + const lastMutatedStatusRef = useRef<{runId?: string; status?: EvaluationStatus | null} | null>( + null, + ) + + // Force refetch once when component mounts (useful when returning from details page) + useEffect(() => { + if (!runId) return + + const key = getEvaluationRunScenariosKey(runId) + if (!key) return + + const status = runningEvaluations.data?.run?.status ?? null + const hasChanged = + !lastMutatedStatusRef.current || + lastMutatedStatusRef.current.runId !== runId || + lastMutatedStatusRef.current.status !== status + + if (!hasChanged) return + + lastMutatedStatusRef.current = {runId, status} + + mutate(`${key}-false`) + }, [runId, runningEvaluations.data?.run?.status]) + + // refresh the eval after a completed run + useEffect(() => { + if (evalType !== "auto") return + + const runIdToCheck = runningEvaluations.data?.run?.id + const runStatus = runningEvaluations.data?.run?.status + + if (!runIdToCheck || !runStatus) return + + const isTrackedTempEvaluation = tempEvaluation.some( + (evaluation) => evaluation.id === runIdToCheck, + ) + + if (!isTrackedTempEvaluation) { + handledCompletionRef.current.delete(runIdToCheck) + return + } + + const isTerminalStatus = ![ + EvaluationStatus.PENDING, + EvaluationStatus.RUNNING, + EvaluationStatus.CANCELLED, + EvaluationStatus.INITIALIZED, + ].includes(runStatus) + + if (!isTerminalStatus) { + handledCompletionRef.current.delete(runIdToCheck) + return + } + + const hasHandledCompletion = handledCompletionRef.current.has(runIdToCheck) + + if (hasHandledCompletion) return + + handledCompletionRef.current.add(runIdToCheck) + + setTempEvaluation((prev) => prev.filter((evaluation) => evaluation.id !== runIdToCheck)) + refetch() + }, [ + evalType, + refetch, + runningEvaluations.data?.run?.id, + runningEvaluations.data?.run?.status, + setTempEvaluation, + tempEvaluation, + ]) + + const {runStatus, scenarios} = useMemo(() => { + return extractEvaluationStatus(swrData.data?.scenarios || [], status, evalType) + }, [status, token, swrData.data?.scenarios, evalType]) + + const completedStatuses = [EvaluationStatus.SUCCESS] + const {completedCount, totalCount} = useMemo(() => { + return { + completedCount: scenarios.filter((s) => + completedStatuses.includes(s.status as EvaluationStatus), + ).length, + totalCount: scenarios.length, + } + }, [scenarios]) + + const _status = useMemo(() => { + if (evalType !== "auto") return runStatus + return runningEvaluations.data?.run?.status || runStatus + }, [runningEvaluations.data?.run?.status, runStatus]) + + return ( +
    + + {statusMapper(token)(_status).label} + +
    {`${completedCount} / ${totalCount}`}
    +
    + ) +} + +export default memo(EvaluationStatusCell) diff --git a/web/ee/src/components/HumanEvaluations/assets/LegacyEvalResultCell.tsx b/web/ee/src/components/HumanEvaluations/assets/LegacyEvalResultCell.tsx new file mode 100644 index 0000000000..3dda59ca5b --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/LegacyEvalResultCell.tsx @@ -0,0 +1,32 @@ +import {memo} from "react" + +import {Tag, Typography, Space} from "antd" + +import {getTypedValue} from "@/oss/lib/helpers/evaluate" + +import EvaluationErrorPopover from "../../pages/evaluations/EvaluationErrorProps/EvaluationErrorPopover" + +export const LegacyEvalResultCell = memo(({matchingResults}: {matchingResults: any}) => { + return ( + + {matchingResults?.map((result, index) => + result?.result?.error ? ( + + ) : ( + {getTypedValue(result?.result)} + ), + )} + + ) +}) + +export const LegacyEvalResultCellTitle = memo(({evaluator}: {evaluator: any}) => { + return ( +
    + {evaluator?.name} + + {evaluator?.evaluator?.name} + +
    + ) +}) diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ChartAxis.tsx b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ChartAxis.tsx new file mode 100644 index 0000000000..b548672a5c --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ChartAxis.tsx @@ -0,0 +1,91 @@ +import {FC} from "react" + +import {format3Sig} from "./utils" + +interface ChartAxisProps { + svgWidth: number + svgHeight: number + plotWidth: number + plotHeight: number + margin: {top: number; right: number; bottom: number; left: number} + xLabels: (string | number)[] + yTicks?: number[] // for numeric axes + yLabels?: (string | number)[] // for categorical axes + xScale: (idx: number) => number + yScale: (value: number) => number + yLabelScale?: (idx: number) => number // for categorical axes +} + +export const ChartAxis: FC = ({ + svgWidth, + svgHeight, + plotWidth, + plotHeight, + margin, + xLabels, + yTicks, + yLabels, + xScale, + yScale, + yLabelScale, +}) => ( + + {/* X Axis Line */} + + {/* X Axis Labels */} + {xLabels.map((label, idx) => ( + + {label} + + ))} + {/* Y Axis Line */} + + {/* Y Axis Labels */} + {yLabels && yLabelScale + ? yLabels.map((label, idx) => ( + + {label} + + )) + : yTicks?.map((tick) => ( + + {format3Sig(tick)} + + ))} + +) diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ChartFrame.tsx b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ChartFrame.tsx new file mode 100644 index 0000000000..18bdb0fb4a --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ChartFrame.tsx @@ -0,0 +1,71 @@ +import {type FC, type ReactNode, RefObject, useRef} from "react" + +import {useResizeObserver} from "usehooks-ts" + +export interface ChartFrameProps { + minWidth?: number + minHeight?: number + maxWidth?: number | string + maxHeight?: number | string + margin?: {top: number; right: number; bottom: number; left: number} + children: (frame: { + svgWidth: number + svgHeight: number + plotWidth: number + plotHeight: number + margin: {top: number; right: number; bottom: number; left: number} + }) => ReactNode +} + +const DEFAULT_MARGIN = {top: 8, right: 16, left: 40, bottom: 32} +const DEFAULT_MIN_WIDTH = 200 +const DEFAULT_MIN_HEIGHT = 120 + +const ChartFrame: FC = ({ + minWidth = DEFAULT_MIN_WIDTH, + minHeight = DEFAULT_MIN_HEIGHT, + maxWidth, + maxHeight, + margin = DEFAULT_MARGIN, + children, +}) => { + const containerRef = useRef(null) + const {width: chartWidth = 280, height: chartHeight = 120} = useResizeObserver({ + ref: containerRef as RefObject, + box: "border-box", + }) + const svgWidth = Math.max(chartWidth, minWidth) + const svgHeight = Math.max(chartHeight, minHeight) + const plotWidth = svgWidth - margin.left - margin.right + const plotHeight = svgHeight - margin.top - margin.bottom + + return ( +
    + {children({ + svgWidth, + svgHeight, + plotWidth, + plotHeight, + margin, + })} +
    + ) +} + +export default ChartFrame diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ResponsiveFrequencyChart.tsx b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ResponsiveFrequencyChart.tsx new file mode 100644 index 0000000000..bdc77ae83a --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ResponsiveFrequencyChart.tsx @@ -0,0 +1,463 @@ +import {type FC, memo, useCallback, useState} from "react" + +import clsx from "clsx" + +import {ChartAxis} from "./ChartAxis" +import ChartFrame from "./ChartFrame" +import {getYTicks} from "./chartUtils" + +interface FrequencyDatum { + label: string | number + count: number +} + +interface ResponsiveFrequencyChartProps { + data: FrequencyDatum[] + highlightValues?: (string | number)[] + labelWidth?: number + direction?: "horizontal" | "vertical" + /** Optional: color for bars (also used for highlight when provided) */ + barColor?: string + /** Optional: disable gradient and use solid bars */ + disableGradient?: boolean + dynamicMargin?: Partial<{top: number; right: number; bottom: number; left: number}> +} + +// Resolve fills based on props (keep defaults when not provided) +const DEFAULTS = { + greenSolid: "#95DE64", + blueSolid: "#69B1FF", + graySolid: "#97A4B0", +} + +const CUSTOM_GRADIENT_ID = "barGradientCustom" + +/** + * ResponsiveFrequencyChart renders a vertical bar chart for categorical/frequency data. + * Bars to highlight are inferred automatically from highlightValues (if provided). + */ +const ResponsiveFrequencyChart: FC = memo( + ({ + data, + highlightValues = [], + labelWidth, + direction = "horizontal", + barColor, + disableGradient = false, + dynamicMargin: dynamicPropsMargin, + }) => { + const isVertical = direction === "vertical" + const xMax = Math.max(...data.map((d) => d.count), 1) + const yCount = data.length + const xTicks = getYTicks(xMax) + const yLabels = data.map((d) => d.label) + + // Tooltip state + const [hoveredBar, setHoveredBar] = useState(null) + const [mousePos, setMousePos] = useState<{x: number; y: number} | null>(null) + + // Dynamically calculate margins based on orientation + const defaultMargin = {top: 16, right: 16, bottom: 32, left: 40} + let dynamicMargin = defaultMargin + if (isVertical) { + const longestBottomLabel = yLabels.reduce( + (max: number, label) => Math.max(max, String(label).length), + 0, + ) + const bottomMargin = Math.max(32, Math.min(120, longestBottomLabel * 7 + 16)) + const longestCountLabel = xTicks.reduce( + (max: number, tick) => Math.max(max, String(tick).length), + 0, + ) + const leftMargin = Math.max(40, Math.min(120, longestCountLabel * 7 + 16)) + dynamicMargin = { + ...defaultMargin, + left: leftMargin, + bottom: bottomMargin, + ...dynamicPropsMargin, + } + } else { + const longestLabelLength = yLabels.reduce( + (max: number, label) => Math.max(max, String(label).length), + 0, + ) + const dynamicLeftMargin = Math.max(40, Math.min(120, longestLabelLength * 7 + 16)) + dynamicMargin = {...defaultMargin, left: dynamicLeftMargin, ...dynamicPropsMargin} + } + + // Calculate maxCount and maxCountOccurrences once + const countMap = data.map((d) => d.count) + const maxCount = Math.max(...countMap) + const maxCountOccurrences = countMap.filter((count) => count === maxCount).length + // Store maxCount for later use in rendering + const uniqueMaxCount = maxCountOccurrences === 1 ? maxCount : null + + // Compute highlighted bar indices from highlightValues + const computedHighlightBarIndices = + highlightValues.length > 0 + ? data + .map((d, i) => + highlightValues.some((hv) => String(hv) === String(d.label)) ? i : -1, + ) + .filter((i) => i !== -1) + : [] + + return ( + + {({svgWidth, svgHeight, plotWidth, plotHeight, margin}) => { + // Scales for both orientations + const yLabelScaleHorizontal = (idx: number) => + (idx + 0.5) * (plotHeight / yCount) + const barHeightHorizontal = plotHeight / yCount - 6 + const xScaleHorizontal = (count: number) => (count / xMax) * plotWidth + + const xLabelScaleVertical = (idx: number) => (idx + 0.5) * (plotWidth / yCount) + const barWidthVertical = plotWidth / yCount - 6 + const yScaleVertical = (value: number) => ((xMax - value) / xMax) * plotHeight + + const getFill = useCallback( + (isHighlighted: boolean, d: FrequencyDatum): string => { + // If user supplies barColor, it overrides category colors: + if (barColor) { + // highlighted also uses barColor (solid), mirroring prior component behavior + if (isHighlighted) return barColor + return disableGradient ? barColor : `url(#${CUSTOM_GRADIENT_ID})` + } + // Default behavior (no barColor override) + if (isHighlighted) return DEFAULTS.greenSolid + if (disableGradient) { + // Solid fallbacks + if (d.label === "true") return DEFAULTS.greenSolid + if (uniqueMaxCount !== null && d.count === uniqueMaxCount) + return DEFAULTS.blueSolid + return DEFAULTS.graySolid + } + // Gradient fallbacks + if (d.label === "true") return "url(#barGradientGreen)" + if (uniqueMaxCount !== null && d.count === uniqueMaxCount) + return "url(#barGradientBlue)" + return "url(#barGradientGray)" + }, + [barColor], + ) + + return ( + <> + + {/* Bar gradient defs */} + + {/* If a custom barColor is provided and gradient is enabled, use a single custom gradient */} + {!disableGradient && barColor && ( + + + + + )} + + {/* Otherwise keep the existing three gradients (when gradient is enabled) */} + {!disableGradient && !barColor && ( + <> + {/* Gradient for "true" state */} + + + + + + {/* Gradient for default/false state */} + + + + + + {/* Gradient for most-count (unique max) */} + + + + + + )} + + + {/* Grid and highlight lines */} + {isVertical ? ( + + {xTicks.map((tick) => ( + + ))} + {highlightValues.map((val, i) => { + const idx = data.findIndex( + (d) => + typeof d.label === "number" && d.label === val, + ) + if (idx === -1) return null + return ( + + ) + })} + + ) : ( + + {xTicks.map((tick) => ( + + ))} + {highlightValues.map((val, i) => { + const idx = data.findIndex( + (d) => + typeof d.label === "number" && d.label === val, + ) + if (idx === -1) return null + return ( + + ) + })} + + )} + + {/* Bars */} + + {data.map((d, idx) => { + const isHighlighted = + computedHighlightBarIndices.includes(idx) + const isMaxUnique = + uniqueMaxCount !== null && d.count === uniqueMaxCount + + if (isVertical) { + const barX = + margin.left + + xLabelScaleVertical(idx) - + barWidthVertical / 2 + const barHeight = plotHeight - yScaleVertical(d.count) + return ( + { + setHoveredBar(idx) + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: e.clientX - (svgRect?.left ?? 0), + y: e.clientY - (svgRect?.top ?? 0), + }) + }} + onMouseMove={(e) => { + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: e.clientX - (svgRect?.left ?? 0), + y: e.clientY - (svgRect?.top ?? 0), + }) + }} + onMouseLeave={() => { + setHoveredBar(null) + setMousePos(null) + }} + /> + ) + } + + return ( + { + setHoveredBar(idx) + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: e.clientX - (svgRect?.left ?? 0), + y: e.clientY - (svgRect?.top ?? 0), + }) + }} + onMouseMove={(e) => { + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: e.clientX - (svgRect?.left ?? 0), + y: e.clientY - (svgRect?.top ?? 0), + }) + }} + onMouseLeave={() => { + setHoveredBar(null) + setMousePos(null) + }} + /> + ) + })} + + + {/* Axes */} + {isVertical ? ( + xLabelScaleVertical(idx)} + yScale={yScaleVertical} + /> + ) : ( + xScaleHorizontal(xTicks[idx])} + yScale={() => 0} + /> + )} + + + {/* Tooltip rendered outside SVG, absolutely positioned */} + {hoveredBar !== null && data[hoveredBar] && mousePos && ( +
    + {/* Caret */} +
    +
    + Label: + {String(data[hoveredBar].label)} +
    +
    + Count: + {data[hoveredBar].count} +
    +
    + )} + + ) + }} + + ) + }, +) + +export default ResponsiveFrequencyChart diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ResponsiveMetricChart.tsx b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ResponsiveMetricChart.tsx new file mode 100644 index 0000000000..55594d96d4 --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/ResponsiveMetricChart.tsx @@ -0,0 +1,634 @@ +import {FC, memo, useState} from "react" + +import type {ChartDatum} from "../types" + +import {ChartAxis} from "./ChartAxis" +import ChartFrame from "./ChartFrame" +import {format3Sig} from "./utils" + +interface ResponsiveMetricChartProps { + chartData: ChartDatum[] + extraDimensions: Record + highlightValue?: number + labelWidth?: number + direction?: "horizontal" | "vertical" + dynamicMargin?: Partial<{top: number; right: number; bottom: number; left: number}> + /** Optional: color for bars (also used for highlight). Default keeps current blue. */ + barColor?: string + /** Optional: when true, disables gradient and uses a solid color for bars. */ + disableGradient?: boolean +} + +const DEFAULT_PRIMARY = "#69B1FF" + +/** + * ResponsiveMetricChart is a functional component that renders a responsive histogram + * visualization using SVG. This chart displays data as bars with optional highlighted + * bins, reference lines, and tooltips for detailed information. The chart adapts to + * its container's size and provides scale functions for accurately positioning elements. + */ +/** + * ResponsiveMetricChart is a functional component that renders a responsive histogram + * visualization using SVG. This chart displays data as bars with optional highlighted + * bins, reference lines, and tooltips for detailed information. The chart adapts to + * its container's size and provides scale functions for accurately positioning elements. + * + * The highlighted bin is automatically inferred from highlightValue (if provided). + */ +const ResponsiveMetricChart: FC = memo( + ({ + chartData, + extraDimensions, + highlightValue, + labelWidth, + direction = "horizontal", + dynamicMargin: dynamicPropsMargin, + barColor, + disableGradient = false, + }) => { + const binSize = extraDimensions.binSize || 1 + const yMin = Math.min(...(chartData.map((d) => d.edge) as number[])) + const yMax = Math.max(...(chartData.map((d) => d.edge) as number[])) + binSize + const xMax = Math.max(...chartData.map((d) => d.value)) + + // Y axis: bin midpoints + const yTicks: number[] = chartData.map((d) => (d.edge ?? 0) + binSize / 2) + // X axis: value ticks + const xTicks: number[] = [] + const xTickCount = Math.min(4, xMax) + for (let i = 0; i <= xTickCount; i++) { + xTicks.push((i / xTickCount) * xMax) + } + + const clipPathId = `clip-histogram-${Math.random().toString(36).substr(2, 9)}` + // Tooltip state + const [hoveredBin, setHoveredBin] = useState(null) + const [mousePos, setMousePos] = useState<{x: number; y: number} | null>(null) + + // Compute highlighted bin index from highlightValue + let computedHighlightBinIndex: number | null = null + if (typeof highlightValue === "number" && chartData.length > 0) { + const roundTo = (n: number, digits: number) => { + const factor = Math.pow(10, digits) + return Math.round(n * factor) / factor + } + const DECIMALS = 6 + computedHighlightBinIndex = chartData.findIndex((d, i) => { + const binStart = d.edge ?? 0 + const binEnd = (d.edge ?? 0) + binSize + if (i === chartData.length - 1) { + // Last bin: inclusive of upper edge, round both values for robust comparison + const closeEnough = Math.abs(highlightValue - binEnd) < Math.pow(10, -DECIMALS) + return ( + roundTo(highlightValue, DECIMALS) >= roundTo(binStart, DECIMALS) && + (roundTo(highlightValue, DECIMALS) <= roundTo(binEnd, DECIMALS) || + closeEnough) + ) + } + // Other bins: upper edge exclusive, round for robust comparison + return ( + roundTo(highlightValue, DECIMALS) >= roundTo(binStart, DECIMALS) && + roundTo(highlightValue, DECIMALS) < roundTo(binEnd, DECIMALS) + ) + }) + if (computedHighlightBinIndex === -1) computedHighlightBinIndex = null + } + + // Dynamically calculate left margin for long y-labels + const yLabelsFormatted = yTicks.map(format3Sig) + const defaultMargin = {top: 16, right: 16, bottom: 32, left: 40} + let dynamicMargin = defaultMargin + if (direction === "horizontal") { + const longestLabelLength = yLabelsFormatted.reduce( + (max, label) => Math.max(max, String(label).length), + 0, + ) + const dynamicLeftMargin = Math.max(40, Math.min(120, longestLabelLength * 7 + 16)) + dynamicMargin = {...defaultMargin, left: dynamicLeftMargin} + } else { + const yAxisLabels = xTicks.map(format3Sig) + const longestLeft = yAxisLabels.reduce( + (max, label) => Math.max(max, String(label).length), + 0, + ) + const dynamicLeftMargin = Math.max(40, Math.min(120, longestLeft * 7 + 16)) + const xAxisLabels = yTicks.map(format3Sig) + const longestBottom = xAxisLabels.reduce( + (max, label) => Math.max(max, String(label).length), + 0, + ) + const dynamicBottomMargin = Math.max(32, Math.min(120, longestBottom * 7 + 16)) + dynamicMargin = { + ...defaultMargin, + left: dynamicLeftMargin, + bottom: dynamicBottomMargin, + ...dynamicPropsMargin, + } + } + + // NEW: resolve fills (keep defaults) + const baseSolid = barColor || DEFAULT_PRIMARY + const baseFill = disableGradient ? baseSolid : "url(#barGradientBlue)" + const highlightFill = barColor || DEFAULT_PRIMARY + + return ( +
    + + {({svgWidth, svgHeight, plotWidth, plotHeight, margin}) => { + // Scales for both orientations + const xScaleHorizontal = (value: number) => (value / xMax) * plotWidth + const yScaleHorizontal = (value: number) => + ((yMax - value) / (yMax - yMin)) * plotHeight + + const xScaleVertical = (value: number) => + ((value - yMin) / (yMax - yMin)) * plotWidth + const yScaleVertical = (value: number) => + ((xMax - value) / xMax) * plotHeight + + const isVertical = direction === "vertical" + const xScale = isVertical ? xScaleVertical : xScaleHorizontal + const yScale = isVertical ? yScaleVertical : yScaleHorizontal + + return ( + <> + + {/* Bar gradient (use barColor if provided) */} + + {!disableGradient && ( + + + + + )} + + + {/* Bin size overlay */} + {typeof extraDimensions.binSize === "number" && ( + + bin {format3Sig(extraDimensions.binSize)} + + )} + + {/* Grid lines */} + + {(isVertical ? xTicks : yTicks).map((tick) => ( + + ))} + + + {/* Histogram bars */} + + {chartData.map((d, idx) => { + const isHighlighted = idx === computedHighlightBinIndex + if (isVertical) { + const barLeft = + margin.left + xScaleVertical(d.edge as number) + const barRight = + margin.left + xScaleVertical(d.edge + binSize) + const barWidth = Math.abs(barRight - barLeft) + const barHeight = + plotHeight - yScaleVertical(d.value) + return ( + + + { + setHoveredBin(idx) + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: + e.clientX - + (svgRect?.left ?? 0), + y: + e.clientY - + (svgRect?.top ?? 0), + }) + }} + onMouseMove={(e) => { + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: + e.clientX - + (svgRect?.left ?? 0), + y: + e.clientY - + (svgRect?.top ?? 0), + }) + }} + onMouseLeave={() => { + setHoveredBin(null) + setMousePos(null) + }} + /> + + ) + } + + const barTop = + margin.top + yScaleHorizontal(d.edge + binSize) + const barBottom = + margin.top + yScaleHorizontal(d.edge as number) + const barHeight = Math.abs(barBottom - barTop) + const rawBarWidth = xScaleHorizontal(d.value) + const barWidth = Math.min(rawBarWidth, plotWidth) + return ( + + + { + setHoveredBin(idx) + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: e.clientX - (svgRect?.left ?? 0), + y: e.clientY - (svgRect?.top ?? 0), + }) + }} + onMouseMove={(e) => { + const svgRect = ( + e.target as SVGRectElement + ).ownerSVGElement?.getBoundingClientRect() + setMousePos({ + x: e.clientX - (svgRect?.left ?? 0), + y: e.clientY - (svgRect?.top ?? 0), + }) + }} + onMouseLeave={() => { + setHoveredBin(null) + setMousePos(null) + }} + /> + + ) + })} + + + {/* Reference lines */} + {typeof extraDimensions.mean === "number" && + (isVertical ? ( + + + + {`μ=${format3Sig(extraDimensions.mean)}`} + + + ) : ( + + + + {`μ=${format3Sig(extraDimensions.mean)}`} + + + ))} + + {typeof highlightValue === "number" && + highlightValue !== extraDimensions.mean && + (isVertical ? ( + + + + {format3Sig(highlightValue)} + + + {format3Sig(highlightValue)} + + + ) : ( + + + + {format3Sig(highlightValue)} + + + {format3Sig(highlightValue)} + + + ))} + + {/* Y-axis */} + + + + + {/* X/Y Axes */} + {/* + y-axis: categorical labels (formatted bin midpoints) + const yLabels = yTicks.map(format3Sig) + const yLabelScale = (idx: number) => ((yTicks.length - idx - 0.5) * (plotHeight / yTicks.length)) + */} + + isVertical + ? xScaleVertical(yTicks[idx]) + : xScaleHorizontal(xTicks[idx]) + } + yScale={isVertical ? yScaleVertical : yScaleHorizontal} + yLabels={isVertical ? undefined : yTicks.map(format3Sig)} + yLabelScale={ + isVertical + ? undefined + : (idx: number) => + (yTicks.length - idx - 0.5) * + (plotHeight / yTicks.length) + } + /> + + + {/* Tooltip outside SVG */} + {hoveredBin !== null && + chartData[hoveredBin] && + mousePos && + (() => { + const total = chartData.reduce((sum, d) => sum + d.value, 0) + const count = chartData[hoveredBin].value + const percent = total > 0 ? (count / total) * 100 : 0 + const isHighlighted = + hoveredBin === computedHighlightBinIndex + return ( +
    + {/* Caret */} +
    +
    + Range: + + {format3Sig( + chartData[hoveredBin].edge as number, + )} + – + {format3Sig( + chartData[hoveredBin].edge + binSize, + )} + + {isHighlighted && ( + + Highlighted + + )} +
    +
    + Count:{" "} + {count} + + ({percent.toFixed(1)}%) + +
    +
    + ) + })()} + + ) + }} + +
    + ) + }, +) + +export default ResponsiveMetricChart diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/chartUtils.ts b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/chartUtils.ts new file mode 100644 index 0000000000..de84ed5fbd --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/chartUtils.ts @@ -0,0 +1,11 @@ +// Shared chart utility functions for both histogram and frequency charts + +export function getYTicks(yMax: number, nTicks = 3): number[] { + // Returns evenly spaced ticks from 0 to yMax + if (yMax === 0) return [0] + const step = yMax / (nTicks - 1) + return Array.from( + {length: nTicks}, + (_, i) => Math.round((i * step + Number.EPSILON) * 1000) / 1000, + ) +} diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils.ts b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils.ts new file mode 100644 index 0000000000..f08ccb95ff --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils.ts @@ -0,0 +1,170 @@ +import {ChartDatum, MetricFormatter} from "../types" + +/** + * Transforms the input data into an array of ChartDatum objects for chart rendering. + * - If `extra.distribution` is an array of numbers, returns them as ChartDatum with indices as names. + * - If `extra.distribution` is an array of objects, filters and maps them to ChartDatum based on `count` or `value`. + * - If `extra.percentiles` is an object, converts its entries to ChartDatum. + * - If `extra.iqrs` is an object, converts its entries to ChartDatum. + * + * @param {Record} extra - The input data containing distribution, percentiles, or iqrs. + * @returns {ChartDatum[]} An array of ChartDatum objects for use in charts. + */ +export const buildChartData = (extra: Record): ChartDatum[] => { + // distribution could be array of objects or numbers + // 1️⃣ Numeric histogram or object counts ------------------------------------ + if (Array.isArray(extra.distribution)) { + let data: ChartDatum[] = [] + if (extra.distribution.every((d) => typeof d === "number")) { + // If binSize & min are provided, label bins as ranges e.g. "0–0.1" + if (typeof extra.binSize === "number" && typeof extra.min === "number") { + data = extra.distribution.map((v: number, idx: number) => { + const minNum = Number(extra.min) + const start = minNum + idx * extra.binSize + const end = start + extra.binSize + return { + name: `${format3Sig(start)}–${format3Sig(end)}`, + value: v, + edge: start, // Changed from end to start + } + }) + } else { + data = extra.distribution.map((v: number, idx: number) => ({ + name: String(idx), + value: v, + })) + } + } else if (extra.distribution.every((d) => typeof d === "object" && d != null)) { + if (extra.distribution.every((d: any) => typeof d.value === "number")) { + // If binSize & min are provided, label bins as ranges e.g. "0–0.1" + if (typeof extra.binSize === "number" && typeof extra.min === "number") { + data = extra.distribution.map((d: any, idx: number) => { + const minNum = Number(extra.min) + const start = minNum + idx * extra.binSize + const end = start + extra.binSize + + return { + name: `${format3Sig(start)}–${format3Sig(end)}`, + value: d.count ?? d.value ?? 0, + edge: start, + } + }) + } else { + data = extra.distribution.map((d: any) => ({ + name: String(d.value), + value: d.count ?? d.value ?? 0, + })) + } + } else { + data = extra.distribution + .filter((d: any) => (d.count ?? d.value ?? 0) > 0) + .map((d: any, idx: number) => ({ + name: + typeof d.value === "number" + ? Number(d.value).toPrecision(3) + : String(idx), + value: Number(d.count ?? d.value ?? 0), + })) + } + } + // If we only have a single point, add a zero baseline to avoid Recharts/decimal.js errors + if (data.length === 1) { + data = [{name: "", value: 0}, ...data] + } + return data + } + + // 2️⃣ Categorical metrics: use frequency (all labels) falling back to rank --- + const catArray = Array.isArray(extra.frequency) + ? extra.frequency + : Array.isArray(extra.rank) + ? extra.rank + : null + if (Array.isArray(catArray)) { + const sorted = [...catArray].sort((a: any, b: any) => (b.count ?? 0) - (a.count ?? 0)) + return sorted.map((d: any) => ({name: String(d.value), value: Number(d.count ?? 0)})) + } + + // 3️⃣ Percentiles / IQRs ---------------------------------------------------- + if (extra.percentiles && typeof extra.percentiles === "object") { + return Object.entries(extra.percentiles).map(([k, v]) => ({name: k, value: Number(v)})) + } + if (extra.iqrs && typeof extra.iqrs === "object") { + return Object.entries(extra.iqrs).map(([k, v]) => ({name: k, value: Number(v)})) + } + return [] +} + +/** + * Registry mapping metric keys (full string match or RegExp string) to a formatter. + * Extend this map according to your metric naming conventions. + */ +export const METRIC_FORMATTERS: Record = { + // currency-like costs + cost: {prefix: "$", decimals: 6}, + costs: {prefix: "$", decimals: 6}, + price: {prefix: "$", decimals: 4}, + totalCost: {prefix: "$", decimals: 4}, + "attributes.ag.metrics.costs.cumulative.total": {prefix: "$", decimals: 4}, + // latency + latency: {decimals: 2, suffix: "s", multiplier: 0.001}, + duration: {decimals: 2, suffix: "s", multiplier: 0.001}, + "duration.total": {decimals: 2, suffix: "s", multiplier: 0.001}, + "attributes.ag.metrics.duration.cumulative": {decimals: 2, suffix: "s", multiplier: 0.001}, + "attributes.ag.metrics.tokens.cumulative.total": {decimals: 0}, + "attributes.ag.metrics.errors.cumulative": {decimals: 0}, + + // percentages + accuracy: {suffix: "%", decimals: 2}, + recall: {suffix: "%", decimals: 2}, + precision: {suffix: "%", decimals: 2}, +} + +export const format3Sig = (num: number | string): string => { + if (typeof num !== "number") return String(num) + if (!Number.isFinite(num)) return String(num) + + const abs = Math.abs(num) + if (abs === 0) return "0" + + const exponent = Math.floor(Math.log10(abs)) + + // Use scientific notation if exponent >= 10 or <= -10 + if (exponent >= 10 || exponent <= -10) { + return num.toExponential(2) + } + + // Use fixed-point notation with 3 significant digits + const decimals = Math.max(0, 2 - exponent) + const fixed = num.toFixed(decimals) + + // Strip trailing zeros and possible trailing decimal point + return fixed.replace(/\.?0+$/, "") +} + +/** + * Format a metric value using the mapping above. + * Falls back to the raw value when the metric has no formatter or value is non-numeric. + */ +export function formatMetricValue(metricKey: string, value: number | string): string { + const fmt = METRIC_FORMATTERS[metricKey] || { + decimals: 2, + } + + if (Array.isArray(value)) { + return value.map((v) => { + return formatMetricValue(metricKey, v) + }) + } + if (!fmt) return String(value) + + if (fmt.format) { + return fmt.format(value) + } + + let num = typeof value === "number" ? value : Number(value) + num = fmt.multiplier ? num * fmt.multiplier : num + const rounded = + Number.isFinite(num) && fmt.decimals !== undefined ? format3Sig(num) : format3Sig(value) + return `${fmt.prefix ?? ""}${rounded}${fmt.suffix ?? ""}` +} diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/index.tsx b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/index.tsx new file mode 100644 index 0000000000..5a7c03a731 --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/index.tsx @@ -0,0 +1,387 @@ +import {memo, useCallback, useMemo, useState, type FC} from "react" + +import {Popover, Tag, Space} from "antd" +import clsx from "clsx" +import {useAtomValue} from "jotai" + +import {Expandable} from "@/oss/components/Tables/ExpandableCell" +import { + evalAtomStore, + runMetricsStatsCacheFamily, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {extractPrimitive, inferMetricType} from "@/oss/lib/metricUtils" + +import ResponsiveFrequencyChart from "./assets/ResponsiveFrequencyChart" +import ResponsiveMetricChart from "./assets/ResponsiveMetricChart" +import {buildChartData, format3Sig, formatMetricValue} from "./assets/utils" +import {MetricDetailsPopoverProps} from "./types" + +/** + * MetricDetailsPopover is a React functional component that provides a detailed view + * of metric information within a popover. It displays both a tabular representation + * of primitive metric entries and a chart visualization based on the provided metric + * data. The component determines the appropriate chart type dynamically and supports + * categorical and continuous data representations. + * + * Props: + * - metricKey: The key associated with the metric being displayed. + * - extraDimensions: Additional dimensions or metadata for the metric. + * - highlightValue: Optional value for highlighting in the chart. + * - hidePrimitiveTable: Boolean flag to toggle the visibility of the primitive table. + * - children: ReactNode elements to be rendered inside the popover trigger. + */ +const MetricDetailsPopover: FC = memo( + ({ + metricKey, + metricType, + extraDimensions, + highlightValue, + hidePrimitiveTable, + children, + className, + }) => { + const [open, setOpen] = useState(false) + const handleOpenChange = useCallback((v: boolean) => setOpen(v), []) + + const extraEntries = useMemo(() => Object.entries(extraDimensions), [extraDimensions]) + + const chartData = useMemo( + () => (open ? buildChartData(extraDimensions) : []), + [open, extraDimensions], + ) + + // Dynamically compute the pixel width required for Y-axis labels + const labelWidth = useMemo(() => { + if (!chartData.length) return 0 + const canvas = document.createElement("canvas") + const ctx = canvas.getContext("2d") + if (!ctx) return 0 + ctx.font = "10px Inter, sans-serif" // must match tick font + const max = Math.max(...chartData.map((d) => ctx.measureText(String(d.name)).width)) + return Math.ceil(max) + 8 // + padding + }, [chartData]) + + const primitiveEntries = useMemo(() => { + if (!open || hidePrimitiveTable) return [] + // const order = ["mean", "std", "min", "max", "count", "total", "binSize"] + const order = ["mean", "std", "min", "max", "count", "sum", "binSize", "unique", "rank"] + const allowed = new Set(order) + const _primitiveEntries = extraEntries + .filter(([k]) => allowed.has(k as string)) + .sort(([a], [b]) => { + const ia = order.indexOf(a as string) + const ib = order.indexOf(b as string) + const sa = ia === -1 ? Number.POSITIVE_INFINITY : ia + const sb = ib === -1 ? Number.POSITIVE_INFINITY : ib + + return sa - sb || (a as string).localeCompare(b as string) + }) + return _primitiveEntries + }, [open, hidePrimitiveTable, extraEntries]) + + const tableNode = useMemo(() => { + if (!primitiveEntries.length) return null + + return ( +
    + + {primitiveEntries.map(([k, v]) => ( + + + + + ))} + +
    {k} + {(() => { + if (Array.isArray(v)) { + const limit = 5 + if (k === "unique") { + const items = (v as any[]).slice(0, limit) + return ( +
    + {items.map((itm, idx) => ( + + {String(itm)} + + ))} + {v.length > limit && ( + + … + + )} +
    + ) + } + if ((k === "rank" || k === "frequency") && v.length) { + const items = (v as any[]).slice(0, limit) + return ( +
    + {items.map((o: any, idx) => ( + {`${o.value} (${o.count})`} + ))} + {v.length > limit && ( + + … + + )} +
    + ) + } + } + return formatMetricValue(metricKey, v as any) + })()} +
    + ) + }, [primitiveEntries, metricKey]) + + // Chart type logic + const isCategoricalChart = + Array.isArray(extraDimensions.distribution) || + Array.isArray(extraDimensions.rank) || + Array.isArray(extraDimensions.frequency) + const hasEdge = + chartData.length > 0 && Object.prototype.hasOwnProperty.call(chartData[0], "edge") + + const frequencyData = useMemo(() => { + // Only build for categorical/frequency charts without edge + if (isCategoricalChart && !hasEdge) { + // buildChartData returns [{ name, value }] but ResponsiveFrequencyChart expects [{ label, count }] + return buildChartData(extraDimensions).map((d) => ({ + label: d.name, + count: d.value, + })) + } + return [] + }, [extraDimensions, isCategoricalChart, hasEdge]) + + const chartNode = useMemo(() => { + if (!open) return null + // Histogram (hasEdge): use ResponsiveMetricChart + if (chartData.length > 0 && isCategoricalChart && hasEdge) { + return ( + + ) + } + // Frequency/categorical: use ResponsiveFrequencyChart + if (frequencyData.length > 0 && isCategoricalChart && !hasEdge) { + return ( + + ) + } + // No valid chart type available + return null + }, [chartData, isCategoricalChart, hasEdge, labelWidth, highlightValue, extraDimensions]) + + const content = useMemo( + () => ( +
    + {tableNode} + {chartNode} +
    + ), + [tableNode, chartNode], + ) + if (!extraEntries.length || metricType === "string") { + return <>{children} + } + + return ( +
    + + {children} + +
    + ) + }, +) + +MetricDetailsPopover.displayName = "MetricDetailsPopover" + +/** + * A wrapper component around MetricDetailsPopover that: + * - fetches run metrics using useEvaluationRunMetrics + * - computes a summary of the metric + * - passes the extra dimensions to the MetricDetailsPopover + * - conditionally renders the MetricDetailsPopover if the metric is not null + * + * @param scenarioId - the scenario ID + * @param runId - the run ID + * @param evaluatorSlug - the evaluator slug + * @param evaluatorMetricKey - the metric key + * @param hidePrimitiveTable - whether to hide the primitive table + * @param metricType - the type of the metric (optional) + */ +export const MetricDetailsPopoverWrapper = memo( + ({ + scenarioId, + runId, + evaluatorSlug, + evaluatorMetricKey, + hidePrimitiveTable = false, + metricType, + className, + statsOverride, + debug, + evaluator, + }: { + scenarioId?: string | null + runId: string + evaluatorSlug: string + evaluatorMetricKey: string + hidePrimitiveTable?: boolean + metricType?: string + evaluator?: EvaluatorDto + className?: string + statsOverride?: Record + debug?: boolean + }) => { + const metricKey = useMemo( + () => `${evaluatorSlug}.${evaluatorMetricKey}`, + [evaluatorSlug, evaluatorMetricKey], + ) + + const store = evalAtomStore() + + // Use run-scoped stats cache instead of global cache + const runStatsCache = useAtomValue(runMetricsStatsCacheFamily(runId), {store}) + const stats = statsOverride ?? runStatsCache?.[metricKey] + + const rawPrimitive = useMemo(() => extractPrimitive(stats), [stats]) + + const explicitTypeFromEvaluator = useMemo(() => { + return ( + evaluator?.metrics?.[evaluatorMetricKey]?.type || + evaluator?.metrics?.[evaluatorMetricKey]?.anyOf + ) + // as SchemaMetricType | undefined + }, [evaluator, evaluatorMetricKey]) + const resolvedMetricType = useMemo( + () => explicitTypeFromEvaluator ?? inferMetricType(rawPrimitive, metricType), + [explicitTypeFromEvaluator, rawPrimitive, metricType], + ) + + const summary = useMemo(() => { + if (!stats) return "N/A" + // Numeric metrics → mean + if (typeof (stats as any).mean === "number") { + return format3Sig(Number((stats as any).mean)) + } + // Boolean metrics → proportion of `true` + if (resolvedMetricType === "boolean" && Array.isArray((stats as any).frequency)) { + const trueEntry = (stats as any).frequency.find((f: any) => f.value === true) + const total = (stats as any).count ?? 0 + if (total) { + return ( +
    +
    +
    +
    + true + false +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    + {(((trueEntry?.count ?? 0) / total) * 100).toFixed(2)}% +
    +
    + ) + } + } + // Array metrics → show top 3 items + if (resolvedMetricType === "array" || resolvedMetricType === undefined) { + const items = + Array.isArray((stats as any).rank) && (stats as any).rank.length + ? (stats as any).rank + : Array.isArray((stats as any).unique) + ? (stats as any).unique.map((v: any) => ({value: v, count: undefined})) + : [] + const topItems = items.slice(0, 3) + return ( + + {topItems.map((it: any) => ( + + {String(it.value)} + {it.count !== undefined ? ` (${it.count})` : ""} + + ))} + + ) + } + // Categorical metrics → top rank + if (Array.isArray((stats as any).rank) && (stats as any).rank.length) { + const top = (stats as any).rank[0] + return `${top.value} (${top.count})` + } + if (Array.isArray((stats as any).unique) && (stats as any).unique.length) { + return `${(stats as any).unique.length} unique` + } + if (typeof (stats as any).count === "number") { + return (stats as any).count + } + return "–" + }, [stats, resolvedMetricType]) + + return stats ? ( + + {summary} + + ) : ( + "N/A" + ) + }, +) + +export default MetricDetailsPopover diff --git a/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/types.ts b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/types.ts new file mode 100644 index 0000000000..1d06b45a69 --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/MetricDetailsPopover/types.ts @@ -0,0 +1,36 @@ +import type {ReactNode} from "react" + +export interface MetricDetailsPopoverProps { + metricKey: string + primaryLabel?: string + primaryValue?: number | string + extraDimensions: Record + /** Value to highlight (bin/bar will be inferred from this value) */ + highlightValue?: number | string + /** Hide primitives key‒value table; useful for lightweight popovers */ + hidePrimitiveTable?: boolean + /** Force using edge-axis (for debugging) */ + hasEdge?: boolean + className?: string + children: ReactNode +} + +// helper to transform objects to chart data +export interface ChartDatum { + name: string | number + value: number + edge?: number +} + +export interface MetricFormatter { + /** String to prepend before the numeric value, e.g. "$" */ + prefix?: string + /** String to append after the numeric value, e.g. "%" */ + suffix?: string + /** Number of decimal places to round to. If undefined, value is not rounded */ + decimals?: number + /** Multiplier to apply before formatting */ + multiplier?: number + /** Optional custom formatter receives numeric value and returns formatted string */ + format?: (value: number | string) => string +} diff --git a/web/ee/src/components/HumanEvaluations/assets/SingleModelEvaluationHeader/index.tsx b/web/ee/src/components/HumanEvaluations/assets/SingleModelEvaluationHeader/index.tsx new file mode 100644 index 0000000000..9e33f2450b --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/SingleModelEvaluationHeader/index.tsx @@ -0,0 +1,328 @@ +import {useCallback, useEffect, useMemo, useState, memo} from "react" + +import {Export, Plus, Trash} from "@phosphor-icons/react" +import {Button, message, Space, Typography} from "antd" +import clsx from "clsx" +import dynamic from "next/dynamic" +import Link from "next/link" +import {useSWRConfig} from "swr" + +import {statusMapper} from "@/oss/components/pages/evaluations/cellRenderers/cellRenderers" +import useURL from "@/oss/hooks/useURL" +import {EvaluationType} from "@/oss/lib/enums" +import {calculateAvgScore} from "@/oss/lib/helpers/evaluate" +import {convertToCsv, downloadCsv} from "@/oss/lib/helpers/fileManipulations" +import {getEvaluationRunScenariosKey} from "@/oss/lib/hooks/useEvaluationRunScenarios" +import useEvaluations from "@/oss/lib/hooks/useEvaluations" +import {summarizeMetric} from "@/oss/lib/metricUtils" +import {EvaluationStatus} from "@/oss/lib/Types" +import {getAppValues} from "@/oss/state/app" + +import {SingleModelEvaluationHeaderProps} from "../../types" +import {EvaluationRow} from "../../types" +import {useStyles} from "../styles" +import {extractEvaluationStatus, getMetricSummaryValue} from "../utils" + +const NewEvaluationModal = dynamic(() => import("../../../pages/evaluations/NewEvaluation"), { + ssr: false, +}) +const DeleteEvaluationModal = dynamic( + () => import("@/oss/components/DeleteEvaluationModal/DeleteEvaluationModal"), + {ssr: false}, +) + +const SingleModelEvaluationHeader = ({ + viewType, + selectedRowKeys, + mergedEvaluations, + runMetricsMap, + setSelectedRowKeys, + isDeleteEvalModalOpen, + setIsDeleteEvalModalOpen, + selectedEvalRecord, + setSelectedEvalRecord, + scope, + projectURL, + activeAppId, + extractAppId, +}: SingleModelEvaluationHeaderProps) => { + const classes = useStyles() + const {appURL} = useURL() + const {cache} = useSWRConfig() + const {refetch, handleDeleteEvaluations: deleteEvaluations} = useEvaluations({ + withPreview: true, + types: [EvaluationType.single_model_test], + appId: activeAppId, + }) + + const [isEvalModalOpen, setIsEvalModalOpen] = useState(false) + const [isDeletingEvaluations, setIsDeletingEvaluations] = useState(false) + const [isScrolled, setIsScrolled] = useState(false) + + useEffect(() => { + if (viewType === "overview") return + + const handleScroll = () => { + setIsScrolled(window.scrollY > 180) + } + + window.addEventListener("scroll", handleScroll) + + return () => { + window.removeEventListener("scroll", handleScroll) + } + }, [viewType]) + + const selectedEvaluations = useMemo(() => { + return selectedEvalRecord + ? (() => { + const found = mergedEvaluations.find( + (e) => ("id" in e ? e.id : e.key) === selectedEvalRecord?.id, + ) + return found && "name" in found ? found.name : (found?.key ?? "") + })() + : mergedEvaluations + .filter((e) => selectedRowKeys.includes("id" in e ? e.id : e.key)) + .map((e) => ("name" in e ? e.name : e.key)) + .join(" | ") + }, [selectedEvalRecord, selectedRowKeys]) + + const handleDelete = useCallback( + async (ids: string[]) => { + setIsDeletingEvaluations(true) + try { + await deleteEvaluations(ids) + message.success( + ids.length > 1 ? `${ids.length} Evaluations Deleted` : "Evaluation Deleted", + ) + } catch (err) { + message.error("Failed to delete evaluations") + console.error(err) + } finally { + setIsDeletingEvaluations(false) + setIsDeleteEvalModalOpen(false) + setSelectedRowKeys([]) + } + }, + [deleteEvaluations], + ) + + const runStatus = useCallback( + (runId: string, status: EvaluationStatus, isLegacyEval: boolean) => { + if (isLegacyEval) { + const statusLabel = statusMapper({} as any)(status as EvaluationStatus) + .label as EvaluationStatus + return statusLabel + } + + const key = `${getEvaluationRunScenariosKey(runId)}-false` + const cachedData = cache.get(key) + const scenarios = cachedData?.data?.scenarios + + const {runStatus: _status} = extractEvaluationStatus(scenarios, status) + return _status == "success" ? "completed" : _status + }, + [cache], + ) + + const onExport = useCallback(() => { + const exportEvals = mergedEvaluations.filter((e) => + selectedRowKeys.some((selected) => selected === ("id" in e ? e.id : e.key)), + ) + + try { + if (exportEvals.length) { + const {currentApp} = getAppValues() + const filenameBase = + currentApp?.app_name || + (scope === "project" ? "all_applications" : "evaluations") + const filename = `${filenameBase.replace(/\s+/g, "_")}_human_annotation.csv` + + const rows = exportEvals.map((item) => { + const id = "id" in item ? item.id : item.key + const metrics = runMetricsMap?.[id] + const applicationName = (item as any)?.variants?.[0]?.appName || "-" + const applicationId = extractAppId(item as EvaluationRow) || "-" + + // Note: all the 'in' conditions here are for legacy eval + const row: Record = { + Name: "name" in item ? item.name : item.key, + Variant: `${item.variants?.[0]?.variantName} v${"revisions" in item ? item.revisions?.[0] : item.variants?.[0]?.revision}`, + "Test set": + "testset" in item + ? item.testset.name + : (item.testsets?.[0]?.name ?? ""), + Status: + runStatus(id, item.status, item.status.includes("EVALUATION")) || "", + // legacy eval + ...("resultsData" in item + ? {"Average score": `${calculateAvgScore(item) || 0}%`} + : {}), + ...((item as any).createdBy?.user?.username + ? {"Created by": (item as any).createdBy?.user?.username} + : {}), + "Created on": item.createdAt, + } + + if (scope === "project") { + row.Application = applicationName + row["Application ID"] = applicationId + } + + // Track metric keys consumed by evaluator loop so we don't duplicate + const consumedKeys = new Set() + + if ("evaluators" in item && Array.isArray(item.evaluators)) { + item.evaluators.forEach((ev: any) => { + const metricDefs = + ev.data?.service?.format?.properties?.outputs?.properties || {} + Object.entries(metricDefs).forEach( + ([metricKey, def]: [string, any]) => { + const fullKey = `${ev.slug}.${metricKey}` + consumedKeys.add(fullKey) + const stat = metrics?.[fullKey] + const value = summarizeMetric(stat, def?.type) + row[`${ev.name} ${metricKey}`] = + value !== undefined && value !== null ? value : "N/A" + }, + ) + }) + } + + if (metrics) { + Object.entries(metrics).forEach(([metricKey, stat]) => { + if (consumedKeys.has(metricKey)) return + const value = summarizeMetric(stat as any) + row[metricKey] = value !== undefined && value !== null ? value : "N/A" + }) + } + + return row + }) + + const headerSet = new Set() + rows.forEach((r) => Object.keys(r).forEach((h) => headerSet.add(h))) + const headers = Array.from(headerSet) + + const csvData = convertToCsv(rows, headers) + downloadCsv(csvData, filename) + setSelectedRowKeys([]) + } + } catch (error) { + message.error("Failed to export results. Please try again later") + } + }, [mergedEvaluations, selectedRowKeys, runMetricsMap, scope, extractAppId]) + + return ( + <> + {viewType === "overview" ? ( +
    + + Human Annotation + + {(() => { + const href = + scope === "app" + ? appURL + ? `${appURL}/evaluations?selectedEvaluation=human_annotation` + : undefined + : `${projectURL}/evaluations?selectedEvaluation=human_annotation` + + if (!href) return null + + return ( + + ) + })()} + + + {(scope === "app" && activeAppId) || scope === "project" ? ( + + ) : null} +
    + ) : ( +
    + {(scope === "app" && activeAppId) || scope === "project" ? ( + + ) : null} + + + + + +
    + )} + + {((scope === "app" && activeAppId) || scope === "project") && ( + { + setIsEvalModalOpen(false) + }} + onSuccess={() => { + setIsEvalModalOpen(false) + refetch() + }} + preview={true} + evaluationType={"human"} + /> + )} + + { + setIsDeleteEvalModalOpen(false) + setSelectedEvalRecord(undefined) + }} + onOk={async () => { + const idsToDelete = selectedEvalRecord + ? [selectedEvalRecord.id] + : selectedRowKeys.map((key) => key?.toString()) + await handleDelete(idsToDelete.filter(Boolean)) + }} + evaluationType={selectedEvaluations} + isMultiple={!selectedEvalRecord && selectedRowKeys.length > 0} + /> + + ) +} + +export default memo(SingleModelEvaluationHeader) diff --git a/web/ee/src/components/HumanEvaluations/assets/TableDropdownMenu/index.tsx b/web/ee/src/components/HumanEvaluations/assets/TableDropdownMenu/index.tsx new file mode 100644 index 0000000000..c7210088fe --- /dev/null +++ b/web/ee/src/components/HumanEvaluations/assets/TableDropdownMenu/index.tsx @@ -0,0 +1,138 @@ +import {memo, useMemo} from "react" + +import {MoreOutlined} from "@ant-design/icons" +import {Database, Note, Rocket, Trash} from "@phosphor-icons/react" +import {Dropdown, Button, MenuProps} from "antd" +import {useRouter} from "next/router" + +import {EvaluationStatus} from "@/oss/lib/Types" +import { + buildAppScopedUrl, + buildEvaluationNavigationUrl, + extractPrimaryInvocation, +} from "../../../pages/evaluations/utils" + +import {TableDropdownMenuProps} from "./types" + +const TableDropdownMenu = ({ + record, + evalType, + setSelectedEvalRecord, + setIsDeleteEvalModalOpen, + onVariantNavigation, + baseAppURL, + extractAppId, + scope, + projectURL, + resolveAppId, +}: TableDropdownMenuProps) => { + const router = useRouter() + const primaryInvocation = extractPrimaryInvocation(record) + const resolvedAppId = resolveAppId ? resolveAppId(record) : undefined + const targetAppId = resolvedAppId || primaryInvocation?.appId || extractAppId(record) + const variantId = primaryInvocation?.revisionId || record.variants?.[0]?.id + + const items: MenuProps["items"] = useMemo( + () => [ + { + key: "details", + label: "Open details", + icon: , + disabled: + [ + EvaluationStatus.PENDING, + EvaluationStatus.RUNNING, + EvaluationStatus.CANCELLED, + EvaluationStatus.INITIALIZED, + ].includes(record.status) || !targetAppId, + onClick: (e) => { + e.domEvent.stopPropagation() + if ( + evalType === "auto" && + ![ + EvaluationStatus.PENDING, + EvaluationStatus.RUNNING, + EvaluationStatus.CANCELLED, + EvaluationStatus.INITIALIZED, + ].includes(record.status) && + targetAppId + ) { + const evaluationId = "id" in record ? record.id : record.key + const suffix = + evalType === "auto" + ? `/evaluations/results/${evaluationId}` + : `/evaluations/single_model_test/${evaluationId}` + const pathname = buildEvaluationNavigationUrl({ + scope, + baseAppURL, + projectURL, + appId: targetAppId, + path: suffix, + }) + + if (scope === "project") { + router.push({ + pathname, + query: targetAppId ? {app_id: targetAppId} : undefined, + }) + } else { + router.push(pathname) + } + } + }, + }, + { + key: "variant", + label: "View variant", + icon: , + disabled: !variantId || !targetAppId, + onClick: (e) => { + e.domEvent.stopPropagation() + if (!variantId) return + onVariantNavigation({revisionId: variantId, appId: targetAppId || undefined}) + }, + }, + { + key: "view_testset", + label: "View test set", + icon: , + onClick: (e) => { + e.domEvent.stopPropagation() + router.push(`${projectURL}/testsets/${record.testsets?.[0]?.id}`) + }, + }, + {type: "divider"}, + { + key: "delete_eval", + label: "Delete", + icon: , + danger: true, + onClick: (e) => { + e.domEvent.stopPropagation() + setSelectedEvalRecord(record) + setIsDeleteEvalModalOpen(true) + }, + }, + ], + [ + setSelectedEvalRecord, + setIsDeleteEvalModalOpen, + record, + onVariantNavigation, + evalType, + targetAppId, + baseAppURL, + variantId, + projectURL, + primaryInvocation, + scope, + ], + ) + return ( + + + + ), + }, + { + content: ( +
    +
    +
    + 2/2 + What brings you here? +
    + +
    + + + + {( + survey?.questions[3] as MultipleSurveyQuestion + )?.choices?.map((role: string) => ( + + {role} + + ))} + + + + + + + + {( + survey?.questions[4] as MultipleSurveyQuestion + )?.choices?.map((choice: string) => ( + + {choice} + + ))} + + + + + {selectedHearAboutUsOption == "Other" && ( + + + + )} +
    +
    + + +
    + ), + }, + ] + }, [ + classes.container, + classes.formItem, + classes.mainContainer, + form, + formData?.companySize, + formData?.hearAboutUs, + formData?.userExperience, + formData?.userInterests?.length, + formData?.userRole, + handleStepOneFormData, + handleSubmitFormData, + selectedHearAboutUsOption, + survey?.questions, + ]) + + const showSurveyForm = Boolean(survey?.questions?.length) + const isSurveyLoading = loading && !error + + return ( + <> +
    + agenta-ai + + +
    + + + {showSurveyForm ? steps[currentStep]?.content : null} + + + ) +} + +export default PostSignupForm diff --git a/web/ee/src/components/PostSignupForm/assets/styles.ts b/web/ee/src/components/PostSignupForm/assets/styles.ts new file mode 100644 index 0000000000..6a13fc6acb --- /dev/null +++ b/web/ee/src/components/PostSignupForm/assets/styles.ts @@ -0,0 +1,32 @@ +import {createUseStyles} from "react-jss" + +import {JSSTheme} from "@/oss/lib/Types" + +export const useStyles = createUseStyles((theme: JSSTheme) => ({ + mainContainer: { + width: 400, + marginInline: "auto", + height: "82vh", + display: "flex", + flexDirection: "column", + justifyContent: "space-between", + }, + container: { + padding: theme.paddingLG, + display: "grid", + gap: 32, + borderRadius: theme.borderRadiusLG, + boxShadow: + "0px 9px 28px 8px #0000000D, 0px 3px 6px -4px #0000001F, 0px 6px 16px 0px #00000014", + border: "1px solid", + borderColor: theme.colorBorder, + }, + formItem: { + gap: 8, + "& > .ant-form-item-row": { + "& > .ant-form-item-label": { + fontWeight: theme.fontWeightMedium, + }, + }, + }, +})) diff --git a/web/ee/src/components/PostSignupForm/assets/types.d.ts b/web/ee/src/components/PostSignupForm/assets/types.d.ts new file mode 100644 index 0000000000..3313e82c02 --- /dev/null +++ b/web/ee/src/components/PostSignupForm/assets/types.d.ts @@ -0,0 +1,8 @@ +export interface FormDataType { + companySize?: string + userRole?: string + userExperience?: string + userInterests?: string[] + hearAboutUs?: string + hearAboutUsInputOption: string +} diff --git a/web/ee/src/components/PromptVersioningDrawer/PromptVersioningDrawer.tsx b/web/ee/src/components/PromptVersioningDrawer/PromptVersioningDrawer.tsx new file mode 100644 index 0000000000..12530e2b99 --- /dev/null +++ b/web/ee/src/components/PromptVersioningDrawer/PromptVersioningDrawer.tsx @@ -0,0 +1,152 @@ +import {Button, Divider, Drawer, Empty, Space, Typography} from "antd" +import dayjs from "dayjs" +import duration from "dayjs/plugin/duration" +import relativeTime from "dayjs/plugin/relativeTime" +import {createUseStyles} from "react-jss" + +import {useAppTheme} from "@/oss/components/Layout/ThemeContextProvider" +import ResultComponent from "@/oss/components/ResultComponent/ResultComponent" +import {IPromptRevisions} from "@/oss/lib/Types" +dayjs.extend(relativeTime) +dayjs.extend(duration) + +const {Text} = Typography + +interface StyleProps { + themeMode: "dark" | "light" +} + +interface PromptVersioningDrawerProps { + historyStatus: { + loading: boolean + error: boolean + } + setIsDrawerOpen: React.Dispatch> + isDrawerOpen: boolean + onStateChange: (isDirty: boolean) => void + setRevisionNum: (val: string) => void + promptRevisions: IPromptRevisions[] | undefined +} + +const useStyles = createUseStyles({ + historyContainer: ({themeMode}: StyleProps) => ({ + display: "flex", + flexDirection: "column", + padding: "10px 20px 20px", + margin: "20px 0", + borderRadius: 10, + backgroundColor: themeMode === "dark" ? "#1f1f1f" : "#fff", + color: themeMode === "dark" ? "#fff" : "#000", + borderColor: themeMode === "dark" ? "#333" : "#eceff1", + border: "1px solid", + boxShadow: `0px 4px 8px ${ + themeMode === "dark" ? "rgba(255, 255, 255, 0.1)" : "rgba(0, 0, 0, 0.1)" + }`, + }), + tagText: { + color: "#656d76", + fontSize: 12, + }, + revisionText: { + fontWeight: "bold", + }, + emptyContainer: { + marginTop: "4rem", + }, + divider: { + margin: "15px 0", + }, +}) + +const PromptVersioningDrawer: React.FC = ({ + historyStatus, + setIsDrawerOpen, + isDrawerOpen, + onStateChange, + setRevisionNum, + promptRevisions, +}) => { + const {appTheme} = useAppTheme() + const classes = useStyles({themeMode: appTheme} as StyleProps) + return ( + setIsDrawerOpen(false)} + > + {historyStatus.loading ? ( +
    + +
    + ) : historyStatus.error ? ( +
    + +
    + ) : ( + <> + {promptRevisions?.length ? ( + promptRevisions + ?.map((item: IPromptRevisions) => ( +
    +
    + + {`# ${item.revision}`} + + + + {dayjs(item.created_at).fromNow()} + +
    + + + + + +
    + Config Name: + {item.config.config_name} +
    +
    + Modified By: + {item.modified_by} +
    +
    + +
    +
    + )) + .reverse() + ) : ( + + )} + + )} +
    + ) +} + +export default PromptVersioningDrawer diff --git a/web/ee/src/components/SaveTestsetModal/SaveTestsetModal.tsx b/web/ee/src/components/SaveTestsetModal/SaveTestsetModal.tsx new file mode 100644 index 0000000000..6cba0bc8fe --- /dev/null +++ b/web/ee/src/components/SaveTestsetModal/SaveTestsetModal.tsx @@ -0,0 +1,86 @@ +import {useCallback, useState} from "react" + +import EnhancedModal from "@agenta/oss/src/components/EnhancedUIs/Modal" +import {Input, message} from "antd" + +import useFocusInput from "@/oss/hooks/useFocusInput" +import {createNewTestset} from "@/oss/services/testsets/api" + +import {SaveTestsetModalProps} from "./types" + +const SaveTestsetModal: React.FC = ({ + evaluation, + rows, + onSuccess, + ...props +}) => { + const [submitLoading, setSubmitLoading] = useState(false) + const [testsetName, setTestsetName] = useState("") + const {inputRef} = useFocusInput({isOpen: props.open as boolean}) + + const onClose = useCallback(() => { + setTestsetName("") + setSubmitLoading(false) + props.onCancel?.({} as any) + }, [props]) + + const handleSave = useCallback(() => { + try { + setSubmitLoading(true) + + const newRows = rows.map((row, index) => { + if (evaluation.testset.testsetChatColumn) { + return { + chat: evaluation.testset.csvdata[index].chat, + correct_answer: row.correctAnswer, + annotation: row.note, + } + } + return { + [row.inputs[0].input_name]: row.inputs[0].input_value, + correct_answer: row.correctAnswer, + annotation: row.note, + } + }) + + createNewTestset(testsetName, newRows) + .then(() => onSuccess?.(testsetName)) + .catch(console.error) + .finally(() => { + setSubmitLoading(false) + }) + } catch (error) { + console.error("Error creating testset:", error) + message.error("Failed to create testset. Please try again!") + } finally { + setSubmitLoading(false) + } + }, [rows, evaluation, testsetName, onSuccess]) + + return ( + { + if (open) { + inputRef.current?.input?.focus() + } + }} + {...props} + > + setTestsetName(e.target.value)} + value={testsetName} + className="my-3" + /> + + ) +} + +export default SaveTestsetModal diff --git a/web/ee/src/components/SaveTestsetModal/types.d.ts b/web/ee/src/components/SaveTestsetModal/types.d.ts new file mode 100644 index 0000000000..83a6a6bcf0 --- /dev/null +++ b/web/ee/src/components/SaveTestsetModal/types.d.ts @@ -0,0 +1,13 @@ +import {ModalProps} from "antd" + +import {Evaluation, EvaluationFlow, EvaluationScenario} from "@/oss/lib/Types" + +export interface EvaluationRow extends EvaluationScenario, Record { + evaluationFlow: EvaluationFlow +} + +export interface SaveTestsetModalProps extends ModalProps { + evaluation: Evaluation + rows: EvaluationRow[] + onSuccess: (testsetName: string) => void +} diff --git a/web/ee/src/components/Scripts/assets/CloudScripts.tsx b/web/ee/src/components/Scripts/assets/CloudScripts.tsx new file mode 100644 index 0000000000..b4cb916c02 --- /dev/null +++ b/web/ee/src/components/Scripts/assets/CloudScripts.tsx @@ -0,0 +1,47 @@ +import {useEffect} from "react" + +import {Crisp} from "crisp-sdk-web" +import Head from "next/head" +import Script from "next/script" + +import {getEnv} from "@/oss/lib/helpers/dynamicEnv" + +const CloudScripts = () => { + useEffect(() => { + const isCrispEnabled = !!getEnv("NEXT_PUBLIC_CRISP_WEBSITE_ID") + + if (!isCrispEnabled) { + return + } + + Crisp.configure(getEnv("NEXT_PUBLIC_CRISP_WEBSITE_ID")) + }, []) + + return ( + <> + + Agenta: The LLMOps platform. + + + +
    + +
    + + ) +} + +export default CloudScripts diff --git a/web/ee/src/components/SidePanel/Subscription.tsx b/web/ee/src/components/SidePanel/Subscription.tsx new file mode 100644 index 0000000000..c05efdca56 --- /dev/null +++ b/web/ee/src/components/SidePanel/Subscription.tsx @@ -0,0 +1,29 @@ +import {useMemo} from "react" + +import FreePlanBanner from "@/oss/components/Banners/BillingPlanBanner/FreePlanBanner" +import FreeTrialBanner from "@/oss/components/Banners/BillingPlanBanner/FreeTrialBanner" +import {isDemo} from "@/oss/lib/helpers/utils" +import {Plan} from "@/oss/lib/Types" +import {useSubscriptionData} from "@/oss/services/billing" + +const SidePanelSubscription = () => { + const {subscription} = useSubscriptionData() + + const isShowFreePlanBannerVisible = useMemo( + () => isDemo() && !subscription?.free_trial && subscription?.plan === Plan.Hobby, + [subscription], + ) + const isShowFreeTrialBannerVisible = useMemo( + () => isDemo() && subscription?.free_trial, + [subscription], + ) + + return ( +
    + {isShowFreePlanBannerVisible ? : null} + {isShowFreeTrialBannerVisible ? : null} +
    + ) +} + +export default SidePanelSubscription diff --git a/web/ee/src/components/pages/app-management/components/ApiKeyInput.tsx b/web/ee/src/components/pages/app-management/components/ApiKeyInput.tsx new file mode 100644 index 0000000000..ad44cdb4ea --- /dev/null +++ b/web/ee/src/components/pages/app-management/components/ApiKeyInput.tsx @@ -0,0 +1,61 @@ +import {useMemo, useState} from "react" + +import {Button, Input, Space, Typography, message} from "antd" + +import {isDemo} from "@/oss/lib/helpers/utils" +import {createApiKey} from "@/oss/services/apiKeys/api" +import {useOrgData} from "@/oss/state/org" + +const {Text} = Typography + +interface ApiKeyInputProps { + apiKeyValue: string + onApiKeyChange: React.Dispatch> +} + +const ApiKeyInput: React.FC = ({apiKeyValue, onApiKeyChange}) => { + const [isLoadingApiKey, setIsLoadingApiKey] = useState(false) + const {selectedOrg} = useOrgData() + + const workspaceId: string = useMemo( + () => selectedOrg?.default_workspace.id || "", + [selectedOrg], + ) + + const handleGenerateApiKey = async () => { + try { + setIsLoadingApiKey(true) + + if (workspaceId && isDemo()) { + const {data} = await createApiKey(workspaceId) + onApiKeyChange(data) + message.success("Successfully generated API Key") + } + } catch (error) { + console.error(error) + message.error("Unable to generate API Key") + } finally { + setIsLoadingApiKey(false) + } + } + + return ( + + Create or enter your API key + + onApiKeyChange(e.target.value)} + /> + + + + + ) +} + +export default ApiKeyInput diff --git a/web/ee/src/components/pages/app-management/components/DemoApplicationsSection.tsx b/web/ee/src/components/pages/app-management/components/DemoApplicationsSection.tsx new file mode 100644 index 0000000000..605ef9a773 --- /dev/null +++ b/web/ee/src/components/pages/app-management/components/DemoApplicationsSection.tsx @@ -0,0 +1,96 @@ +import {Button, Card, Flex, Space, Typography} from "antd" +import Image from "next/image" +import {createUseStyles} from "react-jss" + +import {JSSTheme} from "@/oss/lib/Types" +import {useOrgData} from "@/oss/state/org" +import {useProjectData} from "@/oss/state/project" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + demoAppCard: { + width: 400, + "& .ant-card-body": { + padding: theme.paddingSM, + "& span.ant-typography": { + textOverflow: "ellipsis", + fontSize: theme.fontSizeLG, + fontWeight: theme.fontWeightMedium, + lineHeight: theme.lineHeightLG, + color: "inherit", + }, + "& div.ant-typography": { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + color: theme.colorTextSecondary, + }, + }, + }, +})) + +const {Text, Title, Paragraph} = Typography + +const DemoApplicationsSection = () => { + const classes = useStyles() + const {projects} = useProjectData() + const {changeSelectedOrg} = useOrgData() + + const handleViewDemoSwitch = () => { + const project = projects.find((p) => !!p.is_demo) + if (project && project.organization_id) { + changeSelectedOrg(project.organization_id) + } + } + + return ( +
    + + Explore demo applications + + See Agenta in action by exploring fully build prompts, evaluations, + observability and traces. Learn how to set your application by watching + tutorials. + + + +
    + + } + > + + + RAG Q&A with Wikipedia + + Use RAG to answer questions by fetching relevant information from + wikipedia + + + + + + + + +
    +
    + ) +} + +export default DemoApplicationsSection diff --git a/web/ee/src/components/pages/app-management/components/ObservabilityDashboardSection.tsx b/web/ee/src/components/pages/app-management/components/ObservabilityDashboardSection.tsx new file mode 100644 index 0000000000..d41ea8d433 --- /dev/null +++ b/web/ee/src/components/pages/app-management/components/ObservabilityDashboardSection.tsx @@ -0,0 +1,180 @@ +import {useMemo, type ComponentProps} from "react" + +import {AreaChart} from "@tremor/react" +import {Spin, Typography} from "antd" +import round from "lodash/round" +import {createUseStyles} from "react-jss" + +import {formatCurrency, formatLatency, formatNumber} from "@/oss/lib/helpers/formatters" +import {JSSTheme} from "@/oss/lib/Types" + +import {useObservabilityDashboard} from "../../../../state/observability" +import WidgetCard from "../../observability/dashboard/widgetCard" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + margin: "1.5rem 0", + display: "flex", + "& .ant-spin-nested-loading": { + width: "100%", + }, + }, + statText: { + "& span.ant-typography": { + fontSize: theme.fontSize, + lineHeight: theme.lineHeight, + fontWeight: "normal", + color: theme.colorTextSecondary, + }, + "& > span": { + fontWeight: theme.fontWeightMedium, + }, + }, + widgetContainer: { + display: "grid", + gridTemplateColumns: "repeat(2, 1fr)", + gap: 16, + "@media (min-width: 1360px)": { + gridTemplateColumns: "repeat(4, 1fr)", + }, + "@media (max-width: 850px)": { + gridTemplateColumns: "repeat(1, 1fr)", + }, + }, +})) + +const ObservabilityDashboardSection = () => { + const classes = useStyles() + const {data, loading, isFetching} = useObservabilityDashboard() + + const chartData = useMemo(() => (data?.data?.length ? data.data : [{}]), [data]) + + const defaultGraphProps = useMemo>( + () => ({ + className: "h-[168px] p-0", + colors: ["cyan", "red"], + connectNulls: true, + tickGap: 15, + curveType: "monotone", + showGridLines: false, + showLegend: false, + index: "timestamp", + data: chartData, + categories: [], + }), + [chartData], + ) + + return ( +
    + +
    +
    + + Total:{" "} + + {data?.total_count ? formatNumber(data?.total_count) : "-"} + +
    + } + rightSubHeading={ + (data?.failure_rate ?? 0) > 0 && ( +
    + Failed:{" "} + + {" "} + {data?.failure_rate + ? `${formatNumber(data?.failure_rate)}%` + : "-"} + +
    + ) + } + > + 0 + ? ["success_count", "failure_count"] + : ["success_count"] + } + /> + +
    +
    + + Avg:{" "} + + {data?.avg_latency + ? `${formatNumber(data.avg_latency)}ms` + : "-"} + +
    + } + > + + +
    +
    + + Total:{" "} + + {data?.total_cost ? formatCurrency(data.total_cost) : "-"} + +
    + } + rightSubHeading={ +
    + Avg:{" "} + + {data?.total_cost ? formatCurrency(data.avg_cost) : "-"} + +
    + } + > + + +
    +
    + + Total:{" "} + + {" "} + {data?.total_tokens + ? formatNumber(data?.total_tokens) + : "-"} + +
    + } + rightSubHeading={ +
    + Avg:{" "} + + {" "} + {data?.avg_tokens ? formatNumber(data?.avg_tokens) : "-"} + +
    + } + > + + +
    +
    + +
    + ) +} + +export default ObservabilityDashboardSection diff --git a/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorModal.tsx b/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorModal.tsx new file mode 100644 index 0000000000..b7b4b54b00 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorModal.tsx @@ -0,0 +1,77 @@ +import {ExclamationCircleOutlined} from "@ant-design/icons" +import {Collapse, Modal, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {JSSTheme} from "@/oss/lib/Types" + +interface EvaluationErrorModalProps { + isErrorModalOpen: boolean + setIsErrorModalOpen: (value: React.SetStateAction) => void + modalErrorMsg: { + message: string + stackTrace: string + errorType: "invoke" | "evaluation" + } +} + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + errModalStackTrace: { + "& code": { + display: "block", + whiteSpace: "pre-wrap", + }, + maxHeight: 300, + overflow: "auto", + }, +})) + +const EvaluationErrorModal = ({ + isErrorModalOpen, + setIsErrorModalOpen, + modalErrorMsg, +}: EvaluationErrorModalProps) => { + const classes = useStyles() + + const errorText = + modalErrorMsg.errorType === "invoke" + ? "Failed to invoke the LLM application with the following exception:" + : "Failed to compute evaluation with the following exception:" + + return ( + + + Error + + } + onCancel={() => setIsErrorModalOpen(false)} + > + {errorText} + {modalErrorMsg.message && ( + {modalErrorMsg.message} + )} + {modalErrorMsg.stackTrace && ( + + {modalErrorMsg.stackTrace} + + ), + }, + ]} + /> + )} + + ) +} + +export default EvaluationErrorModal diff --git a/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorPopover.tsx b/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorPopover.tsx new file mode 100644 index 0000000000..642bc60c70 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorPopover.tsx @@ -0,0 +1,43 @@ +import {InfoCircleOutlined} from "@ant-design/icons" +import {Button, Popover, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {EvaluationError, JSSTheme, TypedValue} from "@/oss/lib/Types" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + errModalStackTrace: { + maxWidth: 300, + "& code": { + display: "block", + width: "100%", + }, + }, +})) + +const EvaluationErrorPopover = (result: { + result: TypedValue & { + error: null | EvaluationError + } +}) => { + const classes = useStyles() + + return ( + + {result.result.error?.stacktrace} + + } + title={result.result.error?.message} + > + + + ) +} + +export default EvaluationErrorPopover diff --git a/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorText.tsx b/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorText.tsx new file mode 100644 index 0000000000..017b065a3b --- /dev/null +++ b/web/ee/src/components/pages/evaluations/EvaluationErrorProps/EvaluationErrorText.tsx @@ -0,0 +1,19 @@ +import {Button, Typography} from "antd" + +interface EvaluationErrorTextProps { + text: string + handleOnClick: () => void +} + +const EvaluationErrorText = ({text, handleOnClick}: EvaluationErrorTextProps) => { + return ( + + {text}{" "} + + + ) +} + +export default EvaluationErrorText diff --git a/web/ee/src/components/pages/evaluations/EvaluationsView.tsx b/web/ee/src/components/pages/evaluations/EvaluationsView.tsx new file mode 100644 index 0000000000..69694f0205 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/EvaluationsView.tsx @@ -0,0 +1,160 @@ +import {useEffect, useMemo} from "react" + +import {Radio, Typography} from "antd" +import clsx from "clsx" +import dynamic from "next/dynamic" +import {useRouter} from "next/router" +import {createUseStyles} from "react-jss" +import {useLocalStorage} from "usehooks-ts" + +import {useAppId} from "@/oss/hooks/useAppId" +import {useQueryParam} from "@/oss/hooks/useQuery" +import {useBreadcrumbsEffect} from "@/oss/lib/hooks/useBreadcrumbs" +import {JSSTheme} from "@/oss/lib/Types" + +const AutoEvaluation = dynamic( + () => import("@/oss/components/pages/evaluations/autoEvaluation/AutoEvaluation"), + {ssr: false}, +) +const SingleModelEvaluation = dynamic( + () => import("@/oss/components/HumanEvaluations/SingleModelEvaluation"), + {ssr: false}, +) +const AbTestingEvaluation = dynamic( + () => import("@/oss/components/HumanEvaluations/AbTestingEvaluation"), + {ssr: false}, +) + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + display: "flex", + flexDirection: "column", + gap: theme.marginLG, + }, + title: { + fontSize: theme.fontSizeLG, + fontWeight: theme.fontWeightMedium, + lineHeight: theme.lineHeightHeading4, + }, +})) + +type EvaluationScope = "app" | "project" + +const formatLabel = (value: string) => value.replaceAll("_", " ") + +interface EvaluationsViewProps { + scope?: EvaluationScope +} + +const allowedOptionsByScope: Record> = { + app: [ + {value: "auto_evaluation", label: "Automatic"}, + {value: "human_annotation", label: "Human annotation"}, + {value: "human_ab_testing", label: "A/B Testing"}, + ], + project: [ + {value: "auto_evaluation", label: "Automatic"}, + {value: "human_annotation", label: "Human annotation"}, + ], +} + +const EvaluationsView = ({scope = "app"}: EvaluationsViewProps) => { + const classes = useStyles() + const router = useRouter() + const routeAppId = useAppId() + + const uniqueScopeKey = useMemo(() => { + if (scope !== "app") return "project" + if (!routeAppId) return "app" + const parts = routeAppId.split("-") + return parts[parts.length - 1] || "app" + }, [scope, routeAppId]) + + const [defaultKey, setDefaultKey] = useLocalStorage( + `${uniqueScopeKey}-last-visited-evaluation`, + "auto_evaluation", + ) + const [selectedEvaluation, setSelectedEvaluation] = useQueryParam( + "selectedEvaluation", + defaultKey, + ) + + // Ensure selected evaluation is valid for current scope + useEffect(() => { + const allowed = allowedOptionsByScope[scope].map((option) => option.value) + if (!selectedEvaluation || !router.query.selectedEvaluation) { + setSelectedEvaluation(defaultKey) + return + } + + if (!allowed.includes(selectedEvaluation)) { + const fallback = allowed.includes(defaultKey) ? defaultKey : allowed[0] + setSelectedEvaluation(fallback) + } + }, [ + selectedEvaluation, + defaultKey, + setSelectedEvaluation, + scope, + router.query.selectedEvaluation, + ]) + + useEffect(() => { + if (selectedEvaluation && selectedEvaluation !== defaultKey) { + setDefaultKey(selectedEvaluation) + } + }, [selectedEvaluation, defaultKey, setDefaultKey]) + + useBreadcrumbsEffect( + { + breadcrumbs: + scope === "app" + ? {appPage: {label: formatLabel(selectedEvaluation)}} + : {projectPage: {label: formatLabel(selectedEvaluation)}}, + type: "append", + condition: !!selectedEvaluation, + }, + [selectedEvaluation, scope, router.asPath], + ) + + const renderPage = useMemo(() => { + switch (selectedEvaluation) { + case "human_annotation": + return + case "human_ab_testing": + return scope === "app" ? ( + + ) : ( + + ) + case "auto_evaluation": + default: + return + } + }, [selectedEvaluation, scope]) + + const options = allowedOptionsByScope[scope] + + return ( +
    +
    + Evaluations + setSelectedEvaluation(e.target.value)} + > + {options.map((option) => ( + + {option.label} + + ))} + +
    + +
    {renderPage}
    +
    + ) +} + +export default EvaluationsView diff --git a/web/ee/src/components/pages/evaluations/FilterColumns/FilterColumns.tsx b/web/ee/src/components/pages/evaluations/FilterColumns/FilterColumns.tsx new file mode 100644 index 0000000000..d9d0c785d2 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/FilterColumns/FilterColumns.tsx @@ -0,0 +1,88 @@ +import {type ColDef} from "@ag-grid-community/core" +import {CheckOutlined, DownOutlined} from "@ant-design/icons" +import {Button, Dropdown, Space} from "antd" +import {ItemType} from "antd/es/menu/interface" +import {createUseStyles} from "react-jss" + +import {JSSTheme} from "@/oss/lib/Types" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + dropdownMenu: { + "&>.ant-dropdown-menu-item": { + "& .anticon-check": { + display: "none", + }, + }, + "&>.ant-dropdown-menu-item-selected": { + "&:not(:hover)": { + backgroundColor: "transparent !important", + }, + "& .anticon-check": { + display: "inline-flex !important", + }, + }, + }, +})) + +export const generateFilterItems = (colDefs: ColDef[]) => { + return colDefs.map((configs) => ({ + key: configs.headerName as string, + label: ( + + + <>{configs.headerName} + + ), + })) +} + +interface FilterColumnsProps { + isOpen: boolean + handleOpenChange: ( + open: boolean, + info: { + source: "trigger" | "menu" + }, + ) => void + shownCols: string[] + items: ItemType[] + onClick: ({key}: {key: string}) => void + buttonText?: string +} + +const FilterColumns = ({ + isOpen, + handleOpenChange, + shownCols, + items, + onClick, + buttonText, +}: FilterColumnsProps) => { + const classes = useStyles() + + return ( + + + + ) +} + +export default FilterColumns diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/Components/AdvancedSettings.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/AdvancedSettings.tsx new file mode 100644 index 0000000000..a0394ecc7c --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/AdvancedSettings.tsx @@ -0,0 +1,112 @@ +import {memo, useCallback, useMemo} from "react" + +import {QuestionCircleOutlined} from "@ant-design/icons" +import {Button, Col, Flex, Form, Input, InputNumber, Row, Tooltip, Typography} from "antd" +import deepEqual from "fast-deep-equal" + +import {DEFAULT_ADVANCE_SETTINGS} from "../assets/constants" +import {AdvancedSettingsProps} from "../types" + +const AdvancedSettings = ({advanceSettings, setAdvanceSettings}: AdvancedSettingsProps) => { + const handleChange = (key: string, value: any) => { + setAdvanceSettings((prev) => ({ + ...prev, + [key]: value, + })) + } + + const handleResetDefaults = useCallback(() => { + setAdvanceSettings(DEFAULT_ADVANCE_SETTINGS) + }, []) + + const isAdvancedSettingsChanged = useMemo( + () => !deepEqual(advanceSettings, DEFAULT_ADVANCE_SETTINGS), + [advanceSettings], + ) + + const {correct_answer_column, ...rateLimitConfig} = advanceSettings + + return ( + +
    + + + Rate Limit Configuration + + {isAdvancedSettingsChanged && ( + + )} +
    + } + style={{marginBottom: 0}} + > + + {Object.entries(rateLimitConfig).map(([key, value]) => ( + + + {key + .replace(/_/g, " ") + .replace(/\b\w/g, (c) => c.toUpperCase())} +   + + + + + } + rules={[ + { + validator: (_, value) => { + if (value !== null) { + return Promise.resolve() + } + return Promise.reject("This field is required") + }, + }, + ]} + > + handleChange(key, value)} + style={{width: "100%"}} + min={0} + /> + + + ))} + + + + Correct Answer Column  + + + + + } + > + handleChange("correct_answer_column", e.target.value)} + style={{width: "50%"}} + /> + + + + ) +} + +export default memo(AdvancedSettings) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/Components/NewEvaluationModalContent.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/NewEvaluationModalContent.tsx new file mode 100644 index 0000000000..be0b94bde4 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/NewEvaluationModalContent.tsx @@ -0,0 +1,294 @@ +import {type FC, memo, useMemo} from "react" + +import {CloseCircleOutlined} from "@ant-design/icons" +import {Input, Typography, Tabs, Tag} from "antd" +import clsx from "clsx" +import dynamic from "next/dynamic" + +import useFocusInput from "@/oss/hooks/useFocusInput" + +import {useStyles} from "../assets/styles" +import TabLabel from "../assets/TabLabel" +import {NewEvaluationModalContentProps} from "../types" + +import SelectAppSection from "./SelectAppSection" + +const SelectEvaluatorSection = dynamic( + () => import("./SelectEvaluatorSection/SelectEvaluatorSection"), + {ssr: false}, +) + +const SelectTestsetSection = dynamic(() => import("./SelectTestsetSection"), { + ssr: false, +}) + +const SelectVariantSection = dynamic(() => import("./SelectVariantSection"), { + ssr: false, +}) + +const AdvancedSettings = dynamic(() => import("./AdvancedSettings"), { + ssr: false, +}) + +const NewEvaluationModalContent: FC = ({ + onSuccess, + handlePanelChange, + activePanel, + selectedTestsetId, + setSelectedTestsetId, + selectedVariantRevisionIds, + setSelectedVariantRevisionIds, + selectedEvalConfigs, + setSelectedEvalConfigs, + evaluationName, + setEvaluationName, + preview, + evaluationType, + testSets, + variants, + variantsLoading, + evaluators, + evaluatorConfigs, + advanceSettings, + setAdvanceSettings, + appOptions, + selectedAppId, + onSelectApp, + appSelectionDisabled, + ...props +}) => { + const classes = useStyles() + const {inputRef} = useFocusInput({isOpen: props.isOpen || false}) + const appSelectionComplete = Boolean(selectedAppId) + + const selectedTestset = useMemo( + () => testSets.find((ts) => ts._id === selectedTestsetId) || null, + [testSets, selectedTestsetId], + ) + + const selectedVariants = useMemo( + () => variants?.filter((v) => selectedVariantRevisionIds.includes(v.id)) || [], + [variants, selectedVariantRevisionIds], + ) + + const selectedEvalConfig = useMemo(() => { + const source = preview ? (evaluators as any[]) : (evaluatorConfigs as any[]) + return source.filter((cfg) => selectedEvalConfigs.includes(cfg.id)) + }, [preview, evaluators, evaluatorConfigs, selectedEvalConfigs]) + + const items = useMemo(() => { + const requireAppMessage = ( + + Select an application first to load this section. + + ) + + return [ + { + key: "appPanel", + label: ( + + {appSelectionComplete && ( + } + onClose={() => { + if (!appSelectionDisabled) onSelectApp("") + }} + > + {appOptions.find((opt) => opt.value === selectedAppId)?.label ?? + selectedAppId} + + )} + + ), + children: ( +
    + + {!appSelectionComplete && !appSelectionDisabled ? ( + + Please select an application to continue configuring the evaluation. + + ) : null} +
    + ), + }, + { + key: "variantPanel", + label: ( + 0}> + {selectedVariants.map((v) => ( + } + onClose={() => { + setSelectedVariantRevisionIds( + selectedVariantRevisionIds.filter((id) => id !== v.id), + ) + }} + > + {`${v.variantName} - v${v.revision}`} + + ))} + + ), + children: appSelectionComplete ? ( + + ) : ( + requireAppMessage + ), + }, + { + key: "testsetPanel", + label: ( + + {selectedTestset ? ( + } + onClose={() => { + setSelectedTestsetId("") + }} + > + {selectedTestset.name} + + ) : null} + + ), + children: appSelectionComplete ? ( + + ) : ( + requireAppMessage + ), + }, + { + key: "evaluatorPanel", + label: ( + 0}> + {selectedEvalConfig.map((cfg: any) => { + return ( + } + color={cfg.color} + onClose={() => { + setSelectedEvalConfigs( + selectedEvalConfigs.filter((id) => id !== cfg.id), + ) + }} + > + {cfg.name} + + ) + })} + + ), + children: appSelectionComplete ? ( + + ) : ( + requireAppMessage + ), + }, + ...(evaluationType === "auto" + ? [ + { + key: "advancedSettingsPanel", + label: ( + + {Object.entries(advanceSettings).map(([key, value]) => ( + + {key}: {value} + + ))} + + ), + children: appSelectionComplete ? ( + + ) : ( + requireAppMessage + ), + }, + ] + : []), + ] + }, [ + selectedTestset, + selectedVariants, + selectedEvalConfig, + handlePanelChange, + selectedTestsetId, + selectedVariantRevisionIds, + selectedEvalConfigs, + preview, + evaluationType, + testSets, + variants, + evaluators, + evaluatorConfigs, + advanceSettings, + appSelectionComplete, + appOptions, + selectedAppId, + onSelectApp, + appSelectionDisabled, + ]) + + return ( +
    +
    + Evaluation name + { + setEvaluationName(e.target.value) + }} + /> +
    + + +
    + ) +} + +export default memo(NewEvaluationModalContent) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectAppSection.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectAppSection.tsx new file mode 100644 index 0000000000..890a9fa6a3 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectAppSection.tsx @@ -0,0 +1,118 @@ +import {HTMLProps, useMemo} from "react" + +import {Table, Tag, Typography} from "antd" +import type {ColumnsType} from "antd/es/table" + +import {formatDay} from "@/oss/lib/helpers/dateTimeHelper" + +import type {NewEvaluationAppOption} from "../types" + +const formatAppType = (type?: string | null) => { + if (!type) return null + const normalized = type.replace(/_/g, " ") + return normalized.charAt(0).toUpperCase() + normalized.slice(1) +} + +interface SelectAppSectionProps extends HTMLProps { + apps: NewEvaluationAppOption[] + selectedAppId: string + onSelectApp: (value: string) => void + disabled?: boolean +} + +const SelectAppSection = ({ + apps, + selectedAppId, + onSelectApp, + disabled, + className, +}: SelectAppSectionProps) => { + const columns: ColumnsType = useMemo(() => { + return [ + { + title: "Application", + dataIndex: "label", + key: "label", + render: (value: string) => {value}, + }, + { + title: "Type", + dataIndex: "type", + key: "type", + width: 160, + render: (value: string | null | undefined) => { + const label = formatAppType(value) + return label ? ( + {label} + ) : ( + + ) + }, + }, + { + title: "Created", + dataIndex: "createdAt", + key: "createdAt", + width: 240, + render: (value: string, record) => { + const displayDate = value || record.updatedAt || "" + return displayDate ? ( + + {formatDay({date: displayDate, outputFormat: "DD MMM YYYY | h:mm a"})} + + ) : ( + + ) + }, + }, + ] + }, []) + + const dataSource = useMemo( + () => + apps.map((app) => ({ + key: app.value, + ...app, + })), + [apps], + ) + + return ( +
    + (disabled ? "" : "cursor-pointer")} + onRow={(record) => ({ + onClick: () => { + if (disabled || record.value === selectedAppId) return + onSelectApp(record.value) + }, + })} + rowSelection={{ + type: "radio", + columnWidth: 48, + selectedRowKeys: selectedAppId ? [selectedAppId] : [], + onChange: (selectedRowKeys) => { + if (disabled) return + const [key] = selectedRowKeys + onSelectApp(key as string) + }, + getCheckboxProps: () => ({disabled}), + }} + locale={{ + emptyText: disabled + ? "Application selection is locked in app scope" + : "No applications available", + }} + /> + + ) +} + +export default SelectAppSection diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectEvaluatorSection/SelectEvaluatorSection.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectEvaluatorSection/SelectEvaluatorSection.tsx new file mode 100644 index 0000000000..b87ad02813 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectEvaluatorSection/SelectEvaluatorSection.tsx @@ -0,0 +1,360 @@ +import {memo, useEffect, useMemo, useRef, useState} from "react" + +import {PlusOutlined} from "@ant-design/icons" +import {Button, Input, Table, Tag, Space} from "antd" +import {ColumnsType} from "antd/es/table" +import clsx from "clsx" +import dynamic from "next/dynamic" + +import EnhancedDrawer from "@/oss/components/EnhancedUIs/Drawer" +import AnnotateDrawerTitle from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/AnnotateDrawerTitle" +import CreateEvaluator from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/CreateEvaluator" +import {AnnotateDrawerSteps} from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/enum" +import {getMetricsFromEvaluator} from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/transforms" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import useFetchEvaluatorsData from "@/oss/lib/hooks/useFetchEvaluatorsData" +import {Evaluator, EvaluatorConfig} from "@/oss/lib/Types" + +import type {SelectEvaluatorSectionProps} from "../../types" + +const EvaluatorsModal = dynamic( + () => import("../../../autoEvaluation/EvaluatorsModal/EvaluatorsModal"), + { + ssr: false, + loading: () => null, // Prevent flash by not rendering until loaded + }, +) +const NoResultsFound = dynamic(() => import("@/oss/components/NoResultsFound/NoResultsFound"), { + ssr: false, +}) + +const EvaluatorMetrics = memo(({evaluator}: {evaluator: EvaluatorDto<"response">}) => { + const metrics = getMetricsFromEvaluator(evaluator) + return ( +
    + {Object.entries(metrics).map(([key, value]) => { + return ( + + {key} + + ) + })} +
    + ) +}) + +// Use a generic type variable Preview and conditionally type filteredEvalConfigs +const SelectEvaluatorSection = ({ + selectedEvalConfigs, + setSelectedEvalConfigs, + className, + handlePanelChange, + preview, + evaluators: propsEvaluators, + evaluatorConfigs: propsEvaluatorConfigs, + selectedAppId, + ...props +}: SelectEvaluatorSectionProps & {preview?: Preview}) => { + const fetchData = useFetchEvaluatorsData({ + preview: preview as boolean, + queries: {is_human: preview}, + appId: selectedAppId || "", + }) + + const evaluationData = useMemo(() => { + if (preview) { + const evaluators = (propsEvaluators || + fetchData.evaluatorsSwr.data || + []) as EvaluatorDto<"response">[] + const evaluatorConfigs = evaluators + const isLoadingEvaluators = fetchData.isLoadingEvaluators + const isLoadingEvaluatorConfigs = fetchData.isLoadingEvaluatorConfigs + return {evaluators, evaluatorConfigs, isLoadingEvaluators, isLoadingEvaluatorConfigs} + } else { + const evaluators = propsEvaluators?.length + ? propsEvaluators + : ((fetchData.evaluatorsSwr.data || []) as Evaluator[]) + const evaluatorConfigs = (propsEvaluatorConfigs || + fetchData.evaluatorConfigsSwr.data || + []) as EvaluatorConfig[] + const isLoadingEvaluators = fetchData.isLoadingEvaluators + const isLoadingEvaluatorConfigs = fetchData.isLoadingEvaluatorConfigs + return {evaluators, evaluatorConfigs, isLoadingEvaluators, isLoadingEvaluatorConfigs} + } + }, [fetchData, preview, propsEvaluators, propsEvaluatorConfigs]) + + const {evaluators, evaluatorConfigs, isLoadingEvaluators, isLoadingEvaluatorConfigs} = + evaluationData + + const [searchTerm, setSearchTerm] = useState("") + const [isEvaluatorsModalOpen, setIsEvaluatorsModalOpen] = useState(false) + const [current, setCurrent] = useState(0) + const prevSelectedAppIdRef = useRef() + const {refetchEvaluatorConfigs} = fetchData + + useEffect(() => { + if (!selectedAppId) { + prevSelectedAppIdRef.current = selectedAppId + return + } + + if (prevSelectedAppIdRef.current === selectedAppId) { + return + } + + prevSelectedAppIdRef.current = selectedAppId + refetchEvaluatorConfigs() + }, [selectedAppId, refetchEvaluatorConfigs]) + + useEffect(() => { + if (isLoadingEvaluators || isLoadingEvaluatorConfigs) return + + const availableIds = new Set( + (preview + ? (evaluators as EvaluatorDto<"response">[]) + : (evaluatorConfigs as EvaluatorConfig[]) + ).map((config) => config.id), + ) + + setSelectedEvalConfigs((prevSelected) => { + const nextSelected = prevSelected.filter((id) => availableIds.has(id)) + return nextSelected.length === prevSelected.length ? prevSelected : nextSelected + }) + }, [ + preview, + evaluators, + evaluatorConfigs, + isLoadingEvaluators, + isLoadingEvaluatorConfigs, + setSelectedEvalConfigs, + ]) + + const columnsPreview: ColumnsType> = useMemo( + () => [ + { + title: "Name", + dataIndex: "name", + key: "name", + render: (_, record: EvaluatorDto<"response">) => { + return
    {record.name}
    + }, + }, + { + title: "Slug", + dataIndex: "slug", + key: "slug", + render: (_, record: EvaluatorDto<"response">) => { + return
    {record.slug}
    + }, + }, + { + title: "Metrics", + dataIndex: "data", + key: "data", + render: (_, record: EvaluatorDto<"response">) => ( + + ), + }, + ], + [], + ) + + const columnsConfig: ColumnsType = useMemo( + () => [ + { + title: "Name", + dataIndex: "name", + key: "name", + render: (_, record: EvaluatorConfig) => { + return
    {record.name}
    + }, + }, + { + title: "Type", + dataIndex: "type", + key: "type", + render: (x, record: EvaluatorConfig) => { + // Find the evaluator by key to display its name + const evaluator = (evaluators as Evaluator[]).find( + (item) => item.key === record.evaluator_key, + ) + return {evaluator?.name} + }, + }, + ], + [evaluators], + ) + + // Conditionally type filteredEvalConfigs based on Preview + const filteredEvalConfigs: Preview extends true + ? EvaluatorDto<"response">[] + : EvaluatorConfig[] = useMemo(() => { + if (preview) { + // Explicitly narrow types for Preview = true + const data = evaluators as EvaluatorDto<"response">[] + if (!searchTerm) return data + return data.filter((item) => + item.name.toLowerCase().includes(searchTerm.toLowerCase()), + ) as any + } else { + // Explicitly narrow types for Preview = false + const data = evaluatorConfigs as EvaluatorConfig[] + if (!searchTerm) return data + return data.filter((item) => + item.name.toLowerCase().includes(searchTerm.toLowerCase()), + ) as any + } + }, [searchTerm, evaluatorConfigs, preview, evaluators]) + + const selectedEvalConfig = useMemo( + () => evaluatorConfigs.filter((config) => selectedEvalConfigs.includes(config.id)), + [evaluatorConfigs, selectedEvalConfigs], + ) + + const onSelectEvalConfig = (selectedRowKeys: React.Key[]) => { + const currentSelected = new Set(selectedEvalConfigs) + const configs = filteredEvalConfigs as EvaluatorDto<"response">[] + configs.forEach((item) => { + if (selectedRowKeys.includes(item.id)) { + currentSelected.add(item.id) + } else { + currentSelected.delete(item.id) + } + }) + setSelectedEvalConfigs(Array.from(currentSelected)) + } + + return ( + <> +
    +
    + setSearchTerm(e.target.value)} + /> + + + +
    + + {filteredEvalConfigs.length === 0 ? ( + { + setCurrent(1) + setIsEvaluatorsModalOpen(true) + }} + /> + ) : preview ? ( + > + rowSelection={{ + type: "checkbox", + columnWidth: 48, + selectedRowKeys: selectedEvalConfigs, + onChange: (selectedRowKeys) => { + onSelectEvalConfig(selectedRowKeys) + }, + }} + onRow={(record) => ({ + style: {cursor: "pointer"}, + onClick: () => { + if (selectedEvalConfigs.includes(record.id)) { + onSelectEvalConfig( + selectedEvalConfigs.filter((id) => id !== record.id), + ) + } else { + onSelectEvalConfig([...selectedEvalConfigs, record.id]) + } + }, + })} + className="ph-no-capture" + columns={columnsPreview} + rowKey={"id"} + dataSource={filteredEvalConfigs as EvaluatorDto<"response">[]} + scroll={{x: true}} + bordered + pagination={false} + /> + ) : ( + + rowSelection={{ + type: "checkbox", + columnWidth: 48, + selectedRowKeys: selectedEvalConfigs, + onChange: (selectedRowKeys) => { + onSelectEvalConfig(selectedRowKeys) + }, + }} + onRow={(record) => ({ + style: {cursor: "pointer"}, + onClick: () => { + if (selectedEvalConfigs.includes(record.id)) { + onSelectEvalConfig( + selectedEvalConfigs.filter((id) => id !== record.id), + ) + } else { + onSelectEvalConfig([...selectedEvalConfigs, record.id]) + } + }, + })} + className="ph-no-capture" + columns={columnsConfig} + rowKey={"id"} + dataSource={filteredEvalConfigs as EvaluatorConfig[]} + scroll={{x: true, y: 455}} + bordered + pagination={false} + /> + )} +
    + + {preview ? ( + setIsEvaluatorsModalOpen(false)} + onClose={() => setIsEvaluatorsModalOpen(false)} + /> + } + closeIcon={null} + width={400} + onClose={() => setIsEvaluatorsModalOpen(false)} + classNames={{body: "!p-0", header: "!p-4"}} + > + { + setSelectedEvalConfigs(updater) + setIsEvaluatorsModalOpen(false) + }} + /> + + ) : ( + setIsEvaluatorsModalOpen(false)} + current={current} + setCurrent={setCurrent} + appId={selectedAppId || null} + openedFromNewEvaluation={true} + /> + )} + + ) +} + +export default memo(SelectEvaluatorSection) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectTestsetSection.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectTestsetSection.tsx new file mode 100644 index 0000000000..0a73c0c42b --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectTestsetSection.tsx @@ -0,0 +1,137 @@ +import {memo, useMemo, useState} from "react" + +import {Input} from "antd" +import Table, {ColumnsType} from "antd/es/table" +import clsx from "clsx" +import dayjs from "dayjs" +import dynamic from "next/dynamic" + +import {formatDate, formatDay} from "@/oss/lib/helpers/dateTimeHelper" +import {testset} from "@/oss/lib/Types" +import {useTestsetsData} from "@/oss/state/testset" + +import type {SelectTestsetSectionProps} from "../types" + +const NoResultsFound = dynamic(() => import("@/oss/components/NoResultsFound/NoResultsFound"), { + ssr: false, +}) + +const SelectTestsetSection = ({ + testSets: propsTestsets, + selectedTestsetId, + setSelectedTestsetId, + handlePanelChange, + className, + ...props +}: SelectTestsetSectionProps) => { + const [searchTerm, setSearchTerm] = useState("") + const {testsets: fetchedTestSets} = useTestsetsData() + const testSets = useMemo(() => { + return propsTestsets && propsTestsets.length > 0 ? propsTestsets : fetchedTestSets || [] + }, [propsTestsets, fetchedTestSets]) + + const columns: ColumnsType = useMemo(() => { + return [ + { + title: "Name", + dataIndex: "name", + key: "name", + onHeaderCell: () => ({ + style: {minWidth: 180}, + }), + }, + { + title: "Date Modified", + dataIndex: "updated_at", + key: "updated_at", + onHeaderCell: () => ({ + style: {minWidth: 180}, + }), + render: (date: string) => { + return formatDay({date, outputFormat: "DD MMM YYYY | h:mm a"}) + }, + }, + { + title: "Date created", + dataIndex: "created_at", + key: "created_at", + render: (date: string) => { + return formatDay({date, outputFormat: "DD MMM YYYY | h:mm a"}) + }, + onHeaderCell: () => ({ + style: {minWidth: 180}, + }), + }, + ] + }, []) + + const filteredTestset = useMemo(() => { + let allTestsets = testSets.sort( + (a: testset, b: testset) => + dayjs(b.updated_at).valueOf() - dayjs(a.updated_at).valueOf(), + ) + if (searchTerm) { + allTestsets = testSets.filter((item: testset) => + item.name.toLowerCase().includes(searchTerm.toLowerCase()), + ) + } + return allTestsets + }, [searchTerm, testSets]) + + const selectedTestset = useMemo( + () => testSets.find((testset) => testset._id === selectedTestsetId) || null, + [selectedTestsetId, testSets], + ) + + return ( +
    +
    + setSearchTerm(e.target.value)} + /> +
    +
    { + setSelectedTestsetId(selectedRowKeys[0] as string) + handlePanelChange("evaluatorPanel") + }, + }} + className={`ph-no-capture`} + columns={columns} + dataSource={filteredTestset} + rowKey="_id" + scroll={{x: "max-content", y: 455}} + bordered + pagination={false} + locale={{ + emptyText: ( + + ), + }} + onRow={(record) => ({ + style: {cursor: "pointer"}, + onClick: () => { + if (selectedTestset?._id === record._id) { + setSelectedTestsetId("") + } else { + setSelectedTestsetId(record._id) + handlePanelChange("evaluatorPanel") + } + }, + })} + /> + + ) +} + +export default memo(SelectTestsetSection) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectVariantSection.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectVariantSection.tsx new file mode 100644 index 0000000000..6b7931d410 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/Components/SelectVariantSection.tsx @@ -0,0 +1,113 @@ +import {memo, useCallback, useMemo, useState} from "react" + +import {Input} from "antd" +import clsx from "clsx" +import {useAtomValue} from "jotai" +import dynamic from "next/dynamic" + +import {useVariants} from "@/oss/lib/hooks/useVariants" +import {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import {currentAppAtom} from "@/oss/state/app" + +import type {SelectVariantSectionProps} from "../types" + +const VariantsTable = dynamic(() => import("@/oss/components/VariantsComponents/Table"), { + ssr: false, +}) +const NoResultsFound = dynamic(() => import("@/oss/components/NoResultsFound/NoResultsFound"), { + ssr: false, +}) + +const SelectVariantSection = ({ + selectedVariantRevisionIds, + className, + setSelectedVariantRevisionIds, + handlePanelChange, + evaluationType, + variants: propsVariants, + isVariantLoading: propsVariantLoading, + ...props +}: SelectVariantSectionProps) => { + const currentApp = useAtomValue(currentAppAtom) + const {data, isLoading: fallbackLoading} = useVariants(currentApp) + const variants = useMemo(() => propsVariants || data, [propsVariants, data]) + const isVariantLoading = propsVariantLoading ?? fallbackLoading + + const [searchTerm, setSearchTerm] = useState("") + + const filteredVariant = useMemo(() => { + if (!searchTerm) return variants + return variants?.filter((item) => + item.variantName.toLowerCase().includes(searchTerm.toLowerCase()), + ) + }, [searchTerm, variants]) + + const onSelectVariant = useCallback( + (selectedRowKeys: React.Key[]) => { + const selectedId = selectedRowKeys[0] as string | undefined + if (selectedId) { + setSelectedVariantRevisionIds([selectedId]) + handlePanelChange("testsetPanel") + } else { + setSelectedVariantRevisionIds([]) + } + }, + [setSelectedVariantRevisionIds, handlePanelChange], + ) + + const onRowClick = useCallback( + (record: EnhancedVariant) => { + const _record = record as EnhancedVariant & { + children: EnhancedVariant[] + } + onSelectVariant([_record.id]) + }, + [selectedVariantRevisionIds, onSelectVariant], + ) + + return ( +
    +
    + setSearchTerm(e.target.value)} + /> +
    + { + onSelectVariant(selectedRowKeys) + }, + }} + onRow={(record) => { + return { + style: {cursor: "pointer"}, + onClick: () => onRowClick(record as EnhancedVariant), + } + }} + showActionsDropdown={false} + scroll={{x: "max-content", y: 455}} + isLoading={isVariantLoading} + variants={filteredVariant} + onRowClick={() => {}} + className="ph-no-capture" + rowKey={"id"} + locale={{ + emptyText: ( + + ), + }} + /> +
    + ) +} + +export default memo(SelectVariantSection) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/assets/TabLabel/index.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/TabLabel/index.tsx new file mode 100644 index 0000000000..d69cde5ff8 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/TabLabel/index.tsx @@ -0,0 +1,20 @@ +import {memo} from "react" + +import {CheckCircleOutlined} from "@ant-design/icons" +import {Typography} from "antd" + +import {TabLabelProps} from "./types" + +const TabLabel = ({children, tabTitle, completed}: TabLabelProps) => { + return ( +
    +
    + {tabTitle} + {completed ? : null} +
    + {completed &&
    {children}
    } +
    + ) +} + +export default memo(TabLabel) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/assets/TabLabel/types.ts b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/TabLabel/types.ts new file mode 100644 index 0000000000..8a15c81638 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/TabLabel/types.ts @@ -0,0 +1,6 @@ +import {HTMLProps} from "react" + +export interface TabLabelProps extends HTMLProps { + tabTitle: string + completed?: boolean +} diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/assets/constants.ts b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/constants.ts new file mode 100644 index 0000000000..746f5d838f --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/constants.ts @@ -0,0 +1,7 @@ +export const DEFAULT_ADVANCE_SETTINGS = { + batch_size: 10, + max_retries: 3, + retry_delay: 3, + delay_between_batches: 5, + correct_answer_column: "correct_answer", +} diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/assets/styles.ts b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/styles.ts new file mode 100644 index 0000000000..24e40c0cd4 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/assets/styles.ts @@ -0,0 +1,80 @@ +import {createUseStyles} from "react-jss" + +import {JSSTheme} from "@/oss/lib/Types" + +export const useStyles = createUseStyles((theme: JSSTheme) => ({ + modalContainer: { + height: 800, + overflowY: "hidden", + "& > div": { + height: "100%", + }, + "& .ant-modal-content": { + height: "100%", + display: "flex", + flexDirection: "column", + "& .ant-modal-body": { + overflowY: "auto", + flex: 1, + paddingTop: theme.padding, + paddingBottom: theme.padding, + }, + }, + }, + collapseContainer: { + "& .ant-collapse-header": { + alignItems: "center !important", + }, + "& .ant-collapse-content": { + maxHeight: 400, + height: "100%", + overflowY: "auto", + "& .ant-collapse-content-box": { + padding: 0, + }, + }, + }, + title: { + fontSize: theme.fontSizeHeading5, + lineHeight: theme.lineHeightHeading5, + fontWeight: theme.fontWeightMedium, + }, + subTitle: { + fontSize: theme.fontSize, + lineHeight: theme.lineHeight, + fontWeight: theme.fontWeightMedium, + }, + container: { + width: 400, + "& .ant-popover-title": { + marginBottom: theme.margin, + }, + "& .ant-popover-inner": { + padding: `${theme.paddingSM}px ${theme.padding}px`, + }, + }, + tabsContainer: { + height: "100%", + display: "flex", + "& .ant-tabs-content-holder": { + paddingLeft: theme.padding, + flex: 1, + overflow: "auto", + }, + "& .ant-tabs-tab": { + color: theme.colorTextSecondary, + "&:hover": { + backgroundColor: theme.colorInfoBg, + }, + }, + "& .ant-tabs-ink-bar": { + display: "none", + }, + "& .ant-tabs-tab-active": { + backgroundColor: theme.controlItemBgActive, + borderRight: `2px solid ${theme.colorPrimary}`, + color: theme.colorPrimary, + fontWeight: `${theme.fontWeightMedium} !important`, + }, + }, +})) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/index.tsx b/web/ee/src/components/pages/evaluations/NewEvaluation/index.tsx new file mode 100644 index 0000000000..f1e0463f54 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/index.tsx @@ -0,0 +1,551 @@ +import {useCallback, memo, useEffect, useMemo, useRef, useState} from "react" + +import {getDefaultStore} from "jotai" +import dynamic from "next/dynamic" +import {useRouter} from "next/router" + +import {message} from "@/oss/components/AppMessageContext" +import EnhancedModal from "@/oss/components/EnhancedUIs/Modal" +import {useAppId} from "@/oss/hooks/useAppId" +import useURL from "@/oss/hooks/useURL" +import {useVaultSecret} from "@/oss/hooks/useVaultSecret" +import {redirectIfNoLLMKeys} from "@/oss/lib/helpers/utils" +import useAppVariantRevisions from "@/oss/lib/hooks/useAppVariantRevisions" +import useFetchEvaluatorsData from "@/oss/lib/hooks/useFetchEvaluatorsData" +import usePreviewEvaluations from "@/oss/lib/hooks/usePreviewEvaluations" +import {extractInputKeysFromSchema} from "@/oss/lib/shared/variant/inputHelpers" +import {createEvaluation} from "@/oss/services/evaluations/api" +import {fetchTestset} from "@/oss/services/testsets/api" +import {useAppsData} from "@/oss/state/app/hooks" +import {stablePromptVariablesAtomFamily} from "@/oss/state/newPlayground/core/prompts" +import {variantFlagsAtomFamily} from "@/oss/state/newPlayground/core/variantFlags" +import {useTestsetsData} from "@/oss/state/testset" +import {appSchemaAtom, appUriInfoAtom} from "@/oss/state/variant/atoms/fetcher" + +import {buildEvaluationNavigationUrl} from "../utils" + +import {DEFAULT_ADVANCE_SETTINGS} from "./assets/constants" +import {useStyles} from "./assets/styles" +import type {LLMRunRateLimitWithCorrectAnswer, NewEvaluationModalGenericProps} from "./types" + +const NewEvaluationModalContent = dynamic(() => import("./Components/NewEvaluationModalContent"), { + ssr: false, +}) + +const NewEvaluationModal = ({ + onSuccess, + preview = false as Preview, + evaluationType, + ...props +}: NewEvaluationModalGenericProps) => { + const classes = useStyles() + const appId = useAppId() + const isAppScoped = Boolean(appId) + const {apps: availableApps = []} = useAppsData() + const [selectedAppId, setSelectedAppId] = useState(appId || "") + const appOptions = useMemo(() => { + const options = availableApps.map((app) => ({ + label: app.app_name, + value: app.app_id, + type: app.app_type ?? null, + createdAt: app.created_at ?? null, + updatedAt: app.updated_at ?? null, + })) + if (selectedAppId && !options.some((opt) => opt.value === selectedAppId)) { + options.push({ + label: selectedAppId, + value: selectedAppId, + type: null, + createdAt: null, + updatedAt: null, + }) + } + return options + }, [availableApps, selectedAppId]) + const router = useRouter() + const {baseAppURL, projectURL} = useURL() + + // Fetch evaluation data + const evaluationData = useFetchEvaluatorsData({ + preview, + queries: {is_human: evaluationType === "human"}, + appId: selectedAppId || "", + }) + + // Use useMemo to derive evaluators, evaluatorConfigs, and loading flags based on preview flag + const {evaluators, evaluatorConfigs, loadingEvaluators, loadingEvaluatorConfigs} = + useMemo(() => { + if (preview) { + return { + evaluators: evaluationData.evaluatorsSwr?.data || [], + evaluatorConfigs: [], + loadingEvaluators: evaluationData.evaluatorsSwr?.isLoading ?? false, + loadingEvaluatorConfigs: false, + } + } else { + return { + evaluators: [], + evaluatorConfigs: evaluationData.evaluatorConfigsSwr?.data || [], + loadingEvaluators: false, + loadingEvaluatorConfigs: evaluationData.evaluatorConfigsSwr?.isLoading ?? false, + } + } + }, [ + preview, + evaluationData.evaluatorsSwr?.data, + evaluationData.evaluatorsSwr?.isLoading, + evaluationData.evaluatorConfigsSwr?.data, + evaluationData.evaluatorConfigsSwr?.isLoading, + ]) + + const [submitLoading, setSubmitLoading] = useState(false) + const [selectedTestsetId, setSelectedTestsetId] = useState("") + const [selectedVariantRevisionIds, setSelectedVariantRevisionIds] = useState([]) + const [selectedEvalConfigs, setSelectedEvalConfigs] = useState([]) + const [activePanel, setActivePanel] = useState( + isAppScoped ? "variantPanel" : "appPanel", + ) + const [evaluationName, setEvaluationName] = useState("") + const [nameFocused, setNameFocused] = useState(false) + const [advanceSettings, setAdvanceSettings] = + useState(DEFAULT_ADVANCE_SETTINGS) + + useEffect(() => { + if (isAppScoped) { + setSelectedAppId(appId || "") + } else if (!props.open) { + setSelectedAppId("") + } + }, [appId, isAppScoped, props.open]) + + useEffect(() => { + if (!props.open) return + if (!isAppScoped) return + if (!selectedAppId) return + if (activePanel !== "appPanel") return + setActivePanel("variantPanel") + }, [props.open, isAppScoped, selectedAppId, activePanel]) + + const handleAppSelection = useCallback( + (value: string) => { + if (value === selectedAppId) return + setSelectedAppId(value) + setSelectedTestsetId("") + setSelectedVariantRevisionIds([]) + setSelectedEvalConfigs([]) + setEvaluationName("") + setActivePanel("variantPanel") + setAdvanceSettings(DEFAULT_ADVANCE_SETTINGS) + }, + [ + selectedAppId, + isAppScoped, + setSelectedTestsetId, + setSelectedVariantRevisionIds, + setSelectedEvalConfigs, + setEvaluationName, + setActivePanel, + setAdvanceSettings, + ], + ) + + const {variants: appVariantRevisions, isLoading: variantsLoading} = useAppVariantRevisions( + selectedAppId || null, + ) + const filteredVariants = useMemo(() => { + if (!selectedAppId) return [] + return appVariantRevisions || [] + }, [appVariantRevisions, selectedAppId]) + + const {createNewRun: createPreviewEvaluationRun} = usePreviewEvaluations({ + appId: selectedAppId || appId, + }) + const {testsets, isLoading: testsetsLoading} = useTestsetsData() + + const {secrets} = useVaultSecret() + + const handlePanelChange = useCallback((key: string | string[]) => { + setActivePanel(key as string) + }, []) + + const afterClose = useCallback(() => { + props?.afterClose?.() + setEvaluationName("") + setSelectedEvalConfigs([]) + setSelectedTestsetId("") + setSelectedVariantRevisionIds([]) + setAdvanceSettings(DEFAULT_ADVANCE_SETTINGS) + setActivePanel("appPanel") + if (!isAppScoped) { + setSelectedAppId("") + } + }, [props?.afterClose, isAppScoped]) + + // Track focus on any input within modal to avoid overriding user typing + useEffect(() => { + function handleFocusIn(e: FocusEvent) { + if ((e.target as HTMLElement).tagName === "INPUT") { + setNameFocused(true) + } + } + function handleFocusOut(e: FocusEvent) { + if ((e.target as HTMLElement).tagName === "INPUT") { + setNameFocused(false) + } + } + document.addEventListener("focusin", handleFocusIn) + document.addEventListener("focusout", handleFocusOut) + return () => { + document.removeEventListener("focusin", handleFocusIn) + document.removeEventListener("focusout", handleFocusOut) + } + }, []) + + // Memoised base (deterministic) part of generated name (without random suffix) + const generatedNameBase = useMemo(() => { + if (!selectedVariantRevisionIds.length || !selectedTestsetId) return "" + const variant = filteredVariants?.find((v) => selectedVariantRevisionIds.includes(v.id)) + const testset = testsets?.find((ts) => ts._id === selectedTestsetId) + if (!variant || !testset) return "" + return `${variant.variantName}-v${variant.revision}-${testset.name}` + }, [selectedVariantRevisionIds, selectedTestsetId, filteredVariants, testsets]) + + // Auto-generate / update evaluation name intelligently to avoid loops + const lastAutoNameRef = useRef("") + const lastBaseRef = useRef("") + const randomWordRef = useRef("") + + // Generate a short, readable random suffix (stable per modal open) + const genRandomWord = () => { + // Prefer Web Crypto for better entropy + const n = globalThis.crypto?.getRandomValues?.(new Uint32Array(1))?.[0] ?? 0 + if (n) return n.toString(36).slice(0, 5) + // Fallback to Math.random + return Math.random().toString(36).slice(2, 7) + } + + useEffect(() => { + if (!props.open) return + // New random suffix per open, and reset last suggestion trackers + randomWordRef.current = genRandomWord() + lastAutoNameRef.current = "" + lastBaseRef.current = "" + return () => { + randomWordRef.current = "" + } + }, [props.open]) + useEffect(() => { + if (!generatedNameBase) return + if (nameFocused) return // user typing + + // When base (variant/testset) changed → generate new suggestion + if (generatedNameBase !== lastBaseRef.current) { + // Ensure we have a random word for this session + if (!randomWordRef.current) randomWordRef.current = genRandomWord() + const randomWord = randomWordRef.current + const newName = `${generatedNameBase}-${randomWord}` + const shouldUpdate = !evaluationName || evaluationName === lastAutoNameRef.current + lastBaseRef.current = generatedNameBase + lastAutoNameRef.current = newName + if (shouldUpdate) { + setEvaluationName(newName) + } + return + } + + // If user cleared the field (blur) -> restore auto-name + if (!evaluationName) { + setEvaluationName(lastAutoNameRef.current) + } + }, [generatedNameBase, evaluationName, nameFocused, evaluationType]) + + const validateSubmission = useCallback(async () => { + if (!evaluationName) { + message.error("Please enter evaluation name") + return false + } + if (!selectedTestsetId) { + message.error("Please select a test set") + return false + } + if (selectedVariantRevisionIds.length === 0) { + message.error("Please select app variant") + return false + } + if (selectedEvalConfigs.length === 0) { + message.error("Please select evaluator configuration") + return false + } + if ( + !preview && + selectedEvalConfigs.some( + (id) => + evaluatorConfigs.find((config) => config.id === id)?.evaluator_key === + "auto_ai_critique", + ) && + (await redirectIfNoLLMKeys({secrets})) + ) { + message.error("LLM keys are required for AI Critique configuration") + return false + } + + // Validate variant + if (selectedVariantRevisionIds.length > 0) { + const revisions = filteredVariants?.filter((rev) => + selectedVariantRevisionIds.includes(rev.id), + ) + if (!revisions?.length) { + message.error("Please select variant") + return false + } + + const variantInputs = revisions + .map((rev) => { + const store = getDefaultStore() + const flags = store.get(variantFlagsAtomFamily({revisionId: rev.id})) as any + const spec = store.get(appSchemaAtom) as any + const routePath = store.get(appUriInfoAtom)?.routePath || "" + const schemaKeys = spec + ? extractInputKeysFromSchema(spec as any, routePath) + : [] + if (flags?.isCustom) { + // Custom workflows: strictly use schema-defined input keys + return schemaKeys + } + // Non-custom: use stable variables from saved parameters (ignore live edits) + const stableVars = store.get(stablePromptVariablesAtomFamily(rev.id)) || [] + return Array.from(new Set(stableVars)) + }) + .flat() + + const testset = await fetchTestset(selectedTestsetId) + if (!testset) { + message.error("Please select a test set") + return false + } + const testsetColumns = Object.keys(testset?.csvdata[0] || {}) + + if (!testsetColumns.length) { + message.error("Please select a correct testset which has test cases") + return false + } + + // Validate that testset contains required expected answer columns from selected evaluator configs + const missingColumnConfigs = selectedEvalConfigs + .map((configId) => evaluatorConfigs.find((config) => config.id === configId)) + .filter((config) => { + // Only check configs that have a correct_answer_key setting + if (!config?.settings_values?.correct_answer_key) return false + const expectedColumn = config.settings_values.correct_answer_key + return !testsetColumns.includes(expectedColumn) + }) + + if (missingColumnConfigs.length > 0) { + const missingColumns = missingColumnConfigs + .map((config) => config?.settings_values?.correct_answer_key) + .filter(Boolean) + .join(", ") + message.error( + `Please select a testset that has the required expected answer columns: ${missingColumns}`, + ) + return false + } + + if (variantInputs.length > 0) { + const isInputParamsAndTestsetColumnsMatch = variantInputs.every((input) => { + return testsetColumns.includes(input) + }) + if (!isInputParamsAndTestsetColumnsMatch) { + message.error( + "The testset columns do not match the selected variant input parameters", + ) + return false + } + } + } + return true + }, [ + selectedTestsetId, + selectedVariantRevisionIds, + selectedEvalConfigs, + evaluatorConfigs, + secrets, + preview, + evaluationName, + advanceSettings, + filteredVariants, + testsets, + evaluationType, + ]) + + const onSubmit = useCallback(async () => { + setSubmitLoading(true) + try { + if (!(await validateSubmission())) return + + const targetAppId = selectedAppId || appId + if (!targetAppId) { + message.error("Please select an application") + setSubmitLoading(false) + return + } + + const revisions = filteredVariants + const {correct_answer_column, ...rateLimitValues} = advanceSettings + + // Narrow evalDataSource with runtime guards for correct typing + let evalDataSource: typeof evaluatorConfigs | typeof evaluators + if (preview) { + evalDataSource = evaluators + + const selectionData = { + name: evaluationName, + revisions: revisions + ?.filter((rev) => selectedVariantRevisionIds.includes(rev.id)) + .filter(Boolean), + testset: testsets?.find((testset) => testset._id === selectedTestsetId), + evaluators: selectedEvalConfigs + .map((id) => + (evalDataSource || []).find((config) => { + return config.id === id + }), + ) + .filter(Boolean), + rate_limit: rateLimitValues, + correctAnswerColumn: correct_answer_column, + } + + if ( + !selectionData.revisions?.length || + !selectionData.testset || + !selectionData.evaluators?.length || + (evaluationType === "human" && !evaluationName) + ) { + message.error( + `Please select a test set, app variant, ${evaluationType === "human" ? "evaluation name, and" : " and"} evaluator configuration. Missing: ${ + !selectionData.revisions?.length ? "app revision" : "" + } ${!selectionData.testset ? "test set" : ""} ${ + !selectionData.evaluators?.length + ? "evaluators" + : evaluationType === "human" && !evaluationName + ? "evaluation name" + : "" + }`, + ) + setSubmitLoading(false) + return + } else { + const data = await createPreviewEvaluationRun(structuredClone(selectionData)) + + const runId = data.run.runs[0].id + const scope = isAppScoped ? "app" : "project" + const targetPath = buildEvaluationNavigationUrl({ + scope, + baseAppURL, + projectURL, + appId: targetAppId, + path: `/evaluations/single_model_test/${runId}`, + }) + + if (scope === "project") { + router.push({ + pathname: targetPath, + query: targetAppId ? {app_id: targetAppId} : undefined, + }) + } else { + router.push(targetPath) + } + } + } else { + createEvaluation(targetAppId, { + testset_id: selectedTestsetId, + revisions_ids: selectedVariantRevisionIds, + evaluators_configs: selectedEvalConfigs, + rate_limit: rateLimitValues, + correct_answer_column: correct_answer_column, + name: evaluationName, + }) + .then(onSuccess) + .catch(console.error) + .finally(() => { + // refetch + setSubmitLoading(false) + }) + } + } catch (error) { + console.error(error) + setSubmitLoading(false) + } finally { + setSubmitLoading(false) + } + + return + }, [ + appId, + selectedAppId, + selectedTestsetId, + selectedVariantRevisionIds, + selectedEvalConfigs, + advanceSettings, + evaluatorConfigs, + evaluationName, + filteredVariants, + testsets, + evaluators, + preview, + validateSubmission, + createPreviewEvaluationRun, + baseAppURL, + onSuccess, + ]) + + return ( + New {evaluationType === "auto" ? "Auto" : "Human"} Evaluation} + onOk={onSubmit} + okText="Start Evaluation" + maskClosable={false} + width={1200} + className={classes.modalContainer} + confirmLoading={submitLoading} + afterClose={afterClose} + {...props} + > + + + ) +} + +export default memo(NewEvaluationModal) diff --git a/web/ee/src/components/pages/evaluations/NewEvaluation/types.ts b/web/ee/src/components/pages/evaluations/NewEvaluation/types.ts new file mode 100644 index 0000000000..e40ebe0cc2 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/NewEvaluation/types.ts @@ -0,0 +1,92 @@ +import type {Dispatch, HTMLProps, SetStateAction} from "react" + +import {ModalProps} from "antd" + +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import {LLMRunRateLimit, Evaluator, EvaluatorConfig, testset} from "@/oss/lib/Types" + +export interface NewEvaluationAppOption { + label: string + value: string + type?: string | null + createdAt?: string | null + updatedAt?: string | null +} + +export interface LLMRunRateLimitWithCorrectAnswer extends LLMRunRateLimit { + correct_answer_column: string +} + +export interface NewEvaluationModalProps extends ModalProps { + onSuccess?: () => void + evaluationType: "auto" | "human" + preview?: boolean +} + +export interface NewEvaluationModalContentProps extends HTMLProps { + evaluationType: "auto" | "human" + activePanel: string | null + selectedTestsetId: string + selectedVariantRevisionIds: string[] + selectedEvalConfigs: string[] + evaluationName: string + preview?: boolean + isLoading?: boolean + setSelectedTestsetId: Dispatch> + onSuccess?: () => void + handlePanelChange: (key: string | string[]) => void + setSelectedVariantRevisionIds: Dispatch> + setSelectedEvalConfigs: Dispatch> + setEvaluationName: Dispatch> + isOpen?: boolean + testSets: testset[] + variants?: EnhancedVariant[] + variantsLoading?: boolean + evaluators: Evaluator[] | EvaluatorDto<"response">[] + evaluatorConfigs: EvaluatorConfig[] + advanceSettings: LLMRunRateLimitWithCorrectAnswer + setAdvanceSettings: Dispatch> + appOptions: NewEvaluationAppOption[] + selectedAppId: string + onSelectApp: (value: string) => void + appSelectionDisabled?: boolean +} + +export interface SelectVariantSectionProps extends HTMLProps { + isVariantLoading?: boolean + variants?: EnhancedVariant[] + selectedVariantRevisionIds: string[] + setSelectedVariantRevisionIds: Dispatch> + handlePanelChange: (key: string | string[]) => void + evaluationType: "auto" | "human" +} + +export interface SelectTestsetSectionProps extends HTMLProps { + testSets: testset[] + selectedTestsetId: string + setSelectedTestsetId: Dispatch> + handlePanelChange: (key: string | string[]) => void + preview?: boolean +} + +export interface SelectEvaluatorSectionProps extends HTMLProps { + evaluatorConfigs: EvaluatorConfig[] + evaluators: Evaluator[] + selectedEvalConfigs: string[] + setSelectedEvalConfigs: Dispatch> + handlePanelChange: (key: string | string[]) => void + preview?: boolean + selectedAppId?: string +} + +export interface AdvancedSettingsProps { + advanceSettings: LLMRunRateLimitWithCorrectAnswer + setAdvanceSettings: Dispatch> + preview?: boolean +} + +export interface NewEvaluationModalGenericProps + extends Omit { + preview?: Preview +} diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/AutoEvaluation.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/AutoEvaluation.tsx new file mode 100644 index 0000000000..539197f20d --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/AutoEvaluation.tsx @@ -0,0 +1,318 @@ +import {useCallback, useMemo, useState} from "react" + +import {QueryClient, QueryClientProvider} from "@tanstack/react-query" +import {message} from "antd" +import {ColumnsType} from "antd/es/table" +import {useAtom} from "jotai" +import {useRouter} from "next/router" + +import DeleteEvaluationModal from "@/oss/components/DeleteEvaluationModal/DeleteEvaluationModal" +import EnhancedTable from "@/oss/components/EnhancedUIs/Table" +import {filterColumns} from "@/oss/components/Filters/EditColumns/assets/helper" +import {getColumns} from "@/oss/components/HumanEvaluations/assets/utils" +import {EvaluationRow} from "@/oss/components/HumanEvaluations/types" +import {useAppId} from "@/oss/hooks/useAppId" +import useURL from "@/oss/hooks/useURL" +import {EvaluationType} from "@/oss/lib/enums" +import {buildRevisionsQueryParam} from "@/oss/lib/helpers/url" +import useEvaluations from "@/oss/lib/hooks/useEvaluations" +import {tempEvaluationAtom} from "@/oss/lib/hooks/usePreviewRunningEvaluations/states/runningEvalAtom" +import useRunMetricsMap from "@/oss/lib/hooks/useRunMetricsMap" +import {EvaluationStatus} from "@/oss/lib/Types" +import {useAppsData} from "@/oss/state/app" + +import {buildAppScopedUrl, buildEvaluationNavigationUrl, extractEvaluationAppId} from "../utils" + +import AutoEvaluationHeader from "./assets/AutoEvaluationHeader" + +interface AutoEvaluationProps { + viewType?: "overview" | "evaluation" + scope?: "app" | "project" +} + +const AutoEvaluation = ({viewType = "evaluation", scope = "app"}: AutoEvaluationProps) => { + const routeAppId = useAppId() + const activeAppId = scope === "app" ? routeAppId || undefined : undefined + const router = useRouter() + const {baseAppURL, projectURL} = useURL() + + const [selectedRowKeys, setSelectedRowKeys] = useState([]) + + const [selectedEvalRecord, setSelectedEvalRecord] = useState() + const [isDeleteEvalModalOpen, setIsDeleteEvalModalOpen] = useState(false) + const [isDeletingEvaluations, setIsDeletingEvaluations] = useState(false) + const [hiddenColumns, setHiddenColumns] = useState([]) + const [tempEvaluation, setTempEvaluation] = useAtom(tempEvaluationAtom) + const {apps: availableApps = []} = useAppsData() + + const { + mergedEvaluations: _mergedEvaluations, + isLoadingPreview, + isLoadingLegacy, + refetch, + handleDeleteEvaluations: deleteEvaluations, + previewEvaluations, + } = useEvaluations({ + withPreview: true, + types: [EvaluationType.automatic, EvaluationType.auto_exact_match], + evalType: "auto", + appId: activeAppId, + }) + + const previewAutoEvals = useMemo(() => { + const evals = previewEvaluations.swrData?.data?.runs || [] + if (!evals.length) return [] + + return evals?.filter((run) => + run?.data?.steps.every( + (step) => step?.type !== "annotation" || step?.origin === "auto", + ), + ) + }, [previewEvaluations]) + + const mergedEvaluations = useMemo(() => { + const mergedIds = new Set(_mergedEvaluations?.map((e) => e.id)) + const filteredTempEvals = tempEvaluation.filter((e) => !mergedIds.has(e.id)) + return [..._mergedEvaluations, ...filteredTempEvals] + }, [_mergedEvaluations, tempEvaluation]) + + const runIds = useMemo(() => { + return mergedEvaluations + .map((evaluation) => { + const candidate = "id" in evaluation ? evaluation.id : evaluation.key + return typeof candidate === "string" ? candidate.trim() : undefined + }) + .filter((id): id is string => Boolean(id && id.length > 0)) + }, [mergedEvaluations]) + const evaluatorSlugs = useMemo(() => { + const evaSlugs = new Set() + previewAutoEvals.forEach((e) => { + const key = e?.data.steps?.find((step) => step.type === "annotation")?.key + if (key) evaSlugs.add(key) + }) + return evaSlugs + }, [previewAutoEvals]) + + const {data: runMetricsMap} = useRunMetricsMap(runIds, evaluatorSlugs) + + const resolveAppId = useCallback( + (record: EvaluationRow): string | undefined => { + const candidate = extractEvaluationAppId(record) || activeAppId + return candidate + }, + [activeAppId], + ) + + const isRecordNavigable = useCallback( + (record: EvaluationRow): boolean => { + const status = record.status?.value || record.status + const evaluationId = "id" in record ? record.id : record.key + const recordAppId = resolveAppId(record) + const isActionableStatus = ![ + EvaluationStatus.PENDING, + EvaluationStatus.RUNNING, + EvaluationStatus.CANCELLED, + EvaluationStatus.INITIALIZED, + ].includes(status) + return Boolean(isActionableStatus && evaluationId && recordAppId) + }, + [resolveAppId], + ) + + const handleVariantNavigation = useCallback( + ({revisionId, appId: recordAppId}: {revisionId: string; appId?: string}) => { + const targetAppId = recordAppId || activeAppId + if (!targetAppId) { + message.warning("This application's variant is no longer available.") + return + } + + router.push({ + pathname: buildAppScopedUrl(baseAppURL, targetAppId, "/playground"), + query: { + revisions: buildRevisionsQueryParam([revisionId]), + }, + }) + }, + [activeAppId, baseAppURL, router], + ) + + const handleDelete = useCallback( + async (ids: string[]) => { + setIsDeletingEvaluations(true) + try { + await deleteEvaluations(ids) + message.success( + ids.length > 1 ? `${ids.length} Evaluations Deleted` : "Evaluation Deleted", + ) + refetch() + } catch (err) { + message.error("Failed to delete evaluations") + console.error(err) + } finally { + setTempEvaluation((prev) => + prev.length > 0 ? prev.filter((e) => !ids.includes(e?.id)) : [], + ) + setIsDeletingEvaluations(false) + setIsDeleteEvalModalOpen(false) + setSelectedRowKeys([]) + } + }, + [refetch, deleteEvaluations], + ) + + const _columns: ColumnsType = useMemo(() => { + return getColumns({ + evaluations: mergedEvaluations, + onVariantNavigation: handleVariantNavigation, + setSelectedEvalRecord, + setIsDeleteEvalModalOpen, + runMetricsMap, + evalType: "auto", + scope, + baseAppURL, + extractAppId: extractEvaluationAppId, + projectURL, + resolveAppId, + }) + }, [ + mergedEvaluations, + handleVariantNavigation, + setSelectedEvalRecord, + setIsDeleteEvalModalOpen, + runMetricsMap, + scope, + baseAppURL, + projectURL, + resolveAppId, + ]) + + const visibleColumns = useMemo( + () => filterColumns(_columns, hiddenColumns), + [_columns, hiddenColumns], + ) + + const selectedEvaluations = useMemo(() => { + return selectedEvalRecord + ? (() => { + const found = mergedEvaluations.find( + (e) => ("id" in e ? e.id : e.key) === selectedEvalRecord?.id, + ) + return found && "name" in found ? found.name : (found?.key ?? "") + })() + : mergedEvaluations + .filter((e) => selectedRowKeys.includes("id" in e ? e.id : e.key)) + .map((e) => ("name" in e ? e.name : e.id)) + .join(" | ") + }, [selectedEvalRecord, selectedRowKeys, mergedEvaluations]) + + const dataSource = useMemo(() => { + return viewType === "overview" ? mergedEvaluations.slice(0, 5) : mergedEvaluations + }, [mergedEvaluations, viewType]) + + return ( +
    + + 0 && !mergedEvaluations?.length) + } + rowSelection={ + viewType === "evaluation" + ? { + type: "checkbox", + columnWidth: 48, + selectedRowKeys, + onChange: (selectedRowKeys: React.Key[]) => { + setSelectedRowKeys(selectedRowKeys) + }, + getCheckboxProps: (record: EvaluationRow) => ({ + disabled: !isRecordNavigable(record), + }), + } + : undefined + } + className="ph-no-capture" + showHorizontalScrollBar={true} + columns={visibleColumns} + rowKey={(record: any) => ("id" in record ? record.id : record.key)} + dataSource={dataSource} + tableLayout="fixed" + virtualized + onRow={(record) => { + const navigable = isRecordNavigable(record) + const recordAppId = resolveAppId(record) + const evaluationId = "id" in record ? record.id : record.key + return { + style: { + cursor: navigable ? "pointer" : "not-allowed", + }, + className: navigable ? undefined : "cursor-not-allowed opacity-60", + onClick: () => { + if (!navigable || !recordAppId || !evaluationId) { + message.warning( + "This evaluation's application is no longer available.", + ) + return + } + + const targetPath = buildEvaluationNavigationUrl({ + scope, + baseAppURL, + projectURL, + appId: recordAppId, + path: `/evaluations/results/${evaluationId}`, + }) + + if (scope === "project") { + router.push({ + pathname: targetPath, + query: recordAppId ? {app_id: recordAppId} : undefined, + }) + } else { + router.push(targetPath) + } + }, + } + }} + /> + { + setIsDeleteEvalModalOpen(false) + setSelectedEvalRecord(undefined) + }} + onOk={async () => { + const idsToDelete = selectedEvalRecord + ? [selectedEvalRecord.id] + : selectedRowKeys.map((key) => key?.toString()) + await handleDelete(idsToDelete.filter(Boolean)) + }} + evaluationType={selectedEvaluations} + isMultiple={!selectedEvalRecord && selectedRowKeys.length > 0} + /> +
    + ) +} + +export default AutoEvaluation diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/AdvancedSettings.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/AdvancedSettings.tsx new file mode 100644 index 0000000000..3a71dcb469 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/AdvancedSettings.tsx @@ -0,0 +1,128 @@ +import {CaretRightOutlined, InfoCircleOutlined} from "@ant-design/icons" +import Editor from "@monaco-editor/react" +import { + Form, + Input, + InputNumber, + Switch, + Tooltip, + Collapse, + theme, + AutoComplete, + Select, +} from "antd" +import {createUseStyles} from "react-jss" + +import {useAppTheme} from "@/oss/components/Layout/ThemeContextProvider" +import {generatePaths} from "@/oss/lib/transformers" + +const useStyles = createUseStyles((theme: any) => ({ + label: { + display: "flex", + alignItems: "center", + gap: "0.5rem", + }, + editor: { + border: `1px solid ${theme.colorBorder}`, + borderRadius: theme.borderRadius, + overflow: "hidden", + }, +})) + +interface AdvancedSettingsProps { + settings: Record[] + selectedTestcase: { + testcase: Record | null + } +} + +const AdvancedSettings: React.FC = ({settings, selectedTestcase}) => { + const classes = useStyles() + const {appTheme} = useAppTheme() + const {token} = theme.useToken() + + return ( + ( + + )} + > + + {settings.map((field) => { + const rules = [ + {required: field.required ?? true, message: "This field is required"}, + ] + + return ( + + {field.label} + {field.description && ( + + + + )} + + } + initialValue={field.default} + rules={rules} + > + {(field.type === "string" || field.type === "regex") && + selectedTestcase.testcase ? ( + + option!.value + .toUpperCase() + .indexOf(inputValue.toUpperCase()) !== -1 + } + /> + ) : field.type === "string" || field.type === "regex" ? ( + + ) : field.type === "number" ? ( + + ) : field.type === "boolean" || field.type === "bool" ? ( + + ) : field.type === "text" ? ( + + ) : field.type === "code" ? ( + + ) : field.type === "multiple_choice" ? ( + + ) : type === "hidden" ? ( + + ) : type === "messages" ? ( + + ) : type === "number" ? ( + + ) : type === "boolean" || type === "bool" ? ( + + ) : type === "text" ? ( + + ) : type === "code" ? ( + + ) : type === "object" ? ( + + ) : null} + + )} + + {ExternalHelpInfo} + + ) +} diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/EvaluatorTestcaseModal.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/EvaluatorTestcaseModal.tsx new file mode 100644 index 0000000000..1f7904602d --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/EvaluatorTestcaseModal.tsx @@ -0,0 +1,174 @@ +import {useEffect, useMemo, useState} from "react" + +import {CloseOutlined} from "@ant-design/icons" +import {Play} from "@phosphor-icons/react" +import {Button, Divider, Input, Menu, Modal, Table, Typography} from "antd" +import {ColumnsType} from "antd/es/table" + +import {TestSet} from "@/oss/lib/Types" +import {fetchTestset} from "@/oss/services/testsets/api" + +import {useEvaluatorTestcaseModalStyles as useStyles} from "./assets/styles" +import {EvaluatorTestcaseModalProps} from "./types" + +const EvaluatorTestcaseModal = ({ + testsets, + setSelectedTestcase, + selectedTestset, + setSelectedTestset, + ...props +}: EvaluatorTestcaseModalProps) => { + const classes = useStyles() + const [isLoadingTestset, setIsLoadingTestset] = useState(false) + const [testsetCsvData, setTestsetCsvData] = useState([]) + const [selectedRowKeys, setSelectedRowKeys] = useState([]) + const [searchTerm, setSearchTerm] = useState("") + + const filteredTestset = useMemo(() => { + if (!searchTerm) return testsets + return testsets.filter((item) => item.name.toLowerCase().includes(searchTerm.toLowerCase())) + }, [searchTerm, testsets]) + + useEffect(() => { + const testsetFetcher = async () => { + try { + setIsLoadingTestset(true) + const data = await fetchTestset(selectedTestset) + if (data) { + setTestsetCsvData(data.csvdata) + } + } catch (error) { + console.error(error) + } finally { + setIsLoadingTestset(false) + } + } + + testsetFetcher() + }, [selectedTestset]) + + type TestcaseRow = Record & {id: number} + const columnDef = useMemo(() => { + const columns: ColumnsType = [] + + if (testsetCsvData.length > 0) { + const keys = Object.keys(testsetCsvData[0]) + + columns.push( + ...keys.map((key, index) => ({ + title: key, + dataIndex: key, + key: index, + width: 300, + onHeaderCell: () => ({ + style: {minWidth: 160}, + }), + render: (_: any, record: TestcaseRow) => { + return
    {record[key]}
    + }, + })), + ) + } + + return columns + }, [testsetCsvData]) + + const rowSelection = { + selectedRowKeys, + onChange: (keys: React.Key[]) => { + setSelectedRowKeys(keys) + }, + } + + const loadTestCase = () => { + const selectedTestCase = testsetCsvData.find((_, index) => index === selectedRowKeys[0]) + + if (selectedTestCase) { + setSelectedTestcase({testcase: selectedTestCase}) + props.onCancel?.({} as any) + } + } + + return ( + , + iconPosition: "end", + disabled: !selectedRowKeys.length, + onClick: loadTestCase, + loading: isLoadingTestset, + }} + className={classes.container} + title={ +
    + Load Testcase +
    + } + {...props} + > +
    +
    + setSearchTerm(e.target.value)} + /> + + + + ({ + key: testset._id, + label: testset.name, + }))} + onSelect={({key}) => { + setSelectedTestset(key) + setSelectedRowKeys([]) + }} + defaultSelectedKeys={[selectedTestset]} + className={classes.menu} + /> +
    + +
    + + Select a testcase + + +
    ({...data, id: index}) as TestcaseRow, + )} + columns={columnDef} + className="flex-1" + bordered + rowKey={"id"} + pagination={false} + scroll={{y: 500, x: "max-content"}} + onRow={(_, rowIndex) => ({ + className: "cursor-pointer", + onClick: () => { + if (rowIndex !== undefined) { + setSelectedRowKeys([rowIndex]) + } + }, + })} + /> + + + + ) +} + +export default EvaluatorTestcaseModal diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/EvaluatorVariantModal.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/EvaluatorVariantModal.tsx new file mode 100644 index 0000000000..71d67c7e1e --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/EvaluatorVariantModal.tsx @@ -0,0 +1,158 @@ +import { + useCallback, + useMemo, + useState, + useEffect, + type ComponentProps, + type Dispatch, + type SetStateAction, + type Key, +} from "react" + +import {CloseOutlined} from "@ant-design/icons" +import {Play} from "@phosphor-icons/react" +import {Button, Input, Modal, Typography} from "antd" +import {useAtomValue} from "jotai" +import {createUseStyles} from "react-jss" + +import VariantsTable from "@/oss/components/VariantsComponents/Table" +import type {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import {JSSTheme, Variant as BaseVariant} from "@/oss/lib/Types" +import {revisionMapAtom} from "@/oss/state/variant/atoms/fetcher" + +type Variant = BaseVariant & {id?: string} +type EvaluatorVariantModalProps = { + variants: Variant[] | null + setSelectedVariant: Dispatch> + selectedVariant: Variant | null +} & ComponentProps + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + title: { + fontSize: theme.fontSizeHeading4, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightStrong, + }, + container: { + "& .ant-modal-body": { + height: 600, + }, + }, + table: { + "& .ant-table-thead > tr > th": { + height: 32, + padding: "0 16px", + }, + "& .ant-table-tbody > tr > td": { + height: 48, + padding: "0 16px", + }, + }, +})) + +const EvaluatorVariantModal = ({ + variants, + setSelectedVariant, + selectedVariant, + ...props +}: EvaluatorVariantModalProps) => { + const classes = useStyles() + const [searchTerm, setSearchTerm] = useState("") + const [selectedRowKeys, setSelectedRowKeys] = useState([]) + + // Build a list of latest revisions (EnhancedVariant) for each base variant + const revisionMap = useAtomValue(revisionMapAtom) + const latestRevisions: EnhancedVariant[] = useMemo(() => { + const list: EnhancedVariant[] = [] + ;(variants || []).forEach((v) => { + const arr = revisionMap[v.variantId] || [] + if (arr && arr.length > 0) list.push(arr[0]) + }) + return list + }, [variants, revisionMap]) + + // Clear selection when modal is opened + useEffect(() => { + if (props.open) { + // Preselect currently selected variant's latest revision id + const rev = latestRevisions.find((r) => r.variantId === selectedVariant?.variantId) + setSelectedRowKeys(rev?.id ? [rev.id] : []) + } + }, [props.open, selectedVariant?.variantId, latestRevisions]) + + const filtered = useMemo(() => { + const src = latestRevisions + if (!searchTerm) return src + return (src || []).filter((item) => + (item.variantName || "").toLowerCase().includes(searchTerm.toLowerCase()), + ) + }, [searchTerm, latestRevisions]) + + const loadVariant = useCallback(() => { + const selectedRevision = filtered?.find((rev) => rev.id === selectedRowKeys[0]) + if (selectedRevision) { + // Find the base variant matching this revision and pass it back + const base = (variants || []).find((v) => v.variantId === selectedRevision.variantId) + if (base) setSelectedVariant(base) + props.onCancel?.({} as any) + } + }, [filtered, selectedRowKeys, setSelectedVariant, props, variants]) + + return ( + , + iconPosition: "end", + disabled: !selectedRowKeys.length, + onClick: loadVariant, + }} + title={ +
    + + Select variant to run + +
    + } + centered + {...props} + > +
    + setSearchTerm(e.target.value)} + placeholder="Search" + allowClear + className="w-[240px]" + /> + + setSelectedRowKeys(value), + type: "radio", + }} + isLoading={false} + onRowClick={() => {}} + // Use revision id for table and selection, so the cell renderers resolve correctly + rowKey={"id"} + // Use stable name display to avoid showing Draft tag in selection UI + showStableName + className={classes.table} + showActionsDropdown={false} + /> +
    +
    + ) +} + +export default EvaluatorVariantModal diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/Messages.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/Messages.tsx new file mode 100644 index 0000000000..273bcf6f4d --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/Messages.tsx @@ -0,0 +1,158 @@ +import {useEffect, useMemo} from "react" + +import {PlusOutlined} from "@ant-design/icons" +import {MinusCircle} from "@phosphor-icons/react" +import {Button, Form, Input} from "antd" +import isEqual from "lodash/isEqual" + +import MessageEditor from "@/oss/components/Playground/Components/ChatCommon/MessageEditor" +import EnhancedButton from "@/oss/components/Playground/assets/EnhancedButton" + +interface Message { + role: string + content: string +} + +interface MessagesProps { + value?: Message[] + onChange?: (messages: Message[]) => void +} + +const roleOptions = [ + {label: "system", value: "system"}, + {label: "user", value: "user"}, + {label: "assistant", value: "assistant"}, +] + +const normalizeMessages = (messages?: Message[] | string): Message[] => { + if (typeof messages === "string") { + return [{role: "system", content: messages}] + } + if (Array.isArray(messages)) { + return messages.filter(Boolean).map((message) => ({ + role: message.role || "user", + content: message.content || "", + })) + } + return [] +} + +export const Messages: React.FC = ({value = [], onChange}) => { + const form = Form.useFormInstance() + const normalizedValue = useMemo(() => normalizeMessages(value), [value]) + const watchedMessages = Form.useWatch("messages", form) + const currentMessages = watchedMessages ?? normalizedValue + + useEffect(() => { + const currentMessages = form.getFieldValue("messages") + if (!isEqual(currentMessages, normalizedValue)) { + form.setFieldsValue({messages: normalizedValue}) + } + }, [normalizedValue, form]) + + const updateMessages = (updater: (messages: Message[]) => Message[]) => { + const existing = normalizeMessages(form.getFieldValue("messages")) + const updated = updater(existing) + form.setFieldsValue({messages: updated}) + onChange?.(updated) + } + + return ( + + {(fields, {add, remove}) => ( + <> + {fields.map(({key, name, ...restField}, index) => { + const message = currentMessages?.[index] ?? { + role: "user", + content: "", + } + + return ( +
    +
    + + + + + + + + updateMessages((prev) => { + const next = [...prev] + next[index] = { + ...next[index], + role, + } + return next + }) + } + onChangeText={(content) => + updateMessages((prev) => { + const next = [...prev] + next[index] = { + ...next[index], + content: content || "", + } + return next + }) + } + roleOptions={roleOptions} + editorType="border" + headerRight={ + fields.length > 1 ? ( +
    + } + type="text" + onClick={() => { + remove(name) + const updated = normalizeMessages( + form.getFieldValue("messages"), + ) + onChange?.(updated) + }} + tooltipProps={{title: "Remove"}} + /> +
    + ) : undefined + } + /> +
    +
    + ) + })} + + + + + )} +
    + ) +} diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/assets/styles.ts b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/assets/styles.ts new file mode 100644 index 0000000000..5449f65bdd --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/assets/styles.ts @@ -0,0 +1,32 @@ +import {createUseStyles} from "react-jss" + +import {JSSTheme} from "@/oss/lib/Types" +export const useEvaluatorTestcaseModalStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + "& .ant-modal-body": { + height: 600, + overflowY: "auto", + }, + }, + title: { + fontSize: theme.fontSizeHeading4, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightStrong, + }, + subTitle: { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + sidebar: { + display: "flex", + flexDirection: "column", + gap: theme.padding, + width: 200, + }, + menu: { + height: 500, + overflowY: "auto", + borderInlineEnd: `0px !important`, + }, +})) diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/index.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/index.tsx new file mode 100644 index 0000000000..55eaeabc18 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/index.tsx @@ -0,0 +1,340 @@ +import {useEffect, useMemo, useState} from "react" + +import {CloseOutlined} from "@ant-design/icons" +import {ArrowLeft, CaretDoubleRight} from "@phosphor-icons/react" +import {Button, Flex, Form, Input, message, Space, Tooltip, Typography} from "antd" +import dynamic from "next/dynamic" +import {createUseStyles} from "react-jss" + +import {useAppId} from "@/oss/hooks/useAppId" +import {isDemo} from "@/oss/lib/helpers/utils" +import {Evaluator, EvaluatorConfig, JSSTheme, testset, Variant} from "@/oss/lib/Types" +import { + CreateEvaluationConfigData, + createEvaluatorConfig, + updateEvaluatorConfig, +} from "@/oss/services/evaluations/api" + +import AdvancedSettings from "./AdvancedSettings" +import {DynamicFormField} from "./DynamicFormField" + +const DebugSection: any = dynamic( + () => + import( + "@/oss/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/DebugSection" + ), +) + +interface ConfigureEvaluatorProps { + setCurrent: React.Dispatch> + handleOnCancel: () => void + onSuccess: () => void + selectedEvaluator: Evaluator + variants: Variant[] | null + testsets: testset[] | null + selectedTestcase: { + testcase: Record | null + } + setSelectedVariant: React.Dispatch> + selectedVariant: Variant | null + editMode: boolean + editEvalEditValues: EvaluatorConfig | null + setEditEvalEditValues: React.Dispatch> + setEditMode: (value: React.SetStateAction) => void + cloneConfig: boolean + setCloneConfig: React.Dispatch> + setSelectedTestcase: React.Dispatch< + React.SetStateAction<{ + testcase: Record | null + }> + > + setDebugEvaluator: React.Dispatch> + debugEvaluator: boolean + setSelectedTestset: React.Dispatch> + selectedTestset: string + appId?: string | null +} + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + headerText: { + "& .ant-typography": { + lineHeight: theme.lineHeightLG, + fontSize: theme.fontSizeHeading4, + fontWeight: theme.fontWeightStrong, + }, + }, + title: { + fontSize: theme.fontSizeLG, + fontWeight: theme.fontWeightMedium, + lineHeight: theme.lineHeightLG, + }, + formContainer: { + display: "flex", + flexDirection: "column", + gap: theme.padding, + height: "100%", + width: "100%", + maxWidth: "100%", + overflow: "hidden", + "& .ant-form-item": { + marginBottom: 0, + }, + "& .ant-form-item-label": { + paddingBottom: theme.paddingXXS, + }, + }, + formTitleText: { + fontSize: theme.fontSize, + lineHeight: theme.lineHeight, + fontWeight: theme.fontWeightMedium, + }, +})) + +const ConfigureEvaluator = ({ + setCurrent, + selectedEvaluator, + handleOnCancel, + variants, + testsets, + onSuccess, + selectedTestcase, + selectedVariant, + setSelectedVariant, + editMode, + editEvalEditValues, + setEditEvalEditValues, + setEditMode, + cloneConfig, + setCloneConfig, + setSelectedTestcase, + debugEvaluator, + setDebugEvaluator, + selectedTestset, + setSelectedTestset, + appId: appIdOverride, +}: ConfigureEvaluatorProps) => { + const routeAppId = useAppId() + const appId = appIdOverride ?? routeAppId + const classes = useStyles() + const [form] = Form.useForm() + const [submitLoading, setSubmitLoading] = useState(false) + const [traceTree, setTraceTree] = useState<{ + trace: Record | string | null + }>({ + trace: null, + }) + + const evalFields = useMemo( + () => + Object.keys(selectedEvaluator?.settings_template || {}) + .filter((key) => !!selectedEvaluator?.settings_template[key]?.type) + .map((key) => ({ + key, + ...selectedEvaluator?.settings_template[key]!, + advanced: selectedEvaluator?.settings_template[key]?.advanced || false, + })), + [selectedEvaluator], + ) + + const advancedSettingsFields = evalFields.filter((field) => field.advanced) + const basicSettingsFields = evalFields.filter((field) => !field.advanced) + + const onSubmit = (values: CreateEvaluationConfigData) => { + try { + setSubmitLoading(true) + if (!selectedEvaluator.key) throw new Error("No selected key") + const settingsValues = values.settings_values || {} + + const data = { + ...values, + evaluator_key: selectedEvaluator.key, + settings_values: settingsValues, + } + ;(editMode + ? updateEvaluatorConfig(editEvalEditValues?.id!, data) + : createEvaluatorConfig(appId, data) + ) + .then(onSuccess) + .catch(console.error) + .finally(() => setSubmitLoading(false)) + } catch (error: any) { + setSubmitLoading(false) + console.error(error) + message.error(error.message) + } + } + + useEffect(() => { + form.resetFields() + if (editMode) { + form.setFieldsValue(editEvalEditValues) + } else if (cloneConfig) { + form.setFieldValue("settings_values", editEvalEditValues?.settings_values) + } + }, [editMode, cloneConfig]) + + return ( +
    +
    + + {editMode ? ( + <> +
    + + +
    + + + + {selectedEvaluator.name} + + + + + + + + {selectedEvaluator.description} + + + +
    +
    + +
    + + + +
    +
    + + {basicSettingsFields.length ? ( + + + Parameters + + {basicSettingsFields.map((field) => ( + + ))} + + ) : ( + "" + )} + + {advancedSettingsFields.length > 0 && ( + + )} + +
    + + + + + +
    + + +
    +
    + ) +} + +export default ConfigureEvaluator diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/types.ts b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/types.ts new file mode 100644 index 0000000000..e8c600a519 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/ConfigureEvaluator/types.ts @@ -0,0 +1,14 @@ +import {Modal} from "antd" + +import {testset} from "@/oss/lib/Types" + +export type EvaluatorTestcaseModalProps = { + testsets: testset[] + setSelectedTestcase: React.Dispatch< + React.SetStateAction<{ + testcase: Record | null + }> + > + setSelectedTestset: React.Dispatch> + selectedTestset: string +} & React.ComponentProps diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/DeleteModal.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/DeleteModal.tsx new file mode 100644 index 0000000000..f9c76c2d33 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/DeleteModal.tsx @@ -0,0 +1,73 @@ +import {useState} from "react" + +import {ExclamationCircleOutlined} from "@ant-design/icons" +import {Modal, Space, theme, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {checkIfResourceValidForDeletion} from "@/oss/lib/helpers/evaluate" +import {EvaluatorConfig, JSSTheme} from "@/oss/lib/Types" +import {deleteEvaluatorConfig} from "@/oss/services/evaluations/api" + +type DeleteModalProps = { + selectedEvalConfig: EvaluatorConfig + onSuccess: () => void +} & React.ComponentProps + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + title: { + fontSize: theme.fontSizeLG, + fontWeight: theme.fontWeightStrong, + lineHeight: theme.lineHeightLG, + }, +})) + +const DeleteModal = ({selectedEvalConfig, onSuccess, ...props}: DeleteModalProps) => { + const classes = useStyles() + const { + token: {colorWarning}, + } = theme.useToken() + const [isLoading, setIsLoading] = useState(false) + + const handleDelete = async () => { + try { + if ( + !(await checkIfResourceValidForDeletion({ + resourceType: "evaluator_config", + resourceIds: [selectedEvalConfig.id], + })) + ) + return + try { + setIsLoading(true) + await deleteEvaluatorConfig(selectedEvalConfig.id) + await onSuccess() + props.onCancel?.({} as any) + } catch (error) { + console.error(error) + } + } catch (error) { + console.error(error) + } finally { + setIsLoading(false) + } + } + return ( + + + Delete evaluator + + } + centered + okText={"Delete"} + okButtonProps={{danger: true, loading: isLoading}} + onOk={handleDelete} + {...props} + > + Are you sure you want to delete this evaluator? + + ) +} + +export default DeleteModal diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/EvaluatorCard.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/EvaluatorCard.tsx new file mode 100644 index 0000000000..4b83b35d1b --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/EvaluatorCard.tsx @@ -0,0 +1,213 @@ +import {useState} from "react" + +import {MoreOutlined} from "@ant-design/icons" +import {Copy, Note, Trash} from "@phosphor-icons/react" +import {Button, Card, Dropdown, Empty, Tag, Typography} from "antd" +import {useAtom} from "jotai" +import {createUseStyles} from "react-jss" + +import {evaluatorsAtom} from "@/oss/lib/atoms/evaluation" +import {formatDay} from "@/oss/lib/helpers/dateTimeHelper" +import {Evaluator, EvaluatorConfig, JSSTheme} from "@/oss/lib/Types" + +import DeleteModal from "./DeleteModal" + +interface EvaluatorCardProps { + evaluatorConfigs: EvaluatorConfig[] + setEditMode: React.Dispatch> + setCloneConfig: React.Dispatch> + setCurrent: React.Dispatch> + setSelectedEvaluator: React.Dispatch> + setEditEvalEditValues: React.Dispatch> + onSuccess: () => void +} + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + display: "flex", + flexWrap: "wrap", + gap: theme.padding, + height: "100%", + maxHeight: 600, + overflowY: "auto", + }, + cardTitle: { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + evaluatorCard: { + width: 276, + display: "flex", + height: "fit-content", + flexDirection: "column", + transition: "all 0.025s ease-in", + cursor: "pointer", + "& > .ant-card-head": { + minHeight: 0, + padding: theme.paddingSM, + + "& .ant-card-head-title": { + fontSize: theme.fontSize, + fontWeight: theme.fontWeightMedium, + lineHeight: theme.lineHeight, + }, + }, + "& > .ant-card-body": { + padding: theme.paddingSM, + display: "flex", + flexDirection: "column", + gap: theme.marginXS, + "& div": { + display: "flex", + alignItems: "center", + justifyContent: "space-between", + }, + }, + "&:hover": { + boxShadow: theme.boxShadowTertiary, + }, + }, + centeredItem: { + display: "grid", + placeItems: "center", + width: "100%", + height: 600, + }, +})) + +const EvaluatorCard = ({ + evaluatorConfigs, + setEditMode, + setCurrent, + setSelectedEvaluator, + setEditEvalEditValues, + onSuccess, + setCloneConfig, +}: EvaluatorCardProps) => { + const classes = useStyles() + const evaluators = useAtom(evaluatorsAtom)[0] + const [openDeleteModal, setOpenDeleteModal] = useState(false) + const [selectedDelEval, setSelectedDelEval] = useState(null) + + return ( +
    + {evaluatorConfigs.length ? ( + evaluatorConfigs.map((item) => { + const evaluator = evaluators.find((e) => e.key === item.evaluator_key) + + return ( + { + const selectedEval = evaluators.find( + (e) => e.key === item.evaluator_key, + ) + if (selectedEval) { + setEditMode(true) + setSelectedEvaluator(selectedEval) + setEditEvalEditValues(item) + setCurrent(2) + } + }} + title={item.name} + extra={ + , + onClick: (e: any) => { + e.domEvent.stopPropagation() + const selectedEval = evaluators.find( + (e) => e.key === item.evaluator_key, + ) + if (selectedEval) { + setEditMode(true) + setSelectedEvaluator(selectedEval) + setEditEvalEditValues(item) + setCurrent(2) + } + }, + }, + { + key: "clone", + label: "Clone", + icon: , + onClick: (e: any) => { + e.domEvent.stopPropagation() + const selectedEval = evaluators.find( + (e) => e.key === item.evaluator_key, + ) + if (selectedEval) { + setCloneConfig(true) + setSelectedEvaluator(selectedEval) + setEditEvalEditValues(item) + setCurrent(2) + } + }, + }, + {type: "divider"}, + { + key: "delete_app", + label: "Delete", + icon: , + danger: true, + onClick: (e: any) => { + e.domEvent.stopPropagation() + setOpenDeleteModal(true) + setSelectedDelEval(item) + }, + }, + ], + }} + > +
    + ) +} + +export default EvaluatorCard diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/EvaluatorList.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/EvaluatorList.tsx new file mode 100644 index 0000000000..da70be6772 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/EvaluatorList.tsx @@ -0,0 +1,172 @@ +import {useState} from "react" + +import {MoreOutlined} from "@ant-design/icons" +import {Copy, GearSix, Note, Trash} from "@phosphor-icons/react" +import {Button, Dropdown, Table, Tag} from "antd" +import {ColumnsType} from "antd/es/table" +import {useAtom} from "jotai" + +import {evaluatorsAtom} from "@/oss/lib/atoms/evaluation" +import {Evaluator, EvaluatorConfig} from "@/oss/lib/Types" + +import DeleteModal from "./DeleteModal" + +interface EvaluatorListProps { + evaluatorConfigs: EvaluatorConfig[] + setEditMode: React.Dispatch> + setCloneConfig: React.Dispatch> + setCurrent: React.Dispatch> + setSelectedEvaluator: React.Dispatch> + setEditEvalEditValues: React.Dispatch> + onSuccess: () => void +} + +const EvaluatorList = ({ + evaluatorConfigs, + setCloneConfig, + setCurrent, + setEditEvalEditValues, + setEditMode, + setSelectedEvaluator, + onSuccess, +}: EvaluatorListProps) => { + const evaluators = useAtom(evaluatorsAtom)[0] + const [openDeleteModal, setOpenDeleteModal] = useState(false) + const [selectedDelEval, setSelectedDelEval] = useState(null) + + const columns: ColumnsType = [ + // { + // title: "Version", + // dataIndex: "version", + // key: "version", + // onHeaderCell: () => ({ + // style: {minWidth: 80}, + // }), + // }, + { + title: "Name", + dataIndex: "name", + key: "name", + render: (_, record) => { + return
    {record.name}
    + }, + }, + { + title: "Type", + dataIndex: "type", + key: "type", + render: (_, record) => { + const evaluator = evaluators.find((item) => item.key === record.evaluator_key) + return {evaluator?.name} + }, + }, + { + title: , + key: "key", + width: 56, + fixed: "right", + align: "center", + render: (_, record) => { + return ( + , + onClick: (e: any) => { + e.domEvent.stopPropagation() + const selectedEval = evaluators.find( + (e) => e.key === record.evaluator_key, + ) + if (selectedEval) { + setEditMode(true) + setSelectedEvaluator(selectedEval) + setEditEvalEditValues(record) + setCurrent(2) + } + }, + }, + { + key: "clone", + label: "Clone", + icon: , + onClick: (e: any) => { + e.domEvent.stopPropagation() + const selectedEval = evaluators.find( + (e) => e.key === record.evaluator_key, + ) + if (selectedEval) { + setCloneConfig(true) + setSelectedEvaluator(selectedEval) + setEditEvalEditValues(record) + setCurrent(2) + } + }, + }, + {type: "divider"}, + { + key: "delete_app", + label: "Delete", + icon: , + danger: true, + onClick: (e: any) => { + e.domEvent.stopPropagation() + setOpenDeleteModal(true) + setSelectedDelEval(record) + }, + }, + ], + }} + > +
    ({ + style: {cursor: "pointer"}, + onClick: () => { + const selectedEval = evaluators.find((e) => e.key === record.evaluator_key) + if (selectedEval) { + setEditMode(true) + setSelectedEvaluator(selectedEval) + setEditEvalEditValues(record) + setCurrent(2) + } + }, + })} + /> + {selectedDelEval && ( + setOpenDeleteModal(false)} + selectedEvalConfig={selectedDelEval} + onSuccess={onSuccess} + /> + )} + + ) +} + +export default EvaluatorList diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/index.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/index.tsx new file mode 100644 index 0000000000..fbb9f81837 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/Evaluators/index.tsx @@ -0,0 +1,197 @@ +import {useMemo, useState} from "react" + +import {CloseOutlined, PlusOutlined} from "@ant-design/icons" +import {Cards, Table} from "@phosphor-icons/react" +import {Button, Divider, Flex, Input, Radio, Space, Spin, Typography} from "antd" +import {useAtom} from "jotai" +import {createUseStyles} from "react-jss" + +import {evaluatorsAtom} from "@/oss/lib/atoms/evaluation" +import {getEvaluatorTags} from "@/oss/lib/helpers/evaluate" +import {Evaluator, EvaluatorConfig, JSSTheme} from "@/oss/lib/Types" + +import EvaluatorCard from "./EvaluatorCard" +import EvaluatorList from "./EvaluatorList" + +interface EvaluatorsProps { + evaluatorConfigs: EvaluatorConfig[] + handleOnCancel: () => void + setCurrent: React.Dispatch> + setSelectedEvaluator: React.Dispatch> + fetchingEvalConfigs: boolean + setEditMode: React.Dispatch> + setCloneConfig: React.Dispatch> + setEditEvalEditValues: React.Dispatch> + onSuccess: () => void + setEvaluatorsDisplay: any + evaluatorsDisplay: string +} + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + titleContainer: { + display: "flex", + alignItems: "center", + justifyContent: "space-between", + "& .ant-typography": { + fontSize: theme.fontSizeHeading4, + fontWeight: theme.fontWeightStrong, + lineHeight: theme.lineHeightLG, + }, + }, + header: { + display: "flex", + flexDirection: "column", + gap: theme.padding, + }, + radioBtnContainer: { + display: "flex", + alignItems: "center", + gap: theme.marginXS, + "& .ant-radio-button-wrapper": { + borderRadius: theme.borderRadius, + borderInlineStartWidth: "1px", + "&:before": { + width: 0, + }, + "&:not(.ant-radio-button-wrapper-checked)": { + border: "none", + "&:hover": { + backgroundColor: theme.colorBgTextHover, + }, + }, + }, + }, +})) + +const Evaluators = ({ + evaluatorConfigs, + handleOnCancel, + setCurrent, + setSelectedEvaluator, + fetchingEvalConfigs, + setEditMode, + setEditEvalEditValues, + onSuccess, + setCloneConfig, + setEvaluatorsDisplay, + evaluatorsDisplay, +}: EvaluatorsProps) => { + const classes = useStyles() + const [searchTerm, setSearchTerm] = useState("") + const evaluatorTags = getEvaluatorTags() + const evaluators = useAtom(evaluatorsAtom)[0] + const [selectedEvaluatorCategory, setSelectedEvaluatorCategory] = useState("view_all") + + const updatedEvaluatorConfigs = useMemo(() => { + return evaluatorConfigs.map((config) => { + const matchingEvaluator = evaluators.find( + (evaluator) => evaluator.key === config.evaluator_key, + ) + return matchingEvaluator ? {...config, tags: matchingEvaluator.tags} : config + }) + }, [evaluatorConfigs, evaluators]) + + const filteredEvaluators = useMemo(() => { + let filtered = updatedEvaluatorConfigs + + if (selectedEvaluatorCategory !== "view_all") { + filtered = filtered.filter((item) => item.tags?.includes(selectedEvaluatorCategory)) + } + + if (searchTerm) { + filtered = filtered.filter((item) => + item.name.toLowerCase().includes(searchTerm.toLowerCase()), + ) + } + + return filtered + }, [searchTerm, selectedEvaluatorCategory, updatedEvaluatorConfigs]) + + return ( +
    +
    +
    + Configure evaluators + + + +
    +
    +
    + setSelectedEvaluatorCategory(e.target.value)} + > + + View all + + + {evaluatorTags.map((val, idx) => ( + + {val.label} + + ))} + + + + setSearchTerm(e.target.value)} + /> + setEvaluatorsDisplay(e.target.value)} + className="shrink-0" + > + +
    + + + + + + + + + + + + + {evaluatorsDisplay === "list" ? ( + + ) : ( + + )} + + + ) +} + +export default Evaluators diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/EvaluatorsModal.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/EvaluatorsModal.tsx new file mode 100644 index 0000000000..99b6e9f1b6 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/EvaluatorsModal.tsx @@ -0,0 +1,201 @@ +// @ts-nocheck +import {memo, useEffect, useMemo, useState} from "react" + +import {ModalProps} from "antd" +import clsx from "clsx" +import {useAtom, useAtomValue} from "jotai" +import {useLocalStorage} from "usehooks-ts" + +import EnhancedModal from "@/oss/components/EnhancedUIs/Modal" +import {useAppId} from "@/oss/hooks/useAppId" +import {evaluatorConfigsAtom, evaluatorsAtom} from "@/oss/lib/atoms/evaluation" +import {groupVariantsByParent} from "@/oss/lib/helpers/variantHelper" +import useFetchEvaluatorsData from "@/oss/lib/hooks/useFetchEvaluatorsData" +import useStatelessVariants from "@/oss/lib/hooks/useStatelessVariants" +import {useVariants} from "@/oss/lib/hooks/useVariants" +import {Evaluator, EvaluatorConfig, Variant} from "@/oss/lib/Types" +import {currentAppAtom} from "@/oss/state/app" +import {useTestsetsData} from "@/oss/state/testset" + +import ConfigureEvaluator from "./ConfigureEvaluator" +import Evaluators from "./Evaluators" +import NewEvaluator from "./NewEvaluator" + +interface EvaluatorsModalProps extends ModalProps { + current: number + setCurrent: React.Dispatch> + openedFromNewEvaluation?: boolean + appId?: string | null +} + +const EvaluatorsModal = ({ + current, + setCurrent, + openedFromNewEvaluation = false, + appId: appIdOverride, + ...modalProps +}: EvaluatorsModalProps) => { + const routeAppId = useAppId() + const appId = appIdOverride ?? routeAppId + const [debugEvaluator, setDebugEvaluator] = useLocalStorage("isDebugSelectionOpen", false) + const [evaluators] = useAtom(evaluatorsAtom) + const [evaluatorConfigs] = useAtom(evaluatorConfigsAtom) + const [selectedEvaluator, setSelectedEvaluator] = useState(null) + const {refetchEvaluatorConfigs, isLoadingEvaluatorConfigs: fetchingEvalConfigs} = + useFetchEvaluatorsData({appId: appId ?? ""}) + const [selectedTestcase, setSelectedTestcase] = useState<{ + testcase: Record | null + }>({ + testcase: null, + }) + const currentApp = useAtomValue(currentAppAtom) + const [selectedVariant, setSelectedVariant] = useState(null) + const [editMode, setEditMode] = useState(false) + const [cloneConfig, setCloneConfig] = useState(false) + const [editEvalEditValues, setEditEvalEditValues] = useState(null) + const [evaluatorsDisplay, setEvaluatorsDisplay] = useLocalStorage<"card" | "list">( + "evaluator_view", + "list", + ) + const [selectedTestset, setSelectedTestset] = useState("") + const {testsets} = useTestsetsData() + + useEffect(() => { + if (testsets?.length) { + setSelectedTestset(testsets[0]._id) + } + }, [testsets]) + + const {variants: data} = useStatelessVariants() + + const variants = useMemo(() => groupVariantsByParent(data, true), [data]) + + useEffect(() => { + if (variants?.length) { + setSelectedVariant(variants[0]) + } + }, [data]) + + const steps = useMemo(() => { + return [ + { + content: ( + modalProps.onCancel?.({} as any)} + setCurrent={setCurrent} + setSelectedEvaluator={setSelectedEvaluator} + fetchingEvalConfigs={fetchingEvalConfigs} + setEditMode={setEditMode} + setEditEvalEditValues={setEditEvalEditValues} + onSuccess={refetchEvaluatorConfigs} + setCloneConfig={setCloneConfig} + setEvaluatorsDisplay={setEvaluatorsDisplay} + evaluatorsDisplay={evaluatorsDisplay} + /> + ), + }, + { + content: ( + modalProps.onCancel?.({} as any)} + setSelectedEvaluator={setSelectedEvaluator} + setEvaluatorsDisplay={setEvaluatorsDisplay} + evaluatorsDisplay={evaluatorsDisplay} + /> + ), + }, + ] + }, [ + evaluatorConfigs, + fetchingEvalConfigs, + evaluatorsDisplay, + evaluators, + modalProps.onCancel, + setCurrent, + setSelectedEvaluator, + debugEvaluator, + selectedTestcase, + selectedVariant, + selectedTestset, + editMode, + cloneConfig, + editEvalEditValues, + variants, + testsets, + ]) + + if (selectedEvaluator) { + steps.push({ + content: ( + { + modalProps.onCancel?.({} as any) + setEditMode(false) + setCloneConfig(false) + setEditEvalEditValues(null) + }} + variants={variants || []} + testsets={testsets || []} + onSuccess={() => { + refetchEvaluatorConfigs() + setEditMode(false) + if (openedFromNewEvaluation) { + modalProps.onCancel?.({} as any) + } else { + setCurrent(0) + } + }} + selectedTestcase={selectedTestcase} + selectedVariant={selectedVariant} + setSelectedVariant={setSelectedVariant} + editMode={editMode} + editEvalEditValues={editEvalEditValues} + setEditEvalEditValues={setEditEvalEditValues} + setEditMode={setEditMode} + cloneConfig={cloneConfig} + setCloneConfig={setCloneConfig} + setSelectedTestcase={setSelectedTestcase} + setDebugEvaluator={setDebugEvaluator} + debugEvaluator={debugEvaluator} + selectedTestset={selectedTestset} + setSelectedTestset={setSelectedTestset} + appId={appId} + /> + ), + }) + } + + return ( + +
    _div]:!h-full", + { + "max-w-[600px]": current === 2 && !debugEvaluator, + "max-w-[95vw]": current !== 2 || debugEvaluator, + }, + ])} + > + {steps[current]?.content} +
    +
    + ) +} + +export default memo(EvaluatorsModal) diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/NewEvaluatorCard.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/NewEvaluatorCard.tsx new file mode 100644 index 0000000000..5f40bcc974 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/NewEvaluatorCard.tsx @@ -0,0 +1,114 @@ +import {ArrowRight} from "@phosphor-icons/react" +import {Card, Empty, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {Evaluator, JSSTheme} from "@/oss/lib/Types" + +interface CreateEvaluatorCardProps { + evaluators: Evaluator[] + setSelectedEvaluator: React.Dispatch> + setCurrent: (value: React.SetStateAction) => void +} + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + display: "flex", + flexWrap: "wrap", + gap: theme.padding, + height: "100%", + maxHeight: 600, + overflowY: "auto", + }, + cardTitle: { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + evaluatorCard: { + flexDirection: "column", + width: 276, + display: "flex", + height: "fit-content", + transition: "all 0.025s ease-in", + cursor: "pointer", + position: "relative", + "& > .ant-card-head": { + minHeight: 0, + padding: theme.paddingSM, + + "& .ant-card-head-title": { + fontSize: theme.fontSize, + fontWeight: theme.fontWeightMedium, + lineHeight: theme.lineHeight, + display: "flex", + justifyContent: "space-between", + alignItems: "center", + }, + }, + "& > .ant-card-body": { + height: 122, + overflowY: "auto", + padding: theme.paddingSM, + "& .ant-typography": { + color: theme.colorTextSecondary, + }, + }, + "&:hover": { + boxShadow: theme.boxShadowTertiary, + }, + }, + arrowIcon: { + opacity: 0, + transition: "opacity 0.3s", + }, + evaluatorCardHover: { + "&:hover $arrowIcon": { + opacity: 1, + }, + }, + centeredItem: { + display: "grid", + placeItems: "center", + width: "100%", + height: 600, + }, +})) + +const CreateEvaluatorCard = ({ + evaluators, + setSelectedEvaluator, + setCurrent, +}: CreateEvaluatorCardProps) => { + const classes = useStyles() + + return ( +
    + {evaluators.length ? ( + evaluators.map((evaluator) => ( + + {evaluator.name} + + + } + onClick={() => { + setSelectedEvaluator(evaluator) + setCurrent(2) + }} + > + {evaluator.description} + + )) + ) : ( +
    + +
    + )} +
    + ) +} + +export default CreateEvaluatorCard diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/NewEvaluatorList.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/NewEvaluatorList.tsx new file mode 100644 index 0000000000..790747250a --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/NewEvaluatorList.tsx @@ -0,0 +1,85 @@ +import {ArrowRight} from "@phosphor-icons/react" +import {Table, Tag, Typography} from "antd" +import {ColumnsType} from "antd/es/table" +import {createUseStyles} from "react-jss" + +import {Evaluator, JSSTheme} from "@/oss/lib/Types" + +interface CreateEvaluatorListProps { + evaluators: Evaluator[] + setSelectedEvaluator: React.Dispatch> + setCurrent: (value: React.SetStateAction) => void +} + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + arrowIcon: { + opacity: 0, + transition: "opacity 0.3s", + }, + evaluatorCardHover: { + "&:hover $arrowIcon": { + opacity: 1, + }, + }, +})) + +const CreateEvaluatorList = ({ + evaluators, + setSelectedEvaluator, + setCurrent, +}: CreateEvaluatorListProps) => { + const classes = useStyles() + + const columns: ColumnsType = [ + { + title: "Name", + dataIndex: "key", + key: "key", + width: 200, + render: (_, record) => { + return ( +
    + {record.name} +
    + ) + }, + }, + { + title: "Description", + dataIndex: "description", + key: "description", + render: (_, record) => { + return ( +
    + + {record.description} + + + +
    + ) + }, + }, + ] + return ( +
    ({ + className: classes.evaluatorCardHover, + onClick: () => { + setSelectedEvaluator(record) + setCurrent(2) + }, + })} + pagination={false} + /> + ) +} + +export default CreateEvaluatorList diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/index.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/index.tsx new file mode 100644 index 0000000000..6a6366d809 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/EvaluatorsModal/NewEvaluator/index.tsx @@ -0,0 +1,142 @@ +import {useMemo, useState} from "react" + +import {CloseOutlined} from "@ant-design/icons" +import {ArrowLeft} from "@phosphor-icons/react" +import {Button, Divider, Flex, Input, Radio, Space, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {getEvaluatorTags} from "@/oss/lib/helpers/evaluate" +import {Evaluator, JSSTheme} from "@/oss/lib/Types" + +import NewEvaluatorList from "./NewEvaluatorList" + +interface NewEvaluatorProps { + setCurrent: React.Dispatch> + handleOnCancel: () => void + evaluators: Evaluator[] + setSelectedEvaluator: React.Dispatch> + setEvaluatorsDisplay: any + evaluatorsDisplay: string +} + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + title: { + display: "flex", + alignItems: "center", + justifyContent: "space-between", + "& .ant-typography": { + fontSize: theme.fontSizeHeading4, + fontWeight: theme.fontWeightStrong, + lineHeight: theme.lineHeightLG, + }, + }, + subTitle: { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + radioBtnContainer: { + display: "flex", + alignItems: "center", + gap: theme.marginXS, + "& .ant-radio-button-wrapper": { + borderRadius: theme.borderRadius, + borderInlineStartWidth: "1px", + "&:before": { + width: 0, + }, + "&:not(.ant-radio-button-wrapper-checked)": { + border: "none", + "&:hover": { + backgroundColor: theme.colorBgTextHover, + }, + }, + }, + }, +})) + +const NewEvaluator = ({ + evaluators, + setCurrent, + handleOnCancel, + setSelectedEvaluator, + setEvaluatorsDisplay, + evaluatorsDisplay, +}: NewEvaluatorProps) => { + const classes = useStyles() + const [searchTerm, setSearchTerm] = useState("") + const evaluatorTags = getEvaluatorTags() + const [selectedEvaluatorCategory, setSelectedEvaluatorCategory] = useState("view_all") + + const filteredEvaluators = useMemo(() => { + let filtered = evaluators + + if (selectedEvaluatorCategory !== "view_all") { + filtered = filtered.filter((item) => item.tags.includes(selectedEvaluatorCategory)) + } + + if (searchTerm) { + filtered = filtered.filter((item) => + item.name.toLowerCase().includes(searchTerm.toLowerCase()), + ) + } + + return filtered + }, [searchTerm, selectedEvaluatorCategory, evaluators]) + + return ( +
    +
    +
    + +
    +
    +
    + setSelectedEvaluatorCategory(e.target.value)} + > + View all + + {evaluatorTags.map((val, idx) => ( + + {val.label} + + ))} + + + + setSearchTerm(e.target.value)} + placeholder="Search" + allowClear + /> + +
    +
    + +
    + +
    + +
    +
    + ) +} + +export default NewEvaluator diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/Filters/SearchFilter.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/Filters/SearchFilter.tsx new file mode 100644 index 0000000000..bdbdec8ecf --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/Filters/SearchFilter.tsx @@ -0,0 +1,78 @@ +import {Input, TableColumnType, DatePicker} from "antd" +import {FilterDropdownProps} from "antd/es/table/interface" +import dayjs from "dayjs" + +import {statusMapper} from "@/oss/components/pages/evaluations/cellRenderers/cellRenderers" +import {_Evaluation, EvaluationStatus} from "@/oss/lib/Types" + +type DataIndex = keyof _Evaluation + +type CellDataType = "number" | "text" | "date" + +export function getFilterParams( + dataIndex: DataIndex, + type: CellDataType, +): TableColumnType<_Evaluation> { + const filterDropdown = ({setSelectedKeys, selectedKeys, confirm}: FilterDropdownProps) => { + return ( +
    e.stopPropagation()}> + {type === "date" ? ( + { + setSelectedKeys(dateString ? [dateString] : []) + confirm() + }} + /> + ) : ( + { + setSelectedKeys(e.target.value ? [e.target.value] : []) + confirm({closeDropdown: false}) + }} + style={{display: "block"}} + step={0.1} + type={type} + /> + )} +
    + ) + } + + const onFilter = (value: any, record: any) => { + try { + const cellValue = record[dataIndex] + + if (type === "date") { + return dayjs(cellValue).isSame(dayjs(value), "day") + } + if (dataIndex === "status") { + const statusLabel = statusMapper({} as any)(record.status.value as EvaluationStatus) + .label as EvaluationStatus + return statusLabel.toLowerCase().includes(value.toLowerCase()) + } + + if (typeof cellValue === "object" && cellValue !== null) { + if (Array.isArray(cellValue)) { + return cellValue.some((item) => + item.variantName?.toLowerCase().includes(value.toLowerCase()), + ) + } else if (cellValue.hasOwnProperty("name")) { + return cellValue.name.toString().toLowerCase().includes(value.toLowerCase()) + } else if (cellValue.hasOwnProperty("value")) { + return cellValue.value.toString().toLowerCase().includes(value.toLowerCase()) + } + } + return cellValue?.toString().toLowerCase().includes(value.toLowerCase()) + } catch (error) { + console.error(error) + } + } + + return { + filterDropdown, + onFilter, + } +} diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/assets/AutoEvaluationHeader.tsx b/web/ee/src/components/pages/evaluations/autoEvaluation/assets/AutoEvaluationHeader.tsx new file mode 100644 index 0000000000..52ce792956 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/assets/AutoEvaluationHeader.tsx @@ -0,0 +1,679 @@ +import {memo, useCallback, useMemo, useState} from "react" + +import {ArrowsLeftRight, Export, Gauge, Plus, Trash} from "@phosphor-icons/react" +import {Button, Space, Input, message, theme, Typography} from "antd" +import {ColumnsType} from "antd/es/table" +import {useAtom, useSetAtom} from "jotai" +import dynamic from "next/dynamic" +import Link from "next/link" +import {useRouter} from "next/router" + +import EditColumns from "@/oss/components/Filters/EditColumns" +import {formatColumnTitle} from "@/oss/components/Filters/EditColumns/assets/helper" +import {formatMetricValue} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover/assets/utils" +import {EvaluationRow} from "@/oss/components/HumanEvaluations/types" +import {useQueryParam} from "@/oss/hooks/useQuery" +import useURL from "@/oss/hooks/useURL" +import {snakeToCamelCaseKeys} from "@/oss/lib/helpers/casing" +import {formatDate24, formatDay} from "@/oss/lib/helpers/dateTimeHelper" +import dayjs from "@/oss/lib/helpers/dateTimeHelper/dayjs" +import {getTypedValue} from "@/oss/lib/helpers/evaluate" +import {convertToCsv, downloadCsv} from "@/oss/lib/helpers/fileManipulations" +import {variantNameWithRev} from "@/oss/lib/helpers/variantHelper" +import {searchQueryAtom} from "@/oss/lib/hooks/usePreviewEvaluations/states/queryFilterAtoms" +import {tempEvaluationAtom} from "@/oss/lib/hooks/usePreviewRunningEvaluations/states/runningEvalAtom" +import {getMetricConfig} from "@/oss/lib/metrics/utils" +import {EvaluationStatus} from "@/oss/lib/Types" +import {getAppValues} from "@/oss/state/app" + +import {statusMapper} from "../../../evaluations/cellRenderers/cellRenderers" +import {useStyles} from "../assets/styles" +import EvaluatorsModal from "../EvaluatorsModal/EvaluatorsModal" + +import {buildAppScopedUrl, buildEvaluationNavigationUrl} from "../../utils" +import {AutoEvaluationHeaderProps} from "./types" + +const isLegacyEvaluation = (evaluation: any): boolean => "aggregated_results" in evaluation + +const getEvaluationKey = (evaluation: any): string | undefined => + (evaluation?.id ?? evaluation?.key)?.toString() + +const disallowedCompareStatuses = new Set([ + EvaluationStatus.RUNNING, + EvaluationStatus.PENDING, + EvaluationStatus.CANCELLED, + EvaluationStatus.INITIALIZED, + EvaluationStatus.STARTED, +]) + +const NewEvaluationModal = dynamic(() => import("../../NewEvaluation"), { + ssr: false, +}) + +const AutoEvaluationHeader = ({ + selectedRowKeys, + evaluations, + columns, + setSelectedRowKeys, + setHiddenColumns, + setIsDeleteEvalModalOpen, + viewType, + runMetricsMap, + refetch, + scope, + baseAppURL, + projectURL, + activeAppId, + extractAppId, +}: AutoEvaluationHeaderProps) => { + const classes = useStyles() + const router = useRouter() + + const {token} = theme.useToken() + const {appURL} = useURL() + // atoms + const [searchQuery, setSearchQuery] = useAtom(searchQueryAtom) + const setTempEvaluation = useSetAtom(tempEvaluationAtom) + + // local states + const [searchTerm, setSearchTerm] = useState("") + const [newEvalModalOpen, setNewEvalModalOpen] = useState(false) + const [current, setCurrent] = useState(0) + const [isConfigEvaluatorModalOpen, setIsConfigEvaluatorModalOpen] = useQueryParam( + "configureEvaluatorModal", + "", + ) + + const onExport = useCallback(() => { + try { + const selectedKeySet = new Set(selectedRowKeys.map((key) => key?.toString())) + const exportEvals = evaluations.filter((evaluation) => { + const key = getEvaluationKey(evaluation) + return key ? selectedKeySet.has(key) : false + }) + if (!exportEvals.length) return + + const legacyEvals = exportEvals.filter((e) => "aggregated_results" in e) + const newEvals = exportEvals.filter((e) => "name" in e) + + const {currentApp} = getAppValues() + const filenameBase = + currentApp?.app_name || (scope === "project" ? "all_applications" : "evaluations") + const filename = `${filenameBase.replace(/\s+/g, "_")}_evaluation_scenarios.csv` + + const exportableEvals = [] + + if (legacyEvals.length) { + const legacyEvalsData = legacyEvals.map((item) => { + const record: Record = {} + + // 1. variant name + record.Variant = variantNameWithRev({ + variant_name: item.variants[0].variantName ?? "", + revision: item.revisions?.[0], + }) + // 2. testset name + record.Testset = item.testset?.name + + // 3. status + record.Status = statusMapper(token)( + item.status?.value as EvaluationStatus, + ).label + + // 4. aggregated results for legacy evals + item.aggregated_results?.forEach((result) => { + record[result.evaluator_config.name] = getTypedValue(result?.result) + }) + + // 5. avg latency legacy evals + record["Avg. Latency"] = getTypedValue(item?.average_latency) + + // 6. total cost for legacy evals + record["Total Cost"] = getTypedValue(item?.average_cost) + + // 7. created at + record["Created At"] = formatDate24(item?.created_at) + + return record + }) + + exportableEvals.push(...legacyEvalsData) + } + + if (newEvals.length) { + const newEvalsData = newEvals.map((item) => { + // Instead of using a plain object, use a Map to maintain insertion order + const record = new Map() + + // Add properties in the desired order + record.set("Name", item.name) + + // 1. variant name + record.set( + "Variant", + variantNameWithRev({ + variant_name: item.variants[0].variantName ?? "", + revision: item.variants[0].revision, + }), + ) + + // 2. testset name + record.set("Testset", item.testsets?.[0]?.name) + + // 3. status + record.set("Status", statusMapper(token)(item.status as EvaluationStatus).label) + + // 5. evaluator metrics + // 5. metrics (evaluator and invocation) + const metrics = runMetricsMap?.[item.id] || {} + const evaluators = item.evaluators || [] + + // First, collect all metrics and sort them + const sortedMetrics = Object.entries(metrics).sort(([a], [b]) => { + // Evaluator metrics (with dots) come first + const aIsEvaluator = a.includes(".") + const bIsEvaluator = b.includes(".") + + // If both are evaluator metrics, sort them alphabetically + if (aIsEvaluator && bIsEvaluator) { + return a.localeCompare(b) + } + + // If one is evaluator and one is not, evaluator comes first + if (aIsEvaluator) return -1 + if (bIsEvaluator) return 1 + + // Both are not evaluator metrics, sort them alphabetically + return a.localeCompare(b) + }) + + // Then process them in the sorted order + for (const [k, v] of sortedMetrics) { + if (k.includes(".")) { + // Handle evaluator metrics + const [evaluatorSlug, metricKey] = k.split(".") + const evaluator = evaluators.find((e: any) => e.slug === evaluatorSlug) + if (!evaluator) continue + + const key = `${evaluator.name}.${metricKey}` + + if (v.mean !== undefined) { + record.set(key, v.mean) + } else if (v.unique) { + const trueEntry = v?.frequency?.find((f: any) => f?.value === true) + const total = v?.count ?? 0 + const pct = total ? ((trueEntry?.count ?? 0) / total) * 100 : 0 + record.set(key, `${pct.toFixed(2)}%`) + } + } else { + // Handle invocation metrics + const key = formatColumnTitle(k) + + if (v.mean !== undefined) { + const {primary: primaryKey, label} = getMetricConfig(k) + record.set(label || key, formatMetricValue(k, v?.[primaryKey])) + } else if (v.unique) { + const trueEntry = v?.frequency?.find((f: any) => f?.value === true) + const total = v?.count ?? 0 + const pct = total ? ((trueEntry?.count ?? 0) / total) * 100 : 0 + record.set(key, `${pct.toFixed(2)}%`) + } + } + } + // 6. created by + record.set("Created By", item?.createdBy?.user?.username) + + // 7. created at + record.set("Created At", item?.createdAt) + + return Object.fromEntries(record) + }) + + exportableEvals.push(...newEvalsData) + } + + // Get all unique column keys + const columnKeys = new Set() + exportableEvals.forEach((row) => { + Object.keys(row).forEach((key) => columnKeys.add(key)) + }) + + // Build ordered columns according to the desired export order + const startColumns = ["Name", "Variant", "Testset", "Status"].filter((k) => + columnKeys.has(k), + ) + const endColumns = ["Created By", "Created At"].filter((k) => columnKeys.has(k)) + + // Evaluator metrics first (keys with a dot), sorted alphabetically for stability + const evaluatorMetricColumns = Array.from(columnKeys) + .filter((k) => k.includes(".")) + .sort((a, b) => a.localeCompare(b)) + + // Remaining metrics/columns (excluding the above), sorted alphabetically + const remainingColumns = Array.from(columnKeys) + .filter( + (k) => !startColumns.includes(k) && !endColumns.includes(k) && !k.includes("."), + ) + .sort((a, b) => a.localeCompare(b)) + + const _columns = [ + ...startColumns, + ...evaluatorMetricColumns, + ...remainingColumns, + ...endColumns, + ] + + const csvData = convertToCsv(exportableEvals, _columns) + downloadCsv(csvData, filename) + message.success("Results exported successfully!") + } catch (error) { + message.error("Failed to export evaluations") + } + }, [evaluations, selectedRowKeys, runMetricsMap, scope]) + + const onSearch = useCallback( + (text: string) => { + if (!text && !searchQuery) return + if (text === searchQuery) return + + setSearchQuery(text) + }, + [searchQuery], + ) + + const selectedEvaluations = useMemo(() => { + if (!selectedRowKeys.length) return [] + const selectedSet = new Set(selectedRowKeys.map((key) => key?.toString())) + + return evaluations.filter((evaluation: any) => { + const key = getEvaluationKey(evaluation) + return key ? selectedSet.has(key) : false + }) + }, [evaluations, selectedRowKeys]) + + const selectedAppId = useMemo(() => { + const ids = (selectedEvaluations as EvaluationRow[]) + .map((evaluation) => extractAppId(evaluation)) + .filter((id): id is string => typeof id === "string" && id.length > 0) + const uniqueIds = Array.from(new Set(ids)) + const commonId = uniqueIds.length === 1 ? uniqueIds[0] : undefined + return commonId || activeAppId + }, [selectedEvaluations, activeAppId, extractAppId]) + + const selectionType = useMemo(() => { + if (!selectedEvaluations.length) return "none" + + const hasLegacy = selectedEvaluations.some((evaluation) => isLegacyEvaluation(evaluation)) + const hasModern = selectedEvaluations.some((evaluation) => !isLegacyEvaluation(evaluation)) + + if (hasLegacy && hasModern) return "mixed" + if (hasLegacy) return "legacy" + return "modern" + }, [selectedEvaluations]) + + const legacySelections = useMemo( + () => selectedEvaluations.filter((evaluation) => isLegacyEvaluation(evaluation)), + [selectedEvaluations], + ) + + const modernSelections = useMemo( + () => selectedEvaluations.filter((evaluation) => !isLegacyEvaluation(evaluation)), + [selectedEvaluations], + ) + + const legacyCompareDisabled = useMemo(() => { + if (selectionType !== "legacy") return true + if (scope === "app" && !selectedAppId) return true + if (legacySelections.length < 2) return true + + const [first] = legacySelections + + return legacySelections.some((item: any) => { + const status = item.status?.value as EvaluationStatus + return ( + status === EvaluationStatus.STARTED || + status === EvaluationStatus.INITIALIZED || + item.testset?.id !== first?.testset?.id + ) + }) + }, [selectionType, selectedAppId, scope, legacySelections]) + + const modernCompareDisabled = useMemo(() => { + if (selectionType !== "modern") return true + if (!selectedEvaluations.length) return true + // users can compare up to 5 evals at a time + if (selectedEvaluations.length > 5) return true + if (scope === "app" && !selectedAppId) return true + if (modernSelections.length < 2) return true + + const [first] = modernSelections + const baseTestsetId = first?.testsets?.[0]?.id + if (!baseTestsetId) return true + + if (process.env.NODE_ENV !== "production") { + console.debug("[AutoEvaluationHeader] modern compare check", { + scope, + selectedAppId, + baseTestsetId, + selectionCount: modernSelections.length, + statusList: modernSelections.map((run: any) => run?.status ?? run?.status?.value), + testsetIds: modernSelections.map((run: any) => run?.testsets?.[0]?.id), + }) + } + + return modernSelections.some((run: any) => { + const status = (run?.status?.value ?? run?.status) as EvaluationStatus | string + const testsetId = run?.testsets?.[0]?.id + + return ( + !testsetId || + testsetId !== baseTestsetId || + (status && disallowedCompareStatuses.has(status)) + ) + }) + }, [selectionType, selectedEvaluations, selectedAppId, scope, modernSelections]) + + const compareDisabled = useMemo(() => { + if (selectionType === "legacy") return legacyCompareDisabled + if (selectionType === "modern") return modernCompareDisabled + return true + }, [selectionType, legacyCompareDisabled, modernCompareDisabled]) + + const handleCompare = useCallback(() => { + if (compareDisabled) return + const selectedCommonAppId = selectedAppId + if (process.env.NODE_ENV !== "production") { + console.debug("[AutoEvaluationHeader] handleCompare invoked", { + scope, + selectionType, + selectedCommonAppId, + selectedCount: selectedEvaluations.length, + selectedIds: selectedRowKeys, + }) + } + + if (selectionType === "legacy") { + const legacyIds = selectedEvaluations + .filter((evaluation) => isLegacyEvaluation(evaluation)) + .map((evaluation: any) => evaluation.id) + + if (!legacyIds.length) return + + const primaryLegacyAppId = + selectedCommonAppId || + (legacySelections[0] ? extractAppId(legacySelections[0]) : undefined) + if (scope === "app" && !primaryLegacyAppId) return + + const pathname = buildEvaluationNavigationUrl({ + scope, + baseAppURL, + projectURL, + appId: primaryLegacyAppId, + path: "/evaluations/results/compare", + }) + + router.push({ + pathname, + query: { + evaluations: legacyIds.join(","), + ...(scope === "project" && primaryLegacyAppId + ? {app_id: primaryLegacyAppId} + : {}), + }, + }) + return + } + + if (selectionType === "modern") { + const modernSelectionSet = new Set( + selectedEvaluations + .filter((evaluation) => !isLegacyEvaluation(evaluation)) + .map((evaluation: any) => evaluation.id?.toString()), + ) + const modernIds = selectedRowKeys + .map((key) => key?.toString()) + .filter((id) => (id ? modernSelectionSet.has(id) : false)) + const [baseId, ...compareIds] = modernIds + if (!baseId) return + + const baseRun = + modernSelections.find((evaluation) => getEvaluationKey(evaluation) === baseId) || + undefined + const baseAppId = baseRun ? extractAppId(baseRun) : undefined + const effectiveAppId = selectedCommonAppId || baseAppId + if (process.env.NODE_ENV !== "production") { + console.debug("[AutoEvaluationHeader] navigating to compare view", { + baseId, + compareIds, + baseAppId, + selectedCommonAppId, + effectiveAppId, + }) + } + if (scope === "app" && !effectiveAppId) return + + const pathname = buildEvaluationNavigationUrl({ + scope, + baseAppURL, + projectURL, + appId: effectiveAppId, + path: `/evaluations/results/${baseId}`, + }) + + router.push({ + pathname, + query: { + ...(compareIds.length ? {compare: compareIds} : {}), + ...(scope === "project" && effectiveAppId ? {app_id: effectiveAppId} : {}), + }, + }) + } + }, [ + compareDisabled, + selectionType, + selectedEvaluations, + router, + baseAppURL, + projectURL, + selectedRowKeys, + scope, + extractAppId, + modernSelections, + legacySelections, + selectedAppId, + ]) + + return ( +
    + {viewType === "overview" ? ( +
    +
    + + Automatic Evaluation + + {(() => { + const href = + scope === "app" + ? appURL + ? `${appURL}/evaluations?selectedEvaluation=auto_evaluation` + : undefined + : `${projectURL}/evaluations?selectedEvaluation=auto_evaluation` + if (!href) return null + return ( + + + + ) + })()} +
    + {(scope === "app" || scope === "project") && ( + + )} +
    + ) : ( + <> +
    + + {(scope === "app" && activeAppId) || scope === "project" ? ( + <> + + + + ) : null} + + {/*
    + setPagination({page: p, size: s})} + className="flex items-center xl:hidden shrink-0 [&_.ant-pagination-options]:hidden lg:[&_.ant-pagination-options]:block [&_.ant-pagination-options]:!ml-2" + /> + setPagination({page: p, size: s})} + className="hidden xl:flex xl:items-center" + /> +
    */} +
    + +
    + setSearchTerm(e.target.value)} + onClear={() => { + setSearchTerm("") + if (searchQuery) { + setSearchQuery("") + } + }} + onKeyDown={(e) => { + if (e.key === "Enter") { + onSearch(searchTerm) + } + }} + /> +
    + + + + + { + setHiddenColumns(keys) + }} + /> +
    +
    + + setIsConfigEvaluatorModalOpen("")} + current={current} + setCurrent={setCurrent} + /> + + )} + + {(scope === "app" && activeAppId) || scope === "project" ? ( + { + setNewEvalModalOpen(false) + }} + onSuccess={(res) => { + const runningEvaluations = res.data.runs || [] + setTempEvaluation((prev) => { + const existingIds = new Set([ + ...prev.map((e) => e.id), + ...evaluations.map((e) => e.id), + ]) + const newEvaluations = runningEvaluations + .filter((e) => !existingIds.has(e.id)) + .map((e) => { + const camelCase = snakeToCamelCaseKeys(e) + return { + ...camelCase, + data: {steps: [{origin: "auto", type: "annotation"}]}, + status: "running", + createdAt: formatDay({ + date: camelCase.createdAt, + outputFormat: "DD MMM YYYY | h:mm a", + }), + createdAtTimestamp: dayjs( + camelCase.createdAt, + "YYYY/MM/DD H:mm:ssAZ", + ).valueOf(), + } + }) + + return [...prev, ...newEvaluations] + }) + + refetch() + setNewEvalModalOpen(false) + }} + evaluationType="auto" + preview={false} + /> + ) : null} +
    + ) +} + +export default memo(AutoEvaluationHeader) diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/assets/styles.ts b/web/ee/src/components/pages/evaluations/autoEvaluation/assets/styles.ts new file mode 100644 index 0000000000..6da350e19a --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/assets/styles.ts @@ -0,0 +1,8 @@ +import {createUseStyles} from "react-jss" + +export const useStyles = createUseStyles(() => ({ + button: { + display: "flex", + alignItems: "center", + }, +})) diff --git a/web/ee/src/components/pages/evaluations/autoEvaluation/assets/types.ts b/web/ee/src/components/pages/evaluations/autoEvaluation/assets/types.ts new file mode 100644 index 0000000000..c98818da34 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/autoEvaluation/assets/types.ts @@ -0,0 +1,22 @@ +import {ColumnsType} from "antd/es/table" + +import {EvaluationRow} from "@/oss/components/HumanEvaluations/types" +import {BasicStats} from "@/oss/lib/metricUtils" + +export interface AutoEvaluationHeaderProps { + selectedRowKeys: React.Key[] + evaluations: EvaluationRow[] + columns: ColumnsType + setSelectedRowKeys: React.Dispatch> + setHiddenColumns: React.Dispatch> + selectedEvalRecord?: EvaluationRow + setIsDeleteEvalModalOpen: React.Dispatch> + viewType?: "overview" | "evaluation" + runMetricsMap: Record> | undefined + refetch: () => void + scope: "app" | "project" + baseAppURL: string + projectURL: string + activeAppId?: string + extractAppId: (evaluation: EvaluationRow) => string | undefined +} diff --git a/web/ee/src/components/pages/evaluations/cellRenderers/StatusRenderer.tsx b/web/ee/src/components/pages/evaluations/cellRenderers/StatusRenderer.tsx new file mode 100644 index 0000000000..7ba13179c0 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/cellRenderers/StatusRenderer.tsx @@ -0,0 +1,62 @@ +import {InfoCircleOutlined} from "@ant-design/icons" +import {theme, Tooltip, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {useDurationCounter} from "@/oss/hooks/useDurationCounter" +import {_Evaluation, EvaluationStatus, JSSTheme} from "@/oss/lib/Types" + +import {runningStatuses, statusMapper} from "./cellRenderers" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + statusCell: { + display: "flex", + alignItems: "center", + gap: "0.25rem", + height: "100%", + marginBottom: 0, + + "& > div:nth-of-type(1)": { + height: 6, + aspectRatio: 1 / 1, + borderRadius: "50%", + }, + }, + dot: { + height: 3, + aspectRatio: 1 / 1, + borderRadius: "50%", + backgroundColor: "#8c8c8c", + marginTop: 2, + }, + date: { + color: "#8c8c8c", + }, +})) + +const StatusRenderer = (record: _Evaluation) => { + const classes = useStyles() + const {token} = theme.useToken() + const value = record.status.value + const duration = useDurationCounter(record.duration || 0, runningStatuses.includes(value)) + const {label, color} = statusMapper(token)(record.status.value as EvaluationStatus) + const errorMsg = record.status.error?.message + const errorStacktrace = record.status.error?.stacktrace + + return ( + +
    + {label} + {errorMsg && ( + + + + + + )} + + {duration} + + ) +} + +export default StatusRenderer diff --git a/web/ee/src/components/pages/evaluations/cellRenderers/cellRenderers.tsx b/web/ee/src/components/pages/evaluations/cellRenderers/cellRenderers.tsx new file mode 100644 index 0000000000..594a7a416b --- /dev/null +++ b/web/ee/src/components/pages/evaluations/cellRenderers/cellRenderers.tsx @@ -0,0 +1,270 @@ +import {memo, useCallback, useEffect, useState} from "react" + +import {type ICellRendererParams} from "@ag-grid-community/core" +import { + CopyOutlined, + FullscreenExitOutlined, + FullscreenOutlined, + InfoCircleOutlined, +} from "@ant-design/icons" +import {GlobalToken, Space, Tooltip, Typography, message, theme} from "antd" +import dayjs from "dayjs" +import duration from "dayjs/plugin/duration" +import relativeTime from "dayjs/plugin/relativeTime" +import Link from "next/link" +import {createUseStyles} from "react-jss" + +import {useDurationCounter} from "@/oss/hooks/useDurationCounter" +import {getTypedValue} from "@/oss/lib/helpers/evaluate" +import { + EvaluationStatus, + EvaluatorConfig, + JSSTheme, + _Evaluation, + _EvaluationScenario, +} from "@/oss/lib/Types" +dayjs.extend(relativeTime) +dayjs.extend(duration) + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + statusCell: { + display: "flex", + alignItems: "center", + gap: "0.25rem", + height: "100%", + marginBottom: 0, + + "& > div:nth-of-type(1)": { + height: 6, + aspectRatio: 1 / 1, + borderRadius: "50%", + }, + }, + dot: { + height: 3, + aspectRatio: 1 / 1, + borderRadius: "50%", + backgroundColor: "#8c8c8c", + marginTop: 2, + }, + date: { + color: "#8c8c8c", + }, + longCell: { + height: "100%", + position: "relative", + overflow: "hidden", + textOverflow: "ellipsis", + whiteSpace: "nowrap", + "& .ant-space": { + position: "absolute", + bottom: 2, + right: 0, + height: 35, + backgroundColor: theme.colorBgContainer, + padding: "0.5rem", + borderRadius: theme.borderRadius, + border: `1px solid ${theme.colorBorder}`, + display: "none", + }, + "&:hover .ant-space": { + display: "inline-flex", + }, + }, +})) + +export function LongTextCellRenderer(params: ICellRendererParams, output?: any) { + const {value, api, node} = params + const [expanded, setExpanded] = useState( + node.rowHeight !== api.getSizesForCurrentTheme().rowHeight, + ) + const classes = useStyles() + + const onCopy = useCallback(() => { + navigator.clipboard + .writeText(value as string) + .then(() => { + message.success("Copied to clipboard") + }) + .catch(console.error) + }, [value]) + + const onExpand = useCallback(() => { + const cells = document.querySelectorAll(`[row-id='${node.id}'] .ag-cell > *`) + const cellsArr = Array.from(cells || []) + const defaultHeight = api.getSizesForCurrentTheme().rowHeight + if (!expanded) { + cellsArr.forEach((cell) => { + cell.setAttribute( + "style", + "overflow: visible; white-space: pre-wrap; text-overflow: unset;", + ) + }) + const height = Math.max(...cellsArr.map((cell) => cell.scrollHeight)) + node.setRowHeight(height <= defaultHeight ? defaultHeight * 2 : height + 10) + } else { + cellsArr.forEach((cell) => { + cell.setAttribute( + "style", + "overflow: hidden; white-space: nowrap; text-overflow: ellipsis;", + ) + }) + node.setRowHeight(defaultHeight) + } + api.onRowHeightChanged() + }, [expanded, api, node]) + + useEffect(() => { + node.addEventListener("heightChanged", () => { + setExpanded(node.rowHeight !== api.getSizesForCurrentTheme().rowHeight) + }) + }, [api, node]) + + return ( +
    + {output ? output : value} + + {expanded ? ( + + ) : ( + + )} + + +
    + ) +} + +export const ResultRenderer = memo( + ( + params: ICellRendererParams<_EvaluationScenario> & { + config: EvaluatorConfig + }, + ) => { + const result = params.data?.results.find( + (item) => item.evaluator_config === params.config.id, + )?.result + + return {getTypedValue(result)} + }, + (prev, next) => prev.value === next.value, +) + +export const runningStatuses = [EvaluationStatus.INITIALIZED, EvaluationStatus.STARTED] +export const statusMapper = (token: GlobalToken) => (status: EvaluationStatus) => { + const statusMap = { + [EvaluationStatus.PENDING]: { + label: "Pending", + color: token.colorTextSecondary, + }, + [EvaluationStatus.INCOMPLETE]: { + label: "Incomplete", + color: token.colorTextSecondary, + }, + [EvaluationStatus.INITIALIZED]: { + label: "Queued", + color: token.colorTextSecondary, + }, + [EvaluationStatus.RUNNING]: { + label: "Running", + color: token.colorTextSecondary, + }, + [EvaluationStatus.STARTED]: { + label: "Running", + color: token.colorWarning, + }, + [EvaluationStatus.FINISHED]: { + label: "Success", + color: token.colorSuccess, + }, + [EvaluationStatus.SUCCESS]: { + label: "Success", + color: token.colorSuccess, + }, + [EvaluationStatus.ERROR]: { + label: "Failure", + color: token.colorError, + }, + [EvaluationStatus.ERRORS]: { + label: "Failure", + color: token.colorError, + }, + [EvaluationStatus.FAILURE]: { + label: "Failure", + color: token.colorError, + }, + [EvaluationStatus.FINISHED_WITH_ERRORS]: { + label: "Completed with Errors", + color: token.colorWarning, + }, + [EvaluationStatus.AGGREGATION_FAILED]: { + label: "Result Aggregation Failed", + color: token.colorWarning, + }, + } + + return ( + statusMap[status] || { + label: "Unknown", + color: "purple", + } + ) +} + +export const StatusRenderer = memo( + (params: ICellRendererParams<_Evaluation>) => { + const classes = useStyles() + const {token} = theme.useToken() + const duration = useDurationCounter( + params.data?.duration || 0, + runningStatuses.includes(params.value), + ) + const {label, color} = statusMapper(token)(params.data?.status.value as EvaluationStatus) + const errorMsg = params.data?.status.error?.message + const errorStacktrace = params.data?.status.error?.stacktrace + + return ( + +
    + {label} + {errorMsg && ( + + + + + + )} + + {duration} + + ) + }, + (prev, next) => prev.value === next.value && prev.data?.duration === next.data?.duration, +) + +export const LinkCellRenderer = memo( + (params: ICellRendererParams & {href: string}) => { + const {value, href} = params + return {value} + }, + (prev, next) => prev.value === next.value && prev.href === next.href, +) + +export const DateFromNowRenderer = memo( + (params: ICellRendererParams) => { + const [date, setDate] = useState(params.value) + + useEffect(() => { + const interval = setInterval(() => { + setDate((date: any) => dayjs(date).add(1, "second").valueOf()) + }, 60000) + return () => clearInterval(interval) + }, []) + + return {dayjs(date).fromNow()} + }, + (prev, next) => prev.value === next.value, +) diff --git a/web/ee/src/components/pages/evaluations/evaluationCompare/EvaluationCompare.tsx b/web/ee/src/components/pages/evaluations/evaluationCompare/EvaluationCompare.tsx new file mode 100644 index 0000000000..653d405753 --- /dev/null +++ b/web/ee/src/components/pages/evaluations/evaluationCompare/EvaluationCompare.tsx @@ -0,0 +1,632 @@ +import {useEffect, useMemo, useState, type FC} from "react" + +import {type ColDef, type ICellRendererParams} from "@ag-grid-community/core" +import {CheckOutlined, CloseCircleOutlined, DownloadOutlined, UndoOutlined} from "@ant-design/icons" +import {Button, DropdownProps, Space, Spin, Tag, Tooltip, Typography} from "antd" +import {useAtom} from "jotai" +import uniqBy from "lodash/uniqBy" +import Link from "next/link" +import {useRouter} from "next/router" +import {createUseStyles} from "react-jss" + +import AgCustomHeader from "@/oss/components/AgCustomHeader/AgCustomHeader" +import CompareOutputDiff from "@/oss/components/CompareOutputDiff/CompareOutputDiff" +import {useAppTheme} from "@/oss/components/Layout/ThemeContextProvider" +import {useAppId} from "@/oss/hooks/useAppId" +import {useQueryParam} from "@/oss/hooks/useQuery" +import useURL from "@/oss/hooks/useURL" +import {evaluatorsAtom} from "@/oss/lib/atoms/evaluation" +import AgGridReact, {type AgGridReactType} from "@/oss/lib/helpers/agGrid" +import {getColorPairFromStr, getRandomColors} from "@/oss/lib/helpers/colors" +import {getFilterParams, getTypedValue, removeCorrectAnswerPrefix} from "@/oss/lib/helpers/evaluate" +import {escapeNewlines} from "@/oss/lib/helpers/fileManipulations" +import {formatCurrency, formatLatency} from "@/oss/lib/helpers/formatters" +import {getStringOrJson} from "@/oss/lib/helpers/utils" +import {variantNameWithRev} from "@/oss/lib/helpers/variantHelper" +import {useBreadcrumbsEffect} from "@/oss/lib/hooks/useBreadcrumbs" +import {ComparisonResultRow, EvaluatorConfig, JSSTheme, TestSet, _Evaluation} from "@/oss/lib/Types" +import {fetchAllComparisonResults} from "@/oss/services/evaluations/api" +import {getAppValues} from "@/oss/state/app" + +import {LongTextCellRenderer} from "../cellRenderers/cellRenderers" +import EvaluationErrorModal from "../EvaluationErrorProps/EvaluationErrorModal" +import EvaluationErrorText from "../EvaluationErrorProps/EvaluationErrorText" +import FilterColumns, {generateFilterItems} from "../FilterColumns/FilterColumns" +import {isValidId} from "@/oss/lib/helpers/serviceValidations" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + table: { + height: "calc(100vh - 240px)", + }, + infoRow: { + marginTop: "1rem", + margin: "0.75rem 0", + display: "flex", + alignItems: "center", + justifyContent: "space-between", + }, + tag: { + "& a": { + color: "inherit !important", + fontWeight: 600, + "&:hover": { + color: "inherit !important", + textDecoration: "underline", + }, + }, + }, + dropdownMenu: { + "&>.ant-dropdown-menu-item": { + "& .anticon-check": { + display: "none", + }, + }, + "&>.ant-dropdown-menu-item-selected": { + "&:not(:hover)": { + backgroundColor: "transparent !important", + }, + "& .anticon-check": { + display: "inline-flex !important", + }, + }, + }, +})) + +interface Props {} + +const EvaluationCompareMode: FC = () => { + const router = useRouter() + const appId = useAppId() + const classes = useStyles() + const {appTheme} = useAppTheme() + const [evaluationIdsStr = ""] = useQueryParam("evaluations") + const evaluationIdsArray = evaluationIdsStr + .split(",") + .filter((item) => !!item && isValidId(item)) + const [evalIds, setEvalIds] = useState(evaluationIdsArray) + const [hiddenVariants, setHiddenVariants] = useState([]) + const [fetching, setFetching] = useState(false) + const [scenarios, setScenarios] = useState<_Evaluation[]>([]) + const [rows, setRows] = useState([]) + const [testset, setTestset] = useState() + const [evaluators] = useAtom(evaluatorsAtom) + const [gridRef, setGridRef] = useState>() + const [isFilterColsDropdownOpen, setIsFilterColsDropdownOpen] = useState(false) + const [isDiffDropdownOpen, setIsDiffDropdownOpen] = useState(false) + const [selectedCorrectAnswer, setSelectedCorrectAnswer] = useState(["noDiffColumnIsSelected"]) + const [modalErrorMsg, setModalErrorMsg] = useState({ + message: "", + stackTrace: "", + errorType: "invoke" as "invoke" | "evaluation", + }) + const [isErrorModalOpen, setIsErrorModalOpen] = useState(false) + const {baseAppURL, projectURL} = useURL() + // breadcrumbs + const isAppScope = router.asPath.includes("/apps/") + const evaluationsHref = + isAppScope && appId + ? `${baseAppURL}/${appId}/evaluations?selectedEvaluation=auto_evaluation` + : `${projectURL}/evaluations?selectedEvaluation=auto_evaluation` + const breadcrumbKey = isAppScope && appId ? "appPage" : "projectPage" + + useBreadcrumbsEffect( + { + breadcrumbs: { + [breadcrumbKey]: { + label: "auto evaluation", + href: evaluationsHref, + }, + "eval-compare": { + label: "compare", + }, + }, + type: "append", + condition: evaluationIdsArray.length > 0, + }, + [evaluationIdsArray, evaluationsHref, breadcrumbKey], + ) + + const handleOpenChangeDiff: DropdownProps["onOpenChange"] = (nextOpen, info) => { + if (info.source === "trigger" || nextOpen) { + setIsDiffDropdownOpen(nextOpen) + } + } + + const handleOpenChangeFilterCols: DropdownProps["onOpenChange"] = (nextOpen, info) => { + if (info.source === "trigger" || nextOpen) { + setIsFilterColsDropdownOpen(nextOpen) + } + } + + const variants = useMemo(() => { + return rows[0]?.variants || [] + }, [rows]) + + const colors = useMemo(() => { + const previous = new Set() + const colors = getRandomColors() + return variants.map((v) => { + const {textColor} = getColorPairFromStr(v.evaluationId) + if (previous.has(textColor)) return colors.find((c) => !previous.has(c))! + previous.add(textColor) + return textColor + }) + }, [variants]) + + const evaluationIds = useMemo( + () => evaluationIdsStr.split(",").filter((item) => !!item && isValidId(item)), + [evaluationIdsStr], + ) + + const colDefs = useMemo(() => { + const colDefs: ColDef[] = [] + const {inputs, variants} = rows[0] || {} + + if (!rows.length || !variants.length) return [] + + inputs.forEach((input, ix) => { + colDefs.push({ + headerName: `Input: ${input.name}`, + headerComponent: (props: any) => { + return ( + + + {input.name} + Input + + + ) + }, + minWidth: 200, + flex: 1, + field: `inputs.${ix}.value` as any, + ...getFilterParams("text"), + pinned: "left", + cellRenderer: (params: any) => LongTextCellRenderer(params), + }) + }) + + Object.keys(rows[0]) + .filter((item) => item.startsWith("correctAnswer_")) + .forEach((key) => + colDefs.push({ + headerName: `${removeCorrectAnswerPrefix(key)}`, + hide: hiddenVariants.includes(`${removeCorrectAnswerPrefix(key)}`), + headerComponent: (props: any) => { + return ( + + + {removeCorrectAnswerPrefix(key)} + Ground Truth + + + ) + }, + minWidth: 280, + flex: 1, + field: key, + ...getFilterParams("text"), + cellRenderer: (params: any) => LongTextCellRenderer(params), + }), + ) + + variants.forEach((variant, vi) => { + colDefs.push({ + headerComponent: (props: any) => ( + + + Output + {variant.variantName} + + + ), + headerName: "Output", + minWidth: 300, + flex: 1, + field: `variants.${vi}.output` as any, + ...getFilterParams("text"), + hide: hiddenVariants.includes("Output"), + cellRenderer: (params: ICellRendererParams) => { + const result = params.data?.variants.find( + (item: any) => item.evaluationId === variant.evaluationId, + )?.output?.result + + if (result && result.error && result.type == "error") { + return ( + { + setModalErrorMsg({ + message: result.error?.message || "", + stackTrace: result.error?.stacktrace || "", + errorType: "invoke", + }) + setIsErrorModalOpen(true) + }} + /> + ) + } + + return ( + <> + {selectedCorrectAnswer[0] !== "noDiffColumnIsSelected" + ? LongTextCellRenderer( + params, + , + ) + : LongTextCellRenderer(params, getStringOrJson(result?.value))} + + ) + }, + valueGetter: (params) => { + return getStringOrJson( + params.data?.variants.find( + (item) => item.evaluationId === variant.evaluationId, + )?.output?.result.value, + ) + }, + }) + }) + + const confgisMap: Record< + string, + {config: EvaluatorConfig; variant: ComparisonResultRow["variants"][0]; color: string}[] + > = {} + variants.forEach((variant, vi) => { + variant.evaluatorConfigs.forEach(({evaluatorConfig: config}, ix) => { + if (!confgisMap[config.id]) confgisMap[config.id] = [] + confgisMap[config.id].push({variant, config, color: colors[vi]}) + }) + }) + + Object.entries(confgisMap).forEach(([_, configs]) => { + configs.forEach(({config, variant, color}, idx) => { + colDefs.push({ + flex: 1, + minWidth: 200, + headerComponent: (props: any) => { + const evaluator = evaluators.find( + (item) => item.key === config.evaluator_key, + ) + return ( + + + + {config.name} + {evaluator?.name} + + {variant.variantName} + + + ) + }, + headerName: config.name, + type: `evaluator_${idx}`, + field: "variants.0.evaluatorConfigs.0.result" as any, + ...getFilterParams("text"), + hide: hiddenVariants.includes(config.name), + cellRenderer: (params: ICellRendererParams) => { + const result = params.data?.variants + .find((item) => item.evaluationId === variant.evaluationId) + ?.evaluatorConfigs.find( + (item) => item.evaluatorConfig.id === config.id, + )?.result + + return result?.type === "error" && result.error ? ( + { + setModalErrorMsg({ + message: result.error?.message || "", + stackTrace: result.error?.stacktrace || "", + errorType: "evaluation", + }) + setIsErrorModalOpen(true) + }} + /> + ) : ( + {getTypedValue(result)} + ) + }, + valueGetter: (params) => { + return getTypedValue( + params.data?.variants + .find((item) => item.evaluationId === variant.evaluationId) + ?.evaluatorConfigs.find( + (item) => item.evaluatorConfig.id === config.id, + )?.result, + ) + }, + }) + }) + }) + + variants.forEach((variant, vi) => { + colDefs.push({ + headerComponent: (props: any) => ( + + + Latency + {variant.variantName} + + + ), + hide: hiddenVariants.includes("Latency"), + minWidth: 120, + headerName: "Latency", + field: `latency.${vi}` as any, + flex: 1, + valueGetter: (params) => { + const latency = params.data?.variants.find( + (item) => item.evaluationId === variant.evaluationId, + )?.output?.latency + return latency === undefined ? "-" : formatLatency(latency) + }, + ...getFilterParams("text"), + }) + }) + + variants.forEach((variant, vi) => { + colDefs.push({ + headerComponent: (props: any) => ( + + + Cost + {variant.variantName} + + + ), + field: `cost.${vi}` as any, + headerName: "Cost", + minWidth: 120, + hide: !evalIds.includes(variant.evaluationId) || hiddenVariants.includes("Cost"), + flex: 1, + valueGetter: (params) => { + const cost = params.data?.variants.find( + (item) => item.evaluationId === variant.evaluationId, + )?.output?.cost + return cost === undefined ? "-" : formatCurrency(cost) + }, + ...getFilterParams("text"), + }) + }) + + return colDefs + }, [rows, hiddenVariants, evalIds, selectedCorrectAnswer, colors, evaluators]) + + const fetcher = () => { + setFetching(true) + fetchAllComparisonResults(evaluationIds) + .then(({rows, testset, evaluations}) => { + setScenarios(evaluations) + setRows(rows) + setTestset(testset) + setTimeout(() => { + if (!gridRef) return + + const ids: string[] = + gridRef.api + .getColumns() + ?.filter((column) => column.getColDef().field?.endsWith("result")) + ?.map((item) => item.getColId()) || [] + gridRef.api.autoSizeColumns(ids, false) + setFetching(false) + }, 100) + }) + .catch(() => setFetching(false)) + } + + useEffect(() => { + if (!gridRef) return + fetcher() + }, [appId, evaluationIdsStr, gridRef]) + + const handleToggleVariantVisibility = (evalId: string) => { + if (!hiddenVariants.includes(evalId)) { + setHiddenVariants([...hiddenVariants, evalId]) + setEvalIds(evalIds.filter((val) => val !== evalId)) + } else { + setHiddenVariants(hiddenVariants.filter((item) => item !== evalId)) + if (evaluationIdsArray.includes(evalId)) { + setEvalIds([...evalIds, evalId]) + } + } + } + + const shownCols = useMemo( + () => + colDefs + .map((item) => item.headerName) + .filter((item) => item !== undefined && !hiddenVariants.includes(item)) as string[], + [colDefs, hiddenVariants], + ) + + const getDynamicHeaderName = (params: ColDef): string => { + const {headerName, field, type}: any = params + + const getVariantNameWithRev = (index: number): string => { + const scenario = scenarios[index] + const variantName = scenario?.variants[0]?.variantName ?? "" + const revision = scenario?.revisions[0] ?? "" + return variantNameWithRev({variant_name: variantName, revision}) + } + + if (headerName === "Output" || headerName === "Latency" || headerName === "Cost") { + const index = Number(field.split(".")[1]) + return `${headerName} ${getVariantNameWithRev(index)}` + } + + if (type && type.startsWith("evaluator")) { + const index = Number(type.split("_")[1]) + return `${headerName} ${getVariantNameWithRev(index)}` + } + + return headerName + } + + const onExport = (): void => { + const gridApi = gridRef?.api + if (!gridApi) return + + const {currentApp} = getAppValues() + const fileName = `${currentApp?.app_name ?? "export"}_${variants.map(({variantName}) => variantName).join("_")}.csv` + + gridApi.exportDataAsCsv({ + fileName, + processHeaderCallback: (params) => getDynamicHeaderName(params.column.getColDef()), + processCellCallback: (params) => + typeof params.value === "string" ? escapeNewlines(params.value) : params.value, + }) + } + + return ( +
    + Evaluations Comparison +
    + + + Testset: + // TODO: REPLACE WITH NEXT/LINK + + {testset?.name || ""} + + + + + Variants: +
    + {scenarios?.map((v, vi) => ( + + handleToggleVariantVisibility(v.id) + } + style={{cursor: "pointer"}} + /> + ) : ( + + handleToggleVariantVisibility(v.id) + } + style={{cursor: "pointer"}} + /> + ) + } + > + + {variantNameWithRev({ + variant_name: v.variants[0].variantName ?? "", + revision: v.revisions[0], + })} + + + ))} +
    +
    +
    +
    + + !item.headerName?.startsWith("Input")), + "headerName", + ), + )} + isOpen={isFilterColsDropdownOpen} + handleOpenChange={handleOpenChangeFilterCols} + shownCols={shownCols} + onClick={({key}) => { + handleToggleVariantVisibility(key) + setIsFilterColsDropdownOpen(true) + }} + /> + {!!rows.length && ( +
    + Apply difference with: + item.startsWith("correctAnswer_")) + .map((key) => ({ + key: key as string, + label: ( + + + <>{removeCorrectAnswerPrefix(key)} + + ), + }))} + buttonText={ + removeCorrectAnswerPrefix(selectedCorrectAnswer[0]) === + "noDiffColumnIsSelected" + ? "Select Ground Truth" + : removeCorrectAnswerPrefix(selectedCorrectAnswer[0]) + } + isOpen={isDiffDropdownOpen} + handleOpenChange={handleOpenChangeDiff} + shownCols={selectedCorrectAnswer} + onClick={({key}) => { + if (key === selectedCorrectAnswer[0]) { + setSelectedCorrectAnswer(["noDiffColumnIsSelected"]) + } else { + setSelectedCorrectAnswer([key]) + } + setIsDiffDropdownOpen(true) + }} + /> +
    + )} + + + +
    +
    + + +
    + + gridRef={setGridRef} + rowData={rows} + columnDefs={colDefs} + getRowId={(params) => params.data.rowId} + headerHeight={64} + /> +
    +
    + + +
    + ) +} + +export default EvaluationCompareMode diff --git a/web/ee/src/components/pages/evaluations/evaluationScenarios/EvaluationScenarios.tsx b/web/ee/src/components/pages/evaluations/evaluationScenarios/EvaluationScenarios.tsx new file mode 100644 index 0000000000..d627f7c75c --- /dev/null +++ b/web/ee/src/components/pages/evaluations/evaluationScenarios/EvaluationScenarios.tsx @@ -0,0 +1,474 @@ +import {type FC, useEffect, useMemo, useState} from "react" + +import {type ColDef, type ICellRendererParams} from "@ag-grid-community/core" +import {CheckOutlined, DeleteOutlined, DownloadOutlined} from "@ant-design/icons" +import {DropdownProps, Space, Spin, Tag, Tooltip, Typography} from "antd" +import {useAtom, useAtomValue} from "jotai" +import uniqBy from "lodash/uniqBy" +import {useRouter} from "next/router" +import {createUseStyles} from "react-jss" + +import AgCustomHeader from "@/oss/components/AgCustomHeader/AgCustomHeader" +import AlertPopup from "@/oss/components/AlertPopup/AlertPopup" +import CompareOutputDiff from "@/oss/components/CompareOutputDiff/CompareOutputDiff" +import {useAppTheme} from "@/oss/components/Layout/ThemeContextProvider" +import VariantDetailsWithStatus from "@/oss/components/VariantDetailsWithStatus" +import {useAppId} from "@/oss/hooks/useAppId" +import useURL from "@/oss/hooks/useURL" +import {evaluatorsAtom} from "@/oss/lib/atoms/evaluation" +import AgGridReact, {type AgGridReactType} from "@/oss/lib/helpers/agGrid" +import {formatDate} from "@/oss/lib/helpers/dateTimeHelper" +import {getFilterParams, getTypedValue} from "@/oss/lib/helpers/evaluate" +import {escapeNewlines} from "@/oss/lib/helpers/fileManipulations" +import {formatCurrency, formatLatency} from "@/oss/lib/helpers/formatters" +import {getStringOrJson} from "@/oss/lib/helpers/utils" +import {variantNameWithRev} from "@/oss/lib/helpers/variantHelper" +import {useBreadcrumbsEffect} from "@/oss/lib/hooks/useBreadcrumbs" +import {CorrectAnswer, EvaluatorConfig, JSSTheme, _EvaluationScenario} from "@/oss/lib/Types" +import {deleteEvaluations} from "@/oss/services/evaluations/api" +import {fetchAllEvaluators} from "@/oss/services/evaluators" +import {currentAppAtom} from "@/oss/state/app" + +import {LongTextCellRenderer, ResultRenderer} from "../cellRenderers/cellRenderers" +import EvaluationErrorModal from "../EvaluationErrorProps/EvaluationErrorModal" +import EvaluationErrorText from "../EvaluationErrorProps/EvaluationErrorText" +import FilterColumns, {generateFilterItems} from "../FilterColumns/FilterColumns" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + infoRow: { + marginTop: "1rem", + margin: "0.75rem 0", + display: "flex", + alignItems: "center", + justifyContent: "space-between", + }, + date: { + fontSize: "0.75rem", + color: "#8c8c8c", + display: "inline-block", + }, + table: { + height: "calc(100vh - 240px)", + }, +})) + +interface Props { + scenarios: _EvaluationScenario[] +} + +const EvaluationScenarios: FC = ({scenarios: _scenarios}) => { + const router = useRouter() + const appId = useAppId() + const currentApp = useAtomValue(currentAppAtom) + const classes = useStyles() + const {appTheme} = useAppTheme() + const evaluationId = router.query.evaluation_id as string + const [scenarios, setScenarios] = useState<_EvaluationScenario[]>([]) + const [fetching, setFetching] = useState(false) + const [evaluators, setEvaluators] = useAtom(evaluatorsAtom) + const [gridRef, setGridRef] = useState>() + const evalaution = scenarios?.[0]?.evaluation + const [selectedCorrectAnswer, setSelectedCorrectAnswer] = useState(["noDiffColumnIsSelected"]) + const [isFilterColsDropdownOpen, setIsFilterColsDropdownOpen] = useState(false) + const [isDiffDropdownOpen, setIsDiffDropdownOpen] = useState(false) + const [hiddenCols, setHiddenCols] = useState([]) + const {baseAppURL, projectURL} = useURL() + + // breadcrumbs + useBreadcrumbsEffect( + { + breadcrumbs: { + appPage: { + label: "auto evaluation", + href: `${baseAppURL}/${appId}/evaluations?selectedEvaluation=auto_evaluation`, + }, + "eval-detail": { + label: evaluationId, + value: evaluationId, + }, + }, + type: "append", + condition: !!evaluationId, + }, + [evaluationId, baseAppURL], + ) + + const handleOpenChangeFilterCols: DropdownProps["onOpenChange"] = (nextOpen, info) => { + if (info.source === "trigger" || nextOpen) { + setIsFilterColsDropdownOpen(nextOpen) + } + } + + const handleOpenChangeDiff: DropdownProps["onOpenChange"] = (nextOpen, info) => { + if (info.source === "trigger" || nextOpen) { + setIsDiffDropdownOpen(nextOpen) + } + } + + const uniqueCorrectAnswers: CorrectAnswer[] = uniqBy( + scenarios?.[0]?.correct_answers || [], + "key", + ) + const [modalErrorMsg, setModalErrorMsg] = useState({ + message: "", + stackTrace: "", + errorType: "evaluation" as "invoke" | "evaluation", + }) + const [isErrorModalOpen, setIsErrorModalOpen] = useState(false) + + const colDefs = useMemo(() => { + const colDefs: ColDef<_EvaluationScenario>[] = [] + if (!scenarios.length || !evalaution) return colDefs + + scenarios?.[0]?.inputs?.forEach((input, index) => { + colDefs.push({ + flex: 1, + minWidth: 240, + headerName: `Input: ${input.name}`, + hide: hiddenCols.includes(`Input: ${input.name}`), + headerComponent: (props: any) => { + return ( + + + {input.name} + Input + + + ) + }, + ...getFilterParams(input.type === "number" ? "number" : "text"), + field: `inputs.${index}`, + valueGetter: (params) => { + return getTypedValue(params.data?.inputs[index]) + }, + cellRenderer: (params: any) => LongTextCellRenderer(params), + }) + }) + + uniqueCorrectAnswers.forEach((answer: CorrectAnswer, index: number) => { + colDefs.push({ + headerName: answer.key, + hide: hiddenCols.includes(answer.key), + headerComponent: (props: any) => { + return ( + + + {answer.key} + Ground Truth + + + ) + }, + minWidth: 200, + flex: 1, + ...getFilterParams("text"), + valueGetter: (params) => params.data?.correct_answers?.[index]?.value || "", + cellRenderer: (params: any) => LongTextCellRenderer(params), + }) + }) + + const evalVariants = evalaution?.variants || [] + + evalVariants.forEach((_, index) => { + colDefs.push({ + flex: 1, + minWidth: 300, + headerName: "Output", + hide: hiddenCols.includes("Output"), + ...getFilterParams("text"), + field: `outputs.0`, + cellRenderer: (params: ICellRendererParams<_EvaluationScenario>) => { + const correctAnswer = params?.data?.correct_answers?.find( + (item: any) => item.key === selectedCorrectAnswer[0], + ) + const result = params.data?.outputs[index].result + + if (result && result.error && result.type == "error") { + return ( + { + setModalErrorMsg({ + message: result.error?.message || "", + stackTrace: result.error?.stacktrace || "", + errorType: "evaluation", + }) + setIsErrorModalOpen(true) + }} + /> + ) + } + return selectedCorrectAnswer[0] !== "noDiffColumnIsSelected" + ? LongTextCellRenderer( + params, + , + ) + : LongTextCellRenderer(params) + }, + valueGetter: (params: any) => { + const result = params.data?.outputs[index].result.value + return getStringOrJson(result) + }, + }) + }) + + const evaluatorConfigs = scenarios?.[0]?.evaluators_configs || [] + + evaluatorConfigs.forEach((config, index) => { + colDefs.push({ + headerName: config?.name, + hide: hiddenCols.includes(config.name), + headerComponent: (props: any) => { + const evaluator = evaluators.find((item) => item.key === config?.evaluator_key)! + return ( + + + {config.name} + {evaluator?.name} + + + ) + }, + autoHeaderHeight: true, + field: `results`, + ...getFilterParams("text"), + cellRenderer: ( + params: ICellRendererParams<_EvaluationScenario> & { + config: EvaluatorConfig + }, + ) => { + const result = params.data?.results.find( + (item) => item.evaluator_config === params.config.id, + )?.result + + return result?.type === "error" && result.error ? ( + { + setModalErrorMsg({ + message: result.error?.message || "", + stackTrace: result.error?.stacktrace || "", + errorType: "evaluation", + }) + setIsErrorModalOpen(true) + }} + /> + ) : ( + + ) + }, + cellRendererParams: { + config, + }, + valueGetter: (params) => { + return params.data?.results[index].result.value + }, + }) + }) + colDefs.push({ + flex: 1, + minWidth: 120, + headerName: "Cost", + hide: hiddenCols.includes("Cost"), + ...getFilterParams("text"), + valueGetter: (params) => { + return params.data?.outputs[0].cost == undefined + ? "-" + : formatCurrency(params.data.outputs[0].cost) + }, + }) + + colDefs.push({ + flex: 1, + minWidth: 120, + headerName: "Latency", + hide: hiddenCols.includes("Latency"), + ...getFilterParams("text"), + valueGetter: (params) => { + return params.data?.outputs[0].latency == undefined + ? "-" + : formatLatency(params.data.outputs[0].latency) + }, + }) + return colDefs + }, [evalaution, scenarios, selectedCorrectAnswer, hiddenCols, evaluators, uniqueCorrectAnswers]) + + const shownCols = useMemo( + () => + colDefs + .map((item) => item.headerName) + .filter((item) => item !== undefined && !hiddenCols.includes(item)) as string[], + [colDefs, hiddenCols], + ) + + const onToggleEvaluatorVisibility = (evalConfigId: string) => { + if (!hiddenCols.includes(evalConfigId)) { + setHiddenCols([...hiddenCols, evalConfigId]) + } else { + setHiddenCols(hiddenCols.filter((item) => item !== evalConfigId)) + } + } + + const fetcher = () => { + setFetching(true) + Promise.all([evaluators.length ? Promise.resolve(evaluators) : fetchAllEvaluators()]) + .then(([evaluators]) => { + setScenarios(_scenarios) + setEvaluators(evaluators) + setTimeout(() => { + if (!gridRef) return + + const ids: string[] = + gridRef.api + .getColumns() + ?.filter((column) => column.getColDef().field === "results") + ?.map((item) => item.getColId()) || [] + gridRef.api.autoSizeColumns(ids, false) + setFetching(false) + }, 100) + }) + .catch(console.error) + .finally(() => setFetching(false)) + } + + useEffect(() => { + if (!gridRef) return + fetcher() + }, [appId, gridRef, evaluationId]) + + const onExport = () => { + if (!gridRef) return + gridRef.api.exportDataAsCsv({ + fileName: `${currentApp?.app_name}_${evalaution.variants[0].variantName}.csv`, + processHeaderCallback: (params) => { + if (params.column.getColDef().headerName === "Output") { + return `Output ${variantNameWithRev({ + variant_name: evalaution?.variants[0].variantName ?? "", + revision: evalaution.revisions[0], + })}` + } + return params.column.getColDef().headerName as string + }, + processCellCallback: (params) => + typeof params.value === "string" ? escapeNewlines(params.value) : params.value, + }) + } + + const onDelete = () => { + AlertPopup({ + title: "Delete Evaluation", + message: "Are you sure you want to delete this evaluation?", + onOk: () => + deleteEvaluations([evaluationId]) + .then(() => router.push(`${baseAppURL}/${appId}/evaluations`)) + .catch(console.error), + }) + } + + return ( +
    + Evaluation Results +
    + + + {formatDate(evalaution?.created_at)} + + + Testset: + // TODO: REPLACE WITH NEXT/LINK + + {evalaution?.testset.name || ""} + + + + Variant: + + + + + + + { + onToggleEvaluatorVisibility(key) + setIsFilterColsDropdownOpen(true) + }} + /> + {!!scenarios.length && !!scenarios[0].correct_answers?.length && ( +
    + Apply difference with: + ({ + key: answer.key as string, + label: ( + + + <>{answer.key} + + ), + }))} + buttonText={ + selectedCorrectAnswer[0] === "noDiffColumnIsSelected" + ? "Select Ground Truth" + : selectedCorrectAnswer[0] + } + isOpen={isDiffDropdownOpen} + handleOpenChange={handleOpenChangeDiff} + shownCols={selectedCorrectAnswer} + onClick={({key}) => { + if (key === selectedCorrectAnswer[0]) { + setSelectedCorrectAnswer(["noDiffColumnIsSelected"]) + } else { + setSelectedCorrectAnswer([key]) + } + setIsDiffDropdownOpen(true) + }} + /> +
    + )} + + + + + + +
    +
    + + +
    + + gridRef={setGridRef} + rowData={scenarios} + columnDefs={colDefs} + getRowId={(params) => params.data.id} + /> +
    +
    + + +
    + ) +} + +export default EvaluationScenarios diff --git a/web/ee/src/components/pages/evaluations/utils.ts b/web/ee/src/components/pages/evaluations/utils.ts new file mode 100644 index 0000000000..af9d1c59ac --- /dev/null +++ b/web/ee/src/components/pages/evaluations/utils.ts @@ -0,0 +1,185 @@ +import {EvaluationRow} from "@/oss/components/HumanEvaluations/types" + +type Nullable = T | null | undefined + +const parseInvocationMetadata = ( + evaluation: EvaluationRow, +): { + appId?: string + appName?: string + revisionId?: string + variantName?: string + revisionLabel?: string | number +} | null => { + const dataSteps: any[] = (evaluation as any)?.data?.steps || [] + const invocationStep = dataSteps.find((step) => step?.type === "invocation") + if (!invocationStep) return null + + const references = invocationStep.references ?? invocationStep ?? {} + const applicationRevision = + references.applicationRevision || references.application_revision || references.revision + const applicationRef = + references.application || + applicationRevision?.application || + references.applicationRef || + references.application_ref + const variantRef = references.variant || references.variantRef || references.variant_ref + + const rawAppId = + applicationRef?.id || + applicationRef?.app_id || + applicationRef?.appId || + references.application?.id || + references.application?.app_id || + applicationRevision?.application_id || + applicationRevision?.applicationId + + const rawAppName = + applicationRef?.name || + applicationRef?.slug || + references.application?.name || + references.application?.slug + + const rawVariantName = + variantRef?.name || + variantRef?.slug || + variantRef?.variantName || + variantRef?.variant_name || + applicationRef?.name || + applicationRef?.slug || + references.application?.name || + references.application?.slug || + invocationStep.key + + const rawRevisionId = + variantRef?.id || + variantRef?.revisionId || + variantRef?.revision_id || + applicationRevision?.id || + applicationRevision?.revisionId || + applicationRevision?.revision_id + + const revisionLabel = + variantRef?.version ?? + variantRef?.revision ?? + variantRef?.revisionLabel ?? + applicationRevision?.revision ?? + applicationRevision?.version ?? + applicationRevision?.name ?? + null + + if (!rawAppId && !rawRevisionId && !rawVariantName) return null + + return { + appId: typeof rawAppId === "string" ? rawAppId : undefined, + appName: typeof rawAppName === "string" ? rawAppName : undefined, + revisionId: typeof rawRevisionId === "string" ? rawRevisionId : undefined, + variantName: typeof rawVariantName === "string" ? rawVariantName : undefined, + revisionLabel: revisionLabel ?? undefined, + } +} + +export const extractPrimaryInvocation = ( + evaluation: EvaluationRow, +): { + appId?: string + appName?: string + revisionId?: string + variantName?: string + revisionLabel?: string | number +} | null => { + if (!evaluation) return null + + const variants = (evaluation as any)?.variants + if (Array.isArray(variants) && variants.length) { + const variant = variants[0] + return { + appId: + variant?.appId || + (typeof variant?.app_id === "string" ? variant.app_id : undefined) || + (typeof variant?.applicationId === "string" ? variant.applicationId : undefined), + appName: (variant as any)?.appName || (variant as any)?.appSlug, + revisionId: + (variant as any)?.id || + (typeof variant?.revisionId === "string" ? variant.revisionId : undefined) || + (typeof variant?.revision_id === "string" ? variant.revision_id : undefined), + variantName: variant?.variantName || variant?.name || (variant as any)?.slug, + revisionLabel: + (variant as any)?.revisionLabel || + (variant as any)?.revision || + (variant as any)?.version, + } + } + + return parseInvocationMetadata(evaluation) +} + +export const extractEvaluationAppId = (evaluation: EvaluationRow): string | undefined => { + const invocation = extractPrimaryInvocation(evaluation) + if (invocation?.appId) return invocation.appId + + const directAppId: Nullable = (evaluation as any)?.appId + if (typeof directAppId === "string" && directAppId.length > 0) { + return directAppId + } + + const variants = (evaluation as any)?.variants + if (Array.isArray(variants) && variants.length) { + const candidate = variants[0] + const variantAppId = + (typeof candidate?.appId === "string" && + candidate.appId.length > 0 && + candidate.appId) || + (typeof candidate?.app_id === "string" && + candidate.app_id.length > 0 && + candidate.app_id) || + (typeof candidate?.applicationId === "string" && + candidate.applicationId.length > 0 && + candidate.applicationId) + if (variantAppId) return variantAppId + } + + return undefined +} + +export const getCommonEvaluationAppId = (evaluations: EvaluationRow[]): string | undefined => { + if (!Array.isArray(evaluations) || evaluations.length === 0) return undefined + const ids = new Set( + evaluations + .map((evaluation) => extractEvaluationAppId(evaluation)) + .filter((id): id is string => Boolean(id)), + ) + + if (ids.size !== 1) return undefined + const [only] = Array.from(ids) + return only +} + +export const buildAppScopedUrl = (baseAppURL: string, appId: string, path: string): string => { + const normalizedPath = path.startsWith("/") ? path : `/${path}` + return `${baseAppURL}/${encodeURIComponent(appId)}${normalizedPath}` +} + +export const buildProjectEvaluationUrl = (projectURL: string, path: string): string => { + const normalizedPath = path.startsWith("/") ? path : `/${path}` + return `${projectURL}${normalizedPath}` +} + +export const buildEvaluationNavigationUrl = ({ + scope, + baseAppURL, + projectURL, + appId, + path, +}: { + scope: "app" | "project" + baseAppURL: string + projectURL: string + appId?: string + path: string +}) => { + if (scope === "app" && appId) { + return buildAppScopedUrl(baseAppURL, appId, path) + } + return buildProjectEvaluationUrl(projectURL, path) +} diff --git a/web/ee/src/components/pages/observability/dashboard/widgetCard.tsx b/web/ee/src/components/pages/observability/dashboard/widgetCard.tsx new file mode 100644 index 0000000000..27038d1279 --- /dev/null +++ b/web/ee/src/components/pages/observability/dashboard/widgetCard.tsx @@ -0,0 +1,85 @@ +import {ReactNode, useState} from "react" + +import {Tabs, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {JSSTheme} from "@/oss/lib/Types" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + root: { + borderRadius: theme.borderRadiusLG, + border: `1px solid ${theme.colorBorder}`, + display: "flex", + flexDirection: "column", + padding: theme.padding, + }, + title: { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + subHeadingRoot: { + display: "flex", + gap: 8, + marginBottom: theme.padding, + }, +})) + +interface WidgetData { + leftSubHeading?: ReactNode + rightSubHeading?: ReactNode + children?: ReactNode + title: string +} + +interface Props extends WidgetData { + tabs?: WidgetData[] +} + +const WidgetInnerContent: React.FC & {loading?: boolean}> = ({ + leftSubHeading, + rightSubHeading, + children, +}) => { + const classes = useStyles() + + return ( + <> +
    + {leftSubHeading ?? null} + {rightSubHeading ?? null} +
    + {children ?? null} + + ) +} + +const WidgetCard: React.FC = ({title, leftSubHeading, rightSubHeading, tabs, children}) => { + const classes = useStyles() + const [tab, setTab] = useState(tabs?.[0]?.title ?? "") + + return ( +
    + {title} + {tabs?.length ? ( + ({ + key: tab.title, + label: tab.title, + children: , + }))} + /> + ) : ( + + )} +
    + ) +} + +export default WidgetCard diff --git a/web/ee/src/components/pages/overview/deployments/DeploymentHistoryModal.tsx b/web/ee/src/components/pages/overview/deployments/DeploymentHistoryModal.tsx new file mode 100644 index 0000000000..77af135044 --- /dev/null +++ b/web/ee/src/components/pages/overview/deployments/DeploymentHistoryModal.tsx @@ -0,0 +1,415 @@ +// @ts-nocheck +import {useCallback, useEffect, useMemo, useRef, useState} from "react" + +import {CloseOutlined, MoreOutlined, SwapOutlined} from "@ant-design/icons" +import {ClockCounterClockwise, GearSix} from "@phosphor-icons/react" +import {Button, Dropdown, message, Modal, Space, Spin, Table, Typography} from "antd" +import {ColumnsType} from "antd/es/table" +import {useRouter} from "next/router" +import {createUseStyles} from "react-jss" + +import VariantPopover from "@/oss/components/pages/overview/variants/VariantPopover" +import ContentSpinner from "@/oss/components/Spinner/ContentSpinner" +import {formatDay} from "@/oss/lib/helpers/dateTimeHelper" +import {Environment, JSSTheme, Variant} from "@/oss/lib/Types" +import {DeploymentRevision, DeploymentRevisionConfig, DeploymentRevisions} from "@/oss/lib/types_ee" + +import DeploymentRevertModal from "./DeploymentRevertModal" +import HistoryConfig from "./HistoryConfig" + +type DeploymentHistoryModalProps = { + setIsHistoryModalOpen: (value: React.SetStateAction) => void + selectedEnvironment: Environment + variant: Variant +} & React.ComponentProps + +const {Title} = Typography + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + display: "flex", + gap: theme.paddingLG, + padding: `${theme.paddingLG}px 0`, + height: 760, + }, + title: { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + subTitle: { + fontSize: theme.fontSize, + lineHeight: theme.lineHeight, + fontWeight: theme.fontWeightMedium, + }, + modalTitle: { + "& h1.ant-typography": { + fontSize: theme.fontSizeHeading5, + fontWeight: theme.fontWeightMedium, + textTransform: "capitalize", + }, + }, +})) + +const DeploymentHistoryModal = ({ + selectedEnvironment, + setIsHistoryModalOpen, + variant, + ...props +}: DeploymentHistoryModalProps) => { + const classes = useStyles() + const router = useRouter() + const appId = router.query.app_id as string + + const [depRevisionsList, setDepRevisionsList] = useState(null) + const [depRevisionConfig, setDepRevisionConfig] = useState( + null, + ) + const [activeDepRevisionConfig, setActiveDepRevisionConfig] = + useState(null) + const [isDepRevisionLoading, setIsDepRevisionLoading] = useState(false) + const [isDepRevisionConfigLoading, setIsDepRevisionConfigLoading] = useState(false) + const [selectedDepRevision, setSelectedDepRevision] = useState(null) + const [compareDeployment, setCompareDeployment] = useState(false) + const [confirmDepModalOpen, setConfirmDepModalOpen] = useState(false) + + const [isRevertDeploymentLoading, setIsRevertDeploymentLoading] = useState(false) + const [selectedRevert, setSelectedRevert] = useState(null) + + const [selectedRevisionNumber, setSelectedRevisionNumber] = useState(null) + + const fetchControllerRef = useRef(null) + + const deployedAppRevisionId = useMemo(() => { + return depRevisionsList?.deployed_app_variant_revision_id || null + }, [depRevisionsList]) + + const deployedAppRevision = useMemo(() => { + return depRevisionsList?.revisions.find( + (rev) => rev.deployed_app_variant_revision === deployedAppRevisionId, + ) + }, [depRevisionsList, deployedAppRevisionId]) + + const fetchDevRevisionConfig = useCallback(async (record: string) => { + try { + const mod = await import("@/oss/services/deploymentVersioning/api") + const fetchAllDeploymentRevisionConfig = mod?.fetchAllDeploymentRevisionConfig + if (!mod || !fetchAllDeploymentRevisionConfig) return + + const data = await fetchAllDeploymentRevisionConfig(record, undefined, true) + setActiveDepRevisionConfig(data) + } catch (error) { + console.error("Failed to fetch deployment revision config:", error) + } + }, []) + + useEffect(() => { + if (deployedAppRevision?.id) { + fetchDevRevisionConfig(deployedAppRevision.id) + } + }, [deployedAppRevision, fetchDevRevisionConfig]) + + const isShowingCurrentDeployment = useMemo(() => { + return deployedAppRevisionId === selectedDepRevision?.deployed_app_variant_revision + }, [deployedAppRevisionId, selectedDepRevision]) + + const fetchDevRevisions = useCallback(async () => { + setIsDepRevisionLoading(true) + try { + const mod = await import("@/oss/services/deploymentVersioning/api") + const fetchAllDeploymentRevisions = mod?.fetchAllDeploymentRevisions + if (!mod || !fetchAllDeploymentRevisions) return + + const data = await fetchAllDeploymentRevisions(appId, selectedEnvironment.name) + setDepRevisionsList(data) + setSelectedDepRevision(data.revisions.reverse()[0] || null) + const totalRows = data?.revisions.length as number + setSelectedRevisionNumber(totalRows || null) + } catch (error) { + console.error("Failed to fetch deployment revisions:", error) + } finally { + setIsDepRevisionLoading(false) + } + }, [appId, selectedEnvironment]) + + const handleRevertDeployment = async (deploymentRevisionId: string) => { + try { + setIsRevertDeploymentLoading(true) + const mod = await import("@/oss/services/deploymentVersioning/api") + const createRevertDeploymentRevision = mod?.createRevertDeploymentRevision + if (!mod || !createRevertDeploymentRevision) return + + await createRevertDeploymentRevision(deploymentRevisionId) + await fetchDevRevisions() + message.success("Environment successfully reverted to deployment revision") + } catch (error) { + console.error(error) + } finally { + setIsRevertDeploymentLoading(false) + } + } + + const fetchDevRevisionConfigById = useCallback(async (revisionId: string) => { + fetchControllerRef.current?.abort() + const controller = new AbortController() + fetchControllerRef.current = controller + + try { + setIsDepRevisionConfigLoading(true) + const mod = await import("@/oss/services/deploymentVersioning/api") + const fetchAllDeploymentRevisionConfig = mod?.fetchAllDeploymentRevisionConfig + if (!mod || !fetchAllDeploymentRevisionConfig) return + + const data = await fetchAllDeploymentRevisionConfig(revisionId, controller.signal) + setDepRevisionConfig(data) + } catch (error) { + console.error(error) + } finally { + setIsDepRevisionConfigLoading(false) + } + }, []) + + useEffect(() => { + if (appId && selectedEnvironment) { + fetchDevRevisions() + } + }, [appId, selectedEnvironment, fetchDevRevisions]) + + useEffect(() => { + if (selectedDepRevision) { + fetchDevRevisionConfigById(selectedDepRevision.id) + } + }, [selectedDepRevision, fetchDevRevisionConfigById]) + + const columns: ColumnsType = [ + { + title: "Revision", + dataIndex: "revision", + key: "revision", + width: 48, + render: (_, record, index) => { + const totalRows = depRevisionsList?.revisions.length as number + const versionNumber = totalRows - index + return v{versionNumber} + }, + }, + { + title: "Modified by", + dataIndex: "modified_by", + key: "modified_by", + render: (_, record) => {record.modified_by}, + }, + { + title: "Created on", + dataIndex: "created_at", + key: "created_at", + render: (_, record) => {formatDay({date: record.created_at})}, + }, + { + title: , + key: "actions", + width: 56, + fixed: "right", + align: "center", + render: (_, record) => ( + , + onClick: (event) => { + event.domEvent.stopPropagation() + setConfirmDepModalOpen(true) + setSelectedRevert(record) + }, + disabled: + activeDepRevisionConfig?.current_version === record.revision, + }, + { + key: "compare_to_current", + label: "Compare to current", + icon: , + onClick: (event) => { + event.domEvent.stopPropagation() + setSelectedDepRevision(record) + setCompareDeployment(true) + }, + disabled: + activeDepRevisionConfig?.current_version === record.revision, + }, + ], + }} + > +
    ({ + onClick: () => { + const totalRows = depRevisionsList?.revisions + .length as number + setSelectedRevisionNumber(totalRows - (index ?? 0)) + setSelectedDepRevision(record) + }, + style: {cursor: "pointer"}, + })} + pagination={false} + /> + + )} + + +
    + {isDepRevisionConfigLoading || !depRevisionConfig ? ( + + ) : ( + <> + + + Revision v{selectedRevisionNumber} + + {isShowingCurrentDeployment ? ( + Current Deployment + ) : ( + + + {compareDeployment ? ( + + ) : ( + + )} + + )} + + + + Variant Deployed + + {variant && ( + + )} + + {variant ? ( + + ) : null} + + )} +
    + + + + {selectedRevert && ( + setConfirmDepModalOpen(false)} + onOk={async () => { + await handleRevertDeployment(selectedRevert.id) + setConfirmDepModalOpen(false) + }} + selectedRevert={selectedRevert} + selectedEnvironment={selectedEnvironment} + okButtonProps={{loading: isRevertDeploymentLoading}} + selectedDeployedVariant={variant} + /> + )} + + ) +} + +export default DeploymentHistoryModal diff --git a/web/ee/src/components/pages/overview/deployments/DeploymentRevertModal.tsx b/web/ee/src/components/pages/overview/deployments/DeploymentRevertModal.tsx new file mode 100644 index 0000000000..a4279b9136 --- /dev/null +++ b/web/ee/src/components/pages/overview/deployments/DeploymentRevertModal.tsx @@ -0,0 +1,79 @@ +import {Rocket} from "@phosphor-icons/react" +import {Modal, Typography} from "antd" +import {createUseStyles} from "react-jss" + +import {Environment, JSSTheme, Variant} from "@/oss/lib/Types" +import {DeploymentRevision} from "@/oss/lib/types_ee" + +type DeploymentModalProps = { + selectedRevert: DeploymentRevision + selectedEnvironment: Environment + selectedDeployedVariant: Variant +} & React.ComponentProps + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + container: { + "& .ant-modal-footer": { + display: "flex", + alignItems: "center", + justifyContent: "flex-end", + }, + }, + wrapper: { + "& h1": { + fontSize: theme.fontSizeLG, + fontWeight: theme.fontWeightStrong, + lineHeight: theme.lineHeightLG, + marginBottom: 8, + }, + "& span": { + color: theme.colorPrimary, + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + }, +})) + +const DeploymentModal = ({ + selectedEnvironment, + selectedRevert, + selectedDeployedVariant, + ...props +}: DeploymentModalProps) => { + const classes = useStyles() + + return ( + + + Deploy + + } + centered + destroyOnHidden + zIndex={3000} + {...props} + > +
    + Revert Deployment + +
    +
    + You are about to deploy {selectedDeployedVariant.variantName} to{" "} + {selectedEnvironment.name} environment. This will overwrite the existing + configuration. This change will affect all future calls to this environment. +
    +
    + You are about to deploy {selectedEnvironment.name} environment: + Revision v{selectedRevert.revision || 0} +
    +
    +
    +
    + ) +} + +export default DeploymentModal diff --git a/web/ee/src/components/pages/overview/deployments/HistoryConfig.tsx b/web/ee/src/components/pages/overview/deployments/HistoryConfig.tsx new file mode 100644 index 0000000000..5c915177a2 --- /dev/null +++ b/web/ee/src/components/pages/overview/deployments/HistoryConfig.tsx @@ -0,0 +1,112 @@ +import {useMemo} from "react" + +import {Typography} from "antd" +import {useAtomValue} from "jotai" +import {createUseStyles} from "react-jss" + +import {NewVariantParametersView} from "@/oss/components/VariantsComponents/Drawers/VariantDrawer/assets/Parameters" +import {filterVariantParameters} from "@/oss/lib/helpers/utils" +import {useVariants} from "@/oss/lib/hooks/useVariants" +import {JSSTheme, Variant} from "@/oss/lib/Types" +import {DeploymentRevisionConfig} from "@/oss/lib/types_ee" +import {currentAppAtom} from "@/oss/state/app" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + title: { + fontSize: theme.fontSizeLG, + lineHeight: theme.lineHeightLG, + fontWeight: theme.fontWeightMedium, + }, + subTitle: { + fontSize: theme.fontSize, + lineHeight: theme.lineHeight, + fontWeight: theme.fontWeightMedium, + }, + resultTag: { + minWidth: 150, + display: "flex", + borderRadius: theme.borderRadiusSM, + border: `1px solid ${theme.colorBorder}`, + textAlign: "center", + "& > div:nth-child(1)": { + backgroundColor: "rgba(0, 0, 0, 0.02)", + lineHeight: theme.lineHeight, + flex: 1, + minWidth: 50, + borderRight: `1px solid ${theme.colorBorder}`, + padding: "0 7px", + }, + "& > div:nth-child(2)": { + padding: "0 7px", + }, + }, + promptTextField: { + padding: theme.paddingXS, + backgroundColor: theme.colorBgContainerDisabled, + borderRadius: theme.borderRadius, + }, + noParams: { + color: theme.colorTextDescription, + fontWeight: theme.fontWeightMedium, + textAlign: "center", + marginTop: 48, + }, +})) + +interface HistoryConfigProps { + depRevisionConfig: DeploymentRevisionConfig + variant: Variant +} + +const HistoryConfig = ({depRevisionConfig, variant: propsVariant}: HistoryConfigProps) => { + const classes = useStyles() + + const currentApp = useAtomValue(currentAppAtom) + // @ts-ignore + const {data, isLoading} = useVariants(currentApp, [propsVariant]) + const variant = useMemo( + // @ts-ignore + () => data?.variants.find((v) => v.id === propsVariant.id), + [data?.variants, propsVariant.id], + ) + + return ( +
    + Configuration + + {Object.keys(depRevisionConfig.parameters).length ? ( +
    +
    + {!isLoading && !!variant ? ( + + ) : null} +
    + + {depRevisionConfig.parameters && + Object.entries( + filterVariantParameters({ + record: depRevisionConfig.parameters, + key: "prompt", + }), + ).map(([key, value], index) => ( +
    + + {key} + +
    + {JSON.stringify(value)} +
    +
    + ))} +
    + ) : ( + No Parameters + )} +
    + ) +} + +export default HistoryConfig diff --git a/web/ee/src/components/pages/overview/observability/ObservabilityOverview.tsx b/web/ee/src/components/pages/overview/observability/ObservabilityOverview.tsx new file mode 100644 index 0000000000..34ffe5bf56 --- /dev/null +++ b/web/ee/src/components/pages/overview/observability/ObservabilityOverview.tsx @@ -0,0 +1,135 @@ +import {useMemo} from "react" + +import {AreaChart} from "@tremor/react" +import {Col, Row, Spin, Typography} from "antd" +import round from "lodash/round" +import {createUseStyles} from "react-jss" + +import {formatCurrency, formatLatency, formatNumber} from "@/oss/lib/helpers/formatters" +import {JSSTheme} from "@/oss/lib/Types" + +import {useObservabilityDashboard} from "../../../../state/observability" +import WidgetCard from "../../observability/dashboard/widgetCard" + +const useStyles = createUseStyles((theme: JSSTheme) => ({ + statText: { + fontWeight: 400, + }, +})) + +const ObservabilityOverview = () => { + const classes = useStyles() + const {data, loading, isFetching} = useObservabilityDashboard() + + const chartData = useMemo(() => (data?.data?.length ? data.data : [{}]), [data]) + + const defaultGraphProps = useMemo>( + () => ({ + className: "h-[168px] p-0", + colors: ["cyan", "red"], + connectNulls: true, + tickGap: 15, + curveType: "monotone", + showGridLines: false, + showLegend: false, + index: "timestamp", + data: chartData, + categories: [], + }), + [chartData], + ) + + return ( +
    + + +
    + + Total:{" "} + {data?.total_count ? formatNumber(data?.total_count) : "-"} + + } + rightSubHeading={ + (data?.failure_rate ?? 0) > 0 && ( + + Failed:{" "} + {data?.failure_rate + ? `${formatNumber(data?.failure_rate)}%` + : "-"} + + ) + } + > + 0 + ? ["success_count", "failure_count"] + : ["success_count"] + } + /> + + + + + Avg:{" "} + {data?.avg_latency + ? `${formatNumber(data.avg_latency)}ms` + : "-"} + + } + > + + + + + + Total:{" "} + {data?.total_cost ? formatCurrency(data.total_cost) : "-"} + + } + rightSubHeading={ + + Avg:{" "} + {data?.total_cost ? formatCurrency(data.avg_cost) : "-"} + + } + > + + + + + + Total:{" "} + {data?.total_tokens ? formatNumber(data?.total_tokens) : "-"} + + } + rightSubHeading={ + + Avg:{" "} + {data?.avg_tokens ? formatNumber(data?.avg_tokens) : "-"} + + } + > + + + + + + + ) +} + +export default ObservabilityOverview diff --git a/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/AutoRenewalCancelModalContent/index.tsx b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/AutoRenewalCancelModalContent/index.tsx new file mode 100644 index 0000000000..d793165378 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/AutoRenewalCancelModalContent/index.tsx @@ -0,0 +1,33 @@ +import {memo} from "react" + +import {Input, Radio, Typography} from "antd" + +import {CANCEL_REASONS} from "../constants" +import {AutoRenewalCancelModalContentProps} from "../types" + +const AutoRenewalCancelModalContent = ({ + inputValue, + onChangeInput, + ...props +}: AutoRenewalCancelModalContentProps) => { + const _value = props.value + return ( +
    + + Please select one of the reasons + + + {CANCEL_REASONS.map((option) => ( + + {option.label} + + ))} + + {_value === "something-else" && ( + + )} +
    + ) +} + +export default memo(AutoRenewalCancelModalContent) diff --git a/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/constants.ts b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/constants.ts new file mode 100644 index 0000000000..685ee31611 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/constants.ts @@ -0,0 +1,10 @@ +export const CANCEL_REASONS = [ + {value: "dont-want-auto-renewal", label: "I don't want to continue with auto-renewal"}, + {value: "done-with-project", label: "I am done with my project"}, + {value: "switching-to-another-service", label: "I'm switching to another service"}, + {value: "technical-issues", label: "The product has technical issues"}, + {value: "missing-features", label: "The product doesn't have features I wanted"}, + {value: "too-expensive", label: "It's too expensive"}, + {value: "not-used-enough", label: "I don't use it enough"}, + {value: "something-else", label: "Something else"}, +] diff --git a/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/types.d.ts b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/types.d.ts new file mode 100644 index 0000000000..4aa4b6530e --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/assets/types.d.ts @@ -0,0 +1,8 @@ +import {ModalProps, RadioGroupProps} from "antd" + +export interface AutoRenewalCancelModalProps extends ModalProps {} + +export interface AutoRenewalCancelModalContentProps extends RadioGroupProps { + inputValue: string + onChangeInput: (e: ChangeEvent) => void +} diff --git a/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/index.tsx b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/index.tsx new file mode 100644 index 0000000000..73523d2ce9 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/AutoRenewalCancelModal/index.tsx @@ -0,0 +1,74 @@ +import {useCallback, useState} from "react" + +import {message} from "antd" +import dynamic from "next/dynamic" + +import EnhancedModal from "@/oss/components/EnhancedUIs/Modal" +import {cancelSubscription, useSubscriptionData, useUsageData} from "@/oss/services/billing" + +import {AutoRenewalCancelModalProps} from "./assets/types" + +const AutoRenewalCancelModalContent = dynamic( + () => import("./assets/AutoRenewalCancelModalContent"), + {ssr: false}, +) + +const AutoRenewalCancelModal = ({...props}: AutoRenewalCancelModalProps) => { + const [selectOption, setSelectOption] = useState("") + const [inputOption, setInputOption] = useState("") + const [isLoading, setIsLoading] = useState(false) + + const {mutateSubscription} = useSubscriptionData() + const {mutateUsage} = useUsageData() + + const onConfirmCancel = useCallback(async () => { + // TODO: add posthog here to send the select form option data + try { + setIsLoading(true) + const data = await cancelSubscription() + + if (data.data.status === "success") { + message.success("Your subscription has been successfully canceled.") + setTimeout(() => { + mutateUsage() + mutateSubscription() + props.onCancel?.({} as any) + }, 500) + } else { + message.error( + "We were unable to cancel your subscription. Please try again later or contact support if the issue persists.", + ) + } + } catch (error) { + message.error( + "An error occurred while processing your request. Please try again later or contact support if the issue persists.", + ) + } finally { + setIsLoading(false) + } + }, [mutateSubscription, mutateUsage, cancelSubscription]) + + return ( + setSelectOption("")} + {...props} + > + setSelectOption(e.target.value)} + inputValue={inputOption} + onChangeInput={(e) => setInputOption(e.target.value)} + /> + + ) +} + +export default AutoRenewalCancelModal diff --git a/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingCard/index.tsx b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingCard/index.tsx new file mode 100644 index 0000000000..31d60e6dda --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingCard/index.tsx @@ -0,0 +1,96 @@ +import {memo, useMemo} from "react" + +import {Card, Button, Typography} from "antd" + +import {Plan} from "@/oss/lib/Types" + +import {PricingCardProps} from "../types" + +const PricingCard = ({plan, currentPlan, onOptionClick, isLoading}: PricingCardProps) => { + const _isLoading = isLoading === plan.plan + const isDisabled = useMemo( + () => + (isLoading !== null && isLoading !== plan.plan) || + currentPlan?.plan == plan.plan || + currentPlan?.plan == Plan.Business || + currentPlan?.plan == Plan.Enterprise, + [isLoading, currentPlan, plan], + ) + + return ( + + window.open("https://cal.com/mahmoud-mabrouk-ogzgey/demo", "_blank") + } + > + {currentPlan?.plan == Plan.Business || currentPlan?.plan == Plan.Enterprise + ? "Current plan" + : "Talk to us"} + + ) : ( + + ), + ]} + > +
    +
    + + {plan.price + ? `${plan.price.base?.starting_at ? "Starts at " : ""} $ + ${plan.price?.base?.amount} /month` + : "Contact us"} + + + + {plan.description} + +
    + +
      + {plan.features?.map((point, idx) => { + return ( +
    • + {point} +
    • + ) + })} +
    +
    +
    + ) +} + +export default memo(PricingCard) diff --git a/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingModalContent/index.tsx b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingModalContent/index.tsx new file mode 100644 index 0000000000..12a1aad499 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingModalContent/index.tsx @@ -0,0 +1,95 @@ +import {useCallback, useState} from "react" + +import {message, Spin, Typography} from "antd" + +import useURL from "@/oss/hooks/useURL" +import {getEnv} from "@/oss/lib/helpers/dynamicEnv" +import {Plan} from "@/oss/lib/Types" +import { + checkoutNewSubscription, + usePricingPlans, + useSubscriptionData, + useUsageData, +} from "@/oss/services/billing" +import {BillingPlan} from "@/oss/services/billing/types" + +import PricingCard from "../PricingCard" +import {PricingModalContentProps} from "../types" + +const PricingModalContent = ({onCancelSubscription, onCloseModal}: PricingModalContentProps) => { + const {plans, isLoadingPlan} = usePricingPlans() + const {subscription, mutateSubscription} = useSubscriptionData() + const {mutateUsage} = useUsageData() + const {projectURL} = useURL() + + const [isLoading, setIsLoading] = useState(null) + + const onOptionClick = useCallback( + async (plan: BillingPlan) => { + try { + setIsLoading(plan.plan) + // 1. if the selected plan is cloud_v0_hobby and the subscription-plan is not then we trigger the cancel endpoint + // 2. subscription-pan is cloud_v0_hobby then we trigger the checkout endpoint + // 3. if the user can custom plan like cloud_v0_business then we trigger the switch endpoint + + if (plan.plan === Plan.Hobby && subscription?.plan !== Plan.Hobby) { + onCancelSubscription() + return + } else { + const data = await checkoutNewSubscription({ + plan: plan.plan, + success_url: `${getEnv("NEXT_PUBLIC_AGENTA_WEB_URL")}${projectURL || ""}/settings?tab=billing`, + }) + + window.open(data.data.checkout_url, "_blank") + } + + setTimeout(() => { + mutateSubscription() + mutateUsage() + onCloseModal() + }, 500) + } catch (error) { + message.error( + "An error occurred while processing the checkout. Please try again later or contact support if the issue persists.", + ) + } finally { + setIsLoading(null) + } + }, + [ + onCancelSubscription, + checkoutNewSubscription, + mutateSubscription, + mutateUsage, + projectURL, + ], + ) + + if (isLoadingPlan) { + return ( +
    + +
    + ) + } + + return ( +
    + Choose your plan +
    + {plans?.map((plan) => ( + + ))} +
    +
    + ) +} + +export default PricingModalContent diff --git a/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingModalTitle/index.tsx b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingModalTitle/index.tsx new file mode 100644 index 0000000000..b57ff2b041 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/PricingModalTitle/index.tsx @@ -0,0 +1,17 @@ +import {memo} from "react" + +import {Button, Typography} from "antd" + +const PricingModalTitle = () => { + return ( +
    + Plans + + +
    + ) +} + +export default memo(PricingModalTitle) diff --git a/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/SubscriptionPlanDetails/index.tsx b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/SubscriptionPlanDetails/index.tsx new file mode 100644 index 0000000000..c6af8e86f8 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/SubscriptionPlanDetails/index.tsx @@ -0,0 +1,22 @@ +import dayjs from "dayjs" + +import {SubscriptionType} from "@/oss/services/billing/types" + +const SubscriptionPlanDetails = ({subscription}: {subscription: SubscriptionType}) => { + return ( + <> + {subscription?.plan?.split("_")[2]}{" "} + + {subscription.free_trial + ? `trial ends in ${dayjs.unix(subscription.period_end).diff(dayjs(), "day")} ${ + dayjs.unix(subscription.period_end).diff(dayjs(), "day") === 1 + ? "day" + : "days" + }` + : ""} + + + ) +} + +export default SubscriptionPlanDetails diff --git a/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/constants.ts b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/constants.ts new file mode 100644 index 0000000000..8bf4b40d4c --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/constants.ts @@ -0,0 +1,45 @@ +export const PRICING_PLANS_INFO = [ + { + key: "free", + title: "Free", + price: "Free", + description: "Great for hobby projects", + priceDescription: "2 users and 5k traces per month included", + bulletPoints: [ + "2 prompts", + "2 users included", + "5k traces per month included", + "20 evaluations / month included", + ], + }, + { + key: "pro", + title: "Pro", + price: "$49/month", + description: "For production projects", + priceDescription: "2 users and 5k traces per month included", + bulletPoints: [ + "Unlimited prompts", + "3 seats included then $20 per seat", + "Up to 10 seats", + "Unlimited evaluations", + "10k traces / month included", + ], + }, + { + key: "business", + title: "Business", + price: "$49/month", + description: "For teams with security and support needs", + priceDescription: "2 users and 5k traces per month included", + bulletPoints: ["2 prompts", "2 seats", "5k traces", "20 evaluations / month included"], + }, + { + key: "enterprise", + title: "Enterprise", + price: "$49/month", + description: "For teams with security and support needs", + priceDescription: "2 users and 5k traces per month included", + bulletPoints: ["2 prompts", "2 seats", "5k traces", "20 evaluations / month included"], + }, +] diff --git a/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/types.d.ts b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/types.d.ts new file mode 100644 index 0000000000..5c34b9fa15 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/assets/types.d.ts @@ -0,0 +1,28 @@ +import {ModalProps} from "antd" + +import {BillingPlan, SubscriptionType} from "@/oss/services/billing/types" + +export interface PricingModalProps extends ModalProps { + onCancelSubscription: () => void +} + +export interface PricingModalContentProps { + onCloseModal: () => void + onCancelSubscription: () => void +} + +export interface PricingPlan { + key: string + title: string + price: string + description: string + priceDescription: string + bulletPoints: string[] +} + +export interface PricingCardProps { + plan: BillingPlan + currentPlan: SubscriptionType | null + onOptionClick: (plan: BillingPlan) => void + isLoading: string | null +} diff --git a/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/index.tsx b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/index.tsx new file mode 100644 index 0000000000..ace1e472be --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/Modals/PricingModal/index.tsx @@ -0,0 +1,27 @@ +import clsx from "clsx" +import dynamic from "next/dynamic" + +import EnhancedModal from "@/oss/components/EnhancedUIs/Modal" + +import PricingModalTitle from "./assets/PricingModalTitle" +import {PricingModalProps} from "./assets/types" +const PricingModalContent = dynamic(() => import("./assets/PricingModalContent"), {ssr: false}) + +const PricingModal = ({onCancelSubscription, ...props}: PricingModalProps) => { + return ( + } + footer={null} + {...props} + > + props.onCancel?.({} as any)} + onCancelSubscription={onCancelSubscription} + /> + + ) +} + +export default PricingModal diff --git a/web/ee/src/components/pages/settings/Billing/assets/UsageProgressBar/index.tsx b/web/ee/src/components/pages/settings/Billing/assets/UsageProgressBar/index.tsx new file mode 100644 index 0000000000..cbdf21f2f9 --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/assets/UsageProgressBar/index.tsx @@ -0,0 +1,33 @@ +import {memo} from "react" + +import {WarningFilled} from "@ant-design/icons" +import {Space, Typography} from "antd" + +import {UsageProgressBarProps} from "../types" + +const UsageProgressBar = ({ + label, + limit, + used: value, + isUnlimited = false, + free, +}: UsageProgressBarProps) => { + return ( +
    + + {label}{" "} + {!isUnlimited && value >= limit && } + + + + {`${value} / ${limit ? limit : "-"}`} + {`${free ? `(${value > free ? free : value} / ${free} free)` : ``}`} + +
    + ) +} + +export default memo(UsageProgressBar) diff --git a/web/ee/src/components/pages/settings/Billing/assets/types.d.ts b/web/ee/src/components/pages/settings/Billing/assets/types.d.ts new file mode 100644 index 0000000000..95e0578f1a --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/assets/types.d.ts @@ -0,0 +1,12 @@ +interface UsedMetric { + limit: number + used: number +} +export interface UsageProgressBarProps { + label: string + isUnlimited?: boolean + strict?: boolean + limit: number + used: number + free: number +} diff --git a/web/ee/src/components/pages/settings/Billing/index.tsx b/web/ee/src/components/pages/settings/Billing/index.tsx new file mode 100644 index 0000000000..5240c8c53c --- /dev/null +++ b/web/ee/src/components/pages/settings/Billing/index.tsx @@ -0,0 +1,177 @@ +import {useCallback, useState} from "react" + +import {Button, message, Spin, Typography} from "antd" +import dayjs from "dayjs" +import {useRouter} from "next/router" + +import {Plan} from "@/oss/lib/Types" +import {editSubscriptionInfo, useSubscriptionData, useUsageData} from "@/oss/services/billing" + +import UsageProgressBar from "./assets/UsageProgressBar" +import AutoRenewalCancelModal from "./Modals/AutoRenewalCancelModal" +import PricingModal from "./Modals/PricingModal" +import SubscriptionPlanDetails from "./Modals/PricingModal/assets/SubscriptionPlanDetails" +import useURL from "@/oss/hooks/useURL" + +const {Link} = Typography + +const Billing = () => { + const router = useRouter() + const {projectURL} = useURL() + const [isLoadingOpenBillingPortal, setIsLoadingOpenBillingPortal] = useState(false) + const {subscription, isSubLoading} = useSubscriptionData() + const {usage, isUsageLoading} = useUsageData() + const [isOpenPricingModal, setIsOpenPricingModal] = useState(false) + const [isOpenCancelModal, setIsOpenCancelModal] = useState(false) + + const onCancelSubscription = useCallback(() => { + setIsOpenCancelModal(true) + }, []) + + const handleOpenBillingPortal = useCallback(async () => { + try { + setIsLoadingOpenBillingPortal(true) + const data = await editSubscriptionInfo() + + window.open(data.data.portal_url, "_blank") + } catch (error) { + message.error( + "We encountered an issue while opening the Stripe portal. Please try again in a few minutes. If the problem persists, contact support.", + ) + } finally { + setIsLoadingOpenBillingPortal(false) + } + }, [editSubscriptionInfo]) + + const navigateToWorkspaceTab = useCallback(() => { + router.push(`${projectURL}/settings`, {query: {tab: "workspace"}}) + }, [router, projectURL]) + + if (isSubLoading || isUsageLoading) { + return ( +
    + +
    + ) + } + + return ( +
    +
    +
    + Current plan + + + + {subscription?.plan !== Plan.Hobby && ( + + {subscription?.free_trial + ? "Trial period will end on " + : "Auto renews on "} + + {dayjs.unix(subscription?.period_end).format("MMM D, YYYY")} + + + )} + + {subscription?.plan === Plan.Enterprise || + subscription?.plan === Plan.Business ? ( + + For queries regarding your plan,{" "} + + click here to contact us + + + ) : subscription?.plan === Plan.Pro ? ( +
    + + + setIsOpenCancelModal(true)}> + Cancel subscription + +
    + ) : ( + + )} +
    +
    + +
    + Limits + +
    + {Object.entries(usage) + ?.filter(([key]) => key !== "users") + ?.map(([key, info]) => { + return ( + + ) + })} +
    +
    + +
    +
    + Members + +
    + +
    + + + +
    +
    + +
    + + Billing information + + + +
    + + setIsOpenCancelModal(false)} + /> + setIsOpenPricingModal(false)} + onCancelSubscription={onCancelSubscription} + /> +
    + ) +} + +export default Billing diff --git a/web/ee/src/contexts/RunIdContext.tsx b/web/ee/src/contexts/RunIdContext.tsx new file mode 100644 index 0000000000..4c94ed92c4 --- /dev/null +++ b/web/ee/src/contexts/RunIdContext.tsx @@ -0,0 +1,40 @@ +import React, {createContext, useContext} from "react" + +/** + * Context for providing the current evaluation run ID to components. + * This enables components to use run-scoped atoms without prop drilling. + */ +export const RunIdContext = createContext(null) + +/** + * Provider component that supplies the run ID to all child components. + */ +export const RunIdProvider: React.FC<{ + runId: string + children: React.ReactNode +}> = ({runId, children}) => { + return {children} +} + +/** + * Hook to access the current run ID from context. + * Throws an error if used outside of a RunIdProvider. + */ +export const useRunId = (): string => { + const runId = useContext(RunIdContext) + if (!runId) { + throw new Error( + "useRunId must be used within a RunIdProvider. " + + "Make sure your component is wrapped with ", + ) + } + return runId +} + +/** + * Hook to safely access the run ID, returning null if not available. + * Useful for components that can work with or without a run ID. + */ +export const useOptionalRunId = (): string | null => { + return useContext(RunIdContext) +} diff --git a/web/ee/src/hooks/useCrispChat.ts b/web/ee/src/hooks/useCrispChat.ts new file mode 100644 index 0000000000..19db37e57f --- /dev/null +++ b/web/ee/src/hooks/useCrispChat.ts @@ -0,0 +1,43 @@ +import {useState, useCallback, useEffect} from "react" + +import {Crisp} from "crisp-sdk-web" + +import {getEnv} from "@/oss/lib/helpers/dynamicEnv" + +export const useCrispChat = () => { + const isCrispEnabled = !!getEnv("NEXT_PUBLIC_CRISP_WEBSITE_ID") + + const [isVisible, setIsVisible] = useState(false) + + const updateVisibility = useCallback( + (visible: boolean) => { + if (isCrispEnabled) { + if (visible) { + Crisp.chat.show() + Crisp.chat.open() + } else { + Crisp.chat.hide() + } + setIsVisible(visible) + } + }, + [isCrispEnabled], + ) + + const toggle = useCallback(() => { + if (isCrispEnabled) { + updateVisibility(!isVisible) + } + }, [isVisible, updateVisibility, isCrispEnabled]) + + useEffect(() => { + updateVisibility(false) + }, [updateVisibility]) + + return { + isVisible, + setVisible: updateVisibility, + toggle, + isCrispEnabled, + } +} diff --git a/web/ee/src/lib/helpers/evaluate.ts b/web/ee/src/lib/helpers/evaluate.ts new file mode 100644 index 0000000000..82a87f62a2 --- /dev/null +++ b/web/ee/src/lib/helpers/evaluate.ts @@ -0,0 +1,449 @@ +import {EvaluationType} from "@agenta/oss/src/lib/enums" +import {convertToCsv, downloadCsv} from "@agenta/oss/src/lib/helpers/fileManipulations" +import {formatCurrency, formatLatency} from "@agenta/oss/src/lib/helpers/formatters" +import {isDemo} from "@agenta/oss/src/lib/helpers/utils" +import { + Evaluation, + GenericObject, + TypedValue, + Variant, + _Evaluation, + EvaluationScenario, +} from "@agenta/oss/src/lib/Types" +import dayjs from "dayjs" +import capitalize from "lodash/capitalize" +import round from "lodash/round" + +import AlertPopup from "@/oss/components/AlertPopup/AlertPopup" +import {runningStatuses} from "@/oss/components/pages/evaluations/cellRenderers/cellRenderers" +import { + HumanEvaluationListTableDataType, + SingleModelEvaluationListTableDataType, +} from "@/oss/lib/Types" +import {fetchEvaluatonIdsByResource} from "@/oss/services/evaluations/api" + +export const exportExactEvaluationData = (evaluation: Evaluation, rows: GenericObject[]) => { + const exportRow = rows.map((data, ix) => { + return { + ["Inputs"]: + evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn] || + data.inputs[0].input_value, + [`App Variant ${evaluation.variants[0].variantName} Output`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + ["Correct answer"]: data.correctAnswer, + ["Evaluation"]: data.score, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const exportSimilarityEvaluationData = (evaluation: Evaluation, rows: GenericObject[]) => { + const exportRow = rows.map((data, ix) => { + return { + ["Inputs"]: + evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn] || + data.inputs[0].input_value, + [`App Variant ${evaluation.variants[0].variantName} Output`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + ["Correct answer"]: data.correctAnswer, + ["Score"]: data.score, + ["Evaluation"]: data.similarity, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const exportAICritiqueEvaluationData = (evaluation: Evaluation, rows: GenericObject[]) => { + const exportRow = rows.map((data, ix) => { + return { + ["Inputs"]: + evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn] || + data.inputs[0].input_value, + [`App Variant ${evaluation.variants[0].variantName} Output`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + ["Correct answer"]: data.correctAnswer, + ["Score"]: data.score, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const exportABTestingEvaluationData = ( + evaluation: Evaluation, + scenarios: EvaluationScenario[], + rows: GenericObject[], +) => { + const exportRow = rows.map((data, ix) => { + const inputColumns = evaluation.testset.testsetChatColumn + ? {Input: evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn]} + : data.inputs.reduce( + (columns: any, input: {input_name: string; input_value: string}) => { + columns[`${input.input_name}`] = input.input_value + return columns + }, + {}, + ) + return { + ...inputColumns, + [`App Variant ${evaluation.variants[0].variantName} Output 0`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + [`App Variant ${evaluation.variants[1].variantName} Output 1`]: data?.columnData1 + ? data?.columnData1 + : data.outputs[1]?.variant_output, + ["Vote"]: + evaluation.variants.find((v: Variant) => v.variantId === data.vote)?.variantName || + data.vote, + ["Expected Output"]: + scenarios[ix]?.correctAnswer || evaluation.testset.csvdata[ix].correct_answer, + ["Additional notes"]: scenarios[ix]?.note, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.variants[1].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const exportSingleModelEvaluationData = ( + evaluation: Evaluation, + scenarios: EvaluationScenario[], + rows: GenericObject[], +) => { + const exportRow = rows.map((data, ix) => { + const inputColumns = evaluation.testset.testsetChatColumn + ? {Input: evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn]} + : data.inputs.reduce( + (columns: any, input: {input_name: string; input_value: string}) => { + columns[`${input.input_name}`] = input.input_value + return columns + }, + {}, + ) + const numericScore = parseInt(data.score) + return { + ...inputColumns, + [`App Variant ${evaluation.variants[0].variantName} Output 0`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + ["Score"]: isNaN(numericScore) ? "-" : numericScore, + ["Expected Output"]: + scenarios[ix]?.correctAnswer || evaluation.testset.csvdata[ix].correct_answer, + ["Additional notes"]: scenarios[ix]?.note, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const exportRegexEvaluationData = ( + evaluation: Evaluation, + rows: GenericObject[], + settings: GenericObject, +) => { + const exportRow = rows.map((data, ix) => { + const isCorrect = data.score === "correct" + const isMatch = settings.regexShouldMatch ? isCorrect : !isCorrect + + return { + ["Inputs"]: + evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn] || + data.inputs[0].input_value, + [`App Variant ${evaluation.variants[0].variantName} Output`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + ["Match / Mismatch"]: isMatch ? "Match" : "Mismatch", + ["Evaluation"]: data.score, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const exportWebhookEvaluationData = (evaluation: Evaluation, rows: GenericObject[]) => { + const exportRow = rows.map((data, ix) => { + return { + ["Inputs"]: + evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn] || + data.inputs[0].input_value, + [`App Variant ${evaluation.variants[0].variantName} Output`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + ["Correct answer"]: data.correctAnswer, + ["Score"]: data.score, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const exportCustomCodeEvaluationData = (evaluation: Evaluation, rows: GenericObject[]) => { + const exportRow = rows.map((data, ix) => { + return { + ["Inputs"]: + evaluation.testset.csvdata[ix]?.[evaluation.testset.testsetChatColumn] || + data.inputs[0].input_value, + [`App Variant ${evaluation.variants[0].variantName} Output`]: data?.columnData0 + ? data?.columnData0 + : data.outputs[0]?.variant_output, + ["Correct answer"]: data.correctAnswer, + ["Score"]: data.score, + } + }) + const exportCol = Object.keys(exportRow[0]) + + const csvData = convertToCsv(exportRow, exportCol) + const filename = `${evaluation.appName}_${evaluation.variants[0].variantName}_${evaluation.evaluationType}.csv` + downloadCsv(csvData, filename) +} + +export const calculateResultsDataAvg = (resultsData: Record, multiplier = 10) => { + const obj = {...resultsData} + Object.keys(obj).forEach((key) => { + if (isNaN(+key)) delete obj[key] + }) + + const count = Object.values(obj).reduce((acc, value) => acc + +value, 0) + const sum = Object.keys(obj).reduce((acc, key) => acc + (parseFloat(key) || 0) * +obj[key], 0) + return (sum / count) * multiplier +} + +export const getVotesPercentage = (record: HumanEvaluationListTableDataType, index: number) => { + const variant = record.votesData.variants[index] + return record.votesData.variants_votes_data[variant]?.percentage +} + +export const checkIfResourceValidForDeletion = async ( + data: Omit[0], "appId">, +) => { + if (isDemo()) { + const response = await fetchEvaluatonIdsByResource(data) + if (response.data.length > 0) { + const name = + (data.resourceType === "testset" + ? "Testset" + : data.resourceType === "evaluator_config" + ? "Evaluator" + : "Variant") + (data.resourceIds.length > 1 ? "s" : "") + + const suffix = response.data.length > 1 ? "s" : "" + AlertPopup({ + title: `${name} is in use`, + message: `The ${name} is currently in used by ${response.data.length} evaluation${suffix}. Please delete the evaluation${suffix} first.`, + cancelText: null, + okText: "Ok", + }) + return false + } + } + return true +} + +export function getTypedValue(res?: TypedValue) { + const {value, type, error} = res || {} + if (type === "error") { + return error?.message + } + + if (value === undefined) return "-" + + switch (type) { + case "number": + return round(Number(value), 2) + case "boolean": + case "bool": + return capitalize(value?.toString()) + case "cost": + return formatCurrency(Number(value)) + case "latency": + return formatLatency(Number(value)) + default: + return value?.toString() + } +} + +type CellDataType = "number" | "text" | "date" +export function getFilterParams(type: CellDataType) { + const filterParams: GenericObject = {} + if (type == "date") { + filterParams.comparator = function ( + filterLocalDateAtMidnight: Date, + cellValue: string | null, + ) { + if (cellValue == null) return -1 + const cellDate = dayjs(cellValue).startOf("day").toDate() + if (filterLocalDateAtMidnight.getTime() === cellDate.getTime()) { + return 0 + } + if (cellDate < filterLocalDateAtMidnight) { + return -1 + } + if (cellDate > filterLocalDateAtMidnight) { + return 1 + } + } + } + + return { + sortable: true, + floatingFilter: true, + filter: + type === "number" + ? "agNumberColumnFilter" + : type === "date" + ? "agDateColumnFilter" + : "agTextColumnFilter", + cellDataType: type === "number" ? "text" : type, + filterParams, + comparator: getCustomComparator(type), + } +} + +export const calcEvalDuration = (evaluation: _Evaluation) => { + return dayjs( + runningStatuses.includes(evaluation.status.value) ? Date.now() : evaluation.updated_at, + ).diff(dayjs(evaluation.created_at), "milliseconds") +} + +const getCustomComparator = (type: CellDataType) => (valueA: string, valueB: string) => { + const getNumber = (val: string) => { + const num = parseFloat(val || "0") + return isNaN(num) ? 0 : num + } + + valueA = String(valueA) + valueB = String(valueB) + + switch (type) { + case "date": + return dayjs(valueA).diff(dayjs(valueB)) + case "text": + return valueA.localeCompare(valueB) + case "number": + return getNumber(valueA) - getNumber(valueB) + default: + return 0 + } +} + +export const removeCorrectAnswerPrefix = (str: string) => { + return str.replace(/^correctAnswer_/, "") +} + +export const mapTestcaseAndEvalValues = ( + settingsValues: Record, + selectedTestcase: Record, +) => { + const testcaseObj: Record = {} + const evalMapObj: Record = {} + + Object.entries(settingsValues).forEach(([key, value]) => { + if (typeof value === "string" && value.startsWith("testcase.")) { + testcaseObj[key] = selectedTestcase[value.split(".")[1]] + } else { + evalMapObj[key] = value + } + }) + + return {testcaseObj, evalMapObj} +} + +export const transformTraceKeysInSettings = ( + settingsValues: Record, +): Record => { + return Object.keys(settingsValues).reduce( + (acc, curr) => { + if ( + !acc[curr] && + typeof settingsValues[curr] === "string" && + settingsValues[curr].startsWith("trace.") + ) { + acc[curr] = settingsValues[curr].replace("trace.", "") + } else { + acc[curr] = settingsValues[curr] + } + + return acc + }, + {} as Record, + ) +} + +export const getEvaluatorTags = () => { + const evaluatorTags = [ + { + label: "Classifiers", + value: "classifiers", + }, + { + label: "Similarity", + value: "similarity", + }, + { + label: "AI / LLM", + value: "ai_llm", + }, + { + label: "Functional", + value: "functional", + }, + ] + + if (isDemo()) { + evaluatorTags.unshift({ + label: "RAG", + value: "rag", + }) + } + + return evaluatorTags +} + +export const calculateAvgScore = (evaluation: SingleModelEvaluationListTableDataType) => { + let score = 0 + if (evaluation.scoresData) { + score = + ((evaluation.scoresData.correct?.length || evaluation.scoresData.true?.length || 0) / + evaluation.scoresData.nb_of_rows) * + 100 + } else if (evaluation.resultsData) { + const multiplier = { + [EvaluationType.auto_webhook_test]: 100, + [EvaluationType.single_model_test]: 1, + } + score = calculateResultsDataAvg( + evaluation.resultsData, + multiplier[evaluation.evaluationType as keyof typeof multiplier], + ) + score = isNaN(score) ? 0 : score + } else if (evaluation.avgScore) { + score = evaluation.avgScore * 100 + } + + return score +} diff --git a/web/ee/src/lib/helpers/hashUtils.ts b/web/ee/src/lib/helpers/hashUtils.ts new file mode 100644 index 0000000000..5c66724e5a --- /dev/null +++ b/web/ee/src/lib/helpers/hashUtils.ts @@ -0,0 +1,73 @@ +// Utility to generate a hash ID for annotation/invocation steps, aligned with backend make_hash_id +// Uses blake2b if available, otherwise falls back to SHA-256 + +import blake from "blakejs" +// import { v4 as uuidv4 } from "uuid" // Use this for UUIDs if needed + +const REFERENCE_KEYS = [ + "application", + "application_variant", + "application_revision", + "testset", + "testcase", + "evaluator", +] + +// Recursively stable, whitespace-free JSON stringifier +function stableStringifyRecursive(obj: any): string { + if (obj === null || typeof obj !== "object") { + return JSON.stringify(obj) + } + if (Array.isArray(obj)) { + return `[${obj.map(stableStringifyRecursive).join(",")}]` + } + const keys = Object.keys(obj).sort() + const entries = keys.map( + (key) => `${JSON.stringify(key)}:${stableStringifyRecursive(obj[key])}`, + ) + return `{${entries.join(",")}}` +} + +export function makeHashId({ + references, + links, +}: { + references?: Record + links?: Record +}): string { + if (!references && !links) return "" + const payload: Record = {} + + for (const k of Object.keys(references || {})) { + if (REFERENCE_KEYS.includes(k)) { + const v = references![k] + // Only include 'id' field, not 'slug' + if (v.id != null) { + payload[k] = {id: v.id} + } + } + } + for (const k of Object.keys(links || {})) { + const v = links![k] + payload[k] = { + span_id: v.span_id, + trace_id: v.trace_id, + } + } + // Stable, deep, whitespace-free JSON + const serialized = stableStringifyRecursive(payload) + + // blake2b hash (digest_size=16) + try { + // Use blakejs (same as backend example) + return blake.blake2bHex(serialized, null, 16) + } catch (e) { + // Fallback: SHA-256 + if (window.crypto?.subtle) { + throw new Error( + "blake2b not available and crypto.subtle is async. Provide a polyfill or use a sync fallback.", + ) + } + return btoa(serialized) + } +} diff --git a/web/ee/src/lib/helpers/traceUtils.ts b/web/ee/src/lib/helpers/traceUtils.ts new file mode 100644 index 0000000000..17909bddcb --- /dev/null +++ b/web/ee/src/lib/helpers/traceUtils.ts @@ -0,0 +1,146 @@ +import {uuidToTraceId} from "@/oss/lib/hooks/useAnnotations/assets/helpers" + +import {TraceData, TraceTree} from "../hooks/useEvaluationRunScenarioSteps/types" + +export function findTraceForStep(traces: any[] | undefined, traceId?: string): any | undefined { + if (!traces?.length || !traceId) return undefined + const noDash = uuidToTraceId(traceId) + + return traces.find((t) => { + // Case 1: wrapper with trees array (new shape) + if (t?.trees?.length) { + const firstTree = t.trees[0] + if (firstTree?.tree?.id === traceId) return true + if (firstTree?.nodes?.[0]?.trace_id === noDash) return true + } + // Case 2: flat shape { tree, nodes } + if (t?.tree?.id === traceId) return true + if (t?.nodes?.[0]?.trace_id === noDash) return true + return false + }) +} + +// generic safe path resolver +export function resolvePath(obj: any, path: string): any { + const parts = path.split(".") + let current: any = obj + for (let i = 0; i < parts.length && current !== undefined; i++) { + const key = parts[i] + if (key in current) { + current = current[key] + continue + } + // if the exact key not found, try joining the remaining parts as a whole key (to support dots inside actual key names) + const remainder = parts.slice(i).join(".") + if (remainder in current) { + current = current[remainder] + return current + } + return undefined + } + return current +} + +// Unified helper to obtain trace and response value for a specific invocation step +// Manual mapping for legacy/compatibility keys to canonical keys +const INVOCATION_OUTPUT_KEY_MAP: Record = { + "attributes.ag.data.outputs": "data.outputs", + // Add more mappings here if needed +} + +export function readInvocationResponse({ + scenarioData, + stepKey, + path, + optimisticResult, + forceTrace, + scenarioId, +}: { + scenarioData: any + stepKey: string + path?: string + optimisticResult?: any + forceTrace?: TraceTree + scenarioId?: string +}): {trace?: any; value?: any; rawValue?: any; testsetId?: string; testcaseId?: string} { + if (!scenarioData) return {} + + const invocationSteps: any[] = Array.isArray(scenarioData.invocationSteps) + ? scenarioData.invocationSteps + : [] + const stepByKey = stepKey ? invocationSteps.find((s: any) => s?.stepKey === stepKey) : undefined + const stepByScenario = + !stepByKey && scenarioId + ? invocationSteps.find((s: any) => s?.scenarioId === scenarioId) + : undefined + const invocationStep = stepByKey ?? stepByScenario ?? invocationSteps[0] + const effectiveStepKey = invocationStep?.stepKey ?? stepKey + + // --- PATH RESOLUTION LOGIC --- + let resolvedPath: string | undefined = undefined + if (path) { + resolvedPath = path + } else if (scenarioData.mappings && Array.isArray(scenarioData.mappings) && effectiveStepKey) { + const mapEntry = scenarioData.mappings.find((m: any) => m.step?.key === effectiveStepKey) + if (mapEntry?.step?.path) { + resolvedPath = mapEntry.step.path + } + } + // After resolving, apply legacy/custom mapping if needed + if (resolvedPath && INVOCATION_OUTPUT_KEY_MAP[resolvedPath]) { + resolvedPath = INVOCATION_OUTPUT_KEY_MAP[resolvedPath] + } + // --- END PATH RESOLUTION LOGIC --- + + // --- MAPPING LOGIC FOR TESTSET/TESTCASE INFERENCE --- + let testsetId: string | undefined = undefined + let testcaseId: string | undefined = undefined + if (scenarioData.mappings && Array.isArray(scenarioData.mappings) && effectiveStepKey) { + const mapping = scenarioData.mappings.find( + (m: any) => + m.invocationStep?.stepKey === effectiveStepKey || + m.step?.stepKey === effectiveStepKey, + ) + if (mapping && mapping.inputStep?.stepKey) { + const inputStep = scenarioData.inputSteps?.find( + (s: any) => s.stepKey === mapping.inputStep.stepKey, + ) + if (inputStep) { + testsetId = inputStep.testsetId + testcaseId = inputStep.testcaseId + } + } + } + // ----------------------------------------------------- + + // Access trace directly attached to the invocation step (set during enrichment) + const trace = (forceTrace || invocationStep?.trace?.nodes?.[0]) ?? undefined + + // First priority: optimistic result override (e.g., UI enqueue) + let rawValue = optimisticResult + + if (rawValue === undefined && resolvedPath) { + rawValue = resolvePath(trace, resolvedPath) + } + + // Convert raw value to displayable string where possible + let value: any = rawValue + if ( + typeof rawValue === "string" || + typeof rawValue === "number" || + typeof rawValue === "boolean" + ) { + value = String(rawValue) + } else if (rawValue && typeof rawValue === "object") { + if (typeof (rawValue as any).content === "string") { + value = (rawValue as any).content + } else { + try { + value = JSON.stringify(rawValue, null, 2) + } catch { + value = String(rawValue as any) + } + } + } + return {trace, value, rawValue, testsetId, testcaseId} +} diff --git a/web/ee/src/lib/hooks/useEvalScenarioQueue/index.ts b/web/ee/src/lib/hooks/useEvalScenarioQueue/index.ts new file mode 100644 index 0000000000..9843fedbee --- /dev/null +++ b/web/ee/src/lib/hooks/useEvalScenarioQueue/index.ts @@ -0,0 +1,348 @@ +import {useCallback, useEffect, useMemo, useRef} from "react" + +import {loadable} from "jotai/utils" + +// import {triggerScenarioRevalidation} from "@/oss/components/EvalRunDetails/assets/annotationUtils" +// import {getCurrentProject} from "@/oss/contexts/project.context" +// import {useAppId} from "@/oss/hooks/useAppId" +// import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +// import {evalAtomStore} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +// Import EE run-scoped atoms for multi-run support +import {triggerScenarioRevalidation} from "@/oss/components/EvalRunDetails/HumanEvalRun/assets/annotationUtils" +import {setOptimisticStepData} from "@/oss/components/EvalRunDetails/HumanEvalRun/assets/optimisticUtils" +import {useAppId} from "@/oss/hooks/useAppId" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {evaluationRunStateFamily} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms/runScopedAtoms" +import {useJwtRefresher} from "@/oss/lib/hooks/useJWT" +import {EvaluationStatus} from "@/oss/lib/Types" +import {slugify} from "@/oss/lib/utils/slugify" +import type {ConfigMessage, ResultMessage, RunEvalMessage} from "@/oss/lib/workers/evalRunner/types" +import {getProjectValues} from "@/oss/state/project" + +// import {setOptimisticStepData} from "../../../components/EvalRunDetails/assets/optimisticUtils" +import {evalAtomStore} from "../useEvaluationRunData/assets/atoms" +import {triggerMetricsFetch} from "../useEvaluationRunData/assets/atoms/runScopedMetrics" +import {scenarioStepFamily} from "../useEvaluationRunData/assets/atoms/runScopedScenarios" +import {IInvocationStep} from "../useEvaluationRunScenarioSteps/types" + +import {BatchingQueue} from "./responseQueue" + +let sharedWorker: Worker | null = null +let isWorkerInitialized = false + +const MAX_RETRIES = 1 + +export function useEvalScenarioQueue(options?: {concurrency?: number; runId?: string}) { + const {jwt} = useJwtRefresher() + const {runId: optionsRunId} = options || {} + + /* -------- helpers that read atoms lazily -------- */ + const getRunMeta = useCallback(() => { + const store = evalAtomStore() + const effectiveRunId = optionsRunId + if (!effectiveRunId) { + console.warn("[useEvalScenarioQueue] No runId provided, cannot get run metadata") + return {runId: undefined, revision: undefined} + } + const runState = store.get(evaluationRunStateFamily(effectiveRunId)) + const run = runState?.enrichedRun + return { + runId: effectiveRunId, + revision: run?.variants?.[0], + } + }, [optionsRunId]) + + const workerRef = useRef(null) + const retryCountRef = useRef>(new Map()) + const abortedRef = useRef>(new Set()) + // New refs for timestamps and transitions + const timestampsRef = useRef>(new Map()) + const transitionsRef = useRef>(new Map()) + const appId = useAppId() + + // placeholder for batching queue ref – will init after handleResult + const queueRef = useRef>(undefined) + + // ---- handle single worker message ---- + const handleResult = useCallback( + (data: ResultMessage) => { + const {runId} = getRunMeta() + const {invocationStepTarget, invocationKey, scenarioId, status, result} = data + + if (abortedRef.current.has(scenarioId)) return + if (!invocationStepTarget) return + + if (status === EvaluationStatus.FAILURE) { + const retryCount = retryCountRef.current.get(scenarioId) ?? 0 + if (retryCount < MAX_RETRIES) { + if (!runId) return + const nextRetry = retryCount + 1 + retryCountRef.current.set(scenarioId, nextRetry) + setOptimisticStepData( + scenarioId, + [ + { + ...structuredClone(invocationStepTarget), + status: EvaluationStatus.RUNNING, + }, + ], + runId, + ) + + workerRef.current?.postMessage({ + type: "run-invocation", + jwt, + appId, + scenarioId, + runId, + requestBody: result?.requestBody ?? {}, + endpoint: result?.endpoint ?? "", + apiUrl: getAgentaApiUrl(), + projectId: getProjectValues().projectId, + invocationKey, + invocationStepTarget, + }) + return + } + } else { + retryCountRef.current.delete(scenarioId) + } + + try { + const optimisticData: IInvocationStep = { + ...structuredClone(invocationStepTarget), + status, + traceId: result.traceId, + trace: result?.tree, + } + + if ("invocationParameters" in invocationStepTarget) { + optimisticData.invocationParameters = + status === EvaluationStatus.SUCCESS + ? undefined + : (invocationStepTarget as IInvocationStep).invocationParameters + } + + if (runId) { + // Apply optimistic updates directly to maintain loading state continuity + setOptimisticStepData(scenarioId, [optimisticData], runId) + } + + // Delay the server revalidation to allow optimistic state to be visible + // This prevents immediate overwrite of the "running" status + triggerScenarioRevalidation(runId, scenarioId, [optimisticData]) + } catch (err) { + console.error("Failed to trigger scenario step refetch", err) + } + + const now = Date.now() + const existingTransitions = transitionsRef.current.get(scenarioId) ?? [] + transitionsRef.current.set(scenarioId, [ + ...existingTransitions, + {status, timestamp: now}, + ]) + const existingTimestamps = timestampsRef.current.get(scenarioId) ?? {} + if (status === "pending" && existingTimestamps.startedAt === undefined) { + timestampsRef.current.set(scenarioId, {...existingTimestamps, startedAt: now}) + } + if ( + (status === EvaluationStatus.SUCCESS || status === EvaluationStatus.FAILURE) && + existingTimestamps.endedAt === undefined + ) { + timestampsRef.current.set(scenarioId, {...existingTimestamps, endedAt: now}) + + // Trigger metrics refresh when scenario completes (success or failure) + if (runId) { + triggerMetricsFetch(runId) + } + } + }, + [jwt, retryCountRef, abortedRef, appId], + ) + + // initialize queue after we have stable handleResult + if (!queueRef.current) { + queueRef.current = new BatchingQueue((batch) => { + batch.forEach((item) => handleResult(item.payload)) + }) + } + + useEffect(() => { + if (!sharedWorker) { + sharedWorker = new Worker( + new URL("@/oss/lib/workers/evalRunner/evalRunner.worker.ts", import.meta.url), + ) + } + + workerRef.current = sharedWorker + + if (!isWorkerInitialized) { + const concurrency = options?.concurrency ?? 5 + const configMsg: ConfigMessage = {type: "config", maxConcurrent: concurrency} + sharedWorker.postMessage(configMsg) + isWorkerInitialized = true + } + + sharedWorker.onmessage = (e: MessageEvent) => { + handleResult(e.data) + // if (e.data.type === "result") { + // queueRef.current?.push(e.data) + // } + } + }, [jwt, options?.concurrency, appId]) + + const enqueueScenario = useCallback( + (scenarioId: string, stepKey?: string) => { + const store = evalAtomStore() + // Use run-scoped atom - runId should always be available in EE version + if (!optionsRunId) { + console.warn( + "[useEvalScenarioQueue] No runId provided, cannot get scenario step data", + ) + return undefined + } + + const stepLoadable = store.get( + loadable(scenarioStepFamily({scenarioId, runId: optionsRunId})), + ) + + if (stepLoadable.state === "hasData") { + const stepData = stepLoadable.data + // use data safely here + const invSteps = stepData?.invocationSteps ?? [] + const target = stepKey + ? invSteps.find((s) => s.stepKey === stepKey) + : invSteps.find((s) => s.invocationParameters) + + if (!target?.invocationParameters) return + const {runId, revision} = getRunMeta() + if (!jwt || !runId) return + + const invocationSteps: any[] | undefined = stepData?.invocationSteps + let requestBody: any, endpoint: string | undefined + let invocationStepTarget: any | undefined + if (invocationSteps) { + if (stepKey) { + invocationStepTarget = invocationSteps.find((s) => s.stepKey === stepKey) + } else { + invocationStepTarget = invocationSteps.find((s) => s.invocationParameters) + } + if (invocationStepTarget?.invocationParameters) { + requestBody = structuredClone( + invocationStepTarget.invocationParameters?.requestBody, + ) + endpoint = invocationStepTarget.invocationParameters?.endpoint + } + } + // Optimistic running override using shared helper + queueMicrotask(() => { + setOptimisticStepData( + scenarioId, + [ + { + ...structuredClone(invocationStepTarget), + status: EvaluationStatus.RUNNING, + }, + ], + runId, + ) + }) + retryCountRef.current.set(scenarioId, 0) + abortedRef.current.delete(scenarioId) + + let invocationKey: string | undefined + if (revision) { + invocationKey = slugify( + revision.name ?? revision.variantName ?? "invocation", + revision.id, + ) + } + + // Append required references to invocation request body before sending to worker + // invocationStepTarget is defined above in this scope + try { + if (requestBody && typeof requestBody === "object") { + const references: Record = + (requestBody.references as any) || {} + + // Testset id – derive from graph: find input step with same testcaseId + let testsetId: string | undefined + const inputSteps: any[] | undefined = stepData?.inputSteps + if (Array.isArray(inputSteps) && invocationStepTarget) { + const matchingInput = inputSteps.find( + (s) => s.testcaseId === (invocationStepTarget as any).testcaseId, + ) + testsetId = + matchingInput?.testcase?.testset_id || + matchingInput?.references?.testset?.id || + matchingInput?.refs?.testset?.id + } + if (testsetId) { + references.testset = {id: testsetId} + } + + // Application related references + if (appId) references.application = {id: appId} + const variantId = revision?.variantId || revision?.id || undefined + if (variantId) references.application_variant = {id: String(variantId)} + if (revision?.id) + references.application_revision = {id: String(revision.id)} + + requestBody.references = references + } + } catch (err) { + console.error("Failed to append references to invocation payload", err) + } + + if (endpoint) { + const message: RunEvalMessage = { + type: "run-invocation", + appId: appId, + jwt, + scenarioId, + runId, + requestBody, + endpoint, + invocationKey, + invocationStepTarget, + apiUrl: getAgentaApiUrl(), + projectId: getProjectValues().projectId, + } + + workerRef.current?.postMessage(message) + + // Update timestamps and transitions on enqueue + const now = Date.now() + const existingTransitions = transitionsRef.current.get(scenarioId) ?? [] + transitionsRef.current.set(scenarioId, [ + ...existingTransitions, + {status: "pending", timestamp: now}, + ]) + const existingTimestamps = timestampsRef.current.get(scenarioId) ?? {} + if (existingTimestamps.startedAt === undefined) { + timestampsRef.current.set(scenarioId, { + ...existingTimestamps, + startedAt: now, + }) + } + } + } + }, + [jwt, getRunMeta], + ) + + const cancelScenario = useCallback((scenarioId: string) => { + if (process.env.NODE_ENV !== "production") { + console.debug(`[EvalQueue] Cancelling scenario ${scenarioId}`) + } + abortedRef.current.add(scenarioId) + }, []) + + return useMemo( + () => ({ + enqueueScenario, + cancelScenario, + }), + [enqueueScenario, cancelScenario], + ) +} diff --git a/web/ee/src/lib/hooks/useEvalScenarioQueue/responseQueue.ts b/web/ee/src/lib/hooks/useEvalScenarioQueue/responseQueue.ts new file mode 100644 index 0000000000..b575de18ce --- /dev/null +++ b/web/ee/src/lib/hooks/useEvalScenarioQueue/responseQueue.ts @@ -0,0 +1,48 @@ +export interface QueueItem { + payload: T + receivedAt: number +} + +/** + * Generic in-memory batching queue. Push items and they will be flushed + * either when we reach `maxBatch` length or after `maxWaitMs` timeout, + * whichever comes first. The consumer provides a `processBatch` callback + * that receives all pending items in the order they were received. + */ +export class BatchingQueue { + private pending: QueueItem[] = [] + private flushTimer: ReturnType | null = null + + constructor( + private readonly processBatch: (items: QueueItem[]) => void, + private readonly maxBatch = 20, + private readonly maxWaitMs = 150, + ) {} + + push(item: T) { + this.pending.push({payload: item, receivedAt: Date.now()}) + // If we already reached the batch size, flush synchronously + if (this.pending.length >= this.maxBatch) { + this.flush() + return + } + // Otherwise ensure a timer exists + if (!this.flushTimer) { + this.flushTimer = setTimeout(() => this.flush(), this.maxWaitMs) + } + } + + flush() { + if (this.flushTimer) { + clearTimeout(this.flushTimer) + this.flushTimer = null + } + if (this.pending.length === 0) return + const batch = this.pending.splice(0, this.pending.length) + try { + this.processBatch(batch) + } catch (err) { + console.error("[BatchingQueue] processBatch failed", err) + } + } +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/bulkFetch.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/bulkFetch.ts new file mode 100644 index 0000000000..bd13a5dab0 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/bulkFetch.ts @@ -0,0 +1,96 @@ +import {createStore} from "jotai" + +import {UseEvaluationRunScenarioStepsFetcherResult} from "../../../useEvaluationRunScenarioSteps/types" +import {fetchScenarioViaWorkerAndCache} from "../helpers/fetchScenarioViaWorker" + +import { + bulkStepsStatusFamily, + bulkStepsCacheFamily, + evaluationRunStateFamily, + enrichedRunFamily, +} from "./runScopedAtoms" + +/* + Bulk scenario-step prefetching for Evaluation Run screen. + Updated to work with run-scoped atom families instead of global atoms. + This allows multiple evaluation runs to have independent bulk fetch states. +*/ + +// Legacy exports for backward compatibility during migration +// These will be removed once all components are migrated +export const bulkStepsStatusAtom = bulkStepsStatusFamily("__legacy__") + +// Bulk fetch logic updated to work with run-scoped atom families +export async function runBulkFetch( + store: ReturnType, + runId: string, + scenarioIds: string[], + opts: { + force?: boolean + onComplete?: (map: Map) => void + } = {}, +): Promise> { + if (!scenarioIds || !scenarioIds.length) { + return new Map() + } + + const status = store.get(bulkStepsStatusFamily(runId)) + + if (!opts.force && (status === "loading" || status === "done")) { + const cachedData = store.get(bulkStepsCacheFamily(runId)) + + return cachedData + } + + const enrichedRun = store.get(enrichedRunFamily(runId)) + const evaluationRunState = store.get(evaluationRunStateFamily(runId)) + const runIndex = evaluationRunState?.runIndex + + // Validate scenario IDs and filter out skeleton/placeholder IDs + const validScenarioIds = scenarioIds.filter((id) => { + if (!id || typeof id !== "string") return false + + // Skip skeleton/placeholder IDs gracefully + if (id.startsWith("skeleton-") || id.startsWith("placeholder-")) { + return false + } + + const uuidRegex = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i + return uuidRegex.test(id) + }) + + // Use filtered valid IDs + scenarioIds = validScenarioIds + + // Early return if no valid scenario IDs remain after filtering + if (scenarioIds.length === 0) { + return store.get(bulkStepsCacheFamily(runId)) + } + + if (!runId || !enrichedRun || !runIndex) { + return store.get(bulkStepsCacheFamily(runId)) + } + + store.set(bulkStepsStatusFamily(runId), "loading") + // return + try { + const params = {runId, evaluation: enrichedRun, runIndex} + + const workerResult = + (await fetchScenarioViaWorkerAndCache(params, scenarioIds)) || new Map() + + // Write all results to the bulk cache atom at once + store.set(bulkStepsCacheFamily(runId), (draft) => { + for (const [scenarioId, scenarioSteps] of workerResult?.entries() || []) { + if (scenarioSteps) { + draft.set(scenarioId, scenarioSteps) + } + } + }) + + store.set(bulkStepsStatusFamily(runId), "done") + } catch (err) { + console.error("[bulk-steps] bulk fetch ERROR", err) + store.set(bulkStepsStatusFamily(runId), "error") + } +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/cache.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/cache.ts new file mode 100644 index 0000000000..50338bd00a --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/cache.ts @@ -0,0 +1,6 @@ +import {atom} from "jotai" + +import type {TraceData} from "../../../useEvaluationRunScenarioSteps/types" + +// traceId -> TraceData +export const traceCacheAtom = atom(new Map()) diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/index.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/index.ts new file mode 100644 index 0000000000..474aab8bef --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/index.ts @@ -0,0 +1,19 @@ +import {atom} from "jotai" + +// New run-scoped atoms +export * from "./runScopedAtoms" +export * from "./runScopedScenarios" + +// Migration helper for backward compatibility - only export specific functions +export {getCurrentRunId} from "./migrationHelper" + +// Legacy atoms and functions (for backward compatibility during migration) +import {evalAtomStore, initializeRun} from "./store" + +// re-export legacy store helpers (will be deprecated) +export {evalAtomStore, initializeRun} + +export * from "./utils" +export * from "./bulkFetch" +export * from "./progress" +export * from "./cache" diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/migrationHelper.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/migrationHelper.ts new file mode 100644 index 0000000000..52e43467c5 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/migrationHelper.ts @@ -0,0 +1,18 @@ +/** + * Migration helper for gradual transition from global atoms to run-scoped atoms + * + * This provides compatibility layers that allow existing components to work + * while we gradually migrate them to use the new run-scoped atom families. + */ + +// Current active run ID - this is a temporary bridge during migration +let currentRunId: string | null = null + +export const getCurrentRunId = (): string => { + if (!currentRunId) { + throw new Error( + "No current run ID set. Make sure to call setCurrentRunId() before using legacy atoms.", + ) + } + return currentRunId +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/progress.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/progress.ts new file mode 100644 index 0000000000..be83fcba04 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/progress.ts @@ -0,0 +1,263 @@ +import deepEqual from "fast-deep-equal" +import {Atom, atom} from "jotai" +import {atomFamily, loadable, selectAtom} from "jotai/utils" +import {eagerAtom} from "jotai-eager" +import {atomWithImmer} from "jotai-immer" + +import {evalTypeAtom} from "@/oss/components/EvalRunDetails/state/evalType" + +import {EvaluationLoadingState} from "../../types" +import {defaultLoadingState} from "../constants" + +// import {bulkStepsCacheAtom} from "./bulkFetch" + +import {evaluationRunStateFamily} from "./runScopedAtoms" +import { + displayedScenarioIdsFamily, + scenarioIdsFamily, + scenarioStepFamily, + scenarioStepLocalFamily, +} from "./runScopedScenarios" +import {ScenarioCounts, StatusCounters} from "./types" + +// ---------------- Shared counter helper ---------------- +const emptyCounters = (): StatusCounters => ({ + pending: 0, + running: 0, + completed: 0, + cancelled: 0, + unannotated: 0, + failed: 0, +}) + +const tallyStatus = (counters: StatusCounters, status: string): void => { + switch (status) { + case "pending": + case "revalidating": + counters.pending += 1 + break + case "running": + counters.running += 1 + break + case "success": + case "done": + counters.completed += 1 + break + case "incomplete": + counters.unannotated += 1 + break + case "failed": + case "failure": + case "error": + counters.failed += 1 + break + case "cancelled": + counters.cancelled += 1 + break + default: + counters.pending += 1 + } +} + +export const progressFamily = atomFamily( + (runId: string) => + eagerAtom((get) => { + const scenarios = get(evaluationRunStateFamily(runId)).scenarios || [] + const counters = emptyCounters() + + scenarios.forEach((s) => { + const statusLoadable = get( + loadable(scenarioStatusFamily({scenarioId: s.id, runId})), + ) + const status = + statusLoadable.state === "hasData" ? statusLoadable.data.status : "pending" + tallyStatus(counters, status) + }) + + const percentComplete = + counters.completed + counters.failed + counters.cancelled + counters.unannotated > 0 + ? Math.round((counters.completed / scenarios.length) * 100) + : 0 + + return { + total: scenarios.length, + pending: counters.pending, + inProgress: counters.running, + completed: counters.completed, + error: counters.failed, + cancelled: counters.cancelled, + percentComplete, + } + }), + deepEqual, +) + +export const loadingStateAtom = atomWithImmer(defaultLoadingState) + +// Run-scoped atom family to compute scenario step progress for displayedScenarioIds +export const scenarioStepProgressFamily = atomFamily( + (runId: string) => + atom((get) => { + const loadingStates = get(loadingStateAtom) + // If we're still fetching the evaluation or scenarios list, reflect that first + if ( + loadingStates.activeStep && + ["eval-run", "scenarios"].includes(loadingStates.activeStep) + ) { + return { + completed: 0, + total: 0, + percent: 0, + loadingStep: loadingStates.activeStep, + } + } + const loadableIds = get(loadable(displayedScenarioIdsFamily(runId))) + + if (loadableIds.state !== "hasData") { + return {completed: 0, total: 0, percent: 0, loadingStep: null} + } + const scenarioIds: string[] = Array.isArray(loadableIds.data) ? loadableIds.data : [] + const total = scenarioIds.length + + let completed = 0 + scenarioIds.forEach((scenarioId: string) => { + if (get(scenarioStepLocalFamily({runId, scenarioId}))) completed++ + }) + const percent = total > 0 ? Math.round((completed / total) * 100) : 0 + return { + completed, + total, + percent, + allStepsFetched: completed === total && total > 0, + loadingStep: completed < total ? "scenario-steps" : null, + } + }), + deepEqual, +) + +export const scenarioStatusFamily = atomFamily((params: {scenarioId: string; runId: string}) => { + return atom(async (get) => { + const data = await get(scenarioStepFamily(params)) + const evalType = get(evalTypeAtom) + + const invocationSteps: any[] = Array.isArray(data?.invocationSteps) + ? data.invocationSteps + : [] + const annotationSteps: any[] = Array.isArray(data?.annotationSteps) + ? data.annotationSteps + : [] + + const isRunning = + data?.invocationSteps.some((s) => s.status === "running") || + data?.annotationSteps.some((s) => s.status === "running") || + data?.inputSteps.some((s) => s.status === "running") + + const isAnnotating = data?.annotationSteps.some((s) => s.status === "annotating") + const isRevalidating = data?.annotationSteps.some((s) => s.status === "revalidating") + + // Determine scenario status based on step outcomes + let computedStatus = "pending" + const allInvSucceeded = + invocationSteps.length > 0 && invocationSteps.every((s) => s.status === "success") + const allAnnSucceeded = + annotationSteps.length > 0 && annotationSteps.every((s) => s.status === "success") + const anyFailed = + data?.invocationSteps.some((s) => s.status === "failure") || + data?.annotationSteps.some((s) => s.status === "failure") || + data?.inputSteps.some((s) => s.status === "failure") + + if (isRunning) { + computedStatus = "running" + } else if (isAnnotating) { + computedStatus = "annotating" + } else if (isRevalidating) { + computedStatus = "revalidating" + } else if (allAnnSucceeded) { + computedStatus = "success" + } else if (allInvSucceeded) { + // In auto eval we don't have any annotation steps for now + computedStatus = evalType === "auto" ? "success" : "incomplete" + } else if (anyFailed) { + computedStatus = "failure" + } else { + computedStatus = "pending" + } + + return { + status: computedStatus, + isAnnotating, + isRevalidating, + } + }) +}, deepEqual) + +export const scenarioStatusAtomFamily = atomFamily((params: {scenarioId: string; runId: string}) => + atom((get) => { + const loadableStatus = get(loadable(scenarioStatusFamily(params))) + return loadableStatus.state === "hasData" ? loadableStatus.data : {status: "pending"} + }), +) + +// Aggregate all scenario steps into a single object keyed by scenarioId (loadable) +// Convenience wrapper so components can safely read status without suspending +export const loadableScenarioStatusFamily = atomFamily( + (params: {scenarioId: string; runId: string}) => loadable(scenarioStatusFamily(params)), + deepEqual, +) + +// Lightweight UI flags derived from scenario status +export const scenarioUiFlagsFamily = atomFamily((params: {scenarioId: string; runId: string}) => { + return atom((get) => { + const statusLoadable = get(loadable(scenarioStatusFamily(params))) + if (statusLoadable.state !== "hasData") { + return {isAnnotating: false, isRevalidating: false} + } + const {isAnnotating, isRevalidating, status} = statusLoadable.data as any + return { + isAnnotating: isAnnotating ?? status === "annotating", + isRevalidating: isRevalidating ?? status === "revalidating", + } + }) +}, deepEqual) + +export const scenarioCountsFamily = atomFamily((runId: string) => { + return atom((get) => { + const ids = get(scenarioIdsFamily(runId)) + const c = emptyCounters() + for (const id of ids) { + const st = get(scenarioStatusAtomFamily({scenarioId: id, runId})) as any + tallyStatus(c, st?.status ?? "pending") + } + return { + total: ids.length, + pending: c.pending, + unannotated: c.unannotated, + failed: c.failed, + } + }) +}, deepEqual) + +// Run-scoped count atoms +export const pendingCountFamily = atomFamily((runId: string) => { + return selectAtom( + scenarioCountsFamily(runId), + (c) => c.pending, + deepEqual, + ) +}, deepEqual) + +export const unannotatedCountFamily = atomFamily((runId: string) => { + return selectAtom( + scenarioCountsFamily(runId), + (c) => c.unannotated, + deepEqual, + ) +}, deepEqual) + +export const failedCountFamily = atomFamily((runId: string) => { + return selectAtom( + scenarioCountsFamily(runId), + (c) => c.failed, + deepEqual, + ) +}, deepEqual) diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedAtoms.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedAtoms.ts new file mode 100644 index 0000000000..882e4b169f --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedAtoms.ts @@ -0,0 +1,105 @@ +import deepEqual from "fast-deep-equal" +import {atom} from "jotai" +import {atomFamily} from "jotai/utils" +import {atomWithImmer} from "jotai-immer" + +import {UseEvaluationRunScenarioStepsFetcherResult} from "../../../useEvaluationRunScenarioSteps/types" +import {EvaluationRunState} from "../../types" +import {initialState} from "../constants" +import type {BasicStats} from "../types" + +/** + * Run-scoped atom families + * + * These atoms replace the global atoms that were previously tied to a single "active" run. + * Each atom family is keyed by runId, allowing multiple evaluation runs to coexist + * without interfering with each other. + */ + +// Core evaluation run state - replaces global evaluationRunStateAtom +export const evaluationRunStateFamily = atomFamily((runId: string) => { + if (runId === undefined || runId === null || runId === "") { + console.error(`[evaluationRunStateFamily] ERROR: Invalid runId received: ${runId}`) + console.trace("Stack trace for invalid runId:") + } + return atomWithImmer(initialState) +}, deepEqual) + +// Bulk fetch status - replaces global bulkStepsStatusAtom +export const bulkStepsStatusFamily = atomFamily( + (runId: string) => atom<"idle" | "loading" | "done" | "error">("idle"), + deepEqual, +) + +// Bulk fetch cache - replaces global bulkStepsCacheAtom +export const bulkStepsCacheFamily = atomFamily( + (runId: string) => atom>(new Map()), + deepEqual, +) + +// Bulk fetch requested flag - for tracking if bulk fetch has been initiated +export const bulkStepsRequestedFamily = atomFamily((runId: string) => atom(false), deepEqual) + +// Bulk started flag - guard so init fires once per run +export const bulkStartedFamily = atomFamily((runId: string) => atom(false), deepEqual) + +// Derived atoms that depend on run state +export const enrichedRunFamily = atomFamily( + (runId: string) => atom((get) => get(evaluationRunStateFamily(runId)).enrichedRun), + deepEqual, +) + +export const runIndexFamily = atomFamily( + (runId: string) => atom((get) => get(evaluationRunStateFamily(runId)).runIndex), + deepEqual, +) + +export const evaluationRunIdFamily = atomFamily( + (runId: string) => + atom(() => { + // Use runId directly since it's the identifier we need + return runId + }), + deepEqual, +) + +// Loading state family - replaces global loadingStateAtom +export const loadingStateFamily = atomFamily( + (runId: string) => + atomWithImmer({ + isLoadingEvaluation: false, + isLoadingScenarios: false, + isLoadingMetrics: false, + activeStep: null as string | null, + }), + deepEqual, +) + +// Run-scoped metric atom families - replaces global metric atoms +export const runMetricsRefreshFamily = atomFamily((runId: string) => atom(0), deepEqual) + +export const runMetricsCacheFamily = atomFamily((runId: string) => atom([]), deepEqual) + +export const runMetricsStatsCacheFamily = atomFamily( + (runId: string) => atom>({}), + deepEqual, +) + +/** + * Helper type for accessing all run-scoped atoms for a specific run + */ +export interface RunScopedAtoms { + runId: string + evaluationRunState: ReturnType + bulkStepsStatus: ReturnType + bulkStepsCache: ReturnType + bulkStepsRequested: ReturnType + bulkStarted: ReturnType + enrichedRun: ReturnType + runIndex: ReturnType + evaluationRunId: ReturnType + loadingState: ReturnType + runMetricsRefresh: ReturnType + runMetricsCache: ReturnType + runMetricsStatsCache: ReturnType +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics.ts new file mode 100644 index 0000000000..ab029794bc --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedMetrics.ts @@ -0,0 +1,546 @@ +import deepEqual from "fast-deep-equal" +import {Atom} from "jotai" +import {atomFamily, selectAtom} from "jotai/utils" +import {eagerAtom} from "jotai-eager" + +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {BasicStats, canonicalizeMetricKey, getMetricValueWithAliases} from "@/oss/lib/metricUtils" +import {slugify} from "@/oss/lib/utils/slugify" +import {getJWT} from "@/oss/services/api" +import {getProjectValues} from "@/oss/state/project" + +import { + evaluationRunStateFamily, + loadingStateFamily, + runMetricsCacheFamily, + runMetricsRefreshFamily, + runMetricsStatsCacheFamily, +} from "./runScopedAtoms" +import {evalAtomStore} from "./store" + +// Re-export the atom families for external use +export {runMetricsCacheFamily, runMetricsStatsCacheFamily} + +import {fetchRunMetricsViaWorker} from "@/agenta-oss-common/lib/workers/evalRunner/runMetricsWorker" + +// Helper: flatten acc object and nested metrics similar to legacy mergedMetricsAtom +export function flattenMetrics(raw: Record): Record { + const flat: Record = {} + Object.entries(raw || {}).forEach(([k, v]) => { + if (k === "acc" && v && typeof v === "object") { + const acc: any = v + if (acc?.costs?.total !== undefined) flat.totalCost = acc.costs.total + if (acc?.duration?.total !== undefined) + flat["duration.total"] = Number((acc.duration.total / 1000).toFixed(6)) + if (acc?.tokens?.total !== undefined) flat.totalTokens = acc.tokens.total + if (acc?.tokens?.prompt !== undefined) flat.promptTokens = acc.tokens.prompt + if (acc?.tokens?.completion !== undefined) flat.completionTokens = acc.tokens.completion + } else if (v && typeof v === "object" && !Array.isArray(v)) { + Object.entries(v).forEach(([sub, sv]) => { + flat[`${k}.${sub}`] = sv + }) + } else { + flat[k] = v + } + }) + return flat +} + +// Deduplicate inflight requests per runId +const inFlight = new Map>() + +const runFetchMetrics = async ( + store: any, + runId: string, + evaluatorSlugs: string[] = [], + revisionSlugs: string[] = [], +) => { + if (inFlight.has(runId)) return inFlight.get(runId)! as Promise + + const promise = (async () => { + evalAtomStore().set(loadingStateFamily(runId), (draft) => { + draft.isLoadingMetrics = true + }) + try { + const apiUrl = getAgentaApiUrl() + const jwt = await getJWT() + const proj = getProjectValues() as any + const projectId = (proj?.id ?? proj?.projectId ?? "") as string + + if (!projectId || !jwt || !apiUrl) { + console.error(`[runScopedMetrics] Missing context for runId: ${runId}`, { + hasProjectId: !!projectId, + hasJwt: !!jwt, + hasApiUrl: !!apiUrl, + }) + throw new Error("Project ID, JWT or API URL not found") + } + + const {metrics, stats} = await fetchRunMetricsViaWorker(runId, { + apiUrl, + jwt, + projectId, + evaluatorSlugs, + revisionSlugs, + }) + + const scenarioMetrics = Array.isArray(metrics) ? metrics : [] + + // Update run-scoped cache atoms + store.set(runMetricsCacheFamily(runId), scenarioMetrics) + store.set(runMetricsStatsCacheFamily(runId), stats || {}) + + // Reset refresh counter back to 0 + store.set(runMetricsRefreshFamily(runId), 0) + } catch (err) { + console.error(`[runScopedMetrics] Error fetching metrics for runId: ${runId}:`, err) + } finally { + inFlight.delete(runId) // cleanup + evalAtomStore().set(loadingStateFamily(runId), (draft) => { + draft.isLoadingMetrics = false + }) + } + })() + inFlight.set(runId, promise) + return promise +} + +// Run-scoped metrics atom family that fetches metrics for a specific runId +export const runMetricsFamily = atomFamily>((runId: string) => { + return eagerAtom((get) => { + if (!runId) { + return [] + } + + // Depend on refresh signal + const refresh = get(runMetricsRefreshFamily(runId)) + + const cached = get(runMetricsCacheFamily(runId)) + + // Normal path: no refresh requested + if (refresh === 0) { + return cached || [] + } + + // Refresh requested (stale-while-revalidate) + if (cached && cached.length > 0) { + // Kick off background revalidation if not already running + if (!inFlight.has(runId)) { + runFetchMetrics(evalAtomStore(), runId) + } + return cached // serve stale data while revalidating + } + + // No cached data → start background fetch and return empty list (no suspense) + if (!inFlight.has(runId)) { + const state = get(evaluationRunStateFamily(runId)) + const evaluators = state?.enrichedRun?.evaluators + if (!evaluators) return [] + + // Handle both array and object formats + const evaluatorsList = Array.isArray(evaluators) + ? evaluators + : Object.values(evaluators) + + const evaluatorSlugs = evaluatorsList.map((ev: any) => ev.slug || ev.id || ev.name) + + const revisions = state?.enrichedRun?.variants + const revisionSlugs = revisions + ? revisions.map((v: any) => slugify(v.name, v.id)) + : // ? revisions.map((v: any) => slugify("comp-1", v.id)) + [] + + const p = runFetchMetrics(evalAtomStore(), runId, evaluatorSlugs, revisionSlugs) + inFlight.set(runId, p) + } + return [] + }) +}, deepEqual) + +// Run-scoped scenario metrics map atom family +const scenarioMetricsCache = new WeakMap>>() + +const normalizeStatValue = (value: any) => { + if (!value || typeof value !== "object" || Array.isArray(value)) return value + const next: any = {...value} + + if (Array.isArray(next.freq)) { + next.frequency = next.freq + delete next.freq + } + if (Array.isArray(next.uniq)) { + next.unique = next.uniq + delete next.uniq + } + + if (Array.isArray(next.frequency)) { + next.frequency = next.frequency.map((entry: any) => ({ + value: entry?.value, + count: entry?.count ?? entry?.frequency ?? 0, + })) + + const sorted = [...next.frequency].sort( + (a, b) => b.count - a.count || (a.value === true ? -1 : 1), + ) + next.rank = sorted + if (!Array.isArray(next.unique) || !next.unique.length) { + next.unique = sorted.map((entry) => entry.value) + } + } else if (Array.isArray(next.rank)) { + next.rank = next.rank.map((entry: any) => ({ + value: entry?.value, + count: entry?.count ?? entry?.frequency ?? 0, + })) + } + + return next +} + +export const scenarioMetricsMapFamily = atomFamily< + string, + Atom>> +>((runId: string) => { + return eagerAtom>>((get) => { + // Explicitly depend on refresh signal to ensure reactivity + const refresh = get(runMetricsRefreshFamily(runId)) + + const arr = get(runMetricsFamily(runId)) as any[] + + if (!arr) { + return {} + } + + const cached = scenarioMetricsCache.get(arr) + if (cached && refresh === 0) { + return cached + } + + const map: Record> = {} + arr.forEach((entry: any, index: number) => { + const sid = entry?.scenarioId || entry?.scenario_id || entry?.scenarioID || entry?.id + if (!sid) { + return + } + // The data might already be processed/flattened or still nested + const rawData = entry?.data || {} + + // Check if data is already flat (has direct metric values) or nested (has variant objects) + const firstNonEmptyKey = Object.keys(rawData).find((key) => { + const value = rawData[key] + return ( + value !== null && + value !== undefined && + (typeof value === "object" ? Object.keys(value).length > 0 : true) + ) + }) + + // If you want the first non-empty value: + const firstValue = firstNonEmptyKey ? rawData[firstNonEmptyKey] : undefined + const isAlreadyFlat = + typeof firstValue === "number" || + (typeof firstValue === "object" && (firstValue?.mean || firstValue?.unique)) + + if (isAlreadyFlat) { + // Data is already flat, ensure canonical aliases are present + const normalized: Record = {...rawData} + Object.keys(rawData).forEach((rawKey) => { + normalized[rawKey] = normalizeStatValue(normalized[rawKey]) + const canonical = canonicalizeMetricKey(rawKey) + if (canonical !== rawKey && normalized[canonical] === undefined) { + normalized[canonical] = normalizeStatValue(rawData[rawKey]) + } + }) + map[String(sid)] = normalized + } else { + // Data is nested, process it + const processedData: Record = {} + + // Extract metrics from all variants (usually just one) + Object.values(rawData).forEach((variantData: any) => { + if (variantData && typeof variantData === "object") { + Object.entries(variantData).forEach( + ([metricKey, metricValue]: [string, any]) => { + // Extract the mean value from metric objects like {"mean": 0.000059} + const value = metricValue?.mean ?? metricValue + + // Apply key mapping for common metrics + let mappedKey = metricKey + if (metricKey === "costs.total") mappedKey = "totalCost" + else if (metricKey === "tokens.total") mappedKey = "totalTokens" + else if (metricKey === "tokens.prompt") mappedKey = "promptTokens" + else if (metricKey === "tokens.completion") + mappedKey = "completionTokens" + + const canonical = canonicalizeMetricKey(mappedKey) + processedData[mappedKey] = normalizeStatValue(value) + if (canonical !== mappedKey) { + processedData[canonical] = processedData[canonical] ?? value + } + }, + ) + } + }) + + map[String(sid)] = processedData + } + }) + + scenarioMetricsCache.set(arr, map) + return map + }) +}, deepEqual) + +/** + * Run-scoped scenario metrics selector + * Returns a single metric primitive for a given scenario without triggering wide re-renders. + * Specialized for the case where you only need a single metric value. like table cells + */ +export const scenarioMetricSelectorFamily = atomFamily< + {runId: string; scenarioId: string}, + Atom>> +>(({runId, scenarioId}) => { + return selectAtom(scenarioMetricsMapFamily(runId), (s) => s?.[scenarioId], deepEqual) +}, deepEqual) + +/** + * Run-scoped single metric value selector + * Mirrors the legacy scenarioMetricValueFamily but adds runId and optional stepSlug support. + * Returns a single metric primitive for a given scenario without triggering wide re-renders. + */ +export const scenarioMetricValueFamily = atomFamily( + ({ + runId, + scenarioId, + metricKey, + stepSlug, + }: { + runId: string + scenarioId: string + metricKey: string + stepSlug?: string + }) => + selectAtom( + scenarioMetricsMapFamily(runId), + (map) => { + const metrics = map?.[scenarioId] || {} + + const buildCandidateKeys = (base: string): string[] => { + const candidates: string[] = [] + const push = (candidate?: string) => { + if (!candidate) return + if (candidates.includes(candidate)) return + candidates.push(candidate) + } + + push(base) + const slug = stepSlug || base.split(".")[0] + const withoutSlug = + slug && base.startsWith(`${slug}.`) ? base.slice(slug.length + 1) : base + + if (slug) { + push(`${slug}.${withoutSlug}`) + push(`${slug}.attributes.ag.data.outputs.${withoutSlug}`) + push(`${slug}.attributes.ag.metrics.${withoutSlug}`) + } + + push(`attributes.ag.data.outputs.${withoutSlug}`) + push(`attributes.ag.metrics.${withoutSlug}`) + + return candidates + } + + const needsPrefix = Boolean(stepSlug && !metricKey.startsWith(`${stepSlug}.`)) + const key = needsPrefix ? `${stepSlug}.${metricKey}` : metricKey + const candidateKeys = Array.from( + new Set([...buildCandidateKeys(metricKey), ...buildCandidateKeys(key)]), + ) + + for (const candidate of candidateKeys) { + const resolved = getMetricValueWithAliases(metrics, candidate) + if (resolved !== undefined) return resolved + } + return undefined + }, + deepEqual, + ), +) + +// Helper function to trigger metric fetch for a specific runId +export const triggerMetricsFetch = (targetRunId: string) => { + const store = evalAtomStore() + store.set(runMetricsRefreshFamily(targetRunId), (prev) => prev + 1) +} + +/** + * Run-scoped metrics prefetch attachment + * This replaces the legacy attachRunMetricsPrefetch for multi-run support + */ +export function attachRunMetricsPrefetchForRun( + runId: string, + store: ReturnType, +) { + const fetched = new Set() + + // Subscribe to changes in evaluation run state for this specific run + const unsubscribe = store.sub(evaluationRunStateFamily(runId), () => { + const state = store.get(evaluationRunStateFamily(runId)) + const currentRunId = runId + + if (!currentRunId) { + return + } + + if (!state?.enrichedRun?.evaluators) { + return // wait until evaluators are loaded + } + + // Check if metrics are already cached using the actual currentRunId + const cached = store.get(runMetricsCacheFamily(currentRunId)) + if (cached && cached.length > 0) { + if (!fetched.has(currentRunId)) { + fetched.add(currentRunId) // Mark as fetched since cache exists + } + return + } + + // Check if we're already in the process of fetching + if (fetched.has(currentRunId)) { + return + } + + fetched.add(currentRunId) + + // Trigger metrics fetch for the actual currentRunId + triggerMetricsFetch(currentRunId) + }) + + return unsubscribe +} + +/** + * Run-scoped metric data family + * This replaces the legacy metricDataFamily for multi-run support + * Returns { value, distInfo } for a specific metric key on a scenario within a run + */ +export const runScopedMetricDataFamily = atomFamily( + ({ + runId, + scenarioId, + stepSlug, + metricKey, + }: { + runId: string + scenarioId: string + stepSlug?: string + metricKey: string + }) => + eagerAtom<{value: any; distInfo?: any}>((get) => { + // Get the scenario metrics map for this run + const scenarioMetricsMap = get(scenarioMetricsMapFamily(runId)) + // Get the metrics for this specific scenario + const scenarioMetrics = scenarioMetricsMap[scenarioId] + + if (!scenarioMetrics) { + return {value: undefined, distInfo: undefined} + } + + const metricPath = stepSlug ? `${stepSlug}.${metricKey}` : metricKey + + const buildCandidateKeys = (base: string): string[] => { + const candidates: string[] = [] + const push = (candidate?: string) => { + if (!candidate) return + if (candidates.includes(candidate)) return + candidates.push(candidate) + } + + push(base) + + const slug = stepSlug || base.split(".")[0] + const withoutSlug = + slug && base.startsWith(`${slug}.`) ? base.slice(slug.length + 1) : base + + if (slug) { + push(`${slug}.${withoutSlug}`) + push(`${slug}.attributes.ag.data.outputs.${withoutSlug}`) + push(`${slug}.attributes.ag.metrics.${withoutSlug}`) + } + + push(`attributes.ag.data.outputs.${withoutSlug}`) + push(`attributes.ag.metrics.${withoutSlug}`) + + return candidates + } + + const candidateKeys = Array.from( + new Set([...buildCandidateKeys(metricKey), ...buildCandidateKeys(metricPath)]), + ) + + const resolveFromSource = (source?: Record) => { + if (!source) return undefined + for (const candidate of candidateKeys) { + const resolved = getMetricValueWithAliases(source, candidate) + if (resolved !== undefined) return resolved + } + return undefined + } + + const value = resolveFromSource(scenarioMetrics) + + // Get distribution info from stats cache (if available) + const statsCache = get(runMetricsStatsCacheFamily(runId)) + const distInfo = resolveFromSource(statsCache) + + return {value, distInfo} + }), +) + +// Cache for computed stats maps (adds binSize lazily) to preserve identity per raw object +const computedStatsCache = new WeakMap, Record>() + +// Atom family to read the entire stats map for a run, lazily adding binSize per entry. +// IMPORTANT: It also subscribes to runMetricsFamily(runId) to ensure that refresh triggers +// fetching even when only stats are being read by the UI. +export const runMetricStatsFamily = atomFamily( + ({runId}: {runId: string}) => + eagerAtom>((get) => { + // Wire up to metrics array to drive fetching on refresh + // This ensures that setting runMetricsRefreshFamily(runId) will cause + // runMetricsFamily(runId) to evaluate and kick off the background fetch. + // We ignore its value here and continue to return the stats map. + get(runMetricsFamily(runId)) + + const obj = get(runMetricsStatsCacheFamily(runId)) as Record + if (!obj) return obj + + const cached = computedStatsCache.get(obj) + if (cached) return cached + + let mutated = false + const result: Record = {} + for (const [key, s] of Object.entries(obj)) { + if ( + s && + (s as any).binSize === undefined && + (s as any).distribution && + (s as any).distribution.length + ) { + const bins = (s as any).distribution.length + const range = ((s as any).max ?? 0) - ((s as any).min ?? 0) + result[key] = { + ...(s as any), + binSize: bins ? (range !== 0 ? range / bins : 1) : 1, + } as BasicStats + mutated = true + } else { + result[key] = s as BasicStats + } + } + + const finalMap = mutated ? result : obj + // memoize for this raw object identity + computedStatsCache.set(obj, finalMap) + return finalMap + }), + deepEqual, +) diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedScenarios.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedScenarios.ts new file mode 100644 index 0000000000..61a77a362d --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/runScopedScenarios.ts @@ -0,0 +1,376 @@ +import deepEqual from "fast-deep-equal" +import {Atom, atom} from "jotai" +import {atomFamily, loadable} from "jotai/utils" +import {Loadable} from "jotai/vanilla/utils/loadable" +import {atomWithImmer} from "jotai-immer" + +import {urlStateAtom} from "@/oss/components/EvalRunDetails/state/urlState" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import { + evalAtomStore, + evalScenarioFilterAtom, +} from "@/oss/lib/hooks/useEvaluationRunData/assets/atoms" +import {getJWT} from "@/oss/services/api" +import {getProjectValues} from "@/oss/state/project" + +import {UseEvaluationRunScenarioStepsFetcherResult} from "../../../useEvaluationRunScenarioSteps/types" +import {fetchScenarioListViaWorker} from "../helpers/fetchScenarioListViaWorker" +import {fetchScenarioViaWorkerAndCache} from "../helpers/fetchScenarioViaWorker" + +import {scenarioStatusAtomFamily} from "./progress" +import {bulkStepsStatusFamily, enrichedRunFamily, evaluationRunStateFamily} from "./runScopedAtoms" + +/** + * Run-scoped scenario atoms + * + * These atoms replace the global scenario atoms and are scoped to specific evaluation runs. + * Each atom family is keyed by runId, allowing multiple evaluation runs to have + * independent scenario state. + */ + +// Atom family to force refetch of scenario steps - now scoped by runId +export const scenarioStepRefreshFamily = atomFamily( + (params: {runId: string; scenarioId: string}) => atom(0), + deepEqual, +) + +// Per-scenario local cache that can be mutated independently - now scoped by runId +export const scenarioStepLocalFamily = atomFamily( + (params: {runId: string; scenarioId: string}) => + atomWithImmer({}), + deepEqual, +) + +// Deduplicate in-flight fetches for scenario steps - now per runId +const scenarioStepInFlightMap = new Map>>() + +export const scenarioStepFamily = atomFamily< + {runId: string; scenarioId: string}, + Atom> +>((params) => { + const {runId, scenarioId} = params + return atom(async (get): Promise => { + // Depend on refresh version so that incrementing it triggers refetch + const refresh = get(scenarioStepRefreshFamily(params)) + + // Access data directly from run-scoped atom instead of derived atoms + const runState = get(evaluationRunStateFamily(runId)) + const evaluation = runState?.enrichedRun + const runIndex = runState?.runIndex + + const testsetData = evaluation?.testsets?.[0] + if (!runId || !evaluation || !testsetData || !runIndex) { + console.warn( + `[scenarioStepFamily] Missing runId/evaluation/testsetData for ${scenarioId}`, + ) + return undefined + } + + // Wait if bulk fetch in-flight to avoid duplicate per-scenario fetches + const status = get(bulkStepsStatusFamily(runId)) + if (status === "loading") { + while (get(bulkStepsStatusFamily(runId)) === "loading") { + await new Promise((r) => setTimeout(r, 16)) + } + } + + const fetchParams = { + runId, + evaluation, + runIndex, + } + + // Get or create in-flight map for this runId + if (!scenarioStepInFlightMap.has(runId)) { + scenarioStepInFlightMap.set(runId, new Map()) + } + const inFlightMap = scenarioStepInFlightMap.get(runId)! + + // Local cached value first + const local = get(scenarioStepLocalFamily(params)) + if (local && Object.keys(local).length > 0) { + if (refresh > 0 && !inFlightMap.has(scenarioId)) { + const bgPromise = (async () => { + await fetchScenarioViaWorkerAndCache(fetchParams, [scenarioId]) + evalAtomStore().set(scenarioStepRefreshFamily(params), 0) + })() + inFlightMap.set(scenarioId, bgPromise) + bgPromise.finally(() => inFlightMap.delete(scenarioId)) + } + return local + } + + // Fallback to bulk cache - return undefined if not cached + return undefined + }) +}, deepEqual) + +// Loadable version of scenario step family - scoped by runId +export const loadableScenarioStepFamily = atomFamily( + (params: {runId: string; scenarioId: string}) => loadable(scenarioStepFamily(params)), + deepEqual, +) + +// Scenarios atom - scoped by runId +export const scenariosFamily = atomFamily( + (runId: string) => + atom((get) => { + const state = get(evaluationRunStateFamily(runId)) + const scenarios = state.scenarios || [] + return scenarios + }), + deepEqual, +) + +// Scenario IDs atom - scoped by runId +export const scenarioIdsFamily = atomFamily( + (runId: string) => + atom((get) => { + const scenarios = get(scenariosFamily(runId)) + return scenarios.map((s) => s.id) + }), + deepEqual, +) + +// Total count atom - scoped by runId +export const totalCountFamily = atomFamily( + (runId: string) => + atom((get) => { + const scenarios = get(scenariosFamily(runId)) + return scenarios.length + }), + deepEqual, +) + +// Scenario steps atom - aggregates all scenario steps for a run +export const scenarioStepsFamily = atomFamily( + (runId: string) => + atom((get) => { + const scenarioIds = get(scenarioIdsFamily(runId)) + const stepsMap: Record< + string, + Loadable + > = {} + + scenarioIds.forEach((scenarioId) => { + stepsMap[scenarioId] = get(loadableScenarioStepFamily({runId, scenarioId})) + }) + + return stepsMap + }), + deepEqual, +) + +// Displayed scenario IDs with filtering - scoped by runId +export const displayedScenarioIdsFamily = atomFamily( + (runId: string) => + atom((get) => { + const scenarios = get(scenariosFamily(runId)) + const scenarioIds = scenarios.map((s: any) => s.id || s._id) + + // Get the current filter value from the global filter atom + // Note: evalScenarioFilterAtom is global but that's OK since filter preference is shared across runs + const filter = get(evalScenarioFilterAtom) + + // If filter is "all", return all scenarios + if (filter === "all") { + return scenarioIds + } + + // Filter scenarios based on their status + const filteredScenarioIds = scenarioIds.filter((scenarioId: string) => { + const statusData = get(scenarioStatusAtomFamily({scenarioId, runId})) + const status = statusData?.status || "pending" + + switch (filter) { + case "pending": + return status === "pending" || status === "revalidating" + case "unannotated": + return status === "incomplete" + case "failed": + return status === "failure" + default: + return true + } + }) + return filteredScenarioIds + }), + deepEqual, +) + +/** + * Helper functions for run-scoped scenario operations + */ + +// Revalidate scenario function - now requires runId +export async function revalidateScenarioForRun( + runId: string, + scenarioId: string, + store: ReturnType, + updatedSteps?: UseEvaluationRunScenarioStepsFetcherResult["steps"], +) { + // Apply optimistic override if requested + if (updatedSteps) { + // Apply optimistic updates to maintain continuous loading state + store.set(scenarioStepLocalFamily({runId, scenarioId}), (draft: any) => { + if (!draft) return draft + updatedSteps.forEach((updatedStep) => { + const targetStep = + draft.invocationSteps?.find((s: any) => s.stepKey === updatedStep.stepKey) || + draft.inputSteps?.find((s: any) => s.stepKey === updatedStep.stepKey) || + draft.annotationSteps?.find((s: any) => s.stepKey === updatedStep.stepKey) + if (!targetStep) return + // Merge updated step data + Object.entries(updatedStep).forEach(([k, v]) => { + // @ts-ignore – dynamic merge + targetStep[k] = v as any + }) + }) + return draft + }) + } + + // Bump refresh counter so the specific scenario refetches + try { + store.set(scenarioStepRefreshFamily({runId, scenarioId}), (v = 0) => v + 1) + } catch (err) { + console.error("[atoms] failed to bump scenario refresh counter", err) + } + + // Return a promise that resolves when the refreshed data is available + return store.get(scenarioStepFamily({runId, scenarioId})) +} + +// Bulk prefetch function for run-scoped scenarios +export function attachBulkPrefetchForRun( + runId: string, + store: ReturnType, +) { + // Subscribe to changes in displayed scenario IDs for this specific run + const unsubscribe = store.sub(displayedScenarioIdsFamily(runId), () => { + const scenarioIds = store.get(displayedScenarioIdsFamily(runId)) + if (scenarioIds.length > 0) { + // Trigger bulk fetch for this specific run + // The bulk fetch logic should work with run-scoped atoms + try { + // Import the bulk fetch function + import("./bulkFetch").then(({runBulkFetch}) => { + runBulkFetch(store, runId, scenarioIds) + }) + } catch (error) { + console.error( + `attachBulkPrefetchForRun: Error triggering bulk fetch for ${runId.slice(0, 8)}:`, + error, + ) + } + } + }) + + return unsubscribe +} + +// Scenario list prefetch function for run-scoped scenarios +// This fetches the scenarios for a run when the enriched run becomes available +export function attachScenarioListPrefetchForRun( + runId: string, + store: ReturnType, +) { + // Subscribe to changes in enriched run for this specific run + const unsubscribe = store.sub(enrichedRunFamily(runId), () => { + const enrichedRun = store.get(enrichedRunFamily(runId)) + const currentScenarios = store.get(scenariosFamily(runId)) + + // Only fetch scenarios if we have an enriched run but no scenarios yet + if (enrichedRun && currentScenarios.length === 0) { + const fetchScenarios = async () => { + try { + const {projectId} = getProjectValues() + const apiUrl = getAgentaApiUrl() + const jwt = await getJWT() + + if (!jwt) { + console.warn( + `[attachScenarioListPrefetchForRun] No JWT available for ${runId}`, + ) + return + } + + const scenarios = await fetchScenarioListViaWorker({ + apiUrl, + jwt, + projectId, + runId, + }) + store.set(evaluationRunStateFamily(runId), (draft: any) => { + draft.scenarios = scenarios.map((s, idx) => ({ + ...s, + scenarioIndex: idx + 1, + })) + }) + } catch (error) { + console.error( + `[attachScenarioListPrefetchForRun] Error fetching scenarios for ${runId}:`, + error, + ) + } + } + + fetchScenarios() + } + }) + + return unsubscribe +} + +// Neighbor prefetch function for run-scoped scenarios +export function attachNeighbourPrefetchForRun( + runId: string, + store: ReturnType, +) { + let lastScenarioId: string | null = null + let latestUrl = store.get(urlStateAtom) + let latestIds = store.get(scenarioIdsFamily(runId)) + + const maybePrefetch = () => { + const {view, scenarioId} = latestUrl + if (view !== "focus" || !scenarioId) return + if (!latestIds.length) return + if (scenarioId === lastScenarioId) return + + const idx = latestIds.indexOf(scenarioId) + if (idx === -1) return + + lastScenarioId = scenarioId + const neighbours = latestIds.filter((_, i) => Math.abs(i - idx) === 1) + const allIds = [scenarioId, ...neighbours] + const toFetch = allIds.filter( + (id) => !store.get(scenarioStepLocalFamily({runId, scenarioId: id})), + ) + if (!toFetch.length) { + return + } + + // Import and use run-scoped bulk fetch + import("./bulkFetch").then(({runBulkFetch}) => { + runBulkFetch(store, runId, toFetch, {force: true}) + }) + } + + // Subscribe to URL changes + const unsubscribeUrl = store.sub(urlStateAtom, () => { + latestUrl = store.get(urlStateAtom) + maybePrefetch() + }) + + // Subscribe to scenario IDs availability/changes for this specific run + const unsubscribeScenarios = store.sub(scenarioIdsFamily(runId), () => { + latestIds = store.get(scenarioIdsFamily(runId)) + maybePrefetch() + }) + + // Return cleanup function that unsubscribes from both subscriptions + return () => { + unsubscribeUrl() + unsubscribeScenarios() + } +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/store.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/store.ts new file mode 100644 index 0000000000..bb2014fd41 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/store.ts @@ -0,0 +1,74 @@ +import {createStore, getDefaultStore} from "jotai" + +import {attachRunMetricsPrefetchForRun} from "./runScopedMetrics" +import { + attachBulkPrefetchForRun, + attachNeighbourPrefetchForRun, + attachScenarioListPrefetchForRun, +} from "./runScopedScenarios" + +/** + * Single global Jotai store for all evaluation runs. + * Uses run-scoped atom families instead of multiple stores. + * This is the proper Jotai pattern for multi-entity state management. + */ +const globalStoreKey = "__agenta_globalEvalStore__" + +// Create or retrieve the single global store +function createGlobalStore() { + const store = getDefaultStore() + + return store +} + +// Global singleton store that persists across HMR +const globalStore: ReturnType = + (globalThis as any)[globalStoreKey] || createGlobalStore() +;(globalThis as any)[globalStoreKey] = globalStore + +// Track which runs have been initialized to avoid duplicate subscriptions +const initializedRuns = new Set() + +/** + * Returns the single global Jotai store. + * All evaluation runs use the same store with run-scoped atom families. + */ +export function evalAtomStore(): ReturnType { + return getDefaultStore() +} + +/** + * Initialize a run in the global store. + * This ensures that run-scoped atoms are properly set up for the given runId. + * Sets up run-specific subscriptions for prefetching. + */ +export function initializeRun(runId: string): void { + if (!runId) { + console.warn("[initializeRun] No runId provided") + return + } + + // Avoid duplicate initialization + if (initializedRuns.has(runId)) { + return + } + + // Mark as initialized + initializedRuns.add(runId) + + // Set up run-specific subscriptions for prefetching + // These will work with run-scoped atom families + try { + // Attach scenario list prefetch to fetch scenarios when enriched run is available + attachScenarioListPrefetchForRun(runId, globalStore) + attachBulkPrefetchForRun(runId, globalStore) + attachNeighbourPrefetchForRun(runId, globalStore) + + // Attach metrics prefetch to auto-fetch metrics when evaluators are available + attachRunMetricsPrefetchForRun(runId, globalStore) + } catch (error) { + console.error(`[initializeRun] Error setting up subscriptions for ${runId}:`, error) + // Remove from initialized set if setup failed + initializedRuns.delete(runId) + } +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/types.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/types.ts new file mode 100644 index 0000000000..bf9f8b7124 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/types.ts @@ -0,0 +1,16 @@ +// Aggregated scenario counts used in filters +export interface ScenarioCounts { + total: number + pending: number + unannotated: number + failed: number +} + +export interface StatusCounters { + pending: number + running: number + completed: number + cancelled: number + unannotated: number + failed: number +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/utils.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/utils.ts new file mode 100644 index 0000000000..609da22ebe --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/atoms/utils.ts @@ -0,0 +1,24 @@ +import deepEqual from "fast-deep-equal" +import {Atom, atom} from "jotai" +import {atomFamily} from "jotai/utils" + +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" + +import {evaluationRunStateFamily} from "./runScopedAtoms" + +type HumanEvalViewTypes = "focus" | "list" | "table" | "results" +type AutoEvalViewTypes = "overview" | "test-cases" | "prompt" + +// UI atom to track current scenario view type ("focus" or "table") +// export const runViewTypeAtom = atom("focus") + +export const evaluationEvaluatorsFamily = atomFamily( + (runId: string) => + atom((get) => get(evaluationRunStateFamily(runId)).enrichedRun?.evaluators) as Atom< + EvaluatorDto[] + >, + deepEqual, +) + +export type ScenarioFilter = "all" | "pending" | "unannotated" | "failed" +export const evalScenarioFilterAtom = atom("all") diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/constants.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/constants.ts new file mode 100644 index 0000000000..4aed7b3526 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/constants.ts @@ -0,0 +1,25 @@ +import type {EvaluationLoadingState, EvaluationRunState, IStatusMeta} from "../types" + +export const initialState: EvaluationRunState = { + rawRun: undefined, + isPreview: undefined, + enrichedRun: undefined, + isComparison: false, + isBase: false, + compareIndex: undefined, + scenarios: undefined, + statusMeta: {} as IStatusMeta, + steps: undefined, + metrics: undefined, + isLoading: {run: false, scenarios: false, steps: false, metrics: false}, + isError: {run: false, scenarios: false, steps: false, metrics: false}, +} + +export const defaultLoadingState: EvaluationLoadingState = { + isLoadingEvaluation: true, + isLoadingScenarios: false, + isLoadingSteps: false, + isLoadingMetrics: false, + activeStep: null, + scenarioStepProgress: {completed: 0, total: 0, percent: 0}, +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex.ts new file mode 100644 index 0000000000..4574658171 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex.ts @@ -0,0 +1,124 @@ +/** + * Step roles we care about in the evaluation workflow. + */ +export type StepKind = "input" | "invocation" | "annotation" + +/** Mapping entry for a single column extracted from a step */ +export interface ColumnDef { + /** Column (human-readable) name e.g. "country" or "outputs" */ + name: string + /** "input" | "invocation" | "annotation" */ + kind: StepKind + /** Optional evaluator metric primitive type ("number", "boolean", etc.) */ + metricType?: string + /** Dot-path used to resolve the value inside the owning step payload / testcase */ + path: string + /** Key of the step that owns this column */ + stepKey: string + /** Unique column key used by UI tables */ + key?: string +} + +/** Metadata we store per step key */ +export interface StepMeta { + key: string + kind: StepKind + /** List of upstream step keys declared in `inputs` */ + upstream: string[] + /** Raw references blob – may contain application, evaluator, etc. */ + refs: Record +} + +export interface RunIndex { + /** Map stepKey -> meta */ + steps: Record + /** Map stepKey -> array of ColumnDefs */ + columnsByStep: Record + /** Convenience sets for quick lookup */ + invocationKeys: Set + annotationKeys: Set + inputKeys: Set +} + +/** + * Build a ready-to-use index for an evaluation run. + * Call this **once** right after fetching the raw run and cache the result. + * The index can then be shared by single-scenario and bulk fetchers. + */ +export function buildRunIndex(rawRun: any): RunIndex { + const steps: Record = {} + const columnsByStep: Record = {} + + // Build evaluator slug->key set later + const evaluatorSlugToId = new Map() + + // 1️⃣ Index steps ------------------------------------------------------- + for (const s of rawRun?.data?.steps ?? []) { + let kind: StepKind = "annotation" + if (s.references?.testset) { + kind = "input" + } else if (s.references?.applicationRevision || s.references?.application) { + kind = "invocation" + } else if (s.references?.evaluator) { + kind = "annotation" + if (s.references.evaluator.slug) { + evaluatorSlugToId.set(s.references.evaluator.slug, s.references.evaluator.id) + } + } + + steps[s.key] = { + key: s.key, + kind, + upstream: (s.inputs ?? []).map((i: any) => i.key), + refs: s.references ?? {}, + } + } + + // 2️⃣ Group column defs by step --------------------------------------- + for (const m of rawRun?.data?.mappings ?? []) { + const colKind: StepKind = + m.column.kind === "testset" + ? "input" + : m.column.kind === "invocation" + ? "invocation" + : "annotation" + const col: ColumnDef = { + name: m.column.name, + kind: colKind, + path: m.step.path, + stepKey: m.step.key, + } + ;(columnsByStep[col.stepKey] ||= []).push(col) + } + + // 3️⃣ Precompute key sets by role ---------------------- + const invocationKeys = new Set() + const annotationKeys = new Set() + const inputKeys = new Set() + + for (const meta of Object.values(steps)) { + if (meta.kind === "invocation") invocationKeys.add(meta.key) + if (meta.kind === "annotation") annotationKeys.add(meta.key) + if (meta.kind === "input") inputKeys.add(meta.key) + } + + return {steps, columnsByStep, invocationKeys, annotationKeys, inputKeys} +} + +export function serializeRunIndex(idx: RunIndex) { + return { + ...idx, + invocationKeys: [...idx.invocationKeys], + annotationKeys: [...idx.annotationKeys], + inputKeys: [...idx.inputKeys], + } +} + +export function deserializeRunIndex(idx: any): RunIndex { + return { + ...idx, + invocationKeys: new Set(idx.invocationKeys), + annotationKeys: new Set(idx.annotationKeys), + inputKeys: new Set(idx.inputKeys), + } +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/fetchScenarioListViaWorker.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/fetchScenarioListViaWorker.ts new file mode 100644 index 0000000000..026e609c2b --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/fetchScenarioListViaWorker.ts @@ -0,0 +1,48 @@ +import {v4 as uuid} from "uuid" + +import type {IScenario} from "@/oss/lib/hooks/useEvaluationRunScenarios/types" + +// Dynamically imported to avoid main bundle weight +let _worker: Worker | null = null +function getWorker() { + if (!_worker) { + _worker = new Worker( + new URL("@/oss/lib/workers/evalRunner/scenarioListWorker.ts", import.meta.url), + { + type: "module", + }, + ) + } + return _worker +} + +interface Params { + apiUrl: string + jwt: string + projectId: string + runId: string +} + +export async function fetchScenarioListViaWorker( + params: Params, + timeoutMs = 120000, +): Promise { + const worker = getWorker() + const requestId = uuid() + return new Promise((resolve, reject) => { + const handle = (e: MessageEvent) => { + const {requestId: rid, ok, data, error} = e.data + if (rid !== requestId) return + worker.removeEventListener("message", handle) + clearTimeout(timer) + if (ok) resolve(data as IScenario[]) + else reject(new Error(error)) + } + worker.addEventListener("message", handle) + const timer = setTimeout(() => { + worker.removeEventListener("message", handle) + reject(new Error("scenario list worker timeout")) + }, timeoutMs) + worker.postMessage({requestId, payload: params}) + }) +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/fetchScenarioViaWorker.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/fetchScenarioViaWorker.ts new file mode 100644 index 0000000000..667b7858ba --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/fetchScenarioViaWorker.ts @@ -0,0 +1,184 @@ +import {UseEvaluationRunScenarioStepsFetcherResult} from "../../../useEvaluationRunScenarioSteps/types" +import {evalAtomStore} from "../atoms" + +import {buildAuthContext, buildEvalWorkerContext} from "./workerContext" + +/** + * Fetch one or more scenarios' steps via the Web-Worker in bulk and cache the + * results inside `bulkStepsCacheAtom`. + * + * The helper returns the `Map` produced by the worker where each key is a + * `scenarioId` and the value is the enriched steps result for that scenario. + * If the worker fails to return data for a given scenario the entry will be + * missing from the map – callers should handle that case by falling back to a + * direct network request. + */ +// Deduplication cache to prevent multiple simultaneous calls for the same run +const inFlightFetches = new Map< + string, + Promise> +>() + +export const fetchScenarioViaWorkerAndCache = async ( + params: { + runId: string + evaluation: any + runIndex: any + }, + scenarioIds: string[], +): Promise> => { + // Safety checks for parameters + if (!params || !params.runId) { + return new Map() + } + + // Ensure scenarioIds is an array + const scenarioIdsArray = Array.isArray(scenarioIds) ? scenarioIds : [] + const cacheKey = `${params.runId}-${scenarioIdsArray.join(",")}` + + if (scenarioIdsArray.length === 0) { + return new Map() + } + + // Check if there's already an in-flight fetch for this exact request + if (inFlightFetches.has(cacheKey)) { + return inFlightFetches.get(cacheKey)! + } + + // Create the promise and cache it immediately + const fetchPromise = performFetch(params, scenarioIdsArray) + inFlightFetches.set(cacheKey, fetchPromise) + + try { + const result = await fetchPromise + return result + } finally { + // Clean up the cache entry when done + inFlightFetches.delete(cacheKey) + } +} + +const performFetch = async ( + params: { + runId: string + evaluation: any + runIndex: any + }, + scenarioIds: string[], +): Promise> => { + // Import run-scoped atoms at the top level + + const {scenarioStepLocalFamily: runScopedLocalFamily} = await import( + "../atoms/runScopedScenarios" + ) + + let context + try { + context = buildEvalWorkerContext({ + runId: params.runId, + evaluation: params.evaluation, + runIndex: params.runIndex, + }) + } catch (error) { + throw error + } + + const {jwt, apiUrl, projectId} = await buildAuthContext() + const {fetchStepsViaWorker} = await import( + "@/agenta-oss-common/lib/workers/evalRunner/bulkWorker" + ) + + const store = evalAtomStore() + + // Create a map to collect processed data for return + const processedResults = new Map() + + await fetchStepsViaWorker({ + context: { + ...context, + jwt, + apiUrl, + projectId, + }, + scenarioIds, + onChunk: (chunk) => { + chunk.forEach((val, key) => { + // Save to individual scenario atoms + store.set(runScopedLocalFamily({runId: params.runId, scenarioId: key}), (draft) => { + if (!draft) { + draft = { + steps: [], + annotationSteps: [], + invocationSteps: [], + inputSteps: [], + } + } + + // Store existing optimistic step statuses before overwriting + const preserveOptimisticStatuses = (existingSteps: any[], newSteps: any[]) => { + if (!existingSteps || !newSteps) return newSteps + + const shouldHoldOptimistic = ( + existingStatus: string, + serverStatus?: string, + ) => { + if (!existingStatus) return false + const optimisticStates = ["running", "revalidating"] + if (!optimisticStates.includes(existingStatus)) return false + + if (!serverStatus) return true + + // Only keep optimistic states while the server still reports a non-final status + const transitionalStates = new Set([ + "pending", + "running", + "annotating", + "revalidating", + ]) + + return transitionalStates.has(serverStatus) + } + + return newSteps.map((newStep: any) => { + const existingStep = existingSteps.find( + (s: any) => s.stepKey === newStep.stepKey, + ) + if ( + existingStep?.status && + shouldHoldOptimistic(existingStep.status, newStep.status) + ) { + return {...newStep, status: existingStep.status} + } + return newStep + }) + } + + // Merge server data while preserving optimistic statuses + for (const [k, v] of Object.entries(val)) { + if ( + k === "invocationSteps" || + k === "annotationSteps" || + k === "inputSteps" + ) { + ;(draft as any)[k] = preserveOptimisticStatuses( + (draft as any)[k], + v as any[], + ) + } else { + ;(draft as any)[k] = v + } + } + }) + + // Also collect the processed data for bulk cache return + processedResults.set(key, { + state: "hasData", + data: val, + }) + }) + }, + }) + + // Return the aggregated results map so callers receive data + return processedResults +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/scenarioFilters.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/scenarioFilters.ts new file mode 100644 index 0000000000..6b5aef797f --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/scenarioFilters.ts @@ -0,0 +1,66 @@ +import {Getter} from "jotai" +import {loadable} from "jotai/utils" + +import {IScenario} from "@/oss/lib/hooks/useEvaluationRunScenarios/types" + +import {scenarioStatusFamily} from "../atoms/progress" +import {scenarioStepFamily} from "../atoms/runScopedScenarios" + +export type ScenarioFilter = "all" | "pending" | "failed" | "unannotated" + +/** + * Determine whether a scenario matches the active filter. + * + * All atoms that need scenario filtering (counts, displayed list, etc.) should + * use this utility to guarantee that numbers and UI stay in sync. + */ +export const scenarioMatchesFilter = ( + get: Getter, + scenario: IScenario, + filter: ScenarioFilter, + runId: string, +): boolean => { + if (filter === "all") return true + + const scenarioId = (scenario as any).id || (scenario as any)._id + + if (filter === "pending") { + const statusLoad = get(loadable(scenarioStatusFamily({runId, scenarioId}))) + if (statusLoad.state !== "hasData") return true // treat unknown as pending while loading + const st = statusLoad.data.status + return ["pending", "running", "initialized", "started"].includes(st) + } + + if (filter === "failed") { + const statusLoad = get(loadable(scenarioStatusFamily({runId, scenarioId}))) + if (statusLoad.state !== "hasData") return false + const st = statusLoad.data.status + return st === "failure" || st === "error" + } + + if (filter === "unannotated") { + const stepLoad = get(loadable(scenarioStepFamily({runId, scenarioId}))) + if (stepLoad.state !== "hasData") return true // include while loading + const data = stepLoad.data + const hasAnn = + Array.isArray(data?.annotationSteps) && + data.annotationSteps.length > 0 && + data.annotationSteps.every((s: any) => !!s?.annotation) + const allInvSucceeded = + Array.isArray(data?.invocationSteps) && + data.invocationSteps.every((s) => s.status === "success") + return allInvSucceeded && !hasAnn + } + + return true +} + +export const filterScenarios = ( + get: Getter, + scenarios: IScenario[], + filter: ScenarioFilter, + runId: string, +): IScenario[] => { + if (!filter || filter === "all") return scenarios + return scenarios.filter((s) => scenarioMatchesFilter(get, s, filter, runId)) +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/workerContext/index.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/workerContext/index.ts new file mode 100644 index 0000000000..5328207590 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/workerContext/index.ts @@ -0,0 +1,145 @@ +import {getDefaultStore} from "jotai" + +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {EnrichedEvaluationRun} from "@/oss/lib/hooks/usePreviewEvaluations/types" +import {transformToRequestBody} from "@/oss/lib/shared/variant/transformer/transformToRequestBody" +import type {WorkspaceMember} from "@/oss/lib/Types" +import {getJWT} from "@/oss/services/api" +import {currentAppAtom} from "@/oss/state/app" +import {currentAppContextAtom} from "@/oss/state/app/selectors/app" +import {transformedPromptsAtomFamily} from "@/oss/state/newPlayground/core/prompts" +import {requestSchemaMetaAtomFamily} from "@/oss/state/newPlayground/core/requestSchemaMeta" +import {getOrgValues} from "@/oss/state/org" +import {getProjectValues} from "@/oss/state/project" +import {appUriInfoAtom, appSchemaAtom} from "@/oss/state/variant/atoms/fetcher" + +import {RunIndex} from "../buildRunIndex" + +import {EvalWorkerContextBase, WorkerAuthContext} from "./types" + +/** + * Build the evaluation-specific context for a worker fetch based on the current jotai store state. + */ +export const buildEvalWorkerContext = (params: { + runId: string + evaluation: EnrichedEvaluationRun + runIndex: RunIndex +}): EvalWorkerContextBase => { + const {selectedOrg} = getOrgValues() + const members = (selectedOrg?.default_workspace?.members as WorkspaceMember[]) || [] + + const store = getDefaultStore() + const appType = store.get(currentAppAtom)?.app_type + + const chatVariantIds: string[] = (params.evaluation?.variants || []) + .filter(Boolean) + .map((v: any) => { + const routePath = store.get(appUriInfoAtom)?.routePath + const meta = store.get(requestSchemaMetaAtomFamily({variant: v as any, routePath})) + return meta?.hasMessages ? (v as any).id : undefined + }) + .filter(Boolean) as string[] + + // Build a stable parameters map per revision using transformedPromptsAtomFamily(useStableParams) + const parametersByRevisionId: Record = {} + const revisionIds = (params.evaluation?.variants || []) + .map((v: any) => v?.id) + .filter(Boolean) as string[] + for (const rid of revisionIds) { + const stable = store.get( + transformedPromptsAtomFamily({revisionId: rid, useStableParams: true}), + ) + if (stable) parametersByRevisionId[rid] = stable + } + + return { + runId: params.runId, + mappings: params.evaluation?.data?.mappings ?? [], + members, + appType, + evaluators: params.evaluation?.evaluators || [], + testsets: params.evaluation?.testsets || [], + variants: (params.evaluation?.variants || []).map((v) => { + try { + const routePath = store.get(appUriInfoAtom)?.routePath + const spec = store.get(appSchemaAtom) + const meta = store.get(requestSchemaMetaAtomFamily({variant: v as any, routePath})) + // Custom workflow detection: + // - no messages container, and no `inputs` container => top-level custom inputs + // Completion apps usually have `inputs`; treat them as non-custom. + const hasInputsContainer = Array.isArray(meta?.inputKeys) + ? meta.inputKeys.includes("inputs") + : false + const isCustom = Boolean(!meta?.hasMessages && !hasInputsContainer) + const appType = (store.get(currentAppContextAtom)?.appType as any) || undefined + const rid = (v as any)?.id as string | undefined + const stableOptional = rid + ? store.get( + transformedPromptsAtomFamily({ + revisionId: rid, + useStableParams: true, + }), + ) + : undefined + return { + ...v, + isCustom, + // precompute optionalParameters to avoid metadata lookup in worker + optionalParameters: + stableOptional || + transformToRequestBody({ + variant: v, + isChat: meta?.hasMessages, + isCustom, + appType, + spec: spec as any, + routePath, + }), + } + } catch { + return { + ...v, + optionalParameters: transformToRequestBody({ + variant: v, + appType: + ((() => { + try { + return store.get(currentAppContextAtom)?.appType as any + } catch { + return undefined + } + })() as any) || undefined, + spec: ((): any => { + try { + return store.get(appSchemaAtom) + } catch { + return undefined + } + })(), + routePath: ((): any => { + try { + return store.get(appUriInfoAtom)?.routePath + } catch { + return undefined + } + })(), + }), + } + } + }), + runIndex: params.runIndex, + chatVariantIds, + uriObject: store.get(appUriInfoAtom) || undefined, + parametersByRevisionId, + } +} + +/** + * Resolve JWT, apiUrl and projectId in a single place. + */ +export const buildAuthContext = async (): Promise => { + const jwt = (await getJWT()) || "" + const apiUrl = getAgentaApiUrl() + const {projectId} = getProjectValues() ?? "" + return {jwt, apiUrl, projectId} +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/workerContext/types.ts b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/workerContext/types.ts new file mode 100644 index 0000000000..55296ce340 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/assets/helpers/workerContext/types.ts @@ -0,0 +1,31 @@ +import type {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import type {PreviewTestSet, WorkspaceMember} from "@/oss/lib/Types" + +import {RunIndex} from "../buildRunIndex" + +/** + * Minimal context object that the evaluation worker expects for enrichment. + * It purposefully contains only clone-safe data (no functions, Dates, etc.). + */ +export interface EvalWorkerContextBase { + runId: string + mappings: unknown[] + members: WorkspaceMember[] + evaluators: EvaluatorDto[] + testsets: PreviewTestSet[] + variants: EnhancedVariant[] + runIndex: RunIndex + uriObject?: {runtimePrefix: string; routePath?: string} + /** Stable transformed parameters keyed by revision id */ + parametersByRevisionId?: Record +} + +/** + * Authentication / environment info passed separately to the worker. + */ +export interface WorkerAuthContext { + jwt: string + apiUrl: string + projectId: string +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/index.ts b/web/ee/src/lib/hooks/useEvaluationRunData/index.ts new file mode 100644 index 0000000000..dd333d153c --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/index.ts @@ -0,0 +1,272 @@ +import {useCallback, useMemo} from "react" + +import deepEqual from "fast-deep-equal" +import {type WritableDraft} from "immer" +import {atom, useAtomValue, useSetAtom} from "jotai" +import {selectAtom} from "jotai/utils" +import useSWR from "swr" + +import {evalTypeAtom} from "@/oss/components/EvalRunDetails/state/evalType" +import {useAppId} from "@/oss/hooks/useAppId" +import axios from "@/oss/lib/api/assets/axiosConfig" +import {snakeToCamelCaseKeys} from "@/oss/lib/helpers/casing" +import {isDemo} from "@/oss/lib/helpers/utils" +import useEnrichEvaluationRun from "@/oss/lib/hooks/usePreviewEvaluations/assets/utils" +import {Evaluation, GenericObject, PreviewTestSet} from "@/oss/lib/Types" +import { + fetchAllEvaluationScenarios as fetchAllLegacyAutoEvaluationScenarios, + fetchEvaluation as fetchLegacyAutoEvaluation, +} from "@/oss/services/evaluations/api" +import { + fetchAllLoadEvaluationsScenarios, + fetchLoadEvaluation as fetchLegacyEvaluationData, +} from "@/oss/services/human-evaluations/api" +import {fetchTestset} from "@/oss/services/testsets/api" +import {userAtom} from "@/oss/state/profile/selectors/user" +import {projectIdAtom} from "@/oss/state/project/selectors/project" +import { + prefetchProjectVariantConfigs, + setProjectVariantReferencesAtom, +} from "@/oss/state/projectVariantConfig" + +import {evalAtomStore, evaluationRunStateFamily, loadingStateAtom} from "./assets/atoms" +import {buildRunIndex} from "./assets/helpers/buildRunIndex" +import {collectProjectVariantReferences} from "../usePreviewEvaluations/projectVariantConfigs" + +const fetchLegacyScenariosData = async ( + evaluationId: string, + evaluationObj: Evaluation, + type: "auto" | "human" | null, +): Promise => { + if (type === "auto") { + return fetchAllLegacyAutoEvaluationScenarios(evaluationId) + } else { + return new Promise((resolve) => { + fetchAllLoadEvaluationsScenarios(evaluationId, evaluationObj).then((data) => { + resolve( + data.map((item: GenericObject) => { + const numericScore = parseInt(item.score) + return {...item, score: isNaN(numericScore) ? null : numericScore} + }), + ) + }) + }) + } +} + +/** + * Hook to manage and fetch evaluation run data and scenarios. + * + * This hook supports both preview and legacy evaluation runs, providing + * functionality to fetch, enrich, and manage the state of evaluation data. + * It utilizes SWR for data fetching and caching, and Jotai for state management. + * + * @param {string | null} evaluationTableId - The ID of the evaluation table to fetch data for. + * @param {boolean} [debug=false] - Flag for enabling debug mode, which might provide additional logging or behavior. + * @param {() => void} [onScenariosLoaded] - Optional callback to be invoked when scenarios are successfully loaded. + * + * @returns {object} An object containing SWR mutate functions and methods to refetch evaluation and scenarios data. + */ +const useEvaluationRunData = (evaluationTableId: string | null, debug = false, runId?: string) => { + const evalType = useAtomValue(evalTypeAtom) + const routeAppId = useAppId() + // Get isPreview from run-scoped atom if runId is available + const isPreviewSelector = useCallback((state: any) => state.isPreview, []) + const isPreview = useAtomValue( + useMemo(() => { + if (!runId) return atom(false) + return selectAtom(evaluationRunStateFamily(runId), isPreviewSelector, deepEqual) + }, [runId, isPreviewSelector]), + ) + + const projectId = useAtomValue(projectIdAtom) + const setProjectVariantReferences = useSetAtom(setProjectVariantReferencesAtom) + const user = useAtomValue(userAtom) + const requireUser = isDemo() + const enrichRun = useEnrichEvaluationRun({debug, evalType}) + + // New fetcher for preview runs that fetches and enriches with testsetData + const fetchAndEnrichPreviewRun = useCallback(async () => { + if (!evaluationTableId || !projectId || (requireUser && !user?.id)) { + evalAtomStore().set(loadingStateAtom, (draft) => { + draft.isLoadingEvaluation = false + draft.activeStep = null + }) + return null + } + + evalAtomStore().set(loadingStateAtom, (draft) => { + draft.isLoadingEvaluation = true + draft.activeStep = "eval-run" + }) + + try { + const runRes = await axios.get( + `/preview/evaluations/runs/${evaluationTableId}?project_id=${projectId}`, + ) + const rawRun = snakeToCamelCaseKeys(runRes.data?.run) + const runIndex = buildRunIndex(rawRun) + + const testsetIds = Array.from( + Object.values(runIndex.steps || {}) + .map((m: any) => m?.refs?.testset?.id) + .filter(Boolean) + .reduce((acc: Set, id: string) => acc.add(id), new Set()), + ) as string[] + + const fetchedTestsets = ( + await Promise.all( + testsetIds.map((tid) => fetchTestset(tid, true).catch(() => null)), + ) + ).filter(Boolean) as PreviewTestSet[] + + if (!fetchedTestsets.length) { + evalAtomStore().set( + evaluationRunStateFamily(runId || evaluationTableId), + (draft: any) => { + draft.rawRun = runRes.data?.run + draft.enrichedRun = rawRun + draft.runIndex = runIndex + draft.isPreview = true + }, + ) + return rawRun + } + + if (!rawRun) { + if (runId) { + evalAtomStore().set(evaluationRunStateFamily(runId), (draft) => { + draft.isPreview = false + }) + } + return null + } + + const enrichedRun = enrichRun ? enrichRun(rawRun, fetchedTestsets, runIndex) : null + if (enrichedRun && (runId || evaluationTableId)) { + const effectiveRunId = runId || evaluationTableId + evalAtomStore().set( + evaluationRunStateFamily(effectiveRunId), + (draft: WritableDraft) => { + draft.rawRun = runRes.data?.run + draft.isPreview = true + draft.enrichedRun = enrichedRun + draft.runIndex = runIndex + }, + ) + } + + if (!routeAppId && projectId && enrichedRun) { + const references = collectProjectVariantReferences([enrichedRun], projectId) + setProjectVariantReferences(references) + prefetchProjectVariantConfigs(references) + } + + return enrichedRun + } catch (error: any) { + if (axios.isCancel?.(error) || error?.code === "ERR_CANCELED") { + return null + } + throw error + } finally { + evalAtomStore().set(loadingStateAtom, (draft) => { + draft.isLoadingEvaluation = false + draft.activeStep = null + }) + } + }, [enrichRun, evaluationTableId, projectId, runId, user?.id, requireUser]) + + const swrKey = + !!enrichRun && evaluationTableId && (!requireUser || !!user?.id) + ? [ + "previewRun", + evaluationTableId, + evalType, + projectId ?? "none", + requireUser ? (user?.id ?? "anon") : "no-user", + ] + : null + + const previewRunSwr = useSWR(swrKey, fetchAndEnrichPreviewRun, { + revalidateIfStale: false, + revalidateOnFocus: false, + revalidateOnReconnect: false, + }) + + // New fetcher for legacy runs that fetches and enriches with testsetData + const fetchAndEnrichLegacyRun = async () => { + const rawRun = + evalType === "auto" + ? await fetchLegacyAutoEvaluation(evaluationTableId as string) + : await fetchLegacyEvaluationData(evaluationTableId as string) + if (!rawRun) return null + + if (evalType === "auto") { + return rawRun + } + + const testsetId = (rawRun?.testset as any)?._id + let testsetData = testsetId ? await fetchTestset(testsetId) : null + + if (testsetData) { + // @ts-ignore + rawRun.testset = testsetData + } + return rawRun + } + + // Legacy: Use SWR to load evaluation data if not a preview + const legacyEvaluationSWR = useSWR( + !!enrichRun && previewRunSwr.data === null && evaluationTableId + ? ["legacyEval", evaluationTableId, evalType] + : null, + fetchAndEnrichLegacyRun, + { + onSuccess(data, key, config) { + if (!data) return + // Populate run-scoped atoms + if (runId) { + evalAtomStore().set(evaluationRunStateFamily(runId), (draft) => { + draft.rawRun = data + draft.isPreview = false + // @ts-ignore + draft.enrichedRun = data + }) + } + }, + }, + ) + + // Legacy: Load scenarios once legacyEvaluation is available + const legacyScenariosSWR = useSWR( + !(isPreview ?? true) && legacyEvaluationSWR.data?.id && !!projectId + ? ["legacyScenarios", evaluationTableId, projectId] + : null, + () => + fetchLegacyScenariosData( + evaluationTableId as string, + legacyEvaluationSWR.data as Evaluation, + evalType, + ), + ) + + return { + // Mutate functions + legacyEvaluationSWR, + legacyScenariosSWR, + refetchEvaluation() { + if (isPreview) { + previewRunSwr.mutate() + } else { + legacyEvaluationSWR.mutate() + } + }, + refetchScenarios() { + if (!isPreview) { + legacyScenariosSWR.mutate() + } + }, + } +} + +export default useEvaluationRunData diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/types.ts b/web/ee/src/lib/hooks/useEvaluationRunData/types.ts new file mode 100644 index 0000000000..23b28f7915 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/types.ts @@ -0,0 +1,141 @@ +import {EvaluationStatus, PreviewTestSet, WorkspaceMember} from "@/oss/lib/Types" + +import {Evaluation} from "../../Types" +import type {Metric} from "../useEvaluationRunMetrics/types" +import type {IScenario} from "../useEvaluationRunScenarios/types" +import type { + IStepResponse, + UseEvaluationRunScenarioStepsFetcherResult, +} from "../useEvaluationRunScenarioSteps/types" +import {EvaluatorDto} from "../useEvaluators/types" +import type {EnrichedEvaluationRun, EvaluationRun} from "../usePreviewEvaluations/types" + +import {RunIndex} from "./assets/helpers/buildRunIndex" + +export interface ScenarioStatus { + status: + | "pending" + | "running" + | "revalidating" + | "success" + | "error" + | "cancelled" + | "done" + | "failed" + result?: { + data?: unknown + } + error?: string +} + +export interface ScenarioStatusCounts { + total: number + pending: number + running: number + done: number + success: number + failed: number + cancelled: number +} + +export type ScenarioStatusMap = Record + +export interface IStatusMeta { + total: number + completed: number + pending: number + inProgress: number + error: number + cancelled: number + success: number + percentComplete: number + statusSummary: Record + timeline: {scenarioId: string; status: string}[] + timestamps: Record + transitions: Record + durations: Record + statusDurations: Record> +} + +export interface EvaluationRunState { + rawRun?: EvaluationRun | Evaluation + isPreview?: boolean + enrichedRun?: EnrichedEvaluationRun + /** Whether this evaluation is being used for comparison */ + isComparison?: boolean + /** Whether this is the base evaluation being compared against */ + isBase?: boolean + /** Position in comparison view (1 for base, 2+ for comparisons) */ + compareIndex?: number + scenarios?: IScenario[] + /** Summary of scenario statuses and timings */ + statusMeta: IStatusMeta + steps?: { + inputStep?: IStepResponse + invocationStep?: IStepResponse + annotationSteps?: IStepResponse[] + mainInputParams: any + secondaryInputParams: any + scenarioIndex: string + count: number + next?: string + } + metrics?: { + data: Metric[] + count: number + next?: string + } + isLoading: {run: boolean; scenarios: boolean; steps: boolean; metrics: boolean} + isError: {run: boolean; scenarios: boolean; steps: boolean; metrics: boolean} + /** + * Map of scenarioId to scenario steps and related data + */ + scenarioSteps?: Record + /** Pre-computed index of steps and mappings for this run */ + runIndex?: import("./assets/helpers/buildRunIndex").RunIndex +} + +export type LoadingStep = "eval-run" | "scenarios" | "scenario-steps" | null +export interface ScenarioStepProgress { + completed: number + total: number + percent: number +} + +export interface EvaluationLoadingState { + isLoadingEvaluation: boolean + isLoadingScenarios: boolean + isLoadingSteps: boolean + isLoadingMetrics: boolean + activeStep: LoadingStep + scenarioStepProgress: ScenarioStepProgress +} + +export interface OptimisticScenarioOverride { + status: EvaluationStatus + /** + * UI-only status used to indicate intermediate states like + * "revalidating" or "annotating" that are not recognised by the backend + */ + uiStatus?: "revalidating" | "annotating" + result?: any +} + +export interface EvalRunDataContextType { + runId: string + mappings: any + members: WorkspaceMember[] + evaluators: EvaluatorDto[] + testsets: PreviewTestSet[] + variants: any[] + /** + * Given an array of scenario IDs, fetches step data for each, and then + * enriches each step list with inputStep, invocationStep, trace, annotationSteps, + * and invocationParameters. Caches the results in `bulkStepsCacheAtom`. + * + * @param scenarioIds array of scenario IDs + * @param context the `EvalRunDataContextType` object containing runId, mappings, members, evaluators, testsets, and variants + * @param set the jotai `set` callback + */ + runIndex?: RunIndex +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunData/useEvalRunScenarioData.tsx b/web/ee/src/lib/hooks/useEvaluationRunData/useEvalRunScenarioData.tsx new file mode 100644 index 0000000000..85c5bf75c2 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunData/useEvalRunScenarioData.tsx @@ -0,0 +1,43 @@ +import {useMemo} from "react" + +import {useAtomValue} from "jotai" +import {loadable} from "jotai/utils" + +import {UseEvaluationRunScenarioStepsFetcherResult} from "../useEvaluationRunScenarioSteps/types" + +import {getCurrentRunId} from "./assets/atoms/migrationHelper" +import {scenarioStepFamily} from "./assets/atoms/runScopedScenarios" +import {evalAtomStore} from "./assets/atoms/store" + +const useEvalRunScenarioData = (scenarioId: string, runId?: string) => { + const store = evalAtomStore() + + // Memoize runId calculation to prevent infinite loops + const effectiveRunId = useMemo(() => { + if (runId) return runId + try { + return getCurrentRunId() + } catch (error) { + console.warn("[useEvalRunScenarioData] No run ID available:", error) + return null + } + }, [runId]) + + // Read from the same global store that writes are going to + const stepLoadable = useAtomValue( + loadable(scenarioStepFamily({scenarioId, runId: effectiveRunId || ""})), + {store}, + ) + + return useMemo(() => { + let data: UseEvaluationRunScenarioStepsFetcherResult | undefined = + stepLoadable.state === "hasData" ? stepLoadable.data : undefined + + if (stepLoadable.state === "hasData" && stepLoadable.data?.trace) { + data = stepLoadable.data + } + return data + }, [stepLoadable]) +} + +export default useEvalRunScenarioData diff --git a/web/ee/src/lib/hooks/useEvaluationRunMetrics/assets/utils.ts b/web/ee/src/lib/hooks/useEvaluationRunMetrics/assets/utils.ts new file mode 100644 index 0000000000..b990f89ad6 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunMetrics/assets/utils.ts @@ -0,0 +1,24 @@ +import axios from "@/oss/lib/api/assets/axiosConfig" + +import type {MetricResponse} from "../types" + +/** + * SWR fetcher for fetching metrics from the API. + * + * Given a URL, this function performs a GET request to the URL, extracts the + * `metrics` array, `count`, and `next` properties from the response, and + * returns them in an object. + * + * @param {string} url The URL to fetch + * @return {Promise<{metrics: MetricResponse[], count: number, next?: string}>} + */ +export const fetcher = (url: string) => + axios.get(url).then((res) => { + const raw = res.data + const metrics: MetricResponse[] = Array.isArray(raw.metrics) ? raw.metrics : [] + return { + metrics, + count: raw.count as number, + next: raw.next as string | undefined, + } + }) diff --git a/web/ee/src/lib/hooks/useEvaluationRunMetrics/index.ts b/web/ee/src/lib/hooks/useEvaluationRunMetrics/index.ts new file mode 100644 index 0000000000..3f5f158ef0 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunMetrics/index.ts @@ -0,0 +1,112 @@ +import {useMemo} from "react" + +import useSWR from "swr" + +import { + METRICS_ENDPOINT, + createScenarioMetrics, + updateMetric, + updateMetrics, + computeRunMetrics, +} from "@/oss/services/runMetrics/api" + +import {fetcher} from "./assets/utils" +import type { + MetricResponse, + Metric, + UseEvaluationRunMetricsOptions, + UseEvaluationRunMetricsResult, +} from "./types" + +/** + * Hook to fetch and create metrics for a specific evaluation run (and optionally scenario). + * + * @param runId The UUID of the evaluation run. If falsy, fetching is skipped. + * @param options Optional filters/pagination: { limit, next, scenarioIds, statuses }. + */ +const useEvaluationRunMetrics = ( + runIds: string | string[] | null | undefined, + scenarioId?: string | null, + options?: UseEvaluationRunMetricsOptions, +): UseEvaluationRunMetricsResult => { + // Build query parameters + const queryParams = new URLSearchParams() + + // Append one or many run_ids query params + if (runIds) { + if (Array.isArray(runIds) && runIds.length > 0) { + // Ensure deterministic ordering for SWR key stability + const sorted = [...runIds].sort() + sorted.forEach((id) => queryParams.append("run_ids", id)) + } else { + queryParams.append("run_ids", runIds) + } + } + if (options?.limit !== undefined) { + queryParams.append("limit", options.limit.toString()) + } + if (options?.next) { + queryParams.append("next", options.next) + } + if (scenarioId) { + queryParams.append("scenario_ids", scenarioId) + } else if (options?.scenarioIds) { + options.scenarioIds.forEach((sid) => queryParams.append("scenario_ids", sid)) + } + if (options?.statuses) { + options.statuses.forEach((st) => queryParams.append("status", st)) + } + + const swrKey = useMemo(() => { + const queryRunIds = queryParams.getAll("run_ids").filter((a) => a !== "undefined" && !!a) + const queryScenarioIds = queryParams + .getAll("scenario_ids") + .filter((a) => a !== "undefined" && !!a) + + return queryRunIds.length > 0 || queryScenarioIds.length > 0 + ? `${METRICS_ENDPOINT}?${queryParams.toString()}` + : null + }, [queryParams]) + + // SWR response typed to raw MetricResponse[] + const swrData = useSWR<{ + metrics: MetricResponse[] + count: number + next?: string + }>(swrKey, fetcher) + + // Convert raw MetricResponse[] to camelCase Metric[] + const rawMetrics = swrData.data?.metrics + const camelMetrics: Metric[] | undefined = rawMetrics + ? rawMetrics.map((item) => item) + : undefined + + const totalCount = swrData.data?.count + const nextToken = swrData.data?.next + + return { + get metrics() { + return camelMetrics + }, + get count() { + return totalCount + }, + get next() { + return nextToken + }, + get isLoading() { + return !swrData.error && !swrData.data + }, + get isError() { + return !!swrData.error + }, + swrData, + mutate: () => swrData.mutate(), + createScenarioMetrics, + updateMetric, + updateMetrics, + computeRunMetrics, + } +} + +export default useEvaluationRunMetrics diff --git a/web/ee/src/lib/hooks/useEvaluationRunMetrics/types.ts b/web/ee/src/lib/hooks/useEvaluationRunMetrics/types.ts new file mode 100644 index 0000000000..20de372a60 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunMetrics/types.ts @@ -0,0 +1,75 @@ +import {EvaluationStatus, SnakeToCamelCaseKeys} from "@/oss/lib/Types" + +// Raw API response type for one metric (snake_case) +export interface MetricResponse { + id: string + run_id: string + scenario_id?: string + status?: EvaluationStatus + data: { + outputs: Record + } + created_at?: string + // …other fields in snake_case if backend adds more… +} + +// CamelCased version of MetricResponse +export type Metric = SnakeToCamelCaseKeys + +// Options for fetching metrics (pagination & filters) +export interface UseEvaluationRunMetricsOptions { + limit?: number + next?: string + scenarioIds?: string[] + statuses?: string[] +} + +// Result returned by useEvaluationRunMetrics hook +export interface UseEvaluationRunMetricsResult { + metrics: Metric[] | undefined + count?: number + next?: string + isLoading: boolean + isError: boolean + swrData: import("swr").SWRResponse< + { + metrics: MetricResponse[] + count: number + next?: string + }, + any + > + mutate: () => Promise + createScenarioMetrics: ( + apiUrl: string, + jwt: string, + runId: string, + entries: { + scenarioId: string + data: Record + }[], + ) => Promise + updateMetric: ( + apiUrl: string, + jwt: string, + metricId: string, + changes: { + data?: Record + status?: string + tags?: Record + meta?: Record + }, + ) => Promise + updateMetrics: ( + apiUrl: string, + jwt: string, + metrics: { + id: string + data?: Record + status?: string + tags?: Record + meta?: Record + }[], + ) => Promise + computeRunMetrics: (metrics: {data: Record}[]) => Record +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunScenarioSteps/types.ts b/web/ee/src/lib/hooks/useEvaluationRunScenarioSteps/types.ts new file mode 100644 index 0000000000..df4a15447b --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunScenarioSteps/types.ts @@ -0,0 +1,162 @@ +import {SWRResponse, SWRConfiguration} from "swr" + +import type {PreviewTestSet, SnakeToCamelCaseKeys} from "../../Types" +import {AnnotationDto} from "../useAnnotations/types" +import {RunIndex} from "../useEvaluationRunData/assets/helpers/buildRunIndex" + +// Step type for useEvaluationRunScenarioSteps fetcher result (camelCase, derived from StepResponseStep) +// Options for fetching steps (pagination, filters) +export interface UseEvaluationRunScenarioStepsOptions { + limit?: number + next?: string + keys?: string[] + statuses?: string[] +} + +// Result type returned by the hook +export interface UseEvaluationRunScenarioStepsResult { + isLoading: boolean + swrData: SWRResponse + // Function to revalidate + mutate: () => Promise +} + +export interface UseEvaluationRunScenarioStepsConfig extends SWRConfiguration { + concurrency?: number +} + +// --- Types for useEvaluationRunScenarioSteps fetcher result --- +export interface StepResponse { + steps: StepResponseStep[] + count: number + next?: string +} +export interface StepResponseStep { + id: string + // + run_id: string + scenario_id: string + step_key: string + repeat_idx?: number + timestamp?: string + interval?: number + // + status: string + // + // hash_id?: string + trace_id?: string + testcase_id?: string + error?: Record + // + created_at?: string + created_by_id?: string + // + is_legacy?: boolean + inputs?: Record + ground_truth?: Record +} +export type IStepResponse = SnakeToCamelCaseKeys + +export interface TraceNode { + trace_id: string + span_id: string + lifecycle: { + created_at: string + } + root: { + id: string + } + tree: { + id: string + } + node: { + id: string + name: string + type: string + } + parent?: { + id: string + } + time: { + start: string + end: string + } + status: { + code: string + } + data: Record + metrics: Record + refs: Record + otel: { + kind: string + attributes: Record + } + nodes?: Record +} + +export interface TraceData { + trees: TraceTree[] + version: string + count: number +} + +export interface TraceTree { + tree: { + id: string + } + nodes: TraceNode[] +} + +export type InvocationParameters = Record< + string, + { + requestBody: { + ag_config: { + prompt: { + messages: {role: string; content: string}[] + template_format: string + input_keys: string[] + llm_config: { + model: string + tools: any[] + } + } + } + inputs: Record + } + endpoint: string + } | null +> + +export interface IInvocationStep extends IStepResponse { + trace?: TraceTree + invocationParameters?: InvocationParameters +} + +export interface IInputStep extends IStepResponse { + inputs?: Record + groundTruth?: Record + testcase?: PreviewTestSet["data"]["testcases"][number] +} +export interface IAnnotationStep extends IStepResponse { + annotation?: AnnotationDto +} + +export interface UseEvaluationRunScenarioStepsFetcherResult { + steps: IStepResponse[] + mappings?: any[] + + // Single primary steps (kept for backward compatibility) + // invocationStep?: IStepResponse + annotationSteps: IAnnotationStep[] + invocationSteps: IInvocationStep[] + inputSteps: IInputStep[] + annotations?: AnnotationDto[] | null + + // NEW: support multiple role steps per scenario + inputStep?: IStepResponse + scenarioId?: string + trace?: TraceTree | TraceData | null + // annotation?: AnnotationDto | null + invocationParameters?: InvocationParameters +} diff --git a/web/ee/src/lib/hooks/useEvaluationRunScenarios/index.ts b/web/ee/src/lib/hooks/useEvaluationRunScenarios/index.ts new file mode 100644 index 0000000000..5b4396480b --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunScenarios/index.ts @@ -0,0 +1,133 @@ +import {useCallback} from "react" + +import {useSetAtom} from "jotai" +import useSWR, {SWRConfiguration} from "swr" + +import axios from "@/oss/lib/api/assets/axiosConfig" +import {snakeToCamelCaseKeys} from "@/oss/lib/helpers/casing" + +import {evalAtomStore, loadingStateAtom} from "../useEvaluationRunData/assets/atoms" +import {evaluationRunStateFamily} from "../useEvaluationRunData/assets/atoms/runScopedAtoms" + +import {IScenario, ScenarioResponse, UseEvaluationRunScenariosOptions} from "./types" + +// Fetcher factory that posts a query to the new endpoint and syncs atoms of current store +const makeFetcher = ( + endpoint: string, + syncAtom: boolean, + setLoading: ReturnType, + runId?: string | null, + params?: UseEvaluationRunScenariosOptions, +): (() => Promise<{ + scenarios: IScenario[] + count: number + next?: string +}>) => { + return () => { + if (syncAtom) { + setLoading((draft) => { + draft.isLoadingScenarios = true + draft.isLoadingEvaluation = false + draft.activeStep = "scenarios" + }) + } + + // Build request body for /preview/evaluations/scenarios/query + const body: Record = { + scenario: { + ...(runId ? {run_ids: [runId]} : {}), + }, + windowing: { + ...(params?.limit !== undefined ? {limit: params.limit} : {}), + ...(params?.next ? {next: params.next} : {}), + }, + } + + return axios.post(endpoint, body).then((res) => { + const raw = res.data + const scenarios = Array.isArray(raw.scenarios) + ? (raw.scenarios.map((scenario: ScenarioResponse, index: number) => ({ + ...snakeToCamelCaseKeys(scenario), + scenarioIndex: (scenario.meta?.index || 0) + 1, + })) as IScenario[]) + : ([] as IScenario[]) + + if (syncAtom) { + setLoading((draft) => { + draft.isLoadingScenarios = false + draft.activeStep = null + }) + // Only sync to run-scoped atom if runId is available + if (runId) { + evalAtomStore().set(evaluationRunStateFamily(runId), (draft) => { + draft.scenarios = scenarios + }) + } + } + return { + scenarios, + count: raw.count as number, + next: raw.next as string | undefined, + } + }) + } +} + +/** + * @deprecated + * @param runId + * @param params + * @returns + */ +export const getEvaluationRunScenariosKey = ( + runId?: string | null | undefined, + params?: UseEvaluationRunScenariosOptions, +) => { + if (!runId) return null + const parts: string[] = ["scenarios-query", `run:${runId}`] + if (params?.limit !== undefined) parts.push(`limit:${params.limit}`) + if (params?.next) parts.push(`next:${params.next}`) + return parts.join("|") +} +/** + * @deprecated + * Hook to fetch scenarios belonging to a specific evaluation run, + * plus some “progress” aggregates (pending vs. completed). + * + * @param runId The UUID of the run. If falsy, fetching is skipped. + * @param params Optional pagination: { limit, next }. + */ + +interface UseEvaluationRunScenariosHookOptions extends SWRConfiguration { + syncAtom?: boolean +} +const useEvaluationRunScenarios = ( + runId: string | null | undefined, + params?: UseEvaluationRunScenariosOptions, + {syncAtom = true, ...options}: UseEvaluationRunScenariosHookOptions = {}, +) => { + const setLoading = useSetAtom(loadingStateAtom) + + // Build query string only if runId is provided + const swrKey = getEvaluationRunScenariosKey(runId, params) + + const fetcher = useCallback( + makeFetcher("/preview/evaluations/scenarios/query", syncAtom, setLoading, runId, params), + [syncAtom, setLoading, runId, params?.limit, params?.next], + ) + + const swrData = useSWR<{ + scenarios: IScenario[] + count: number + next?: string + }>(swrKey ? `${swrKey}-${syncAtom}` : null, swrKey ? fetcher : null, { + ...options, + revalidateIfStale: false, + revalidateOnFocus: false, + revalidateOnReconnect: false, + }) + + return swrData +} + +export default useEvaluationRunScenarios diff --git a/web/ee/src/lib/hooks/useEvaluationRunScenarios/types.ts b/web/ee/src/lib/hooks/useEvaluationRunScenarios/types.ts new file mode 100644 index 0000000000..3a83ebc6ef --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluationRunScenarios/types.ts @@ -0,0 +1,24 @@ +import {SnakeToCamelCaseKeys} from "../../Types" + +// Raw API response type for one scenario (snake_case) +export interface ScenarioResponse { + id: string + run_id: string + status: string + created_by_id: string + created_at: string + // …other fields in snake_case if backend adds more… +} + +// CamelCased version of ScenarioResponse +export interface IScenario extends SnakeToCamelCaseKeys { + scenarioIndex: number +} + +// +// Pagination/options for the hook: +// +export interface UseEvaluationRunScenariosOptions { + limit?: number + next?: string +} diff --git a/web/ee/src/lib/hooks/useEvaluations.ts b/web/ee/src/lib/hooks/useEvaluations.ts new file mode 100644 index 0000000000..b7a40bf875 --- /dev/null +++ b/web/ee/src/lib/hooks/useEvaluations.ts @@ -0,0 +1,345 @@ +import {useMemo, useCallback} from "react" + +// import {useAppId} from "@/oss/hooks/useAppId" + +import axios from "@agenta/oss/src/lib/api/assets/axiosConfig" +import {EvaluationType} from "@agenta/oss/src/lib/enums" +import { + abTestingEvaluationTransformer, + fromEvaluationResponseToEvaluation, + singleModelTestEvaluationTransformer, +} from "@agenta/oss/src/lib/transformers" +import {Evaluation, EvaluationResponseType, ListAppsItem} from "@agenta/oss/src/lib/Types" +import {useAtomValue} from "jotai" +import useSWR from "swr" + +import {useAppId} from "@/oss/hooks/useAppId" +import {deleteEvaluations as deleteAutoEvaluations} from "@/oss/services/evaluations/api" +import {fetchAllEvaluations} from "@/oss/services/evaluations/api" +import {deleteEvaluations as deleteHumanEvaluations} from "@/oss/services/human-evaluations/api" +import {fetchAllLoadEvaluations, fetchEvaluationResults} from "@/oss/services/human-evaluations/api" +import {useAppsData} from "@/oss/state/app" +import {getProjectValues, projectIdAtom} from "@/oss/state/project" + +import usePreviewEvaluations from "./usePreviewEvaluations" + +const deleteRuns = async (ids: string[]) => { + const {projectId} = getProjectValues() + await axios.delete(`/preview/evaluations/runs/?project_id=${projectId}`, { + data: { + run_ids: ids, + }, + }) + + return ids +} + +/** + * Custom hook to manage evaluations, combining legacy evaluations and preview evaluations. + * + * @param {Object} params - Configuration object. + * @param {boolean} [params.withPreview] - Whether to include preview evaluations. + * @param {EvaluationType[]} params.types - List of evaluation types to filter. + * + * @returns {Object} An object containing: + * - `legacyEvaluations`: SWR object with data, error, and loading state for legacy evaluations. + * - `previewEvaluations`: Object with data and loading state for preview evaluations. + * - `mergedEvaluations`: Combined list of legacy and preview evaluations. + * - `isLoadingLegacy`: Loading state of legacy evaluations. + * - `isLoadingPreview`: Loading state of preview evaluations. + * - `refetch`: Function to refetch both legacy and preview evaluations. + * - `handleDeleteEvaluations`: Function to delete evaluations by IDs. + */ +const useEvaluations = ({ + withPreview, + types, + evalType, + appId: appIdOverride, +}: { + withPreview?: boolean + types: EvaluationType[] + evalType?: "human" | "auto" + appId?: string | null +}) => { + const routeAppId = useAppId() + const appId = (appIdOverride ?? routeAppId) || undefined + const {apps: availableApps = []} = useAppsData() + const projectId = useAtomValue(projectIdAtom) + + const appIdsForScope = useMemo(() => { + if (appId) return [appId] + return (availableApps as ListAppsItem[]) + .map((application) => application.app_id) + .filter((id): id is string => typeof id === "string" && id.length > 0) + }, [appId, availableApps]) + + /** + * Fetches legacy evaluations for the given appId and transforms them into the required format. + * Also fetches auto evaluations if the selected types require it. + * Returns an object containing human and auto evaluations. + */ + const legacyFetcher = useCallback(async () => { + if (!projectId || appIdsForScope.length === 0) { + return { + humanEvals: [], + autoEvals: [], + } + } + + const needsAutoEvaluations = types.some((type) => + [ + EvaluationType.human_a_b_testing, + EvaluationType.single_model_test, + EvaluationType.human_scoring, + EvaluationType.auto_exact_match, + EvaluationType.automatic, + ].includes(type), + ) + + const responses = await Promise.all( + appIdsForScope.map(async (targetAppId) => { + const rawEvaluations: EvaluationResponseType[] = await fetchAllLoadEvaluations( + targetAppId, + projectId, + ) + + const preparedEvaluations = rawEvaluations + .map((evaluationResponse) => ({ + evaluation: { + ...fromEvaluationResponseToEvaluation(evaluationResponse), + appId: targetAppId, + }, + raw: evaluationResponse, + })) + .filter(({evaluation}) => types.includes(evaluation.evaluationType)) + + const results = await Promise.all( + preparedEvaluations.map(({evaluation}) => + fetchEvaluationResults(evaluation.id), + ), + ) + + const humanEvaluations = results + .map((result, index) => { + const {evaluation, raw} = preparedEvaluations[index] + if (!result) return undefined + + if (evaluation.evaluationType === EvaluationType.single_model_test) { + const transformed = singleModelTestEvaluationTransformer({ + item: evaluation, + result, + }) + return { + ...transformed, + appId: targetAppId, + appName: evaluation.appName, + } + } + + if (evaluation.evaluationType === EvaluationType.human_a_b_testing) { + if (Object.keys(result.votes_data || {}).length > 0) { + const transformed = abTestingEvaluationTransformer({ + item: raw, + results: result.votes_data, + }) + return { + ...transformed, + appId: targetAppId, + appName: evaluation.appName, + } + } + } + + return undefined + }) + .filter((item): item is Record => Boolean(item)) + .filter( + (item: any) => + item.resultsData !== undefined || + !(Object.keys(item.scoresData || {}).length === 0) || + item.avgScore !== undefined, + ) + + const autoEvaluations = needsAutoEvaluations + ? (await fetchAllEvaluations(targetAppId)) + .sort( + (a, b) => + new Date(b.created_at || 0).getTime() - + new Date(a.created_at || 0).getTime(), + ) + .map((evaluation) => ({ + ...evaluation, + appId: targetAppId, + })) + : [] + + return { + humanEvals: humanEvaluations, + autoEvals: autoEvaluations, + } + }), + ) + + const humanEvals = responses + .flatMap((response) => response.humanEvals) + .sort( + (a, b) => + new Date(b?.createdAt ?? 0).getTime() - new Date(a?.createdAt ?? 0).getTime(), + ) + const autoEvals = responses.flatMap((response) => response.autoEvals) + + return { + humanEvals, + autoEvals, + } + }, [appIdsForScope, projectId, types]) + + /** + * SWR hook for fetching and caching legacy evaluations using the legacyFetcher. + */ + const legacyEvaluations = useSWR( + !projectId || appIdsForScope.length === 0 + ? null + : ["legacy-evaluations", projectId, ...appIdsForScope], + legacyFetcher, + ) + + /** + * Hook for fetching preview evaluations if withPreview is enabled. + */ + const previewEvaluations = usePreviewEvaluations({ + skip: !withPreview, + types, + appId, + }) + + // Extract runs from preview evaluations + const {runs} = previewEvaluations || {} + + /** + * Lazily combines legacy and preview evaluations into a single array. + * Returns an empty array if either source is not yet loaded. + */ + const computeMergedEvaluations = useCallback( + (evalType?: "human" | "auto") => { + const legacyData = legacyEvaluations.data || {autoEvals: [], humanEvals: []} + const legacyAuto = legacyData.autoEvals || [] + const legacyHuman = legacyData.humanEvals || [] + let filteredLegacy = [] + if (types.includes(EvaluationType.single_model_test)) { + filteredLegacy = legacyHuman + } else { + filteredLegacy = legacyAuto + } + + if (!runs || !Array.isArray(runs)) { + return filteredLegacy + } + + // Filtering out evaluations based on eval type + let filteredRuns = [] + if (evalType === "human") { + filteredRuns = runs.filter((run) => + run?.data?.steps.some( + (step) => step.type === "annotation" && step.origin === "human", + ), + ) + if (filteredLegacy.length > 0) { + const autoEvalLagecyRuns = filteredLegacy.filter( + (run) => run?.evaluation_type === "single_model_test", + ) + + filteredRuns = [...filteredRuns, ...autoEvalLagecyRuns] + } + } else if (evalType === "auto") { + filteredRuns = runs.filter((run) => + run?.data?.steps.every( + (step) => + step.type !== "annotation" || + step.origin === "auto" || + step.origin === undefined, + ), + ) + if (filteredLegacy.length > 0) { + const autoEvalLagecyRuns = filteredLegacy.filter( + (run) => "aggregated_results" in run, + ) + + filteredRuns = [...filteredRuns, ...autoEvalLagecyRuns] + } + } else { + filteredRuns = [...filteredLegacy, ...runs] + } + + return filteredRuns.sort((a, b) => { + return b.createdAtTimestamp - a.createdAtTimestamp + }) + }, + [legacyEvaluations.data, runs, types, evalType], + ) + + /** + * Refetches both legacy and preview evaluations in parallel. + * Use this after mutations that affect evaluation data. + */ + const refetchAll = useCallback(async () => { + await Promise.all([legacyEvaluations.mutate(), previewEvaluations.swrData.mutate()]) + }, [legacyEvaluations, previewEvaluations]) + + /** + * Deletes evaluations by IDs, handling both legacy and preview evaluations. + * Determines which IDs correspond to legacy or preview runs, deletes them accordingly, and refetches all data. + * @param _ids - Single ID or array of IDs to delete + */ + const handleDeleteEvaluations = useCallback( + async (_ids: string[] | string) => { + const ids = Array.isArray(_ids) ? _ids : typeof _ids === "string" ? [_ids] : [] + const listOfLegacyEvals = + evalType === "auto" + ? legacyEvaluations.data?.autoEvals || [] + : legacyEvaluations.data?.humanEvals || [] + + // Determine which IDs are legacy evaluations + const legacyEvals = listOfLegacyEvals + .filter((e) => ids.includes(e.key || e.id)) + .map((e) => e.key || e.id) + + // IDs that are preview runs + const runsIds = ids.filter((id) => !legacyEvals.includes(id)) + try { + if (legacyEvals.length > 0) { + if (evalType === "auto") { + await deleteAutoEvaluations(ids) + } else { + await deleteHumanEvaluations(ids) + } + } + + if (runsIds.length > 0) { + await deleteRuns(runsIds) + } + await refetchAll() + } catch (error) { + console.error(error) + } + }, + [legacyEvaluations, refetchAll], + ) + + const mergedEvaluations = useMemo( + () => computeMergedEvaluations(evalType), + [computeMergedEvaluations, evalType], + ) + + return { + legacyEvaluations, + previewEvaluations, + mergedEvaluations, + isLoadingLegacy: legacyEvaluations.isLoading, + isLoadingPreview: previewEvaluations?.swrData?.isLoading ?? false, + refetch: refetchAll, + handleDeleteEvaluations, + } +} + +export default useEvaluations diff --git a/web/ee/src/lib/hooks/useInvocationResult/index.ts b/web/ee/src/lib/hooks/useInvocationResult/index.ts new file mode 100644 index 0000000000..118137f9ea --- /dev/null +++ b/web/ee/src/lib/hooks/useInvocationResult/index.ts @@ -0,0 +1,143 @@ +import {useMemo} from "react" + +import {useAtomValue} from "jotai" + +import {renderChatMessages} from "@/oss/components/EvalRunDetails/assets/renderChatMessages" +import {evalTypeAtom} from "@/oss/components/EvalRunDetails/state/evalType" +import {useRunId} from "@/oss/contexts/RunIdContext" +import {readInvocationResponse} from "@/oss/lib/helpers/traceUtils" + +import {getCurrentRunId} from "../useEvaluationRunData/assets/atoms/migrationHelper" +import {scenarioStatusAtomFamily} from "../useEvaluationRunData/assets/atoms/progress" +import {evalAtomStore} from "../useEvaluationRunData/assets/atoms/store" +import useEvalRunScenarioData from "../useEvaluationRunData/useEvalRunScenarioData" + +import type {UseInvocationResult, UseInvocationResultArgs} from "./types" + +export function useInvocationResult({ + scenarioId, + stepKey, + runId: maybeRunId, + editorType = "shared", + viewType = "single", +}: UseInvocationResultArgs): UseInvocationResult { + const store = evalAtomStore() + + // Use provided runId or fallback to current run context (memoized to prevent infinite loops) + const contextRunId = useRunId() + const runId = useMemo(() => { + if (maybeRunId) return maybeRunId + if (contextRunId) return contextRunId + try { + return getCurrentRunId() + } catch (error) { + console.warn("[useInvocationResult] No run ID available:", error) + return null + } + }, [maybeRunId, contextRunId]) + + const evalType = useAtomValue(evalTypeAtom) + // Call all hooks before any early returns + const data = useEvalRunScenarioData(scenarioId, runId || "") + // Read from the same global store that writes are going to + const status = useAtomValue( + useMemo( + () => scenarioStatusAtomFamily({scenarioId, runId: runId || ""}), + [scenarioId, runId], + ), + {store}, + ) as any + + // Early return if no runId is available + if (!runId) { + return { + trace: undefined, + value: undefined, + rawValue: undefined, + messageNodes: null, + status: undefined, + } + } + + const { + trace: _trace, + value: derivedVal, + rawValue, + } = readInvocationResponse({ + scenarioData: data, + stepKey, + forceTrace: status?.trace, + optimisticResult: status?.result, + scenarioId, + }) + + const trace = status?.trace || _trace + // For auto evaluation only + const errorMessage = useMemo(() => { + if (evalType !== "auto") return "" + const findInvocation = data?.invocationSteps?.find((d) => d.scenarioId === scenarioId) + return findInvocation?.error?.stacktrace ?? "" + }, [data, scenarioId, evalType]) + + const {messageNodes, value, hasError} = useMemo(() => { + // Determine chat vs primitive + let messageNodes: React.ReactNode[] | null = null + let value: string | object | undefined = undefined + let hasError = false + + if (trace?.exception) { + value = trace?.exception?.message + hasError = true + } else if (errorMessage) { + value = errorMessage + hasError = true + } else { + const processChat = (jsonStr: string) => { + try { + const arr = JSON.parse(jsonStr) + if ( + Array.isArray(arr) && + arr.every((m: any) => "role" in m && "content" in m) + ) { + return renderChatMessages({ + keyPrefix: `${scenarioId}-${stepKey}`, + rawJson: jsonStr, + view: viewType, + editorType, + }) + } + + return null + } catch (err) {} + } + + if (rawValue) { + if (typeof rawValue === "string") { + messageNodes = processChat(rawValue) + if (!messageNodes) value = rawValue + } else if ( + typeof rawValue === "object" && + "role" in rawValue && + "content" in rawValue + ) { + messageNodes = renderChatMessages({ + keyPrefix: `${scenarioId}-${stepKey}-${trace?.trace_id ?? ""}`, + rawJson: JSON.stringify([rawValue]), + view: viewType, + editorType, + }) + } else { + value = rawValue as any + } + } + + if (!messageNodes) { + value = value ?? derivedVal + } + } + + return {messageNodes, value, hasError} + }, [trace, errorMessage]) + + return {trace, value, rawValue, messageNodes, status, hasError} +} diff --git a/web/ee/src/lib/hooks/useInvocationResult/types.ts b/web/ee/src/lib/hooks/useInvocationResult/types.ts new file mode 100644 index 0000000000..32646e425f --- /dev/null +++ b/web/ee/src/lib/hooks/useInvocationResult/types.ts @@ -0,0 +1,18 @@ +import {ScenarioStatusMap} from "../useEvaluationRunData/types" + +export interface UseInvocationResultArgs { + scenarioId: string + stepKey: string + runId?: string // Optional: for multi-run support + editorType?: "simple" | "shared" + viewType?: "single" | "table" +} + +export interface UseInvocationResult { + trace?: any + value?: string | object + rawValue?: any + messageNodes: React.ReactNode[] | null + status?: ScenarioStatusMap[string] + hasError?: boolean +} diff --git a/web/ee/src/lib/hooks/usePreviewEvaluations/assets/utils.ts b/web/ee/src/lib/hooks/usePreviewEvaluations/assets/utils.ts new file mode 100644 index 0000000000..3ffc624ec7 --- /dev/null +++ b/web/ee/src/lib/hooks/usePreviewEvaluations/assets/utils.ts @@ -0,0 +1,396 @@ +import {useCallback, useMemo} from "react" + +import {getDefaultStore} from "jotai" + +import {useAppId} from "@/oss/hooks/useAppId" +import {formatDay} from "@/oss/lib/helpers/dateTimeHelper" +import dayjs from "@/oss/lib/helpers/dateTimeHelper/dayjs" +import {RunIndex, StepMeta} from "@/oss/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex" +import useEvaluators from "@/oss/lib/hooks/useEvaluators" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import { + EnrichedEvaluationRun, + EvaluationRun, + IEvaluationRunDataStep, +} from "@/oss/lib/hooks/usePreviewEvaluations/types" +import useStatelessVariants from "@/oss/lib/hooks/useStatelessVariants" +import {EnhancedObjectConfig} from "@/oss/lib/shared/variant/genericTransformer/types" +import {AgentaConfigPrompt, EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import {WorkspaceMember, SnakeToCamelCaseKeys, PreviewTestSet} from "@/oss/lib/Types" +import {useAppList} from "@/oss/state/app/hooks" +import {transformedPromptsAtomFamily} from "@/oss/state/newPlayground/core/prompts" +import {variantFlagsAtomFamily} from "@/oss/state/newPlayground/core/variantFlags" +import {useOrgData} from "@/oss/state/org" +import {getProjectValues} from "@/oss/state/project" + +export const enrichEvaluationRun = ({ + run: _run, + testsets, + variantsData, + evaluators, + members, + runIndex, + extras, + appNameById, + projectScope = false, +}: { + run: SnakeToCamelCaseKeys + testsets: PreviewTestSet[] + variantsData: any + evaluators: EvaluatorDto[] + members: WorkspaceMember[] + runIndex?: RunIndex + extras?: { + parametersByRevisionId?: Record + flagsByRevisionId?: Record + variantConfigs?: Record + } + appNameById?: Map + projectScope?: boolean +}) => { + const run: Partial = _run + // Convert snake_case keys to camelCase recursively + run.createdAtTimestamp = dayjs(run.createdAt, "YYYY/MM/DD H:mm:ssAZ").valueOf() + // Format creation date for display + run.createdAt = formatDay({date: run.createdAt, outputFormat: "DD MMM YYYY | h:mm a"}) + // Derive potential ids via runIndex – allow multiple + const testsetIds: string[] = [] + const revisionIds: string[] = [] + + if (runIndex) { + for (const meta of Object.values(runIndex.steps) as StepMeta[]) { + if (meta.refs?.testset) { + testsetIds.push(meta.refs.testset.id) + } + if (meta.refs?.applicationRevision) { + revisionIds.push(meta.refs.applicationRevision.id) + } + } + } + + const uniqueTestsetIds = Array.from(new Set(testsetIds)) + const uniqueRevisionIds = Array.from(new Set(revisionIds)) + + // Resolve testset objects + const resolvedTestsets = testsets + ? (uniqueTestsetIds + .flatMap((id) => + testsets + ?.filter((ts) => ts.id === id) + .map((ts) => ({ + ...ts, + name: ts.name, + createdAt: ts.created_at, + createdAtTimestamp: dayjs( + ts.created_at, + "YYYY/MM/DD H:mm:ssAZ", + ).valueOf(), + })), + ) + .filter(Boolean) as PreviewTestSet[]) + : [] + + // Support both shapes: array or { variants: [...] } + const variantList: EnhancedVariant>[] = Array.isArray( + variantsData, + ) + ? variantsData + : (variantsData?.variants as EnhancedVariant>[]) || + [] + + const configVariants: EnhancedVariant>[] = + extras?.variantConfigs + ? Object.entries(extras.variantConfigs) + .map(([key, config]) => { + if (!config) return null + const variantRef = config.variant_ref || {} + const applicationRef = config.application_ref || {} + const id = variantRef.id || key + if (!id) return null + return { + id, + variantId: variantRef.id || id, + variantName: + variantRef.slug || + variantRef.id || + variantRef.name || + config?.service_ref?.slug || + key, + name: + variantRef.slug || + variantRef.id || + variantRef.name || + config?.service_ref?.slug || + key, + configName: variantRef.slug || variantRef.name, + appId: applicationRef?.id, + appSlug: applicationRef?.slug, + appStatus: undefined, + uri: config.url, + revision: variantRef.version ?? null, + revisionLabel: variantRef.version ?? null, + createdAtTimestamp: run.createdAtTimestamp, + createdAt: run.createdAt, + configParams: config.params, + } as any + }) + .filter(Boolean) + : [] + + const variantMap = new Map() + variantList.forEach((variant: any) => { + if (!variant?.id) return + variantMap.set(String(variant.id), variant) + }) + configVariants.forEach((variant: any) => { + if (!variant?.id) return + const key = String(variant.id) + if (!variantMap.has(key)) { + variantMap.set(key, variant) + return + } + const existing = variantMap.get(key) + variantMap.set(key, { + ...existing, + ...variant, + variantName: variant.variantName || existing?.variantName, + configName: variant.configName || existing?.configName, + name: variant.name || existing?.name, + }) + }) + const combinedVariantList: EnhancedVariant>[] = + Array.from(variantMap.values()) + + const filteredVariants = combinedVariantList.filter((v) => uniqueRevisionIds.includes(v.id)) + + const fallbackVariants = + filteredVariants.length || !runIndex + ? [] + : Array.from(runIndex.invocationKeys) + .map((key) => { + const meta = runIndex.steps[key] + if (!meta) return null + const refs = meta.refs || {} + const application = + refs.application || refs.applicationRevision?.application || {} + const revision = refs.applicationRevision || {} + + const appId = + application?.id || + application?.app_id || + application?.application_id || + revision?.application_id || + undefined + + const variantName = + application?.name || application?.slug || refs.variant?.name || meta.key + + const revisionId = + revision?.id || revision?.revision_id || revision?.revisionId || meta.key + + const revisionLabel = + revision?.name || revision?.revision || revision?.version || undefined + + return { + id: revisionId, + variantId: revisionId, + appId, + appName: application?.name, + variantName, + revision: revisionLabel, + revisionLabel, + createdAt: run.createdAt, + createdAtTimestamp: run.createdAtTimestamp, + } + }) + .filter((item): item is Record => Boolean(item)) + + const projectId = getProjectValues().projectId + + const baseVariants = filteredVariants.length ? filteredVariants : [] + const combinedVariants = (baseVariants.length ? baseVariants : fallbackVariants).map( + (variant) => { + const mappedName = variant.appId ? appNameById?.get(variant.appId) : undefined + if (mappedName && (!variant.appName || variant.appName === variant.appId)) { + return { + ...variant, + appName: mappedName, + } + } + return variant + }, + ) as typeof fallbackVariants + + const normalizedVariants = combinedVariants + .map((variant) => { + const fallbackId = + variant.id || variant.variantId || (variant as any).revisionId || undefined + if (fallbackId && variant.id !== fallbackId) { + return { + ...variant, + id: fallbackId, + variantId: variant.variantId || fallbackId, + } + } + return variant.id + ? variant + : { + ...variant, + variantId: variant.variantId || fallbackId, + id: fallbackId, + } + }) + .filter((variant) => Boolean(variant.id)) + + const primaryVariant = normalizedVariants[0] + + const returnValue = { + ...run, + appId: (run as any).appId || primaryVariant?.appId, + appName: (run as any).appName || primaryVariant?.appName, + variants: normalizedVariants, + testsets: resolvedTestsets, + createdBy: members.find((member) => member.user.id === run.createdById), + parametersByRevisionId: extras?.parametersByRevisionId || {}, + flagsByRevisionId: extras?.flagsByRevisionId || {}, + } + + normalizedVariants.forEach((variant: any) => { + const revisionKey = variant.id || variant.variantId + if (!revisionKey) return + if (variant.configParams) { + returnValue.parametersByRevisionId[revisionKey] = + returnValue.parametersByRevisionId[revisionKey] || variant.configParams + } + if (!returnValue.appId && variant.appId) { + returnValue.appId = variant.appId + } + if (!returnValue.appName && variant.appName) { + returnValue.appName = variant.appName + } + }) + if (!returnValue.appName && returnValue.appId && appNameById) { + const mappedName = appNameById.get(returnValue.appId) + if (mappedName) { + returnValue.appName = mappedName + } + } + if (runIndex) { + // Find all annotation steps via index if available + const annotationSteps = Array.from(runIndex.annotationKeys) + .map((k) => { + // locate original step for richer data + return (run.data?.steps || []).find((s) => s.key === k) as + | IEvaluationRunDataStep + | undefined + }) + .filter(Boolean) + + // Extract all evaluator slugs or IDs from those steps + const evaluatorRefs = annotationSteps + .map((step) => step?.references?.evaluator?.id) + .filter((id): id is string => !!id) + // Match evaluator objects using slug or id + const matchedEvaluators = evaluatorRefs + .map((id: string) => evaluators?.find((e) => e.slug === id || e.id === id)) + .filter(Boolean) + + returnValue.evaluators = matchedEvaluators as EvaluatorDto[] + } + + return returnValue as EnrichedEvaluationRun +} + +const useEnrichEvaluationRun = ({ + evalType = "human", +}: { + evalType?: "human" | "auto" +} = {}): + | (( + run: SnakeToCamelCaseKeys, + testsetData?: PreviewTestSet[], + runIndex?: RunIndex, + ) => EnrichedEvaluationRun) + | undefined => { + const {selectedOrg} = useOrgData() + const members = selectedOrg?.default_workspace?.members || [] + const routeAppId = useAppId() + const isProjectScope = !routeAppId + const appList = useAppList() + const appNames = useMemo(() => { + return new Map((appList || []).map((item) => [item.app_id, item.app_name])) + }, [appList]) + + const {data: evaluators, isLoading: _loadingEvaluators} = useEvaluators({ + preview: true, + queries: {is_human: evalType === "human"}, + }) + const {revisions: variantsData, isLoading: _variantsLoading} = useStatelessVariants({ + lightLoading: true, + }) + const effectiveVariantsData = isProjectScope ? (variantsData ?? []) : variantsData + + const enrichRun = useCallback( + ( + run: SnakeToCamelCaseKeys, + testsetData?: PreviewTestSet[], + runIndex?: RunIndex, + options?: {variantConfigs?: Record}, + ) => { + // Derive transformed parameters and flags per revision on-demand from atoms + const store = getDefaultStore() + const revisionIds: string[] = runIndex + ? Array.from( + new Set( + Object.values(runIndex.steps) + .map((m: any) => m?.refs?.applicationRevision?.id) + .filter(Boolean) as string[], + ), + ) + : [] + + const parametersByRevisionId: Record = {} + const flagsByRevisionId: Record = {} + for (const rid of revisionIds) { + parametersByRevisionId[rid] = store.get( + transformedPromptsAtomFamily({revisionId: rid, useStableParams: true}), + ) + flagsByRevisionId[rid] = store.get(variantFlagsAtomFamily({revisionId: rid})) + } + + const result = enrichEvaluationRun({ + run, + testsets: testsetData || [], + variantsData: effectiveVariantsData || [], + evaluators: (evaluators as EvaluatorDto[]) || [], + members, + runIndex, + extras: { + parametersByRevisionId, + flagsByRevisionId, + variantConfigs: options?.variantConfigs, + }, + projectScope: isProjectScope, + appNameById: appNames, + }) as EnrichedEvaluationRun + + if (process.env.NODE_ENV !== "production") { + const variantSummary = (result?.variants || []).map((v: any) => ({ + id: v?.id, + variantId: v?.variantId, + name: v?.variantName ?? v?.name, + appStatus: v?.appStatus, + })) + } + + return result + }, + [effectiveVariantsData, evaluators, members, isProjectScope, appNames], + ) + + const evaluatorsReady = Array.isArray(evaluators) + + return !_variantsLoading && evaluatorsReady ? enrichRun : undefined +} + +export default useEnrichEvaluationRun diff --git a/web/ee/src/lib/hooks/usePreviewEvaluations/index.ts b/web/ee/src/lib/hooks/usePreviewEvaluations/index.ts new file mode 100644 index 0000000000..65438ca047 --- /dev/null +++ b/web/ee/src/lib/hooks/usePreviewEvaluations/index.ts @@ -0,0 +1,459 @@ +import {useCallback, useEffect, useMemo} from "react" + +import {useAtomValue, useSetAtom} from "jotai" +import {atomFamily} from "jotai/utils" +import {atomWithQuery} from "jotai-tanstack-query" +import {useSWRConfig} from "swr" +import {v4 as uuidv4} from "uuid" + +import {useAppId} from "@/oss/hooks/useAppId" +import axios from "@/oss/lib/api/assets/axiosConfig" +import {EvaluationType} from "@/oss/lib/enums" +import {snakeToCamelCaseKeys} from "@/oss/lib/helpers/casing" +import useEvaluators from "@/oss/lib/hooks/useEvaluators" +import {EvaluationStatus, SnakeToCamelCaseKeys, TestSet} from "@/oss/lib/Types" +import {slugify} from "@/oss/lib/utils/slugify" +import {createEvaluationRunConfig} from "@/oss/services/evaluationRuns/api" +import {CreateEvaluationRunInput} from "@/oss/services/evaluationRuns/api/types" +import {fetchTestset} from "@/oss/services/testsets/api" +import {getProjectValues} from "@/oss/state/project" +import { + prefetchProjectVariantConfigs, + setProjectVariantReferencesAtom, +} from "@/oss/state/projectVariantConfig" +import {usePreviewTestsetsData, useTestsetsData} from "@/oss/state/testset" + +import {buildRunIndex} from "../useEvaluationRunData/assets/helpers/buildRunIndex" +import {getEvaluationRunScenariosKey} from "../useEvaluationRunScenarios" + +import useEnrichEvaluationRun from "./assets/utils" +import {collectProjectVariantReferences} from "./projectVariantConfigs" + +const EMPTY_RUNS: any[] = [] +interface PreviewEvaluationRunsData { + runs: SnakeToCamelCaseKeys[] + count: number +} + +interface PreviewEvaluationRunsQueryParams { + projectId?: string + appId?: string + searchQuery?: string + references: any[] + typesKey: string + debug: boolean + enabled: boolean +} + +const previewEvaluationRunsQueryAtomFamily = atomFamily((serializedParams: string) => + atomWithQuery(() => { + const params = JSON.parse(serializedParams) as PreviewEvaluationRunsQueryParams + const {projectId, appId, searchQuery, references, typesKey, debug, enabled} = params + + return { + queryKey: [ + "previewEvaluationRuns", + projectId ?? "none", + appId ?? "all", + typesKey, + searchQuery ?? "", + JSON.stringify(references ?? []), + ], + enabled, + refetchOnWindowFocus: false, + refetchOnReconnect: false, + queryFn: async () => { + if (!projectId) { + return {runs: [], count: 0} + } + + const payload: Record = { + run: {}, + } + payload.run.references = references ?? [] + if (searchQuery) { + payload.run.search = searchQuery + } + + const queryParams: Record = {project_id: projectId} + if (appId) queryParams.app_id = appId + + const response = await axios.post(`/preview/evaluations/runs/query`, payload, { + params: queryParams, + }) + + return { + runs: (response.data?.runs || []).map((run: EvaluationRun) => + snakeToCamelCaseKeys(run), + ), + count: response.data?.count || 0, + } + }, + } + }), +) + +interface PreviewEvaluationsQueryState { + data?: PreviewEvaluationRunsData + mutate: () => Promise + refetch: () => Promise + isLoading: boolean + isPending: boolean + isError: boolean + error: unknown +} +import {searchQueryAtom} from "./states/queryFilterAtoms" +import {EnrichedEvaluationRun, EvaluationRun} from "./types" + +const SCENARIOS_ENDPOINT = "/preview/evaluations/scenarios/" + +/** + * Custom hook to manage and enrich preview evaluation runs. + * Fetches preview runs via a shared atom query, enriches them with related metadata (testset, variant, evaluators), + * and sorts them by creation timestamp descending. + * + * @param skip - Optional flag to skip fetching preview evaluations. + * @returns Object containing SWR response, enriched runs, and a function to trigger new evaluation creation. + */ +const usePreviewEvaluations = ({ + skip, + types: propsTypes = [], + debug, + appId: appIdOverride, +}: { + skip?: boolean + types?: EvaluationType[] + debug?: boolean + appId?: string | null +} = {}): { + swrData: PreviewEvaluationsQueryState + createNewRun: (paramInputs: CreateEvaluationRunInput) => Promise + runs: EnrichedEvaluationRun[] +} => { + // atoms + const searchQuery = useAtomValue(searchQueryAtom) + const projectId = getProjectValues().projectId + + const debugEnabled = debug ?? process.env.NODE_ENV !== "production" + + const types = useMemo(() => { + return propsTypes.map((type) => { + switch (type) { + case EvaluationType.single_model_test: + case EvaluationType.human: + return EvaluationType.human + case EvaluationType.auto_exact_match: + case EvaluationType.automatic: + return EvaluationType.automatic + default: + return type + } + }) + }, [propsTypes]) + const {mutate: globalMutate} = useSWRConfig() + const routeAppId = useAppId() + const appId = (appIdOverride ?? routeAppId) || undefined + + const {data: humanEvaluators} = useEvaluators({ + preview: true, + queries: { + is_human: !types.includes(EvaluationType.automatic), + }, + }) + + const referenceFilters = useMemo(() => { + const filters: any[] = [] + if (appId) { + filters.push({ + application: {id: appId}, + }) + } + if (types.includes(EvaluationType.human)) { + if (Array.isArray(humanEvaluators) && humanEvaluators.length > 0) { + humanEvaluators.forEach((ev) => { + filters.push({ + evaluator: {id: ev.id}, + }) + }) + } else { + filters.push({ + evaluator: {}, + }) + } + } + return filters + }, [appId, humanEvaluators, types]) + + const typesKey = useMemo(() => types.slice().sort().join("|"), [types]) + const queryEnabled = !skip && Boolean(projectId) + + const serializedQueryParams = useMemo( + () => + JSON.stringify({ + projectId, + appId, + searchQuery, + references: referenceFilters, + typesKey, + debug: debugEnabled, + enabled: queryEnabled, + }), + [projectId, appId, searchQuery, referenceFilters, typesKey, debugEnabled, queryEnabled], + ) + + const evaluationRunsAtom = useMemo( + () => previewEvaluationRunsQueryAtomFamily(serializedQueryParams), + [serializedQueryParams], + ) + + const evaluationRunsQuery = useAtomValue(evaluationRunsAtom) + + const rawRuns = queryEnabled ? (evaluationRunsQuery.data?.runs ?? EMPTY_RUNS) : EMPTY_RUNS + + const evaluationRunsState = useMemo(() => { + const isPending = (evaluationRunsQuery as any).isPending ?? false + const isLoading = + (evaluationRunsQuery as any).isLoading ?? + (evaluationRunsQuery as any).isFetching ?? + isPending + const data = queryEnabled ? evaluationRunsQuery.data : {runs: [], count: 0} + return { + data, + mutate: async () => evaluationRunsQuery.refetch(), + refetch: evaluationRunsQuery.refetch, + isLoading, + isPending, + isError: queryEnabled ? ((evaluationRunsQuery as any).isError ?? false) : false, + error: queryEnabled ? evaluationRunsQuery.error : undefined, + } + }, [evaluationRunsQuery, queryEnabled]) + const setProjectVariantReferences = useSetAtom(setProjectVariantReferencesAtom) + + useEffect(() => { + if (!projectId) { + setProjectVariantReferences([]) + return + } + if (appId) { + setProjectVariantReferences([]) + return + } + const references = collectProjectVariantReferences(rawRuns, projectId) + setProjectVariantReferences(references) + prefetchProjectVariantConfigs(references) + }, [appId, projectId, rawRuns, setProjectVariantReferences]) + /** + * Hook to fetch testsets data. + */ + const {testsets} = useTestsetsData() + const {testsets: previewTestsets} = usePreviewTestsetsData() + const enrichRun = useEnrichEvaluationRun({ + evalType: types.includes(EvaluationType.automatic) ? "auto" : "human", + }) + + /** + * Helper to create scenarios for a given run and testset. + * Each CSV row becomes its own scenario. + */ + const createScenarios = useCallback( + async ( + runId: string, + testset: TestSet & {data: {testcaseIds?: string[]; testcases?: {id: string}[]}}, + ): Promise => { + if (!testset?.id) { + throw new Error(`Testset with id ${testset.id} not found.`) + } + + // 1. Build payload: each row becomes a scenario + const payload = { + scenarios: ( + testset.data.testcaseIds ?? + testset.data.testcases?.map((tc) => tc.id) ?? + [] + ).map((_id, index) => ({ + run_id: runId, + // meta: {index}, + })), + } + + // 2. Invoke the scenario endpoint + const response = await axios.post(SCENARIOS_ENDPOINT, payload) + + // Extract and return new scenario IDs + return response.data.scenarios.map((s: any) => s.id) + }, + [testsets, debug], + ) + + /** + * Helper to compute enriched and sorted runs (lazy) when accessed. + */ + const computeRuns = useCallback((): EnrichedEvaluationRun[] => { + if (!rawRuns.length || !enrichRun) return [] + const enriched: EnrichedEvaluationRun[] = rawRuns + .map((_run) => { + const runClone = structuredClone(_run) + const runIndex = buildRunIndex(runClone) + return enrichRun(runClone, previewTestsets?.testsets || [], runIndex) + }) + .filter((run): run is EnrichedEvaluationRun => Boolean(run)) + + // Sort enriched runs by timestamp, descending + return enriched.sort((a, b) => { + const tA = new Date(a.createdAtTimestamp || 0).getTime() + const tB = new Date(b.createdAtTimestamp || 0).getTime() + return tB - tA + }) + }, [rawRuns, previewTestsets, enrichRun, debug]) + + const createNewRun = useCallback( + async (paramInputs: CreateEvaluationRunInput) => { + // JIT migrate old testsets before creating a new run + if (!paramInputs.testset || !paramInputs.testset._id) { + throw new Error("Testset is required and must have an _id for migration.") + } + try { + // 1. Converts the old testset to the new format + const existingPreviewQuery = await axios.get( + `/preview/simple/testsets/${paramInputs.testset._id}`, + ) + const existingQuery = await fetchTestset(paramInputs.testset._id, false) + const existingPreview = existingPreviewQuery.data?.testset + const existing = existingQuery + let testset + if (!existingPreview) { + const result = await axios.post( + `/preview/simple/testsets/${paramInputs.testset._id}/transfer`, + ) + testset = result.data.testset + } else { + testset = existingPreview + } + + if (testset) { + paramInputs.testset = snakeToCamelCaseKeys(testset) + } + } catch (migrationErr: any) { + throw new Error( + `Failed to migrate testset before creating run: ${migrationErr?.message || migrationErr}`, + ) + } + + // 2. Creates the the payload schema + const params = createEvaluationRunConfig(paramInputs) + + // 3. Invokes run endpoint + const response = await axios.post("/preview/evaluations/runs/", params) + + // Extract the newly created runId + const runId = response.data.runs?.[0]?.id + if (!runId) { + throw new Error("createNewRun: runId not returned in response.") + } + // Now create scenarios for each row in the specified testset + if (!paramInputs.testset) { + throw new Error("Testset is required to create scenarios") + } + // 4. Creates the scenarios + const scenarioIds = await createScenarios(runId, paramInputs.testset) + + // Fire off input, invocation, and annotation steps together in one request (non-blocking) + try { + // const repeatId = uuidv4() + // const retryId = uuidv4() + // 5. First generate step keys & IDs per scenario + const revision = paramInputs.revisions?.[0] + const evaluators = paramInputs.evaluators || [] + const inputKey = slugify( + paramInputs.testset.name ?? paramInputs.testset.slug ?? "testset", + paramInputs.testset.id, + ) + const invocationKey = revision + ? slugify( + (revision as any).name ?? + (revision as any).variantName ?? + (revision as any)._parentVariant?.variantName ?? + "invocation", + revision.id, + ) + : "invocation" + + const scenarioStepsData = scenarioIds.map((scenarioId, index) => { + const hashId = uuidv4() + return { + testcaseId: + paramInputs.testset?.data?.testcaseIds?.[index] ?? + paramInputs.testset?.data?.testcases?.[index]?.id, + scenarioId, + hashId, + } + }) + + // 6. Build a single steps array combining input, invocation, and evaluator steps + const allSteps = scenarioStepsData.flatMap( + ({scenarioId, testcaseId, repeatId, retryIdInput, hashId}) => { + const base = { + testcase_id: testcaseId, + scenario_id: scenarioId, + run_id: runId, + } + const stepsArray: any[] = [ + { + ...base, + status: EvaluationStatus.SUCCESS, + step_key: inputKey, + }, + { + ...base, + step_key: invocationKey, + }, + ] + + evaluators.forEach((ev) => { + stepsArray.push({ + ...base, + step_key: `${invocationKey}.${ev.slug}`, + }) + }) + return stepsArray + }, + ) + // 7. Invoke the /results endpoint + await axios + .post(`/preview/evaluations/results/?project_id=${projectId}`, { + results: allSteps, + }) + .then((res) => { + // Revalidate scenarios data + globalMutate(getEvaluationRunScenariosKey(runId)) + }) + .catch((err) => { + console.error( + "[usePreviewEvaluations] createNewRun: failed to create steps", + err, + ) + }) + } catch (err) { + console.error("[usePreviewEvaluations] createNewRun: error scheduling steps", err) + } + // 8. Refresh SWR data for runs + await evaluationRunsState.mutate() + // Return both run response and scenario IDs + return { + run: response.data, + scenarios: scenarioIds, + } + }, + [debug, createScenarios, globalMutate, evaluationRunsState, projectId, appId], + ) + + return { + swrData: evaluationRunsState, + createNewRun, + get runs() { + return enrichRun ? computeRuns() : [] + }, + } +} + +export default usePreviewEvaluations diff --git a/web/ee/src/lib/hooks/usePreviewEvaluations/projectVariantConfigs.ts b/web/ee/src/lib/hooks/usePreviewEvaluations/projectVariantConfigs.ts new file mode 100644 index 0000000000..1961184254 --- /dev/null +++ b/web/ee/src/lib/hooks/usePreviewEvaluations/projectVariantConfigs.ts @@ -0,0 +1,131 @@ +import {ProjectVariantConfigKey} from "@/oss/state/projectVariantConfig" + +interface InvocationReference { + appId?: string + appSlug?: string + revisionId?: string + revisionVersion?: number | null + variantSlug?: string + fallbackKey?: string +} + +const normalizeReference = (refs: any, fallbackKey?: string): InvocationReference | null => { + if (!refs) return null + + const applicationRevision = + refs.applicationRevision || refs.application_revision || refs.application_ref?.revision + const applicationRef = + refs.application || + applicationRevision?.application || + refs.application_ref || + refs.applicationRef + const variantRef = refs.variant || refs.variant_ref || refs.variantRef + + const appId = + applicationRef?.id || + applicationRevision?.application_id || + applicationRevision?.applicationId + const appSlug = applicationRef?.slug || applicationRef?.name + + const revisionId = + applicationRevision?.id || + applicationRevision?.revisionId || + applicationRevision?.revision_id || + variantRef?.id || + variantRef?.revisionId || + variantRef?.revision_id + const revisionVersion = + applicationRevision?.revision ?? + applicationRevision?.version ?? + variantRef?.version ?? + variantRef?.revision + let variantSlug = + variantRef?.slug || variantRef?.name || variantRef?.variantName || variantRef?.variant_name + + if (!variantSlug) { + variantSlug = + refs.application?.slug || + refs.application?.name || + refs.applicationRef?.slug || + refs.applicationRef?.name || + fallbackKey + } + + if (!appId && !appSlug) return null + + return { + appId, + appSlug, + revisionId: revisionId || fallbackKey, + revisionVersion, + variantSlug: variantSlug || fallbackKey, + fallbackKey, + } +} + +const extractInvocationReference = (run: any): InvocationReference | null => { + const steps: any[] = run?.data?.steps || [] + const invocationStep = steps.find((step: any) => { + if (step?.type === "invocation") return true + const refs = step?.references ?? step + return Boolean( + refs?.application || + refs?.applicationRevision || + refs?.application_revision || + refs?.applicationRef || + refs?.application_ref, + ) + }) + + if (!invocationStep) return null + const refs = invocationStep.references ?? invocationStep + return normalizeReference(refs, invocationStep.key) +} + +export const collectProjectVariantReferences = ( + runs: any[], + projectId?: string, +): ProjectVariantConfigKey[] => { + if (!Array.isArray(runs) || !projectId) return [] + const collected = new Map() + + runs.forEach((run) => { + const invocation = extractInvocationReference(run) + let reference: ProjectVariantConfigKey | undefined + + if (invocation) { + reference = { + projectId, + appId: invocation.appId, + appSlug: invocation.appSlug, + variantId: invocation.revisionId, + variantSlug: invocation.variantSlug, + variantVersion: invocation.revisionVersion ?? null, + } + } else if (Array.isArray((run as any)?.variants) && (run as any).variants.length) { + const variant = (run as any).variants[0] + reference = { + projectId, + appId: variant?.appId || variant?.app_id, + appSlug: variant?.appSlug || variant?.app_slug, + variantId: variant?.id || variant?.revisionId || variant?.revision_id, + variantSlug: + variant?.variantSlug || variant?.variantName || variant?.slug || variant?.name, + variantVersion: + (variant?.revision as number | null | undefined) ?? + (variant?.revisionLabel as number | string | null | undefined) ?? + null, + } + } + + if (!reference) return + if (!reference.variantId && !reference.variantSlug) return + + const key = JSON.stringify(reference) + if (!collected.has(key)) { + collected.set(key, reference) + } + }) + + return Array.from(collected.values()) +} diff --git a/web/ee/src/lib/hooks/usePreviewEvaluations/states/queryFilterAtoms.ts b/web/ee/src/lib/hooks/usePreviewEvaluations/states/queryFilterAtoms.ts new file mode 100644 index 0000000000..5a99f2bcf5 --- /dev/null +++ b/web/ee/src/lib/hooks/usePreviewEvaluations/states/queryFilterAtoms.ts @@ -0,0 +1,7 @@ +import {atom} from "jotai" + +// search query atom +export const searchQueryAtom = atom("") + +// pagination atom +export const paginationAtom = atom({size: 20, page: 1}) diff --git a/web/ee/src/lib/hooks/usePreviewEvaluations/types.ts b/web/ee/src/lib/hooks/usePreviewEvaluations/types.ts new file mode 100644 index 0000000000..f8b684c348 --- /dev/null +++ b/web/ee/src/lib/hooks/usePreviewEvaluations/types.ts @@ -0,0 +1,84 @@ +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import {PreviewTestSet, SnakeToCamelCaseKeys, WorkspaceMember} from "@/oss/lib/Types" + +/** + * Interface representing a single evaluation run as returned from the backend API. + * Contains metadata and structured evaluation logic steps including input, + * invocation (application), and annotation (evaluation) stages. + */ + +export type EvaluationRunDataStep = + | { + /** First step: define the test input and optionally the testset variant/revision */ + key: string + type: "input" + /** References to testset and optionally its variant/revision */ + references: Record + } + | { + /** Invocation step: connects the application variant to the input */ + key: string + type: "invocation" + /** Defines which previous steps this step takes input from */ + inputs: {key: string}[] + /** References to application, variant, and revision IDs */ + references: Record + } + | { + /** Annotation step: applies an evaluator to the input + invocation results */ + key: string + type: "annotation" + /** Usually takes input from both the "input" and "invocation" steps */ + inputs: {key: string}[] + /** References to evaluator slug and evaluator variant ID */ + references: Record + } + +export type IEvaluationRunDataStep = SnakeToCamelCaseKeys +export interface EvaluationRun { + /** Unique identifier for the evaluation run */ + id: string + /** Display name for the run */ + name: string + /** Optional description text for the run */ + description: string + /** ISO timestamp of when the run was created */ + created_at: string + /** ID of the user who created the run */ + created_by_id: string + /** Optional metadata object (arbitrary key-value pairs) */ + meta: Record + /** Flags associated with the run (internal use) */ + flags: Record + /** Current status of the run (e.g., "pending", "completed") */ + status: string + data: { + /** Array of evaluation steps that define execution flow */ + steps: EvaluationRunDataStep[] + /** Mappings define how to extract values from steps for display or evaluation */ + mappings: { + /** Type of the mapping, determines what the value represents */ + kind: "input" | "ground_truth" | "application" | "evaluator" + /** Display name for the mapped value */ + name: string + /** Path reference to the data inside a step */ + step: { + /** The step key this mapping belongs to */ + key: string + /** Path within the step data (e.g., 'country' or 'data.outputs.metric') */ + path: string + } + }[] + } +} + +export interface EnrichedEvaluationRun extends SnakeToCamelCaseKeys { + /** All distinct testsets referenced in this run */ + testsets: PreviewTestSet[] + createdBy?: WorkspaceMember + createdAtTimestamp?: number + /** All distinct application revisions (variants) referenced */ + variants?: EnhancedVariant[] + evaluators?: EvaluatorDto[] +} diff --git a/web/ee/src/lib/hooks/usePreviewRunningEvaluations/index.ts b/web/ee/src/lib/hooks/usePreviewRunningEvaluations/index.ts new file mode 100644 index 0000000000..d73f381a85 --- /dev/null +++ b/web/ee/src/lib/hooks/usePreviewRunningEvaluations/index.ts @@ -0,0 +1,58 @@ +import {atomFamily} from "jotai/utils" +import {atomWithQuery} from "jotai-tanstack-query" + +import axios from "@/oss/lib/api/assets/axiosConfig" +import {EvaluationStatus} from "@/oss/lib/Types" +import {getProjectValues} from "@/oss/state/project" + +import {EnrichedEvaluationRun} from "../usePreviewEvaluations/types" + +const REFETCH_INTERVAL = 10000 + +export const resourceStatusQueryFamily = atomFamily((id) => + atomWithQuery((get) => { + const projectId = getProjectValues().projectId + + return { + queryKey: ["resourceStatus", id, projectId], + queryFn: async () => { + const res = await axios.get( + `/preview/evaluations/runs/${id}?project_id=${projectId}`, + ) + return res.data + }, + + // Poll every 5s until success; then stop polling. + refetchInterval: (query) => { + const data = query.state.data as EnrichedEvaluationRun | undefined + + if ( + ![ + EvaluationStatus.PENDING, + EvaluationStatus.RUNNING, + EvaluationStatus.CANCELLED, + EvaluationStatus.INITIALIZED, + ].includes(data?.run?.status) + ) + return false // stop polling + return REFETCH_INTERVAL // keep polling + }, + + enabled: Boolean(id) && Boolean(projectId), + + // Avoid accidental refetches after success + refetchOnWindowFocus: false, + refetchOnReconnect: false, + + // Reasonable cache/stale settings + staleTime: 10_000, + gcTime: 5 * 60 * 1000, + } + }), +) + +// export const allResourceStatusesAtom = atom((get) => { +// const ids = get(runningEvaluationIdsAtom) +// const uniqueIds = Array.from(new Set(ids)) +// return uniqueIds.map((id) => get(resourceStatusQueryFamily(id))) +// }) diff --git a/web/ee/src/lib/hooks/usePreviewRunningEvaluations/states/runningEvalAtom.ts b/web/ee/src/lib/hooks/usePreviewRunningEvaluations/states/runningEvalAtom.ts new file mode 100644 index 0000000000..058724f731 --- /dev/null +++ b/web/ee/src/lib/hooks/usePreviewRunningEvaluations/states/runningEvalAtom.ts @@ -0,0 +1,10 @@ +import {atom} from "jotai" + +import {EnrichedEvaluationRun} from "../../usePreviewEvaluations/types" + +// Collect all the running evaluation ids +export const runningEvaluationIdsAtom = atom([]) + +// This atom collects the running evaluations a store it temporarily +// until we fix the issue on backend +export const tempEvaluationAtom = atom([]) diff --git a/web/ee/src/lib/hooks/useRunMetricsMap/index.ts b/web/ee/src/lib/hooks/useRunMetricsMap/index.ts new file mode 100644 index 0000000000..2595ab2e55 --- /dev/null +++ b/web/ee/src/lib/hooks/useRunMetricsMap/index.ts @@ -0,0 +1,171 @@ +import useSWR from "swr" + +import axios from "@/oss/lib/api/assets/axiosConfig" +import {METRICS_ENDPOINT, computeRunMetrics} from "@/oss/services/runMetrics/api" +import {BasicStats} from "@/oss/services/runMetrics/api/types" + +import type {MetricResponse} from "../useEvaluationRunMetrics/types" + +// Returns aggregated advanced stats per run +const fetchRunMetricsMap = async ( + runIds: string[], + evaluatorSlugs: Set | undefined, +): Promise>> => { + // POST to query endpoint with body { metrics: { run_ids: [...] }, windowing: {...} } + const endpoint = `${METRICS_ENDPOINT}query` + const body: Record = { + metrics: { + ...(Array.isArray(runIds) && runIds.length ? {run_ids: runIds} : {}), + }, + windowing: {}, + } + const res = await axios.post(endpoint, body) + + const rawMetrics: MetricResponse[] = Array.isArray(res.data?.metrics) ? res.data.metrics : [] + + // Process evaluator metrics to ensure they have the correct prefix important for auto eval + const processedMetrics = rawMetrics.map((metric) => { + if (!metric.data) return metric + + const processedData: Record = {} + + // add evaluator metrics to processed data + Object.entries(metric.data as Record>).forEach( + ([stepKey, stepData]) => { + const parts = stepKey.split(".") + if (parts.length === 1) { + const slug = parts[0] + if (evaluatorSlugs?.has(slug)) { + // This is an evaluator metric, ensure all keys are prefixed + const newStepData: Record = {} + Object.entries(stepData).forEach(([key, value]) => { + const prefixedKey = key.startsWith(`${slug}.`) ? key : `${slug}.${key}` + newStepData[prefixedKey] = value + }) + processedData[stepKey] = newStepData + } else { + // Keep non-evaluator data as is + processedData[stepKey] = stepData + } + } else { + // Keep invocation data as is + processedData[stepKey] = stepData + } + }, + ) + + return { + ...metric, + data: processedData, + } + }) + + // Helper to classify & flatten metric payload (mirrors fetchRunMetrics.worker) + const transformData = (data: Record): Record => { + const flat: Record = {} + Object.entries(data || {}).forEach(([stepKey, metrics]) => { + // Pass-through for analytics keys like ag.metrics.* + if (stepKey.startsWith("ag.")) { + const raw = metrics + let value: any = raw + if (typeof raw === "object" && raw !== null) { + if ("mean" in raw) value = (raw as any).mean + else if ("value" in raw) value = (raw as any).value + } + flat[stepKey] = value + return + } + + const isAnalyticsPath = stepKey.startsWith("attributes.ag.") + + const parts = stepKey.split(".") + const isInvocation = parts.length === 1 + const slug = isInvocation ? undefined : parts[1] + Object.entries(metrics as Record).forEach(([metricKey, raw]) => { + let value: any = raw + if (typeof raw === "object" && raw !== null) { + if ("mean" in raw) { + value = (raw as any).mean + } else if ("value" in raw) { + value = (raw as any).value + } else if ("freq" in raw || "uniq" in raw) { + const mapped: any = {...raw} + if (Array.isArray(mapped.freq)) { + mapped.frequency = mapped.freq.map((entry: any) => ({ + value: entry?.value, + count: entry?.count ?? entry?.frequency ?? 0, + })) + mapped.frequency.sort( + (a: any, b: any) => + b.count - a.count || (a.value === true ? -1 : 1), + ) + delete mapped.freq + mapped.rank = mapped.frequency + } + if (Array.isArray(mapped.uniq)) { + mapped.unique = mapped.uniq + delete mapped.uniq + } + if (!Array.isArray(mapped.unique) && Array.isArray(mapped.frequency)) { + mapped.unique = mapped.frequency.map((entry: any) => entry.value) + } + value = mapped + } + } + if (isAnalyticsPath) { + const canonicalKey = `${stepKey}.${metricKey}` + flat[canonicalKey] = value + // Legacy fallback for evaluator display (last segment only) + const legacyKey = metricKey + if (!(legacyKey in flat)) { + flat[legacyKey] = value + } + } else if (isInvocation) { + let newKey = metricKey + if (metricKey.startsWith("tokens.")) { + newKey = metricKey.slice(7) + "Tokens" // tokens.prompt -> promptTokens + } else if (metricKey.startsWith("cost")) { + newKey = "totalCost" + } + flat[newKey] = value + } else { + const pref = slug ? `${slug}.` : "" + flat[`${pref}${metricKey}`] = value + } + }) + }) + return flat + } + + const buckets: Record}[]> = {} + processedMetrics.forEach((m) => { + const metric = m + if (!metric.scenario_id || !metric.run_id) return + const key = metric.run_id + if (!buckets[key]) buckets[key] = [] + const flattened = transformData(metric.data as any) + buckets[key].push({data: flattened}) + }) + + const result: Record> = {} + Object.entries(buckets).forEach(([runId, entries]) => { + const agg = computeRunMetrics(entries) + result[runId] = agg + }) + + return result +} + +const useRunMetricsMap = ( + runIds: string[] | undefined, + evaluatorSlugs: Set | undefined, +) => { + const swrKey = runIds && runIds.length ? ["runMetricsMap", ...runIds] : null + const {data, error, isLoading} = useSWR>>( + swrKey, + () => fetchRunMetricsMap(runIds!, evaluatorSlugs!), + ) + return {data, isLoading, isError: !!error} +} + +export default useRunMetricsMap diff --git a/web/ee/src/lib/metricColumnFactory.tsx b/web/ee/src/lib/metricColumnFactory.tsx new file mode 100644 index 0000000000..72503b77c3 --- /dev/null +++ b/web/ee/src/lib/metricColumnFactory.tsx @@ -0,0 +1,112 @@ +import React from "react" + +import {ColumnsType} from "antd/es/table" + +import {MetricDetailsPopoverWrapper} from "@/oss/components/HumanEvaluations/assets/MetricDetailsPopover" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {buildMetricSorter} from "@/oss/lib/metricSorter" +import { + isSortableMetricType, + BasicStats, + canonicalizeMetricKey, + getMetricValueWithAliases, +} from "@/oss/lib/metricUtils" + +const resolveMetricStats = ( + metrics: Record | undefined, + candidates: (string | undefined)[], + fallbackSuffix?: string, +): BasicStats | undefined => { + if (!metrics) return undefined + const allCandidates = [...candidates] + if (fallbackSuffix) { + candidates.forEach((key) => { + if (!key || key.endsWith(fallbackSuffix)) return + allCandidates.push(`${key}.${fallbackSuffix}`) + }) + } + for (const key of allCandidates) { + if (!key) continue + if (metrics[key]) return metrics[key] + const canonical = canonicalizeMetricKey(key) + if (canonical !== key && metrics[canonical]) return metrics[canonical] + const alias = getMetricValueWithAliases(metrics, key) + if (alias) return alias + } + return undefined +} + +import {EvaluationRow} from "../components/HumanEvaluations/types" + +export interface BuildEvaluatorMetricColumnsParams { + evaluator: EvaluatorDto + runMetricsMap?: Record> + hidePrimitiveTable?: boolean + debug?: boolean +} + +export function buildEvaluatorMetricColumns({ + evaluator, + runMetricsMap, + hidePrimitiveTable = false, + debug = false, +}: BuildEvaluatorMetricColumnsParams): ColumnsType { + const metricKeys = Object.keys(evaluator.metrics || {}) + return metricKeys.map((metricKey) => { + const schemaType = evaluator.metrics?.[metricKey]?.type + const sortable = isSortableMetricType(schemaType) + + const analyticsCandidates = [ + `attributes.ag.data.outputs.${metricKey}`, + `attributes.ag.metrics.${metricKey}`, + ] + const baseCandidates = [ + `${evaluator.slug}.${metricKey}`, + metricKey, + ...analyticsCandidates, + ...analyticsCandidates.map((path) => `${evaluator.slug}.${path}`), + ] + + return { + key: `${evaluator.slug}:${metricKey}`, + dataIndex: metricKey, + title: ( +
    + {metricKey} +
    + ), + sorter: sortable + ? buildMetricSorter((row) => { + const runId = "id" in row ? row.id : (row as any).key + const metrics = runMetricsMap?.[runId] + return resolveMetricStats(metrics, baseCandidates) + }) + : undefined, + render: (_: any, record: EvaluationRow) => { + const hasEvaluator = Array.isArray((record as any).evaluators) + ? (record as any).evaluators.some( + (e: EvaluatorDto) => e.slug === evaluator.slug, + ) + : false + + const runMetric = + runMetricsMap?.[("id" in record ? record.id : (record as any).key) as string] + const stats = resolveMetricStats(runMetric, baseCandidates) + + return hasEvaluator ? ( + + ) : ( +
    + ) + }, + } as any + }) as ColumnsType +} diff --git a/web/ee/src/lib/metricSorter.ts b/web/ee/src/lib/metricSorter.ts new file mode 100644 index 0000000000..e6edc9bd70 --- /dev/null +++ b/web/ee/src/lib/metricSorter.ts @@ -0,0 +1,19 @@ +import {extractPrimitive, metricCompare} from "./metricUtils" + +/** + * Build an Ant Design-compatible sorter object that compares metric values in a row-agnostic way. + * Provide a getter that receives the table row (any shape) and returns the raw metric value. + * + * This isolates the shared compare logic so that different tables only need to supply + * their own way of fetching the raw value (from atoms, maps, etc.). When the metric + * payload shape changes we only have to update `extractPrimitive` / `metricCompare`. + */ +export function buildMetricSorter(getRaw: (row: RowType) => unknown) { + return { + compare: (a: RowType, b: RowType) => { + const primA = extractPrimitive(getRaw(a)) + const primB = extractPrimitive(getRaw(b)) + return metricCompare(primA, primB) + }, + } +} diff --git a/web/ee/src/lib/metricUtils.ts b/web/ee/src/lib/metricUtils.ts new file mode 100644 index 0000000000..65a4e2183e --- /dev/null +++ b/web/ee/src/lib/metricUtils.ts @@ -0,0 +1,278 @@ +/* + * Shared metric-handling utilities for Agenta Cloud front-end. + * --------------------------------------------------------------------------- + * These helpers consolidate common logic that previously lived in multiple + * table utilities (HumanEvaluations, VirtualizedScenarioTable, MetricCell …). + * Any future change to the metric data shape (e.g. new statistical fields) can + * now be implemented in a single place. + */ + +// --------------------------------------------------------------------------- +// Type definitions +// --------------------------------------------------------------------------- + +/** Simple histogram entry returned by backend */ +export interface FrequencyEntry { + value: T + count: number +} + +/** Stats object returned by backend `GET /runs/:id/metrics` */ +export interface BasicStats { + mean?: number + sum?: number + /** Ordered frequency list (most common first) */ + frequency?: FrequencyEntry[] + /** Total sample count */ + count?: number + // backend may add extra fields – index signature keeps type-safety while + // allowing unknown additions. + [key: string]: unknown +} + +/** Metric primitive or stats wrapper */ +export type MetricValue = BasicStats | unknown + +/** Union of recognised primitive metric types */ +export type PrimitiveMetricType = "number" | "boolean" | "string" | "array" | "object" | "null" + +/** + * An explicit metric type coming from evaluator schema can be either a single + * string or a JSON-Schema union array (e.g. ["string","null"]). + */ +export type SchemaMetricType = PrimitiveMetricType | PrimitiveMetricType[] + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +const METRIC_KEY_SYNONYMS: string[][] = [ + ["attributes.ag.metrics.costs.cumulative.total", "totalCost", "costs.total", "cost"], + ["attributes.ag.metrics.duration.cumulative", "duration", "duration.total"], + ["attributes.ag.metrics.tokens.cumulative.total", "totalTokens", "tokens.total", "tokens"], + ["attributes.ag.metrics.errors.cumulative", "errors"], +] + +const aliasToCanonical = new Map() +const canonicalToGroup = new Map() + +METRIC_KEY_SYNONYMS.forEach((group) => { + const [canonical] = group + canonicalToGroup.set(canonical, group) + group.forEach((alias) => { + aliasToCanonical.set(alias, canonical) + }) +}) + +/** + * Return the canonical metric key for the provided alias. If the key is not a + * recognised alias it is returned unchanged. + */ +export const canonicalizeMetricKey = (key: string): string => { + return aliasToCanonical.get(key) ?? key +} + +const resolveMetricCandidates = (key: string): string[] => { + const canonical = canonicalizeMetricKey(key) + const group = canonicalToGroup.get(canonical) + return group ? group : [canonical] +} + +/** + * Fetch a metric value from a flat metrics map using canonical aliases. + * Returns the first non-undefined candidate. + */ +export const getMetricValueWithAliases = ( + metrics: Record, + key: string, +): T | undefined => { + if (!metrics) return undefined + const candidates = resolveMetricCandidates(key) + for (const candidate of candidates) { + if (candidate in metrics && metrics[candidate] !== undefined) { + return metrics[candidate] as T + } + } + return undefined +} + +/** + * Helper used by table headers to provide a human friendly label for well known + * metrics regardless of whether we receive the legacy or the new analytics key. + */ +export const getMetricDisplayName = (key: string): string => { + const canonical = canonicalizeMetricKey(key) + switch (canonical) { + case "attributes.ag.metrics.costs.cumulative.total": + return "Cost (Total)" + case "attributes.ag.metrics.duration.cumulative": + return "Duration (Total)" + case "attributes.ag.metrics.tokens.cumulative.total": + return "Total tokens" + case "attributes.ag.metrics.errors.cumulative": + return "Errors" + default: { + const cleaned = canonical + .replace(/[_\.]/g, " ") + .replace(/\s+/g, " ") + .trim() + .toLowerCase() + return cleaned.replace(/\b\w/g, (c) => c.toUpperCase()) + } + } +} + +/** + * Extract a single primitive value from a metric payload. + * + * The backend may return either: + * • a plain primitive (number | string | boolean | array) + * • a {@link BasicStats} object containing statistical fields. + * + * This helper applies the heuristics used in several places: + * 1. mean + * 2. sum + * 3. first frequency value + * 4. fallback to raw object + */ +export function extractPrimitive(metric: MetricValue): T | undefined { + if (metric === null || metric === undefined) return undefined as any + + // Plain primitives / arrays are returned verbatim. + if (typeof metric !== "object" || Array.isArray(metric)) return metric as any + + const stats = metric as BasicStats + if (stats.mean !== undefined) return stats.mean as any + if (stats.sum !== undefined) return stats.sum as any + if (stats.frequency?.length) return stats.frequency[0].value as any + + // As a last resort return the object itself (caller decides what to do). + return metric as any +} + +/** + * Infer the metric primitive type when evaluator schema does not provide one. + * + * Mainly used by table renderers to decide formatting & sorter eligibility. + */ +export function inferMetricType(raw: unknown, schemaType?: SchemaMetricType): PrimitiveMetricType { + if (schemaType) { + // When evaluator schema provides a union array we choose the first non-null type. + if (Array.isArray(schemaType)) { + const withoutNull = schemaType.filter((t) => t !== "null") + return (withoutNull[0] ?? "string") as PrimitiveMetricType + } + return schemaType as PrimitiveMetricType + } + + if (raw === null) return "null" + if (Array.isArray(raw)) return "array" + const t = typeof raw + if (t === "number" || t === "boolean" || t === "string") return t + return "object" +} + +/** + * Determine if a column with the given metric type should expose sorting. + * + * Current UX policy: only numeric and boolean primitives are sortable. + */ +export function isSortableMetricType(metricType: SchemaMetricType | undefined): boolean { + if (!metricType) return true // fallback + + const types = Array.isArray(metricType) ? metricType : [metricType] + return !types.includes("string") && !types.includes("array") +} + +/** + * Generic comparator function used by AntD Table sorter. + * Returns negative / zero / positive like `Array.prototype.sort` expects. + */ +export function summarizeMetric( + stats: BasicStats | undefined, + schemaType?: SchemaMetricType, +): string | number | undefined { + if (!stats) return undefined + + // 1. mean for numeric metrics (latency etc.) + if (typeof (stats as any).mean === "number") { + return (stats as any).mean + } + + // 2. boolean metrics – proportion of true (percentage) + if (schemaType === "boolean" && Array.isArray((stats as any).frequency)) { + const trueEntry = (stats as any).frequency.find((f: any) => f.value === true) + const total = (stats as any).count ?? 0 + if (total) { + return ((trueEntry?.count ?? 0) / total) * 100 + } + } + + // 3. ranked categorical metrics – show top value and count + if (Array.isArray((stats as any).rank) && (stats as any).rank.length) { + const top = (stats as any).rank[0] + return `${top.value} (${top.count})` + } + + // 4. plain count fallback + if (typeof (stats as any).count === "number") { + return (stats as any).count + } + + return undefined +} + +export function metricCompare(a: unknown, b: unknown): number { + // undefined / null handling – push to bottom + if (a === undefined || a === null) return 1 + if (b === undefined || b === null) return -1 + + // Normalize boolean-like values so categorical metrics sort correctly. + // Accept true/false, "true"/"false" (case-insensitive), and 1/0. + const normalizeBool = (v: unknown): boolean | undefined => { + if (typeof v === "boolean") return v + if (typeof v === "number") { + if (v === 1) return true + if (v === 0) return false + return undefined + } + if (typeof v === "string") { + const s = v.trim().toLowerCase() + if (s === "true") return true + if (s === "false") return false + if (s === "1") return true + if (s === "0") return false + } + return undefined + } + + const boolA = normalizeBool(a) + const boolB = normalizeBool(b) + if (boolA !== undefined && boolB !== undefined) { + // false < true when sorting ascending + return Number(boolA) - Number(boolB) + } + + const numA = Number(a as any) + const numB = Number(b as any) + const bothNumeric = !Number.isNaN(numA) && !Number.isNaN(numB) + if (bothNumeric) return numA - numB + + return String(a).localeCompare(String(b)) +} + +/** + * Compute maximum width among children columns. Used when a metrics group is + * collapsed into one column. + */ +export function maxChildWidth( + children: {key?: string; dataIndex?: string; width?: number}[], + distMap: Record, + fallback = 160, +): number { + return Math.max( + ...children.map( + (ch) => distMap[ch.key ?? ch.dataIndex ?? ""]?.width ?? ch.width ?? fallback, + ), + ) +} diff --git a/web/ee/src/lib/metrics/utils.ts b/web/ee/src/lib/metrics/utils.ts new file mode 100644 index 0000000000..e3255eed26 --- /dev/null +++ b/web/ee/src/lib/metrics/utils.ts @@ -0,0 +1,93 @@ +import {canonicalizeMetricKey, getMetricDisplayName} from "../metricUtils" + +// Shared helpers for metric key humanisation and sorting +// ------------------------------------------------------ +// This centralises the logic used in various tables (virtualised scenario table, +// human-evaluation runs table, etc.) so we have a single source of truth when we +// add new invocation-level metrics. + +export interface MetricConfig { + /** Which field in BasicStats to use when sorting / displaying primary value */ + primary: string + /** Human-readable column title */ + label: string +} + +const TOKEN_ORDER = ["promptTokens", "completionTokens", "totalTokens"] as const + +/** + * Given a flattened invocation metric key (e.g. "latency", "totalCost", + * "duration.total", "promptTokens" …) return: + * 1. primary aggregation key to read from BasicStats + * 2. human-friendly title string used for column headers + */ +export const getMetricConfig = (key: string): MetricConfig => { + const canonical = canonicalizeMetricKey(key) + if (canonical === "attributes.ag.metrics.costs.cumulative.total") { + return {primary: "sum", label: getMetricDisplayName(canonical)} + } + if (canonical === "attributes.ag.metrics.duration.cumulative") { + return {primary: "mean", label: getMetricDisplayName(canonical)} + } + if (canonical === "attributes.ag.metrics.tokens.cumulative.total") { + return {primary: "sum", label: getMetricDisplayName(canonical)} + } + if (canonical === "attributes.ag.metrics.errors.cumulative") { + return {primary: "count", label: getMetricDisplayName(canonical)} + } + + if (canonical !== key) { + return getMetricConfig(canonical) + } + + // Common most-used names first for performance/readability + if (key === "latency") { + return {primary: "mean", label: "Latency (mean)"} + } + if (key === "totalCost" || key === "cost") { + return {primary: "sum", label: "Cost (total)"} + } + + // Token counts (camelCase like promptTokens -> "Prompt tokens (total)") + if (key.endsWith("Tokens")) { + const words = key + .replace(/Tokens$/, " tokens") + .replace(/([A-Z])/g, " $1") + .trim() + const capitalised = words.charAt(0).toUpperCase() + words.slice(1) + return {primary: "sum", label: `${capitalised} (total)`} + } + + // Dotted keys from step summariser e.g. duration.total => Duration (total) + if (key.includes(".")) { + const [base, sub] = key.split(".") + const capitalised = base.charAt(0).toUpperCase() + base.slice(1) + const primary = sub === "total" ? "sum" : sub + return {primary, label: `${capitalised} (${sub})`} + } + + // Fallback – treat as numeric mean + const capitalised = getMetricDisplayName(key) + const primary = key === "errors" ? "count" : "mean" + return {primary, label: `${capitalised} (${primary})`} +} + +/** + * Provide a stable sort priority for invocation metric keys so that tables show + * them in a predictable order: + * 0. cost + * 1. duration.* + * 2. token metrics (prompt, completion, total, then any other token key) + * 3. others alphabetical + */ +export const metricPriority = (key: string): [number, number] => { + const canonical = canonicalizeMetricKey(key) + const target = canonical ?? key + const lc = target.toLowerCase() + if (lc.includes("cost")) return [0, 0] + if (lc.includes("duration")) return [1, 0] + const tokenIdx = TOKEN_ORDER.indexOf(target as (typeof TOKEN_ORDER)[number]) + if (tokenIdx !== -1) return [2, tokenIdx] + if (target.endsWith("Tokens") || lc.includes("token")) return [2, 99] + return [3, 0] +} diff --git a/web/ee/src/lib/tableUtils.ts b/web/ee/src/lib/tableUtils.ts new file mode 100644 index 0000000000..fee6f354bc --- /dev/null +++ b/web/ee/src/lib/tableUtils.ts @@ -0,0 +1,36 @@ +/** + * Generic table-helper utilities shared between Scenario & Human-Evaluation tables. + * Keeping them here ensures we only tweak one place if the backend payload shape changes. + */ + +/** Lightweight lodash.get replacement for simple "a.b.c" paths */ +export function deepGet(obj: any, path: string): any { + if (!obj || typeof obj !== "object") return undefined + return path.split(".").reduce((acc: any, key: string) => (acc ? acc[key] : undefined), obj) +} + +/** + * Recursively collect dotted paths to every leaf value inside a nested object. + * Example: {a:{b:1,c:{d:2}}, e:3} -> ['a.b', 'a.c.d', 'e'] + */ +export function collectLeafPaths(obj: any, prefix = ""): string[] { + if (!obj || typeof obj !== "object") return [] + const paths: string[] = [] + Object.entries(obj).forEach(([k, v]) => { + const p = prefix ? `${prefix}.${k}` : k + if (v && typeof v === "object") { + paths.push(...collectLeafPaths(v, p)) + } else { + paths.push(p) + } + }) + return paths +} + +/** Build placeholder skeleton rows so the table height stays stable while data fetches. */ +export function buildSkeletonRows(count: number): {key: string; isSkeleton: true}[] { + return Array.from({length: count}, (_, idx) => ({ + key: `skeleton-${idx}`, + isSkeleton: true as const, + })) +} diff --git a/web/ee/src/lib/types_ee.ts b/web/ee/src/lib/types_ee.ts new file mode 100644 index 0000000000..92fb607149 --- /dev/null +++ b/web/ee/src/lib/types_ee.ts @@ -0,0 +1,165 @@ +import {GenericObject, RequestMetadata} from "@/oss/lib/Types" +import {Environment, IPromptRevisions} from "@/oss/lib/Types" + +export enum GenerationStatus { + UNSET = "UNSET", + OK = "OK", + ERROR = "ERROR", +} + +export enum GenerationKind { + TOOL = "TOOL", + CHAIN = "CHAIN", + LLM = "LLM", + WORKFLOW = "WORKFLOW", + RETRIEVER = "RETRIEVER", + EMBEDDING = "EMBEDDING", + AGENT = "AGENT", + UNKNOWN = "UNKNOWN", +} + +export interface Generation { + id: string + created_at: string + variant: { + variant_id: string + variant_name: string + revision: number + } + environment: string | null + status: GenerationStatus + error?: string + spankind: GenerationKind + metadata?: RequestMetadata + user_id?: string + children?: [] + parent_span_id?: string + name?: string + content: { + inputs: Record + internals: Record + outputs: string[] | Record + } +} + +export interface GenerationTreeNode { + title: React.ReactElement + key: string + children?: GenerationTreeNode[] +} + +export interface GenerationDetails extends Generation { + config: GenericObject +} + +export interface GenerationDashboardData { + data: { + timestamp: number | string + success_count: number + failure_count: number + cost: number + latency: number + total_tokens: number + prompt_tokens: number + completion_tokens: number + enviornment: string + variant: string + }[] + total_count: number + failure_rate: number + total_cost: number + avg_cost: number + avg_latency: number + total_tokens: number + avg_tokens: number +} + +export interface TracingDashboardData { + buckets: { + errors: { + costs: number + count: number + duration: number + tokens: number + } + timestamp: string + total: { + costs: number + count: number + duration: number + tokens: number + } + window: number + }[] + count: number + version: string +} + +export interface Trace extends Generation {} + +export interface TraceDetails extends GenerationDetails { + spans: Generation[] +} + +export interface DeploymentRevisionConfig { + config_name: string + current_version: number + parameters: Record +} + +export interface IEnvironmentRevision { + revision: number + modified_by: string + created_at: string +} + +export interface IPromptVersioning { + app_id: string + app_name: string + base_id: string + base_name: string + config_name: string + organization_id: string + parameters: Record + previous_variant_name: string | null + revision: number + revisions: [IPromptRevisions] + uri: string + user_id: string + variant_id: string + variant_name: string +} + +export interface DeploymentRevision { + created_at: string + deployed_app_variant_revision: string + deployment: string + id: string + deployed_variant_name: string | null + modified_by: string + revision: number + commit_message: string | null +} + +export interface DeploymentRevisions extends Environment { + revisions: DeploymentRevision[] +} + +export interface EvaluatorMappingInput { + inputs: Record + mapping: Record +} + +export interface EvaluatorMappingOutput { + outputs: Record +} + +export interface EvaluatorInputInterface { + inputs: Record + settings?: Record + credentials?: Record +} + +export interface EvaluatorOutputInterface { + outputs: Record +} diff --git a/web/ee/src/lib/workers/evalRunner/bulkWorker.ts b/web/ee/src/lib/workers/evalRunner/bulkWorker.ts new file mode 100644 index 0000000000..61211ad36c --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/bulkWorker.ts @@ -0,0 +1,143 @@ +/* + * Main-thread helper for the evalRunner bulk-fetch web-worker. + * Lazily spins up a single instance of the worker and multiplexes requests + * by a generated requestId. + */ + +import type {WorkerEvalContext} from "./workerFetch" + +import {serializeRunIndex} from "@/agenta-oss-common/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex" +import {UseEvaluationRunScenarioStepsFetcherResult} from "@/agenta-oss-common/lib/hooks/useEvaluationRunScenarioSteps/types" + +type RawEntry = [string, UseEvaluationRunScenarioStepsFetcherResult] + +interface FetchBulkChunkMessage { + type: "chunk" + requestId: string + json: string // stringified RawEntry[] +} + +interface FetchBulkDoneMessage { + type: "done" + requestId: string +} + +interface FetchBulkErrorMessage { + type: "error" + requestId: string + error: string +} + +type WorkerMessage = FetchBulkChunkMessage | FetchBulkDoneMessage | FetchBulkErrorMessage + +interface Pending { + resolve: (v: Map) => void + reject: (e: unknown) => void + timer: ReturnType + buffer: Map + onChunk?: (chunk: Map) => void +} + +let worker: Worker | null = null +const pendings = new Map() + +function ensureWorker() { + if (worker) return + + // @ts-ignore + worker = new Worker(new URL("./fetchSteps.worker.ts", import.meta.url), {type: "module"}) + worker.onmessage = (event: MessageEvent) => { + const msg = event.data as WorkerMessage + const pending = pendings.get(msg.requestId) + if (!pending) return + + switch (msg.type) { + case "chunk": { + queueMicrotask(() => { + const entries: RawEntry[] = JSON.parse(msg.json) + const chunkMap = new Map() + for (const [id, data] of entries) { + pending.buffer.set(id, data) + chunkMap.set(id, data) + } + if (pending.onChunk) { + try { + pending.onChunk(chunkMap) + } catch (err) { + console.error("[bulkWorker] onChunk error", err) + } + } + }) + break + } + case "done": { + clearTimeout(pending.timer) + pendings.delete(msg.requestId) + pending.resolve(pending.buffer) + break + } + case "error": { + clearTimeout(pending.timer) + pendings.delete(msg.requestId) + console.error(`[bulkWorker] error from worker`, msg.error) + pending.reject(new Error(msg.error)) + break + } + default: + break + } + } +} + +const DEFAULT_WORKER_TIMEOUT_MS = 300_000 // 5 minutes + +export async function fetchStepsViaWorker({ + scenarioIds, + context, + timeoutMs = DEFAULT_WORKER_TIMEOUT_MS, + onChunk, +}: { + scenarioIds: string[] + context: WorkerEvalContext + timeoutMs?: number + onChunk?: (chunk: Map) => void +}): Promise> { + if (typeof Worker === "undefined") { + throw new Error("Web Workers not supported in this environment") + } + ensureWorker() + const requestId = (crypto.randomUUID?.() ?? Math.random().toString(36).slice(2)) as string + return new Promise((resolve, reject) => { + const timer = setTimeout(() => { + pendings.delete(requestId) + reject(new Error(`Worker request timed out after ${timeoutMs} ms`)) + }, timeoutMs) + pendings.set(requestId, { + resolve, + reject, + timer, + buffer: new Map(), + onChunk, + }) + worker!.postMessage({ + type: "fetch-bulk", + requestId, + scenarioIds, + context: { + apiUrl: context.apiUrl, + evaluators: context.evaluators, + jwt: context.jwt, + projectId: context.projectId, + runIndex: serializeRunIndex(context.runIndex), + members: context.members, + runId: context.runId, + mappings: context.mappings, + testsets: context.testsets, + variants: context.variants, + uriObject: context.uriObject, + parametersByRevisionId: context.parametersByRevisionId, + appType: context.appType, + }, + }) + }) +} diff --git a/web/ee/src/lib/workers/evalRunner/evalRunner.worker.ts b/web/ee/src/lib/workers/evalRunner/evalRunner.worker.ts new file mode 100644 index 0000000000..f4ed98c750 --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/evalRunner.worker.ts @@ -0,0 +1,259 @@ +// evalRunner.worker.ts + +import {snakeToCamelCaseKeys} from "@agenta/oss/src/lib/helpers/casing" +import {BaseResponse, EvaluationStatus} from "@agenta/oss/src/lib/Types" + +import { + updateScenarioStatusRemote, + upsertScenarioStep, +} from "@/oss/services/evaluations/workerUtils" +import {createScenarioMetrics, computeRunMetrics} from "@/oss/services/runMetrics/api" + +import {RunEvalMessage, ResultMessage, WorkerMessage} from "./types" + +async function updateScenarioStatus( + apiUrl: string, + jwt: string, + scenarioId: string, + status: EvaluationStatus, + projectId: string, +) { + await updateScenarioStatusRemote(apiUrl, jwt, scenarioId, status, projectId) +} + +const queue: RunEvalMessage[] = [] +let isProcessing = false +let MAX_CONCURRENT = 5 +let activeRequests = 0 + +// eslint-disable-next-line @typescript-eslint/no-unused-vars +let jwt: string | null = null + +self.onmessage = (event: MessageEvent) => { + const msg = event.data + switch (msg.type) { + case "UPDATE_JWT": + jwt = msg.jwt + break + case "run-invocation": + if (msg.jwt) jwt = msg.jwt + queue.push(msg) + if (!isProcessing) processQueue() + break + case "config": + MAX_CONCURRENT = msg.maxConcurrent + if (!isProcessing && queue.length > 0) processQueue() + break + } +} + +async function handleRequest(message: RunEvalMessage) { + const { + jwt, + invocationStepTarget, + scenarioId, + projectId, + runId, + appId, + requestBody, + invocationKey, + endpoint, + apiUrl, + } = message + try { + await updateScenarioStatus(apiUrl, jwt, scenarioId, EvaluationStatus.RUNNING, projectId) + const response = await fetch( + `${endpoint}?application_id=${appId}&project_id=${projectId}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + "ngrok-skip-browser-warning": "1", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify(requestBody), + }, + ) + + const _result = (await response.json()) as BaseResponse + const result = snakeToCamelCaseKeys(_result) + + const message: ResultMessage = { + type: "result", + scenarioId, + status: response.status === 200 ? EvaluationStatus.SUCCESS : EvaluationStatus.FAILURE, + invocationStepTarget, + invocationKey, + result: { + ...result, + requestBody, + endpoint, + }, + // @ts-ignore + error: response.status !== 200 ? result.detail.message : null, + } + + const tryCreateScenarioInvocationMetrics = async (result: any, error?: string | null) => { + const statsMap: Record = {} + + // 1. Flatten numeric leaves into dot-notation keys + const flattenNumeric = (obj: any, prefix = "", out: Record = {}) => { + if (!obj || typeof obj !== "object") return out + Object.entries(obj).forEach(([k, v]) => { + const path = prefix ? `${prefix}.${k}` : k + if (typeof v === "number") { + out[path] = v + } else if (v && typeof v === "object") { + flattenNumeric(v, path, out) + } + }) + return out + } + + if (error) { + const metricsAcc = result?.detail?.tree?.nodes?.[0]?.metrics?.acc + const flatMetrics = flattenNumeric({ + ...(metricsAcc || {}), + errors: 1, + }) + if (!Object.keys(flatMetrics).length) return + + // 2. Compute statistics for each metric + const statsMapRaw = computeRunMetrics([{data: flatMetrics}]) + // 3. If only one value, keep the mean instead of full stats object + Object.entries(statsMapRaw).forEach(([k, v]) => { + const stats = structuredClone(v) + if ("distribution" in stats) delete stats.distribution + if ("iqrs" in stats) delete stats.iqrs + if ("percentiles" in stats) delete stats.percentiles + if ("binSize" in stats) delete stats.binSize + statsMap[k] = stats + }) + } else { + const metricsAcc = result?.tree?.nodes?.[0]?.metrics?.acc + if (!metricsAcc) return + + const flatMetrics = flattenNumeric(metricsAcc) + if (!Object.keys(flatMetrics).length) return + + // 2. Compute statistics for each metric + const statsMapRaw = computeRunMetrics([{data: flatMetrics}]) + + // 3. If only one value, keep the mean instead of full stats object + Object.entries(statsMapRaw).forEach(([k, v]) => { + const stats = structuredClone(v) + if ("distribution" in stats) delete stats.distribution + if ("iqrs" in stats) delete stats.iqrs + if ("percentiles" in stats) delete stats.percentiles + if ("binSize" in stats) delete stats.binSize + statsMap[k] = stats + }) + } + + const stepKey = invocationKey ?? "invocation" + const nestedData = {[stepKey]: statsMap} + + try { + await createScenarioMetrics( + apiUrl, + jwt, + runId, + [{scenarioId, data: nestedData}], + projectId, + ) + } catch (err) { + console.error("INVOCATION METRICS FAILED:", err) + } + } + + if (response.status === 200) { + tryCreateScenarioInvocationMetrics(result) + try { + await upsertScenarioStep({ + apiUrl, + jwt, + runId, + scenarioId, + status: EvaluationStatus.SUCCESS, + projectId, + key: invocationKey ?? "invocation", + traceId: (_result as any)?.trace_id ?? null, + spanId: (_result as any)?.span_id ?? null, + references: {application: {id: appId}}, + }) + message.result.trace = result?.tree + } catch (err) {} + + postMessage(message) + } else { + tryCreateScenarioInvocationMetrics(result, _result?.detail?.message || _result) + updateScenarioStatus(apiUrl, jwt, scenarioId, EvaluationStatus.FAILURE, projectId) + const traceId = result?.detail?.traceId + const spanId = result?.detail?.spanId + + await upsertScenarioStep({ + apiUrl, + jwt, + runId, + scenarioId, + status: EvaluationStatus.FAILURE, + projectId, + key: invocationKey ?? "invocation", + traceId, + spanId, + references: {application: {id: appId}}, + }) + + postMessage(message) + } + } catch (err: any) { + await upsertScenarioStep({ + apiUrl, + jwt, + runId, + scenarioId, + status: EvaluationStatus.FAILURE, + projectId, + key: invocationKey ?? "invocation", + references: {application: {id: appId}}, + }) + const message: ResultMessage = { + type: "result", + scenarioId, + status: EvaluationStatus.FAILURE, + error: err.message ?? "Unknown error", + result: { + requestBody, + endpoint, + }, + invocationStepTarget, + invocationKey, + endpoint, + appId, + } + await updateScenarioStatus(apiUrl, jwt, scenarioId, EvaluationStatus.FAILURE, projectId) + + postMessage(message) + } +} + +async function processQueue() { + isProcessing = true + + while (queue.length > 0 || activeRequests > 0) { + while (activeRequests < MAX_CONCURRENT && queue.length > 0) { + const message = queue.shift()! + activeRequests++ + handleRequest(message).finally(() => { + activeRequests-- + if (!isProcessing && queue.length > 0) { + processQueue() + } + }) + } + // Wait a short time to yield control and allow activeRequests to update + await new Promise((resolve) => setTimeout(resolve, 10)) + } + + isProcessing = false +} diff --git a/web/ee/src/lib/workers/evalRunner/fetchRunMetrics.worker.ts b/web/ee/src/lib/workers/evalRunner/fetchRunMetrics.worker.ts new file mode 100644 index 0000000000..d60df729ab --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/fetchRunMetrics.worker.ts @@ -0,0 +1,151 @@ +/* +Web Worker: Fetch run-level metrics for a single evaluation run. +Receives a message of form: + { requestId: string, payload: { apiUrl: string; jwt: string; projectId: string; runId: string } } +Responds with: + { requestId, ok: true, data: metrics[] } or { requestId, ok:false, error } +*/ + +interface WorkerRequest { + requestId: string + payload: { + apiUrl: string + jwt: string + projectId: string + runId: string + evaluatorSlugs?: string[] + revisionSlugs?: string[] + } +} + +interface WorkerResponse { + requestId: string + ok: boolean + data?: any[] + stats?: Record + error?: string +} + +self.onmessage = async (e: MessageEvent) => { + const {requestId, payload} = e.data + try { + const {apiUrl, jwt, projectId, runId, evaluatorSlugs = [], revisionSlugs = []} = payload + const url = `${apiUrl}/preview/evaluations/metrics/query?project_id=${projectId}` + const body: Record = { + metrics: {run_ids: [runId]}, + windowing: {}, + } + const resp = await fetch(url, { + method: "POST", + headers: { + Authorization: jwt ? `Bearer ${jwt}` : "", + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + }) + if (!resp.ok) throw new Error(`fetch ${resp.status}`) + const json = (await resp.json()) as {metrics?: any[]} + const camel = Array.isArray(json.metrics) ? json.metrics.map((m) => m) : [] + + // Utility to extract slug and category from stepKey + const classifyKey = ( + key: string, + ): {type: "invocation" | "evaluator" | "revision"; slug?: string} => { + const parts = key.split(".") + if (parts.length === 1 && !evaluatorSlugs.includes(parts[0])) + return {type: "invocation"} + const slug = parts[1] + if (evaluatorSlugs.includes(slug)) return {type: "evaluator", slug} + if (revisionSlugs.includes(slug)) return {type: "revision", slug} + // default treat as evaluator + return {type: "evaluator", slug: slug ?? parts[0]} + } + const transformData = (data: Record): Record => { + const flat: Record = {} + Object.entries(data || {}).forEach(([stepKey, metrics]) => { + // // Pass-through for analytics keys like ag.metrics.* + // if (stepKey.startsWith("ag.")) { + // const raw = metrics + // let value: any = raw + // if (typeof raw === "object" && raw !== null) { + // if ("mean" in raw) value = (raw as any).mean + // else if ("value" in raw) value = (raw as any).value + // } + // flat[stepKey] = value + // return + // } + + const {type, slug} = classifyKey(stepKey) + Object.entries(metrics as Record).forEach(([metricKey, raw]) => { + let value: any = structuredClone(raw) + if (typeof raw === "object" && raw !== null) { + if ("mean" in raw) { + value = (raw as any).mean + } else if ("freq" in raw) { + value.frequency = raw.freq + // value.rank = raw.freq + value.unique = raw.uniq + + delete value.freq + delete value.uniq + } else if ("value" in raw) { + value = (raw as any).value + } + } + if (stepKey.startsWith("attributes.ag.")) { + const normalizedKey = `${stepKey}.${metricKey}` + flat[normalizedKey] = value + return + } + // Map invocation-level metrics + if (type === "invocation") { + let newKey = metricKey + if (metricKey.startsWith("tokens.")) { + newKey = metricKey.slice(7) + "Tokens" // tokens.prompt -> promptTokens + } else if (metricKey.startsWith("cost")) { + newKey = "totalCost" // cost or costs.total -> totalCost + } + flat[newKey] = value + } else { + const pref = slug ? `${slug}.` : "" + flat[`${pref}${metricKey}`] = value + } + }) + }) + return flat + } + + camel.forEach((entry: any) => { + // removing the run level metrics from the scenario metrics + if (!entry?.scenario_id) { + // Object.entries(entry.data || {}).forEach(([stepKey, metrics]) => { + // const {type, slug} = classifyKey(stepKey) + // Object.entries(metrics as Record).forEach(([metricKey, raw]) => { + // let value: any = raw + // if (typeof raw === "object" && raw !== null) { + // if ("freq" in raw) { + // value.frequency = raw.freq + // value.rank = raw.freq + // delete value.freq + // entry.data[`${slug}.${metricKey}`] = value + // } + // } + // }) + // }) + return + } + entry.data = transformData(entry.data || {}) + }) + + // Dynamically import to keep worker bundle lean until needed + const {computeRunMetrics} = await import("@/oss/services/runMetrics/api") + const stats = computeRunMetrics(camel.map((m: any) => ({data: m.data || {}}))) + const res: WorkerResponse = {requestId, ok: true, data: camel, stats} + // @ts-ignore + self.postMessage(res) + } catch (err: any) { + const res: WorkerResponse = {requestId, ok: false, error: err.message || "unknown"} + // @ts-ignore + self.postMessage(res) + } +} diff --git a/web/ee/src/lib/workers/evalRunner/fetchSteps.worker.ts b/web/ee/src/lib/workers/evalRunner/fetchSteps.worker.ts new file mode 100644 index 0000000000..44f315d01d --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/fetchSteps.worker.ts @@ -0,0 +1,75 @@ +// Web Worker for bulk scenario steps fetching & enrichment +// Receives {type: "fetch-bulk", requestId, scenarioIds, context} +// Responds with {type: "result", requestId, entries: [ [scenarioId, enrichedResult] ] } + +import {fetchScenarioStepsBulkWorker} from "./workerFetch" +import type {WorkerEvalContext} from "./workerFetch" + +export interface FetchBulkMessage { + type: "fetch-bulk" + requestId: string + scenarioIds: string[] + context: WorkerEvalContext +} + +export interface FetchBulkChunkMessage { + type: "chunk" + requestId: string + json: string // stringified RawEntry[] +} + +export interface FetchBulkDoneMessage { + type: "done" + requestId: string +} + +type OutgoingMessage = + | FetchBulkChunkMessage + | FetchBulkDoneMessage + | { + type: "error" + requestId: string + error: string + } + +self.onmessage = (event: MessageEvent) => { + const msg = event.data + if (msg.type !== "fetch-bulk") return + + const {requestId, scenarioIds, context} = msg + + fetchScenarioStepsBulkWorker(scenarioIds, context) + .then((map) => { + const CHUNK_SIZE = 200 + + const entries = Array.from(map.entries()) + ;(async () => { + for (let i = 0; i < entries.length; i += CHUNK_SIZE) { + const chunkEntries = entries.slice(i, i + CHUNK_SIZE) + const msg: OutgoingMessage = { + type: "chunk", + requestId, + json: JSON.stringify(chunkEntries), + } + // @ts-ignore + self.postMessage(msg) + + // allow main thread to breathe + await new Promise((r) => setTimeout(r, 300)) // ~1 frame @60fps + } + const done: OutgoingMessage = {type: "done", requestId} + // @ts-ignore + self.postMessage(done) + })() + }) + .catch((err) => { + // Post error back so main thread can handle + const errorMsg: OutgoingMessage = { + type: "error", + requestId, + error: err && err.message ? err.message : String(err ?? "unknown"), + } + // @ts-ignore + self.postMessage(errorMsg) + }) +} diff --git a/web/ee/src/lib/workers/evalRunner/pureEnrichment.ts b/web/ee/src/lib/workers/evalRunner/pureEnrichment.ts new file mode 100644 index 0000000000..d9f8436396 --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/pureEnrichment.ts @@ -0,0 +1,610 @@ +/* + * Worker-friendly clone of `hooks/useEvaluationRunData/assets/enrichment`. + * It removes React / cookie / axios dependencies and relies solely on data + * passed from the main thread via the worker context. + */ + +import {uuidToTraceId, uuidToSpanId} from "@/oss/lib/hooks/useAnnotations/assets/helpers" +import {transformApiData} from "@/oss/lib/hooks/useAnnotations/assets/transformer" +import type {AnnotationDto} from "@/oss/lib/hooks/useAnnotations/types" +import type {RunIndex} from "@/oss/lib/hooks/useEvaluationRunData/assets/helpers/buildRunIndex" +import type { + IStepResponse, + StepResponseStep, + UseEvaluationRunScenarioStepsFetcherResult, +} from "@/oss/lib/hooks/useEvaluationRunScenarioSteps/types" +import type {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {constructPlaygroundTestUrl} from "@/oss/lib/shared/variant/stringUtils" +import type {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import type {PreviewTestSet, WorkspaceMember} from "@/oss/lib/Types" + +function collectTraceIds({steps, invocationKeys}: {steps: any[]; invocationKeys: Set}) { + const traceIds: string[] = [] + steps.forEach((st: any) => { + if (invocationKeys.has(st.stepKey) && st.traceId) traceIds.push(st.traceId) + }) + return traceIds +} + +function buildAnnotationLinks({ + annotationSteps, + uuidToTraceId: toTrace, + uuidToSpanId: toSpan, +}: { + annotationSteps: any[] + uuidToTraceId: (uuid: string) => string | undefined + uuidToSpanId: (uuid: string) => string | undefined +}) { + return annotationSteps + .filter((s) => s.traceId) + .map((s) => ({trace_id: toTrace(s.traceId) || s.traceId, span_id: toSpan(s.traceId)})) +} + +export function buildAnnotationMap({ + rawAnnotations, + members, +}: { + rawAnnotations: any[] + members?: any[] +}): Map { + const map = new Map() + if (!rawAnnotations?.length) return map + const normalized = rawAnnotations.map((ann: any) => + transformApiData({data: ann, members: members || []}), + ) + normalized.forEach((a: any) => { + if (a?.trace_id) map.set(a.trace_id, a) + }) + return map +} + +/** Simple dot-path resolver ("a.b.c") */ +export function resolvePath(obj: any, path: string): any { + return path.split(".").reduce((o: any, key: string) => (o ? o[key] : undefined), obj) +} + +export function computeInputsAndGroundTruth({ + testcase, + mappings, + inputKey, + inputParamNames, +}: { + testcase: any + mappings: any[] + inputKey: string + inputParamNames: string[] +}) { + const isRevisionKnown = Array.isArray(inputParamNames) && inputParamNames.length > 0 + + // Heuristic fallback names for ground truth columns when revision input params are unknown + const GT_NAMES = new Set(["correct_answer", "expected_output", "ground_truth", "label"]) + + const inputMappings = (mappings ?? []).filter((m) => { + if (m.step.key !== inputKey) return false + const name = m.column?.name + if (isRevisionKnown) return inputParamNames.includes(name) + // Fallback: treat testset columns not matching GT names as inputs + return m.column?.kind === "testset" && !GT_NAMES.has(name) + }) + + const groundTruthMappings = (mappings ?? []).filter((m) => { + if (m.step.key !== inputKey) return false + const name = m.column?.name + if (isRevisionKnown) return m.column?.kind === "testset" && !inputParamNames.includes(name) + // Fallback: treat well-known GT names as ground truth + return m.column?.kind === "testset" && GT_NAMES.has(name) + }) + + const objFor = (filtered: any[]) => + filtered.reduce((acc: any, m: any) => { + let val = resolvePath(testcase, m.step.path) + if (val === undefined && m.step.path.startsWith("data.")) { + val = resolvePath(testcase, m.step.path.slice(5)) + } + if (val !== undefined) acc[m.column?.name || m.name] = val + return acc + }, {}) + + let inputs = objFor(inputMappings) + let groundTruth = objFor(groundTruthMappings) + + // Fallback: if no mappings for inputs, derive directly from testcase.data keys + if (!Object.keys(inputs).length && testcase && typeof testcase === "object") { + const dataObj = (testcase as any).data || {} + if (dataObj && typeof dataObj === "object") { + Object.keys(dataObj).forEach((k) => { + if (!GT_NAMES.has(k) && k !== "messages") { + inputs[k] = dataObj[k] + } + }) + // Ground truth fallback: pick a known GT field if present + if (!Object.keys(groundTruth).length) { + for (const name of Array.from(GT_NAMES)) { + if (name in dataObj) { + ;(groundTruth as any)[name] = dataObj[name] + break + } + } + } + } + } + + return {inputs, groundTruth} +} + +export function identifyScenarioSteps({ + steps, + runIndex, + evaluators, +}: { + steps: StepResponseStep[] + runIndex?: {inputKeys: Set; invocationKeys: Set; steps: Record} + evaluators: EvaluatorDto[] +}) { + const inputSteps = steps.filter((s) => runIndex?.inputKeys?.has(s.stepKey)) + + const invocationKeys = runIndex?.invocationKeys ?? new Set() + const invocationSteps = steps.filter((s) => invocationKeys.has(s.stepKey)) + + const annotationSteps = steps.filter((s) => { + const keyParts = (s.stepKey || "").split(".") + const evaluatorSlug = keyParts.length > 1 ? keyParts[keyParts.length - 1] : undefined + return evaluatorSlug ? evaluators.some((e) => e.slug === evaluatorSlug) : false + }) + + return {inputSteps, invocationSteps, annotationSteps} +} + +export function deriveTestsetAndRevision({ + inputSteps, + invocationSteps, + runIndex, + testsets, + variants, +}: { + inputSteps: any[] + invocationSteps: any[] + runIndex?: {steps: Record} + testsets: PreviewTestSet[] + variants: EnhancedVariant[] +}): {testsets: PreviewTestSet[]; revisions: EnhancedVariant[]} { + const referencedTestsetIds = new Set() + const referencedRevisionIds = new Set() + + if (runIndex) { + inputSteps.forEach((step) => { + const meta = runIndex.steps[step.stepKey] + const tsId = meta?.refs?.testset?.id + if (tsId) referencedTestsetIds.add(tsId) + }) + invocationSteps.forEach((step) => { + const meta = runIndex.steps[step.stepKey] + const revId = meta?.refs?.applicationRevision?.id + if (revId) referencedRevisionIds.add(revId) + }) + } + + const resolvedTestsets = testsets.filter((t: any) => { + const id = (t as any).id ?? (t as any)._id + return referencedTestsetIds.has(id as string) + }) + const resolvedRevisions = variants.filter((v) => referencedRevisionIds.has(v.id)) + + return {testsets: resolvedTestsets, revisions: resolvedRevisions} +} + +export function enrichInputSteps({ + inputSteps, + testsets, + revisions, + mappings, +}: { + inputSteps: any[] + testsets?: any[] + revisions?: any[] + mappings?: any +}) { + const findTestsetForTestcase = (tcId: string) => + testsets?.find( + (ts: any) => + Array.isArray(ts.data?.testcases) && + ts.data.testcases.some((tc: any) => tc.id === tcId), + ) + + const enrichStep = (step: any) => { + const ts = findTestsetForTestcase(step.testcaseId) + + let inputs = step.inputs ? {...step.inputs} : {} + const groundTruth = step.groundTruth ?? {} + + const canComputeFromTestset = + mappings && Array.isArray(testsets) && testsets.length > 0 && ts + if (canComputeFromTestset) { + const testcase = ts?.data?.testcases?.find((tc: any) => tc.id === step.testcaseId) + if (testcase) { + // We no longer rely on revision.inputParams in worker context. + // Passing an empty list will trigger heuristic fallback in computeInputsAndGroundTruth. + const inputParamNames: string[] = [] + const computed = computeInputsAndGroundTruth({ + testcase, + mappings, + inputKey: step.stepKey, + inputParamNames, + }) + for (const [k, v] of Object.entries(computed.inputs)) { + if (!(k in inputs)) (inputs as Record)[k] = v + } + } + } + + const testcase = testsets + ?.flatMap((t: any) => t.data?.testcases || []) + .find((tc: any) => tc.id === step.testcaseId) + return {...step, inputs, groundTruth, testcase} + } + + const richInputSteps = inputSteps.map((s) => enrichStep(s)) + return {richInputSteps, richInputStep: richInputSteps[0]} +} + +export const prepareRequest = ({ + revision, + inputParametersDict, + uriObject, + precomputedParameters, + appType, +}: { + revision: EnhancedVariant + inputParametersDict: Record + uriObject?: {runtimePrefix: string; routePath?: string} + /** Parameters computed on main thread via transformedPromptsAtomFamily({useStableParams: true}) */ + precomputedParameters?: any +}) => { + if (!revision || !inputParametersDict) return null + + // We no longer store chat flags on the revision; infer from inputs + const isChatVariant = Object.prototype.hasOwnProperty.call( + inputParametersDict || {}, + "messages", + ) + const isCustomVariant = !!appType && appType === "custom" + + const mainInputParams: Record = {} + const secondaryInputParams: Record = {} + + // Derive splitting without relying on deprecated revision.inputParams: + // - messages => top-level (main) param for chat variants + // - everything else => goes under `inputs` + Object.keys(inputParametersDict).forEach((key) => { + const val = inputParametersDict[key] + if (key === "messages") { + mainInputParams[key] = val + } else { + secondaryInputParams[key] = val + } + }) + + // Start from stable precomputed parameters (main-thread transformed prompts) + const baseParams = (precomputedParameters as Record) || {} + const requestBody: Record = { + ...baseParams, + ...mainInputParams, + } + + if (isCustomVariant) { + for (const key of Object.keys(inputParametersDict)) { + if (key !== "inputs") requestBody[key] = inputParametersDict[key] + } + } else { + requestBody["inputs"] = {...(requestBody["inputs"] || {}), ...secondaryInputParams} + } + + if (isChatVariant) { + if (typeof requestBody["messages"] === "string") { + try { + requestBody["messages"] = JSON.parse(requestBody["messages"]) + } catch { + throw new Error("content not valid for messages") + } + } + } + + // Ensure we never crash on missing uriObject; default to empty values + const safeUri = uriObject ?? {runtimePrefix: "", routePath: ""} + + return { + requestBody, + endpoint: constructPlaygroundTestUrl(safeUri, "/test", true), + } +} + +export function buildInvocationParameters({ + invocationSteps, + inputSteps, + uriObject, + parametersByRevisionId, + appType, +}: { + invocationSteps: (IStepResponse & {revision?: any})[] + inputSteps: (IStepResponse & {inputs?: Record})[] + uriObject?: {runtimePrefix: string; routePath?: string} + /** Map of revisionId -> transformed prompts (stable) */ + parametersByRevisionId?: Record +}) { + const map: Record = {} + invocationSteps.forEach((step) => { + const revision = (step as any).revision + const matchInput = inputSteps.find((r) => r.testcaseId === step.testcaseId && r.inputs) + if (step.status !== "success") { + const pre = revision?.id ? parametersByRevisionId?.[revision.id] : undefined + + const params = prepareRequest({ + revision, + inputParametersDict: matchInput?.inputs ?? {}, + uriObject, + precomputedParameters: pre?.ag_config ? pre : pre, + appType, + }) + map[step.stepKey] = params + ;(step as any).invocationParameters = params + } else { + map[step.stepKey] = undefined + ;(step as any).invocationParameters = undefined + } + }) + return map +} + +// ------------------- public worker-friendly funcs ------------------- + +export function computeTraceAndAnnotationRefs({ + steps, + runIndex, + evaluators, +}: { + steps: StepResponseStep[] + runIndex?: {invocationKeys: Set; annotationKeys: Set} + evaluators: EvaluatorDto[] +}) { + const invocationKeys = runIndex?.invocationKeys ?? new Set() + const annotationKeys = runIndex?.annotationKeys ?? new Set() + + const traceIds = collectTraceIds({steps, invocationKeys}) + + // simple evaluator-based identification + const annotationSteps = steps.filter((s) => annotationKeys.has(s.stepKey)) + + const annotationLinks = buildAnnotationLinks({ + annotationSteps, + uuidToTraceId, + uuidToSpanId, + }) + return {traceIds, annotationSteps, annotationLinks} +} + +export async function fetchTraceAndAnnotationMaps({ + traceIds, + annotationLinks, + members, + invocationSteps, + apiUrl, + jwt, + projectId, +}: { + traceIds: string[] + annotationLinks: {trace_id: string; span_id?: string}[] + members: WorkspaceMember[] + invocationSteps: any[] + apiUrl: string + jwt: string + projectId: string +}): Promise<{traceMap: Map; annotationMap: Map}> { + const traceMap = new Map() + const annotationMap = new Map() + + if (traceIds.length) { + try { + const filtering = JSON.stringify({ + conditions: [{key: "tree.id", operator: "in", value: traceIds}], + }) + const params = new URLSearchParams() + params.append("filtering", filtering) + params.append("project_id", projectId) + const resp = await fetch(`${apiUrl}/observability/v1/traces?${params.toString()}`, { + headers: {Authorization: `Bearer ${jwt}`}, + }) + if (resp.ok) { + const data = await resp.json() + const trees = data?.trees || [] + trees.forEach((t: any) => { + if (t.tree?.id) traceMap.set(t.tree.id, t) + }) + } + } catch (err) { + console.error("[pureEnrichment] trace fetch error", err) + } + } + + if (annotationLinks && annotationLinks.length > 0) { + try { + const resp = await fetch( + `${apiUrl}/preview/annotations/query?project_id=${projectId}`, + { + method: "POST", + headers: {"Content-Type": "application/json", Authorization: `Bearer ${jwt}`}, + body: JSON.stringify({annotation_links: annotationLinks}), + }, + ) + if (resp.ok) { + const data = await resp.json() + const annMap = buildAnnotationMap({ + rawAnnotations: data?.annotations || [], + members, + }) + annMap.forEach((v, k) => annotationMap.set(k, v)) + } + } catch (err) { + console.error("[pureEnrichment] annotation fetch error", err) + } + } + + return {traceMap, annotationMap} +} + +// ------------------- pure implementations ------------------- + +export function buildScenarioCore({ + steps, + runIndex, + evaluators, + testsets, + variants, + mappings, + uriObject, + parametersByRevisionId, + appType, +}: { + steps: StepResponseStep[] + runIndex?: RunIndex + evaluators: EvaluatorDto[] + testsets: PreviewTestSet[] + variants: EnhancedVariant[] + mappings?: unknown + uriObject?: {runtimePrefix: string; routePath?: string} + parametersByRevisionId?: Record + appType?: string +}): UseEvaluationRunScenarioStepsFetcherResult { + const {inputSteps, invocationSteps, annotationSteps} = identifyScenarioSteps({ + steps, + runIndex, + evaluators, + }) + + const {testsets: derivedTestsets, revisions} = deriveTestsetAndRevision({ + inputSteps, + invocationSteps, + runIndex, + testsets, + variants, + }) + + const {richInputSteps: enrichedInputSteps} = enrichInputSteps({ + inputSteps, + testsets: derivedTestsets, + revisions, + mappings, + }) + + // Attach revision object to each invocation step + const revisionMap: Record = {} + revisions.forEach((rev: any) => { + revisionMap[rev.id] = rev + }) + const enrichedInvocationSteps = invocationSteps.map((inv) => { + let revObj: any + if (runIndex) { + const meta = (runIndex as any).steps?.[inv.stepKey] + const revId = meta?.refs?.applicationRevision?.id + if (revId) revObj = revisionMap[revId] + } + return revObj ? {...inv, revision: revObj} : inv + }) + + buildInvocationParameters({ + invocationSteps: enrichedInvocationSteps, + inputSteps: enrichedInputSteps, + uriObject, + parametersByRevisionId, + appType, + }) + + return { + inputSteps: enrichedInputSteps, + invocationSteps: enrichedInvocationSteps, + annotationSteps, + } +} + +export function decorateScenarioResult({ + result, + traceMap, + annotationMap, + runIndex, + uuidToTraceId: _uuidToTraceId, +}: { + result: any + traceMap: Map + annotationMap: Map + runIndex?: {invocationKeys: Set; annotationKeys: Set; inputKeys?: Set} + uuidToTraceId: (uuid: string) => string | undefined +}) { + const invocationKeys = runIndex?.invocationKeys ?? new Set() + result.steps?.forEach((st: any) => { + const rawTrace = st.traceId ?? st.trace_id + const traceKey = rawTrace + const traceHex = rawTrace?.includes("-") ? _uuidToTraceId(rawTrace) : rawTrace + + // Invocation steps + if (invocationKeys.has(st.stepKey) || Boolean(st.references?.application)) { + st.isInvocation = true + if (traceKey) { + const tw = traceMap.get(traceKey) + if (tw) { + st.trace = tw.trees ? tw.trees[0] : tw + } + } + } + + // Annotation steps + if (runIndex?.annotationKeys?.has(st.stepKey)) { + if (traceHex) { + st.annotation = annotationMap.get(traceHex) + const tw = traceMap.get(traceKey) + if (tw) { + st.trace = tw.trees ? tw.trees[0] : tw + } + } + } + + // Input steps + if (runIndex?.inputKeys?.has(st.stepKey) && Array.isArray(result.inputSteps)) { + const enriched = result.inputSteps.find( + (inp: any) => inp.stepKey === st.stepKey && inp.inputs, + ) + if (enriched) { + st.inputs = enriched.inputs + st.groundTruth = enriched.groundTruth + if (st.testcaseId && enriched.testcase) { + st.testcase = enriched.testcase + } + } + } + }) + + // Ensure invocationSteps have trace + if (Array.isArray(result.invocationSteps)) { + result.invocationSteps.forEach((inv: any) => { + if (!inv.trace) { + const tid = inv.traceId || inv.trace_id + const tw = tid ? traceMap.get(tid) : undefined + if (tw) { + inv.trace = tw.trees ? tw.trees[0] : tw + } + } + }) + } + // Propagate testcase objects + if (Array.isArray(result.inputSteps)) { + result.inputSteps.forEach((inp: any) => { + if (inp.testcaseId && inp.testcase) { + const testcaseMap: Record = {} + testcaseMap[inp.testcaseId] = inp.testcase + result.steps?.forEach((st: any) => { + if (st.testcaseId && testcaseMap[st.testcaseId]) { + st.testcase = testcaseMap[st.testcaseId] + } + }) + } + }) + } +} diff --git a/web/ee/src/lib/workers/evalRunner/runMetricsWorker.ts b/web/ee/src/lib/workers/evalRunner/runMetricsWorker.ts new file mode 100644 index 0000000000..4d9c4bbdea --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/runMetricsWorker.ts @@ -0,0 +1,78 @@ +/* +Main-thread helper to communicate with fetchRunMetrics.worker.ts. +Ensures single worker instance and multiplexes requests by requestId. +*/ + +// import {snakeToCamelCaseKeys} from "@/oss/lib/helpers/casing" + +interface FetchResultMessage { + requestId: string + ok: boolean + data?: any[] + stats?: Record + error?: string +} + +interface Pending { + resolve: (v: {metrics: any[]; stats: Record}) => void + reject: (e: unknown) => void + timer: ReturnType +} + +let worker: Worker | null = null +const pendings = new Map() + +function ensureWorker() { + if (worker) return + // @ts-ignore + worker = new Worker(new URL("./fetchRunMetrics.worker.ts", import.meta.url), {type: "module"}) + worker.onmessage = (event: MessageEvent) => { + const msg = event.data + const pending = pendings.get(msg.requestId) + if (!pending) return + clearTimeout(pending.timer) + pendings.delete(msg.requestId) + if (!msg.ok) { + pending.reject(new Error(msg.error || "worker error")) + return + } + pending.resolve({metrics: msg.data || [], stats: msg.stats || {}}) + } +} + +export async function fetchRunMetricsViaWorker( + runId: string, + context: { + apiUrl: string + jwt: string + projectId: string + evaluatorSlugs: string[] + revisionSlugs: string[] + }, + timeoutMs = 30000, +): Promise<{metrics: any[]; stats: Record}> { + if (typeof Worker === "undefined") throw new Error("Workers unsupported") + ensureWorker() + const requestId = (crypto.randomUUID?.() ?? Math.random().toString(36).slice(2)) as string + + return new Promise((resolve, reject) => { + const timer = setTimeout(() => { + console.error(`[runMetricsWorker] Timeout for runId: ${runId}, requestId: ${requestId}`) + pendings.delete(requestId) + reject(new Error("Worker timeout")) + }, timeoutMs) + + pendings.set(requestId, { + resolve: (result) => { + resolve(result) + }, + reject: (error) => { + console.error(`[runMetricsWorker] Error for runId: ${runId}:`, error) + reject(error) + }, + timer, + }) + + worker!.postMessage({requestId, payload: {...context, runId}}) + }) +} diff --git a/web/ee/src/lib/workers/evalRunner/scenarioListWorker.ts b/web/ee/src/lib/workers/evalRunner/scenarioListWorker.ts new file mode 100644 index 0000000000..fce789f9a5 --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/scenarioListWorker.ts @@ -0,0 +1,116 @@ +/* +Web Worker: Fetch full scenario list for a preview evaluation run in the background. +It expects a message of shape: +{ + requestId: string; + payload: { + apiUrl: string; + jwt: string; + projectId: string; + runId: string; + } +} +It will paginate through the /preview/evaluations/scenarios/ endpoint and post back: +{ requestId, ok: true, data: scenarios[] } or { requestId, ok:false, error } +*/ + +import type {IScenario} from "@/oss/lib/hooks/useEvaluationRunScenarios/types" + +interface WorkerRequest { + requestId: string + payload: { + apiUrl: string + jwt: string + projectId: string + runId: string + } +} + +interface WorkerResponse { + requestId: string + ok: boolean + data?: IScenario[] + error?: string +} + +// Backend supports cursor-based pagination (windowing with `next`) but not +// an explicit numeric `offset`. Fetch scenarios in smaller batches to +// reduce main-thread work when large evaluations load. +const PAGE_SIZE = 100 + +interface FetchArgs { + apiUrl: string + jwt: string + projectId: string + runId: string + next?: string | null + limit: number +} + +async function fetchPage({ + apiUrl, + jwt, + projectId, + runId, + next, + limit, +}: FetchArgs): Promise<{scenarios: IScenario[]; next?: string}> { + // POST to query endpoint + const url = `${apiUrl}/preview/evaluations/scenarios/query?project_id=${encodeURIComponent(projectId)}` + const body: Record = { + scenario: { + ...(runId ? {run_ids: [runId]} : {}), + }, + windowing: { + limit, + ...(next ? {next} : {}), + }, + } + + const res = await fetch(url, { + method: "POST", + headers: { + Authorization: `Bearer ${jwt}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + }) + if (!res.ok) throw new Error(`fetch ${res.status}`) + const json = (await res.json()) as {scenarios?: IScenario[]; next?: string} + return {scenarios: json.scenarios ?? [], next: json.next} +} + +self.onmessage = async (e: MessageEvent) => { + const {requestId, payload} = e.data + try { + const scenarios: IScenario[] = [] + let next: string | null | undefined = null + let _batch = 0 + do { + const page = await fetchPage({ + ...payload, + next, + limit: PAGE_SIZE, + }) + scenarios.push(...page.scenarios) + _batch += 1 + next = page.next ?? null + } while (next) + + // Deduplicate scenarios by id in case backend returned duplicates + const seen = new Set() + const uniqueScenarios = scenarios.filter((s) => { + if (seen.has(s.id)) return false + seen.add(s.id) + return true + }) + + const resp: WorkerResponse = {requestId, ok: true, data: uniqueScenarios} + // @ts-ignore + self.postMessage(resp) + } catch (err: any) { + const resp: WorkerResponse = {requestId, ok: false, error: err.message || "unknown"} + // @ts-ignore + self.postMessage(resp) + } +} diff --git a/web/ee/src/lib/workers/evalRunner/types.ts b/web/ee/src/lib/workers/evalRunner/types.ts new file mode 100644 index 0000000000..1b98796efd --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/types.ts @@ -0,0 +1,39 @@ +import {EvaluationStatus} from "@/oss/lib/Types" + +import {IStepResponse} from "../../hooks/useEvaluationRunScenarioSteps/types" + +export interface RunEvalMessage { + type: "run-invocation" + jwt: string + appId: string + scenarioId: string + runId: string + apiUrl: string + requestBody: Record + projectId: string + endpoint: string + invocationKey?: string + invocationStepTarget?: IStepResponse +} + +export interface ResultMessage { + type: "result" + scenarioId: string + status: EvaluationStatus + result?: any + error?: string + invocationStepTarget?: IStepResponse + invocationKey?: string +} + +export interface JwtUpdateMessage { + type: "UPDATE_JWT" + jwt: string +} + +export interface ConfigMessage { + type: "config" + maxConcurrent: number +} + +export type WorkerMessage = RunEvalMessage | ConfigMessage | JwtUpdateMessage diff --git a/web/ee/src/lib/workers/evalRunner/workerFetch.ts b/web/ee/src/lib/workers/evalRunner/workerFetch.ts new file mode 100644 index 0000000000..0848d2b1a2 --- /dev/null +++ b/web/ee/src/lib/workers/evalRunner/workerFetch.ts @@ -0,0 +1,298 @@ +/* + * Web-worker compatible utilities for fetching & enriching scenario steps in bulk. + * + * These functions mirror the logic in `fetchScenarioStepsBulk` but avoid any + * main-thread specifics (Jotai atoms, React hooks). They can be executed inside + * a dedicated Web Worker to offload CPU-heavy enrichment for thousands of + * scenarios. + */ + +import {snakeToCamelCaseKeys} from "@/oss/lib/helpers/casing" +import {uuidToTraceId} from "@/oss/lib/hooks/useAnnotations/assets/helpers" // relative to this file +import type { + IStepResponse, + StepResponse, + StepResponseStep, + UseEvaluationRunScenarioStepsFetcherResult, +} from "@/oss/lib/hooks/useEvaluationRunScenarioSteps/types" +import {PreviewTestCase, PreviewTestSet} from "@/oss/lib/Types" + +import { + deserializeRunIndex, + RunIndex, +} from "../../hooks/useEvaluationRunData/assets/helpers/buildRunIndex" +import {EvalRunDataContextType} from "../../hooks/useEvaluationRunData/types" + +import { + buildScenarioCore, + computeTraceAndAnnotationRefs, + decorateScenarioResult, + fetchTraceAndAnnotationMaps, +} from "./pureEnrichment" + +export const DEFAULT_BATCH_SIZE = 100 +export const DEFAULT_BATCH_CONCURRENCY = 2 + +/** + * Simplified, serialisable context passed from main thread to the worker. + * (It extends the original `EvalRunDataContextType` but removes any functions + * and non-cloneable structures.) + */ +export interface WorkerEvalContext extends Omit { + runIndex: RunIndex + jwt: string + apiUrl: string + projectId: string + /** IDs of variants that are chat-based (hasMessages in request schema) */ + chatVariantIds?: string[] + uriObject?: {runtimePrefix: string; routePath?: string} + /** Stable transformed parameters keyed by revision id */ + parametersByRevisionId?: Record +} + +// ------------- helpers ------------- +function chunkArray(arr: T[], size: number): T[][] { + return Array.from({length: Math.ceil(arr.length / size)}, (_, i) => + arr.slice(i * size, i * size + size), + ) +} + +/** + * Fetch & enrich steps for one batch of scenarios. + * Pure function without side-effects beyond network requests. + */ +async function processScenarioBatchWorker( + scenarioIds: string[], + context: WorkerEvalContext, +): Promise> { + const {runId, members, jwt, apiUrl, projectId, appType} = context + + // Validate required parameters + if (!runId || !projectId || !jwt || !apiUrl) { + throw new Error("Missing required parameters for worker fetch") + } + + // Validate scenario IDs and filter out skeleton/placeholder IDs + const validScenarioIds = scenarioIds.filter((id) => { + if (!id || typeof id !== "string") return false + + // Skip skeleton/placeholder IDs gracefully + if (id.startsWith("skeleton-") || id.startsWith("placeholder-")) { + return false + } + + const uuidRegex = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i + return uuidRegex.test(id) + }) + + if (validScenarioIds.length === 0) { + return new Map() + } + + // POST to results query endpoint with body { result: { run_id, run_ids, scenario_ids }, windowing: {} } + const resultsUrl = `${apiUrl}/preview/evaluations/results/query?project_id=${encodeURIComponent( + projectId, + )}` + const body: Record = { + result: { + run_id: runId, + run_ids: [runId], + scenario_ids: validScenarioIds, + }, + windowing: {}, + } + + const resp = await fetch(resultsUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: jwt ? `Bearer ${jwt}` : "", + }, + credentials: "include", + body: JSON.stringify(body), + }) + + if (!resp.ok) { + throw new Error(`Worker fetch failed ${resp.status}`) + } + + const raw = (await resp.json()) as StepResponse + + // Convert to camelCase once + const camelStepsAll = (raw.results ?? []).map((st) => + snakeToCamelCaseKeys(st), + ) + + // Group steps by scenarioId + const perScenarioSteps = new Map() + for (const step of camelStepsAll) { + const sid = (step as any).scenarioId as string + if (!perScenarioSteps.has(sid)) perScenarioSteps.set(sid, []) + perScenarioSteps.get(sid)!.push(step) + } + + // Collect testcase ids + const testcaseIds = new Set() + for (const [_, stepsArr] of perScenarioSteps.entries()) { + for (const s of stepsArr) { + if (s.testcaseId) testcaseIds.add(s.testcaseId) + } + } + + // Fetch testcase data (updated endpoint) + const testcaseResp = await fetch( + `${apiUrl}/preview/testcases/query?project_id=${encodeURIComponent(projectId)}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: jwt ? `Bearer ${jwt}` : "", + }, + credentials: "include", + body: JSON.stringify({testcase_ids: Array.from(testcaseIds)}), + }, + ) + const testcases = (await testcaseResp.json()) as {count: number; testcases: PreviewTestCase[]} + + // Group testcases by their testset_id for easier lookup + const testcasesByTestsetId = (testcases.testcases || []).reduce( + (acc, testcase) => { + if (!acc[testcase.testset_id]) { + acc[testcase.testset_id] = [] + } + acc[testcase.testset_id].push(testcase) + return acc + }, + {} as Record, + ) + + // Update testsets with their matching testcases + const updatedTestsets = context.testsets?.map((testset) => { + const matchingTestcases = testcasesByTestsetId[testset.id] || [] + + if (matchingTestcases.length > 0) { + return { + ...testset, + data: { + ...testset.data, + testcase_ids: matchingTestcases?.map((tc) => tc.id), + testcases: matchingTestcases, + }, + } + } + + // Return testset as is if no matching testcases found + return testset + }) as PreviewTestSet[] + + // Update the context with the new testsets which have the fetched testcases + context.testsets = updatedTestsets + + const scenarioMap = new Map() + + const runIndex = deserializeRunIndex(context.runIndex) + for (const [sid, stepsArr] of perScenarioSteps.entries()) { + const core = buildScenarioCore({ + steps: stepsArr, + runIndex: runIndex, + evaluators: context.evaluators, + testsets: context.testsets, + variants: context.variants, + mappings: context.mappings, + uriObject: context.uriObject, + parametersByRevisionId: context.parametersByRevisionId, + appType: appType, + }) + + const result: UseEvaluationRunScenarioStepsFetcherResult = { + ...core, + steps: stepsArr, + count: stepsArr.length, + next: undefined, + mappings: context.mappings, + } as any + scenarioMap.set(sid, result) + } + + // Enrich traces / annotations + const {traceIds, annotationLinks} = computeTraceAndAnnotationRefs({ + steps: camelStepsAll, + runIndex: runIndex, + evaluators: context.evaluators || [], + }) + + const invocationStepsList = (raw.steps ?? []).filter((s: any) => + runIndex?.invocationKeys?.has?.(s.stepKey), + ) + + const {traceMap, annotationMap} = await fetchTraceAndAnnotationMaps({ + traceIds, + annotationLinks, + members, + invocationSteps: invocationStepsList, + apiUrl, + jwt, + projectId, + }) + + for (const result of scenarioMap.values()) { + decorateScenarioResult({ + result, + traceMap, + annotationMap, + runIndex: runIndex, + uuidToTraceId, + }) + } + + return scenarioMap +} + +/** + * Process all batches with limited concurrency. Returns a merged Map. + */ +async function processAllBatchesWorker( + scenarioIds: string[], + context: WorkerEvalContext, + concurrency: number, + batchSize: number, +): Promise> { + const batches = chunkArray(scenarioIds, batchSize) + const results: Map[] = [] + let idx = 0 + while (idx < batches.length) { + const running = batches + .slice(idx, idx + concurrency) + .map((batch) => processScenarioBatchWorker(batch, context)) + const batchResults = await Promise.all(running) + results.push(...batchResults) + idx += concurrency + } + + return mergeMaps(results) +} + +// Helper: merge many Maps into one. +function mergeMaps(maps: Map[]): Map { + const merged = new Map() + for (const m of maps) { + for (const [k, v] of m) merged.set(k, v) + } + return merged +} + +/** + * Public API for worker usage. Returns a serialisable array of entries. + */ +export async function fetchScenarioStepsBulkWorker( + scenarioIds: string[], + context: WorkerEvalContext, + options?: {batchSize?: number; concurrency?: number}, +): Promise> { + if (scenarioIds.length === 0) + return new Map() + const batchSize = options?.batchSize ?? DEFAULT_BATCH_SIZE + const concurrency = options?.concurrency ?? DEFAULT_BATCH_CONCURRENCY + const map = await processAllBatchesWorker(scenarioIds, context, concurrency, batchSize) + return map +} diff --git a/web/ee/src/pages/_app.tsx b/web/ee/src/pages/_app.tsx new file mode 100644 index 0000000000..92dfb3e135 --- /dev/null +++ b/web/ee/src/pages/_app.tsx @@ -0,0 +1,11 @@ +import "@ant-design/v5-patch-for-react-19" +import "@/oss/styles/globals.css" +import "@/oss/assets/custom-resize-handle.css" +import "react-resizable/css/styles.css" +import "@ag-grid-community/styles/ag-grid.css" +import "@ag-grid-community/styles/ag-theme-alpine.css" +import "jotai-devtools/styles.css" + +import AppPage from "@/oss/components/pages/_app" + +export default AppPage diff --git a/web/ee/src/pages/_document.tsx b/web/ee/src/pages/_document.tsx new file mode 100644 index 0000000000..9351309131 --- /dev/null +++ b/web/ee/src/pages/_document.tsx @@ -0,0 +1,3 @@ +import _DocumentPage from "@agenta/oss/src/pages/_document" + +export default _DocumentPage diff --git a/web/ee/src/pages/auth/[[...path]].tsx b/web/ee/src/pages/auth/[[...path]].tsx new file mode 100644 index 0000000000..7e5a27ad9a --- /dev/null +++ b/web/ee/src/pages/auth/[[...path]].tsx @@ -0,0 +1,3 @@ +import Auth from "@agenta/oss/src/pages/auth/[[...path]]" + +export default Auth diff --git a/web/ee/src/pages/auth/callback/[[...callback]].tsx b/web/ee/src/pages/auth/callback/[[...callback]].tsx new file mode 100644 index 0000000000..5751a67bb0 --- /dev/null +++ b/web/ee/src/pages/auth/callback/[[...callback]].tsx @@ -0,0 +1,3 @@ +import Callback from "@agenta/oss/src/pages/auth/callback/[[...callback]]" + +export default Callback diff --git a/web/ee/src/pages/post-signup/index.tsx b/web/ee/src/pages/post-signup/index.tsx new file mode 100644 index 0000000000..3e73089f76 --- /dev/null +++ b/web/ee/src/pages/post-signup/index.tsx @@ -0,0 +1,5 @@ +import PostSignupForm from "@/oss/components/PostSignupForm/PostSignupForm" + +export default function Apps() { + return +} diff --git a/web/ee/src/pages/w/[workspace_id]/index.tsx b/web/ee/src/pages/w/[workspace_id]/index.tsx new file mode 100644 index 0000000000..24b7e01d2e --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/index.tsx @@ -0,0 +1,3 @@ +import WorkspaceRedirect from "@/oss/components/pages/WorkspaceRedirect" + +export default WorkspaceRedirect diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/deployments/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/deployments/index.tsx new file mode 100644 index 0000000000..15d4a5ea5f --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/deployments/index.tsx @@ -0,0 +1,3 @@ +import DeploymentPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/deployments" + +export default DeploymentPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/endpoints/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/endpoints/index.tsx new file mode 100644 index 0000000000..5dccd2cacb --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/endpoints/index.tsx @@ -0,0 +1,5 @@ +import AppEndpointsPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/endpoints" +import {createParams} from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/endpoints" + +export {createParams} +export default AppEndpointsPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/human_a_b_testing/[evaluation_id]/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/human_a_b_testing/[evaluation_id]/index.tsx new file mode 100644 index 0000000000..76e6526898 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/human_a_b_testing/[evaluation_id]/index.tsx @@ -0,0 +1,115 @@ +import {useEffect, useState} from "react" + +import {useAtom, useAtomValue} from "jotai" +import dynamic from "next/dynamic" +import {useRouter} from "next/router" + +// Avoid SSR for this heavy component to prevent server-side ReferenceErrors from client-only libs +const ABTestingEvaluationTable = dynamic( + () => import("@/oss/components/EvaluationTable/ABTestingEvaluationTable"), + {ssr: false}, +) +import useURL from "@/oss/hooks/useURL" +import {evaluationAtom, evaluationScenariosAtom} from "@/oss/lib/atoms/evaluation" +import {getTestsetChatColumn} from "@/oss/lib/helpers/testset" +import {useBreadcrumbsEffect} from "@/oss/lib/hooks/useBreadcrumbs" +import type {Evaluation} from "@/oss/lib/Types" +import { + fetchLoadEvaluation, + fetchAllLoadEvaluationsScenarios, +} from "@/oss/services/human-evaluations/api" +import {fetchTestset} from "@/oss/services/testsets/api" +import {projectIdAtom} from "@/oss/state/project" +import {variantsAtom} from "@/oss/state/variant/atoms/fetcher" + +export default function Evaluation() { + const router = useRouter() + const projectId = useAtomValue(projectIdAtom) + const evaluationTableId = router.query.evaluation_id + ? router.query.evaluation_id.toString() + : "" + const [evaluationScenarios, setEvaluationScenarios] = useAtom(evaluationScenariosAtom) + const [evaluation, setEvaluation] = useAtom(evaluationAtom) + const [isLoading, setIsLoading] = useState(true) + const appId = router.query.app_id as string + const columnsCount = 2 + const {baseAppURL} = useURL() + // variants from global store + const variantsStore = useAtomValue(variantsAtom) + + useEffect(() => { + if (!evaluation || !projectId) { + return + } + const init = async () => { + setIsLoading(true) + try { + const data = await fetchAllLoadEvaluationsScenarios(evaluationTableId, evaluation) + setEvaluationScenarios(data) + } finally { + setTimeout(() => setIsLoading(false), 1000) + } + } + init() + }, [evaluation, projectId]) + + useEffect(() => { + if (!evaluationTableId) { + return + } + const init = async () => { + const evaluation: Evaluation = await fetchLoadEvaluation(evaluationTableId) + const backendVariants = variantsStore + const testset = await fetchTestset(evaluation.testset._id) + // Create a map for faster access to first array elements + const backendVariantsMap = new Map() + backendVariants.forEach((obj) => backendVariantsMap.set(obj.variantId, obj)) + + // Update variants in second object + evaluation.variants = evaluation.variants.map((variant) => { + const backendVariant = backendVariantsMap.get(variant.variantId) + return backendVariant ? backendVariant : variant + }) + evaluation.testset = { + ...evaluation.testset, + ...testset, + testsetChatColumn: getTestsetChatColumn(testset.csvdata), + } + setEvaluation(evaluation) + } + + init() + }, [evaluationTableId]) + + // breadcrumbs + useBreadcrumbsEffect( + { + breadcrumbs: { + appPage: { + label: "human ab testing", + href: `${baseAppURL}/${appId}/evaluations?selectedEvaluation=human_ab_testing`, + }, + "eval-detail": { + label: evaluationTableId, + value: evaluationTableId, + }, + }, + type: "append", + condition: !!evaluationTableId, + }, + [evaluationTableId], + ) + + return ( +
    + {evaluationTableId && evaluationScenarios && evaluation && ( + + )} +
    + ) +} diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/index.tsx new file mode 100644 index 0000000000..5f9c0ce406 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/index.tsx @@ -0,0 +1,7 @@ +import EvaluationsView from "@/oss/components/pages/evaluations/EvaluationsView" + +const AppEvaluationsPage = () => { + return +} + +export default AppEvaluationsPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/results/[evaluation_id]/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/results/[evaluation_id]/index.tsx new file mode 100644 index 0000000000..761464975d --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/results/[evaluation_id]/index.tsx @@ -0,0 +1,7 @@ +import EvalRunDetailsPage from "@/oss/components/EvalRunDetails" + +const EvaluationPage = () => { + return +} + +export default EvaluationPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/results/compare/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/results/compare/index.tsx new file mode 100644 index 0000000000..9a24e505d7 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/results/compare/index.tsx @@ -0,0 +1,7 @@ +import EvaluationCompare from "@/oss/components/pages/evaluations/evaluationCompare/EvaluationCompare" + +const EvaluationCompareDetails = () => { + return +} + +export default EvaluationCompareDetails diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/single_model_test/[evaluation_id]/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/single_model_test/[evaluation_id]/index.tsx new file mode 100644 index 0000000000..209e1772ec --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/evaluations/single_model_test/[evaluation_id]/index.tsx @@ -0,0 +1,7 @@ +import EvalRunDetailsPage from "@/oss/components/EvalRunDetails" + +const EvaluationPage = () => { + return +} + +export default EvaluationPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/overview/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/overview/index.tsx new file mode 100644 index 0000000000..cc56265403 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/overview/index.tsx @@ -0,0 +1,3 @@ +import OverviewPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/overview" + +export default OverviewPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/playground/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/playground/index.tsx new file mode 100644 index 0000000000..47c2fd3659 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/playground/index.tsx @@ -0,0 +1,3 @@ +import PlaygroundPage from "@agenta/oss/src/components/PlaygroundRouter" + +export default PlaygroundPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/traces/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/traces/index.tsx new file mode 100644 index 0000000000..128eb8aa0f --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/traces/index.tsx @@ -0,0 +1,3 @@ +import TracesPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/traces" + +export default TracesPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/variants/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/variants/index.tsx new file mode 100644 index 0000000000..96ba4c5973 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/variants/index.tsx @@ -0,0 +1,3 @@ +import VariantsPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/apps/[app_id]/variants" + +export default VariantsPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/index.tsx new file mode 100644 index 0000000000..df81bc9d9d --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/apps/index.tsx @@ -0,0 +1,3 @@ +import AppsPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/apps" + +export default AppsPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/index.tsx new file mode 100644 index 0000000000..b01b145bb4 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/index.tsx @@ -0,0 +1,7 @@ +import EvaluationsView from "@/oss/components/pages/evaluations/EvaluationsView" + +const ProjectEvaluationsPage = () => { + return +} + +export default ProjectEvaluationsPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/results/[evaluation_id]/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/results/[evaluation_id]/index.tsx new file mode 100644 index 0000000000..91c5d13e40 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/results/[evaluation_id]/index.tsx @@ -0,0 +1,7 @@ +import EvalRunDetailsPage from "@/oss/components/EvalRunDetails" + +const ProjectAutoEvaluationPage = () => { + return +} + +export default ProjectAutoEvaluationPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/results/compare/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/results/compare/index.tsx new file mode 100644 index 0000000000..4fc96755ce --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/results/compare/index.tsx @@ -0,0 +1,7 @@ +import EvaluationCompare from "@/oss/components/pages/evaluations/evaluationCompare/EvaluationCompare" + +const ProjectEvaluationCompareDetails = () => { + return +} + +export default ProjectEvaluationCompareDetails diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/single_model_test/[evaluation_id]/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/single_model_test/[evaluation_id]/index.tsx new file mode 100644 index 0000000000..67c0827984 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/evaluations/single_model_test/[evaluation_id]/index.tsx @@ -0,0 +1,7 @@ +import EvalRunDetailsPage from "@/oss/components/EvalRunDetails" + +const ProjectHumanEvaluationPage = () => { + return +} + +export default ProjectHumanEvaluationPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/observability/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/observability/index.tsx new file mode 100644 index 0000000000..73d6cb12eb --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/observability/index.tsx @@ -0,0 +1,3 @@ +import ObservabilityPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/observability" + +export default ObservabilityPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/settings/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/settings/index.tsx new file mode 100644 index 0000000000..2ce2ce1d4a --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/settings/index.tsx @@ -0,0 +1,3 @@ +import SettingsPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/settings" + +export default SettingsPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/share/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/share/index.tsx new file mode 100644 index 0000000000..8a46c8ffae --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/share/index.tsx @@ -0,0 +1,86 @@ +import {useEffect, useRef} from "react" + +import {useAtomValue} from "jotai" +import {useRouter} from "next/router" + +import ProtectedRoute from "@/oss/components/ProtectedRoute/ProtectedRoute" +import ContentSpinner from "@/oss/components/Spinner/ContentSpinner" +import useURL from "@/oss/hooks/useURL" +import {EvaluationType} from "@/oss/lib/enums" +import {getAllVariantParameters} from "@/oss/lib/helpers/variantHelper" +import {GenericObject, Variant} from "@/oss/lib/Types" +import {createNewEvaluation} from "@/oss/services/human-evaluations/api" +import {useOrgData} from "@/oss/state/org" +import {variantsAtom} from "@/oss/state/variant/atoms/fetcher" + +const EvaluationShare: React.FC = () => { + const router = useRouter() + const {changeSelectedOrg, selectedOrg, loading} = useOrgData() + const called = useRef(false) + const {baseAppURL} = useURL() + + useEffect(() => { + const {app, org, variants: variantIds, testset, type} = router.query + + //1. check all the required params are present + if (app && org && testset && type && Array.isArray(variantIds) && !loading) { + const executor = async () => { + //make sure this is only called once + if (called.current) { + return + } + called.current = true + + // variants from global store + const allVariants = useAtomValue(variantsAtom) + const variants = variantIds + .map((id) => allVariants.find((item) => item.variantId === id)) + .filter((item) => item !== undefined) as Variant[] + + //get the inputs for each variant + const results = await Promise.all( + variants.map((variant) => + getAllVariantParameters(app as string, variant).then((data) => ({ + variantName: variant.variantName, + inputs: data?.inputs.map((inputParam) => inputParam.name) || [], + })), + ), + ) + const inputs: Record = results.reduce( + (acc: GenericObject, result) => { + acc[result.variantName] = result.inputs + return acc + }, + {}, + ) + + //create the evaluation + const evalId = await createNewEvaluation({ + variant_ids: variantIds, + inputs: inputs[variants[0].variantName], + evaluationType: type as EvaluationType, + evaluationTypeSettings: {}, + llmAppPromptTemplate: "", + selectedCustomEvaluationID: "", + testsetId: testset as string, + }) + + //redirect to the evaluation detail page once all work is done + router.push(`${baseAppURL}/${app}/annotations/${type}/${evalId}`) + } + + if (selectedOrg?.id !== org) { + //2. change the selected org to the one in the query + changeSelectedOrg(org as string, () => { + executor() + }) + } else { + executor() + } + } + }, [router.query, loading]) + + return +} + +export default () => diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/testsets/[testset_id]/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/testsets/[testset_id]/index.tsx new file mode 100644 index 0000000000..157e0d2acc --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/testsets/[testset_id]/index.tsx @@ -0,0 +1,3 @@ +import TestsetPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/testsets/[testset_id]" + +export default TestsetPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/[project_id]/testsets/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/testsets/index.tsx new file mode 100644 index 0000000000..0f0dad4db6 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/[project_id]/testsets/index.tsx @@ -0,0 +1,3 @@ +import TestsetsPage from "@agenta/oss/src/pages/w/[workspace_id]/p/[project_id]/testsets" + +export default TestsetsPage diff --git a/web/ee/src/pages/w/[workspace_id]/p/index.tsx b/web/ee/src/pages/w/[workspace_id]/p/index.tsx new file mode 100644 index 0000000000..06971bcc49 --- /dev/null +++ b/web/ee/src/pages/w/[workspace_id]/p/index.tsx @@ -0,0 +1,3 @@ +import WorkspaceProjectRedirect from "@/oss/components/pages/WorkspaceProjectRedirect" + +export default WorkspaceProjectRedirect diff --git a/web/ee/src/pages/w/index.tsx b/web/ee/src/pages/w/index.tsx new file mode 100644 index 0000000000..cea69f3950 --- /dev/null +++ b/web/ee/src/pages/w/index.tsx @@ -0,0 +1,3 @@ +import WorkspaceSelection from "@/oss/components/pages/WorkspaceSelection" + +export default WorkspaceSelection diff --git a/web/ee/src/pages/workspaces/accept.tsx b/web/ee/src/pages/workspaces/accept.tsx new file mode 100644 index 0000000000..105b9dd603 --- /dev/null +++ b/web/ee/src/pages/workspaces/accept.tsx @@ -0,0 +1,3 @@ +import Accept from "@agenta/oss/src/pages/workspaces/accept" + +export default Accept diff --git a/web/ee/src/services/billing/index.tsx b/web/ee/src/services/billing/index.tsx new file mode 100644 index 0000000000..7adf4fda77 --- /dev/null +++ b/web/ee/src/services/billing/index.tsx @@ -0,0 +1,58 @@ +// Re-export the new atom-based billing hooks and actions +export { + useUsageData, + useSubscriptionData, + usePricingPlans, + useSubscriptionActions, + useBilling, +} from "../../state/billing" + +// Legacy function exports for backward compatibility +// These now use direct API calls for backward compatibility +import axios from "@/oss/lib/api/assets/axiosConfig" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {getProjectValues} from "@/oss/state/project" + +/** + * @deprecated Use useSubscriptionActions().switchSubscription instead + * Legacy function for switching subscription plans + */ +export const switchSubscription = async (payload: {plan: string}) => { + const {projectId} = getProjectValues() + const response = await axios.post( + `${getAgentaApiUrl()}/billing/plans/switch?plan=${payload.plan}&project_id=${projectId}`, + ) + return response +} + +/** + * @deprecated Use useSubscriptionActions().cancelSubscription instead + * Legacy function for canceling subscription + */ +export const cancelSubscription = async () => { + const {projectId} = getProjectValues() + const response = await axios.post( + `${getAgentaApiUrl()}/billing/subscription/cancel?project_id=${projectId}`, + ) + return response +} + +/** + * @deprecated Use useSubscriptionActions().checkoutSubscription instead + * Legacy function for creating new subscription checkout + */ +export const checkoutNewSubscription = async (payload: {plan: string; success_url: string}) => { + const response = await axios.post( + `${getAgentaApiUrl()}/billing/stripe/checkouts/?plan=${payload.plan}&success_url=${payload.success_url}`, + ) + return response +} + +/** + * @deprecated Use useSubscriptionActions().editSubscription instead + * Legacy function for editing subscription info + */ +export const editSubscriptionInfo = async () => { + const response = await axios.post(`${getAgentaApiUrl()}/billing/stripe/portals/`) + return response +} diff --git a/web/ee/src/services/billing/types.d.ts b/web/ee/src/services/billing/types.d.ts new file mode 100644 index 0000000000..22489b0fb0 --- /dev/null +++ b/web/ee/src/services/billing/types.d.ts @@ -0,0 +1,45 @@ +export type Plan = "cloud_v0_hobby" | "cloud_v0_pro" | "cloud_v0_business" | "cloud_v0_enterprise" + +export interface SubscriptionType { + plan: Plan + period_start: number + period_end: number + free_trial: boolean +} + +interface UsageKeyType { + value: number + limit: number | null + free: number + monthly: boolean + strict: boolean +} + +export interface DataUsageType { + traces: UsageKeyType + users: UsageKeyType + prompts: UsageKeyType + jobs: UsageKeyType +} + +interface PriceInfo { + base?: { + amount: number + currency: string + starting_at?: boolean + } + users?: { + tiers: {limit?: number; amount: number; rate?: number}[] + } + traces?: { + tiers: {limit?: number; amount: number; rate?: number}[] + } +} + +export interface BillingPlan { + title: string + description: string + price?: PriceInfo + features: string[] + plan: Plan +} diff --git a/web/ee/src/services/evaluationRuns/api/index.ts b/web/ee/src/services/evaluationRuns/api/index.ts new file mode 100644 index 0000000000..b67ed17d23 --- /dev/null +++ b/web/ee/src/services/evaluationRuns/api/index.ts @@ -0,0 +1,332 @@ +import {getDefaultStore} from "jotai" + +import {getMetricsFromEvaluator} from "@/oss/components/pages/observability/drawer/AnnotateDrawer/assets/transforms" +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {extractInputKeysFromSchema} from "@/oss/lib/shared/variant/inputHelpers" +import {getRequestSchema} from "@/oss/lib/shared/variant/openapiUtils" +import {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import {slugify} from "@/oss/lib/utils/slugify" +import {getAppValues} from "@/oss/state/app" +import {stablePromptVariablesAtomFamily} from "@/oss/state/newPlayground/core/prompts" +import {variantFlagsAtomFamily} from "@/oss/state/newPlayground/core/variantFlags" +import {appSchemaAtom, appUriInfoAtom} from "@/oss/state/variant/atoms/fetcher" + +import {CreateEvaluationRunInput, TestSet} from "./types" + +const extractColumnsFromTestset = (testset?: TestSet): string[] => { + if (!testset) return [] + + const columns = new Set() + + const addColumnsFromObject = (obj?: Record) => { + if (!obj || typeof obj !== "object") return + Object.keys(obj).forEach((key) => { + if (!key || typeof key !== "string") return + if (key.startsWith("__")) return + columns.add(key) + }) + } + + const csvRows = (testset as any)?.csvdata + if (Array.isArray(csvRows) && csvRows.length > 0) { + addColumnsFromObject(csvRows[0] as Record) + } + + const data = (testset as any)?.data + if (data) { + const testcases = data.testcases || data.testCases + if (Array.isArray(testcases) && testcases.length > 0) { + addColumnsFromObject( + (testcases[0] && (testcases[0].data || testcases[0])) as Record, + ) + } + + const columnsList = data.columns || data.columnNames + if (Array.isArray(columnsList)) { + columnsList.forEach((col: any) => { + if (typeof col === "string" && col && !col.startsWith("__")) { + columns.add(col) + } + }) + } + } + + return Array.from(columns) +} + +/** + * Constructs the input step for a given testset, pulling variantId and revisionId + * directly from the testset object. Any undefined reference keys are omitted. + */ + +const buildInputStep = (testset?: TestSet) => { + if (!testset) return + const inputKey = slugify(testset.name ?? (testset as any).slug ?? "testset", testset.id) + if (!testset) { + return + } + + const references: Record = { + testset: {id: testset.id}, + } + + // TODO: after new testsets + // if (testset.variantId) { + // references.testset_variant = {id: testset.variantId} + // } + // if (testset.revisionId) { + // references.testset_revision = {id: testset.revisionId} + // } + + return { + key: inputKey, + type: "input", + origin: "auto", + references, + } +} + +/** + * Constructs the invocation step for a given revision. + * Only includes reference keys if their IDs are defined. + */ +const buildInvocationStep = (revision: EnhancedVariant, inputKey: string) => { + const invocationKey = slugify( + (revision as any).name ?? (revision as any).variantName ?? "invocation", + revision.id, + ) + const references: Record = {} + + const {currentApp} = getAppValues() + const appId = currentApp?.app_id as string + references.application = {id: appId} + + if (revision.variantId !== undefined) { + references.application_variant = {id: revision.variantId} + } + if (revision.id !== undefined) { + references.application_revision = {id: revision.id} + } + return { + key: invocationKey, + type: "invocation", + origin: "human", + references, + inputs: [{key: inputKey}], + } +} + +/** + * Constructs annotation steps for all evaluators. + * Uses each evaluator's slug and id for references. + */ +const buildAnnotationStepsFromEvaluators = ( + evaluators: EvaluatorDto[] | undefined, + inputKey: string, + invocationKey: string, +) => { + if (!evaluators) return [] + return evaluators.map((evaluator) => { + const references: Record = {} + if (evaluator.slug !== undefined) { + references.evaluator = {id: evaluator.id} + } + + // TODO: Enable when we have this information + // if (evaluator.id !== undefined) { + // references.evaluator_variant = {id: evaluator.id} + // } + return { + key: `${invocationKey}.${evaluator.slug}`, + references, + type: "annotation", + origin: "human", + inputs: [{key: inputKey}, {key: invocationKey}], + } + }) +} + +/** + * Constructs the array of mappings for extracting data from steps. + * Uses the revision's inputParams to generate "input" mappings automatically. + * + * @param revision - The EnhancedVariant object containing inputParams. + * @param correctAnswerColumn - The property name in the input step for ground truth. + * @param evaluators - Optional list of evaluators to generate evaluator mappings. + * @param testset - The testset object to conditionally add mappings based on variantId and revisionId. + * @returns An array of mapping objects. + */ +const buildMappings = ( + revision: EnhancedVariant, + correctAnswerColumn: string, + evaluators: EvaluatorDto[] | undefined, + testset?: TestSet, +) => { + const testsetKey = testset + ? slugify(testset.name ?? (testset as any).slug ?? "testset", testset.id) + : "input" + const invocationKey = slugify( + (revision as any).name ?? + (revision as any).variantName ?? + ((revision as any)._parentVariant as any)?.variantName ?? + "invocation", + revision.id, + ) + const mappings: { + column: {kind: "testset" | "invocation" | "evaluator"; name: string} + step: {key: string; path: string} + }[] = [] + const pushedTestsetColumns = new Set() + + // Generate input mappings aligned with Playground (schema + initial prompt vars for custom; prompt tokens for non-custom) + { + const store = getDefaultStore() + const flags = store.get(variantFlagsAtomFamily({revisionId: revision.id})) as any + const isCustom = Boolean(flags?.isCustom) + const spec = store.get(appSchemaAtom) as any + const routePath = store.get(appUriInfoAtom)?.routePath || "" + + let variableNames: string[] = [] + if (isCustom) { + // Custom workflows: strictly use schema-defined input keys + variableNames = spec ? extractInputKeysFromSchema(spec as any, routePath) : [] + } else { + // Non-custom: use stable variables from saved parameters (ignore live prompt edits) + variableNames = store.get(stablePromptVariablesAtomFamily(revision.id)) || [] + } + + variableNames.forEach((name) => { + if (!name || typeof name !== "string") return + pushedTestsetColumns.add(name) + mappings.push({ + column: {kind: "testset", name}, + step: {key: testsetKey, path: `data.${name}`}, + }) + }) + + const req = spec ? (getRequestSchema as any)(spec, {routePath}) : undefined + if (req?.properties?.messages && !pushedTestsetColumns.has("messages")) { + pushedTestsetColumns.add("messages") + mappings.push({ + column: {kind: "testset", name: "messages"}, + step: {key: testsetKey, path: "data.messages"}, + }) + } + } + + if (testset && pushedTestsetColumns.size === 0) { + const normalizedCorrectAnswer = (correctAnswerColumn || "") + .replace(/[\W_]/g, "") + .toLowerCase() + const derivedColumns = extractColumnsFromTestset(testset) + derivedColumns.forEach((name) => { + if (!name || typeof name !== "string") return + const normalized = name.trim() + if (!normalized || normalized.startsWith("__")) return + const normalizedSafe = normalized.replace(/[\W_]/g, "").toLowerCase() + if (normalizedSafe === normalizedCorrectAnswer) return + if (normalizedSafe.includes("correctanswer")) return + if (normalizedSafe.startsWith("testcase") || normalizedSafe.includes("dedup")) return + if (pushedTestsetColumns.has(name) || pushedTestsetColumns.has(normalizedSafe)) return + pushedTestsetColumns.add(name) + pushedTestsetColumns.add(normalizedSafe) + mappings.push({ + column: {kind: "testset", name}, + step: {key: testsetKey, path: `data.${name}`}, + }) + }) + } + + // Application output mapping should use canonical column name "outputs" to align with backend + mappings.push({ + column: {kind: "invocation", name: "outputs"}, + step: {key: invocationKey, path: "attributes.ag.data.outputs"}, + }) + + // Add mappings for testset variantId and revisionId if available + // Additional metadata mappings if available + if (testset?.variantId !== undefined) { + mappings.push({ + column: {kind: "testset", name: "testset_variant_id"}, + step: {key: testsetKey, path: "data.variantId"}, + }) + } + if (testset?.revisionId !== undefined) { + mappings.push({ + column: {kind: "testset", name: "testset_revision_id"}, + step: {key: testsetKey, path: "data.revisionId"}, + }) + } + + // Evaluator output mappings generated dynamically per evaluator + if (evaluators && evaluators.length > 0) { + evaluators?.forEach((evaluator) => { + const metrics = getMetricsFromEvaluator(evaluator) + Object.keys(metrics).forEach((key) => { + mappings.push({ + column: {kind: "evaluator", name: `${evaluator.slug}.${key}`}, + step: {key: `${invocationKey}.${evaluator.slug}`, path: `data.outputs.${key}`}, + }) + }) + }) + } + + return mappings +} + +/** + * Builds the payload required for submitting multiple evaluation runs to the backend. + * Each revision will be wrapped in its own run configuration. + * This function returns an object with a `runs` array that can be sent to + * the POST `/preview/evaluations/runs/` endpoint. + * + * @param name - Base name used in each run + * @param testset - The test set being used in this evaluation (must include variantId & revisionId). + * @param revisions - List of enhanced variant revisions; one run will be generated per revision. + * @param evaluators - List of available evaluators used in annotation. + * @param correctAnswerColumn - The property name in the input step that holds the ground truth value. + * @param meta - Optional metadata object to attach to each run. + * @returns Object containing `runs` array, ready to be POSTed to the backend. + */ +export const createEvaluationRunConfig = ({ + name, + testset, + revisions, + evaluators, + correctAnswerColumn, + meta = {}, // Default to empty object if not provided +}: CreateEvaluationRunInput) => { + // Pre-build the input step (which now includes variantId & revisionId) and mappings + const inputStep = buildInputStep(testset) + const inputKey = slugify(testset?.name ?? (testset as any)?.slug ?? "testset", testset!.id) + const invocationKeysCache: Record = {} + + // Create one run configuration per revision + const runs = revisions.map((revision) => { + const invocationKey = + invocationKeysCache[revision.id] ?? + slugify( + (revision as any).name ?? (revision as any).variantName ?? "invocation", + revision.id, + ) + + invocationKeysCache[revision.id] = invocationKey + + const steps = [ + inputStep, + buildInvocationStep(revision, inputKey), + ...buildAnnotationStepsFromEvaluators(evaluators, inputKey, invocationKey), + ] + // Build mappings for this revision, passing testset as well + const mappings = buildMappings(revision, correctAnswerColumn, evaluators, testset) + return { + key: `evaluation-${revision.variantId}`, + name: `${name}`, + description: "auto-generated evaluation run", + meta, // Include the passed-in meta object + data: {steps, mappings}, + } + }) + + return {runs} +} diff --git a/web/ee/src/services/evaluationRuns/api/types.ts b/web/ee/src/services/evaluationRuns/api/types.ts new file mode 100644 index 0000000000..8c5b6d95ec --- /dev/null +++ b/web/ee/src/services/evaluationRuns/api/types.ts @@ -0,0 +1,18 @@ +import {EvaluatorDto} from "@/oss/lib/hooks/useEvaluators/types" +import {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import type {TestSet as BaseTestSet} from "@/oss/lib/Types" + +// Extend the base TestSet to include optional variantId and revisionId +export interface TestSet extends BaseTestSet { + variantId?: string + revisionId?: string +} + +export interface CreateEvaluationRunInput { + name: string + testset: TestSet | testset | undefined + revisions: EnhancedVariant[] + evaluators?: EvaluatorDto[] + correctAnswerColumn: string + meta?: Record // Optional meta object to include in each run +} diff --git a/web/ee/src/services/evaluationRuns/utils.ts b/web/ee/src/services/evaluationRuns/utils.ts new file mode 100644 index 0000000000..e69de29bb2 diff --git a/web/ee/src/services/evaluations/api/index.ts b/web/ee/src/services/evaluations/api/index.ts new file mode 100644 index 0000000000..926b5bb108 --- /dev/null +++ b/web/ee/src/services/evaluations/api/index.ts @@ -0,0 +1,352 @@ +import uniqBy from "lodash/uniqBy" +import {v4 as uuidv4} from "uuid" + +import axios from "@/oss/lib/api/assets/axiosConfig" +import {getTagColors} from "@/oss/lib/helpers/colors" +import {calcEvalDuration} from "@/oss/lib/helpers/evaluate" +import {isDemo, stringToNumberInRange} from "@/oss/lib/helpers/utils" +import { + ComparisonResultRow, + EvaluationStatus, + Evaluator, + EvaluatorConfig, + KeyValuePair, + LLMRunRateLimit, + TestSet, + _Evaluation, + _EvaluationScenario, +} from "@/oss/lib/Types" +import aiImg from "@/oss/media/artificial-intelligence.png" +import bracketCurlyImg from "@/oss/media/bracket-curly.png" +import codeImg from "@/oss/media/browser.png" +import webhookImg from "@/oss/media/link.png" +import regexImg from "@/oss/media/programming.png" +import exactMatchImg from "@/oss/media/target.png" +import similarityImg from "@/oss/media/transparency.png" +import {fetchTestset} from "@/oss/services/testsets/api" +import {getProjectValues} from "@/oss/state/project" +import {assertValidId, isValidId} from "@/oss/lib/helpers/serviceValidations" + +//Prefix convention: +// - fetch: GET single entity from server +// - fetchAll: GET all entities from server +// - create: POST data to server +// - update: PUT data to server +// - delete: DELETE data from server + +const evaluatorIconsMap = { + auto_exact_match: exactMatchImg, + auto_similarity_match: similarityImg, + auto_regex_test: regexImg, + field_match_test: exactMatchImg, + auto_webhook_test: webhookImg, + auto_ai_critique: aiImg, + auto_custom_code_run: codeImg, + auto_json_diff: bracketCurlyImg, + auto_semantic_similarity: similarityImg, + auto_contains_json: bracketCurlyImg, + rag_faithfulness: codeImg, + rag_context_relevancy: codeImg, +} + +//Evaluators +// export const fetchAllEvaluators = async () => { +// const tagColors = getTagColors() +// const {projectId} = getProjectValues() + +// const response = await axios.get(`/evaluators?project_id=${projectId}`) +// const evaluators = (response.data || []) +// .filter((item: Evaluator) => !item.key.startsWith("human")) +// .filter((item: Evaluator) => isDemo() || item.oss) +// .map((item: Evaluator) => ({ +// ...item, +// icon_url: evaluatorIconsMap[item.key as keyof typeof evaluatorIconsMap], +// color: tagColors[stringToNumberInRange(item.key, 0, tagColors.length - 1)], +// })) as Evaluator[] + +// return evaluators +// } + +// Evaluator Configs +export const fetchAllEvaluatorConfigs = async ( + appId?: string | null, + projectIdOverride?: string | null, +) => { + const tagColors = getTagColors() + const {projectId: projectIdFromStore} = getProjectValues() + const projectId = projectIdOverride ?? projectIdFromStore + + if (!projectId) { + return [] as EvaluatorConfig[] + } + + const response = await axios.get("/evaluators/configs", { + params: { + project_id: projectId, + ...(appId ? {app_id: appId} : {}), + }, + }) + const evaluatorConfigs = (response.data || []).map((item: EvaluatorConfig) => ({ + ...item, + icon_url: evaluatorIconsMap[item.evaluator_key as keyof typeof evaluatorIconsMap], + color: tagColors[stringToNumberInRange(item.evaluator_key, 0, tagColors.length - 1)], + })) as EvaluatorConfig[] + return evaluatorConfigs +} + +export type CreateEvaluationConfigData = Omit +export const createEvaluatorConfig = async ( + appId: string | null | undefined, + config: CreateEvaluationConfigData, +) => { + const {projectId} = getProjectValues() + void appId + + return axios.post(`/evaluators/configs?project_id=${projectId}`, { + ...config, + }) +} + +export const updateEvaluatorConfig = async ( + configId: string, + config: Partial, +) => { + const {projectId} = getProjectValues() + + return axios.put(`/evaluators/configs/${configId}?project_id=${projectId}`, config) +} + +export const deleteEvaluatorConfig = async (configId: string) => { + const {projectId} = getProjectValues() + + return axios.delete(`/evaluators/configs/${configId}?project_id=${projectId}`) +} + +// Evaluations +const evaluationTransformer = (item: any) => ({ + id: item.id, + appId: item.app_id, + created_at: item.created_at, + updated_at: item.updated_at, + duration: calcEvalDuration(item), + status: item.status, + testset: { + id: item.testset_id, + name: item.testset_name, + }, + user: { + id: item.user_id, + username: item.user_username, + }, + variants: item.variant_ids.map((id: string, ix: number) => ({ + variantId: id, + variantName: item.variant_names[ix], + })), + aggregated_results: item.aggregated_results || [], + revisions: item.revisions, + variant_revision_ids: item.variant_revision_ids, + variant_ids: item.variant_ids, + average_cost: item.average_cost, + total_cost: item.total_cost, + average_latency: item.average_latency, +}) +export const fetchAllEvaluations = async (appId: string) => { + const {projectId} = getProjectValues() + + const response = await axios.get(`/evaluations?project_id=${projectId}`, { + params: {app_id: appId}, + }) + return response.data.map(evaluationTransformer) as _Evaluation[] +} + +export const fetchEvaluation = async (evaluationId: string) => { + if (!isValidId(evaluationId)) { + throw new Error("Invalid evaluationId parameter") + } + const {projectId} = getProjectValues() + const id = assertValidId(evaluationId) + + const response = await axios.get(`/evaluations/${encodeURIComponent(id)}`, { + params: {project_id: projectId}, + }) + return evaluationTransformer(response.data) as _Evaluation +} + +export const fetchEvaluationStatus = async (evaluationId: string) => { + if (!isValidId(evaluationId)) { + throw new Error("Invalid evaluationId parameter") + } + const {projectId} = getProjectValues() + const id = assertValidId(evaluationId) + + const response = await axios.get(`/evaluations/${encodeURIComponent(id)}/status`, { + params: {project_id: projectId}, + }) + return response.data as {status: _Evaluation["status"]} +} + +export type CreateEvaluationData = + | { + testset_id: string + variant_ids?: string[] + evaluators_configs: string[] + rate_limit: LLMRunRateLimit + lm_providers_keys?: KeyValuePair + correct_answer_column: string + } + | { + testset_id: string + revisions_ids?: string[] + evaluators_configs: string[] + rate_limit: LLMRunRateLimit + lm_providers_keys?: KeyValuePair + correct_answer_column: string + name: string + } +export const createEvaluation = async (appId: string, evaluation: CreateEvaluationData) => { + const {projectId} = getProjectValues() + + // TODO: new AUTO-EVAL trigger + return await axios.post(`/evaluations/preview/start?project_id=${projectId}`, { + ...evaluation, + app_id: appId, + }) + // return await axios.post(`/evaluations?project_id=${projectId}`, {...evaluation, app_id: appId}) +} + +export const deleteEvaluations = async (evaluationsIds: string[]) => { + const {projectId} = getProjectValues() + + return axios.delete(`/evaluations?project_id=${projectId}`, { + data: {evaluations_ids: evaluationsIds}, + }) +} + +// Evaluation Scenarios +export const fetchAllEvaluationScenarios = async (evaluationId: string) => { + if (!isValidId(evaluationId)) { + throw new Error("Invalid evaluationId parameter") + } + const {projectId} = getProjectValues() + const id = assertValidId(evaluationId) + + const [{data: evaluationScenarios}, evaluation] = await Promise.all([ + axios.get(`/evaluations/${encodeURIComponent(id)}/evaluation_scenarios`, { + params: {project_id: projectId}, + }), + fetchEvaluation(id), + ]) + + evaluationScenarios.forEach((scenario: _EvaluationScenario) => { + scenario.evaluation = evaluation + scenario.evaluators_configs = evaluation.aggregated_results.map( + (item) => item.evaluator_config, + ) + }) + return evaluationScenarios as _EvaluationScenario[] +} + +export const updateScenarioStatus = async ( + scenario: _EvaluationScenario, + status: EvaluationStatus, +) => { + const {projectId} = getProjectValues() + return axios.patch(`/preview/evaluations/scenarios/?project_id=${projectId}`, { + scenarios: [{...scenario, status}], + }) +} + +// Comparison +export const fetchAllComparisonResults = async (evaluationIds: string[]) => { + // Defensive check: Only accept valid UUIDs + const validIds = evaluationIds.filter((id) => isValidId(id)) + if (validIds.length === 0) { + throw new Error("No valid evaluation IDs provided") + } + const scenarioGroups = await Promise.all(validIds.map(fetchAllEvaluationScenarios)) + const testset: TestSet = await fetchTestset(scenarioGroups[0][0].evaluation?.testset?.id) + + const inputsNameSet = new Set() + scenarioGroups.forEach((group) => { + group.forEach((scenario) => { + scenario.inputs.forEach((input) => inputsNameSet.add(input.name)) + }) + }) + + const rows: ComparisonResultRow[] = [] + const inputNames = Array.from(inputsNameSet) + const inputValuesSet = new Set() + const variants = scenarioGroups.map((group) => group[0].evaluation.variants[0]) + const correctAnswers = uniqBy( + scenarioGroups.map((group) => group[0].correct_answers).flat(), + "key", + ) + + for (const data of testset.csvdata) { + const inputValues = inputNames + .filter((name) => data[name] !== undefined) + .map((name) => ({name, value: data[name]})) + const inputValuesStr = inputValues.map((ip) => ip.value).join("") + if (inputValuesSet.has(inputValuesStr)) continue + else inputValuesSet.add(inputValuesStr) + + rows.push({ + id: inputValuesStr, + rowId: uuidv4(), + inputs: inputNames + .map((name) => ({name, value: data[name]})) + .filter((ip) => ip.value !== undefined), + ...correctAnswers.reduce((acc, curr) => { + return {...acc, [`correctAnswer_${curr?.key}`]: data[curr?.key!]} + }, {}), + variants: variants.map((variant, ix) => { + const group = scenarioGroups[ix] + const scenario = group.find((scenario) => + scenario.inputs.every((input) => + inputValues.some( + (ip) => ip.name === input.name && ip.value === input.value, + ), + ), + ) + return { + variantId: variant.variantId, + variantName: variant.variantName, + output: scenario?.outputs[0] || { + result: {type: "string", value: "", error: null}, + }, + evaluationId: scenario?.evaluation.id || "", + evaluatorConfigs: (scenario?.evaluators_configs || []).map((config) => ({ + evaluatorConfig: config, + result: scenario?.results.find( + (result) => result.evaluator_config === config.id, + )?.result || {type: "string", value: "", error: null}, // Adjust this line + })), + } + }), + }) + } + + return { + rows, + testset, + evaluations: scenarioGroups.map((group) => group[0].evaluation), + } +} + +// Evaluation IDs by resource +export const fetchEvaluatonIdsByResource = async ({ + resourceIds, + resourceType, +}: { + resourceIds: string[] + resourceType: "testset" | "evaluator_config" | "variant" +}) => { + const {projectId} = getProjectValues() + + return axios.get(`/evaluations/by_resource?project_id=${projectId}`, { + params: {resource_ids: resourceIds, resource_type: resourceType}, + paramsSerializer: { + indexes: null, //no brackets in query params + }, + }) +} diff --git a/web/ee/src/services/evaluations/api_ee/index.ts b/web/ee/src/services/evaluations/api_ee/index.ts new file mode 100644 index 0000000000..4ae4376f4d --- /dev/null +++ b/web/ee/src/services/evaluations/api_ee/index.ts @@ -0,0 +1,44 @@ +import axios from "@/oss/lib/api/assets/axiosConfig" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {getProjectValues} from "@/oss/state/project" + +//Prefix convention: +// - fetch: GET single entity from server +// - fetchAll: GET all entities from server +// - create: POST data to server +// - update: PUT data to server +// - delete: DELETE data from server + +import { + EvaluatorInputInterface, + EvaluatorMappingInput, + EvaluatorMappingOutput, + EvaluatorOutputInterface, +} from "../../../lib/types_ee" + +export const createEvaluatorDataMapping = async ( + config: EvaluatorMappingInput, +): Promise => { + const {projectId} = getProjectValues() + + const response = await axios.post( + `${getAgentaApiUrl()}/evaluators/map?project_id=${projectId}`, + {...config}, + ) + return response.data +} + +export const createEvaluatorRunExecution = async ( + evaluatorKey: string, + config: EvaluatorInputInterface, +): Promise => { + const {projectId} = getProjectValues() + + const response = await axios.post( + `${getAgentaApiUrl()}/evaluators/${evaluatorKey}/run?project_id=${projectId}`, + { + ...config, + }, + ) + return response.data +} diff --git a/web/ee/src/services/evaluations/workerUtils.ts b/web/ee/src/services/evaluations/workerUtils.ts new file mode 100644 index 0000000000..f48a2cba9e --- /dev/null +++ b/web/ee/src/services/evaluations/workerUtils.ts @@ -0,0 +1,157 @@ +import {EvaluationStatus} from "@/oss/lib/Types" + +/** + * Update scenario status from a WebWorker / non-axios context. + */ +export async function updateScenarioStatusRemote( + apiUrl: string, + jwt: string, + scenarioId: string, + status: EvaluationStatus, + projectId: string, + runId?: string, +): Promise { + try { + // 1. Query results to validate scenario context (scenarios GET is deprecated) + const res = await fetch( + `${apiUrl}/preview/evaluations/results/query?project_id=${projectId}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify({ + result: { + scenario_ids: [scenarioId], + ...(runId ? {run_ids: [runId]} : {}), + }, + windowing: {}, + }), + }, + ) + let scenarioFull: any | null = null + if (res.ok) { + // We no longer rely on the scenario payload; server requires id for PATCH + // Keep minimal object; if server returns extra data in future, parse here + scenarioFull = {id: scenarioId} + } + if (!scenarioFull) scenarioFull = {id: scenarioId} + scenarioFull.status = status + await fetch(`${apiUrl}/preview/evaluations/scenarios/?project_id=${projectId}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify({scenarios: [scenarioFull]}), + }) + } catch { + /* swallow */ + } +} + +/** + * Upsert (create or update) a generic scenario step. Can be used for invocation or annotation steps. + */ +export async function upsertScenarioStep(params: { + apiUrl: string + jwt: string + runId: string + scenarioId: string + status: EvaluationStatus + projectId: string + key: string + traceId?: string | null + spanId?: string | null + references?: Record +}): Promise { + const { + apiUrl, + jwt, + runId, + scenarioId, + status, + projectId, + key, + traceId, + spanId, + references = {}, + } = params + try { + const res = await fetch( + `${apiUrl}/preview/evaluations/results/query?project_id=${projectId}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify({ + result: { + run_ids: [runId], + scenario_ids: [scenarioId], + step_keys: [key], + }, + windowing: {}, + }), + }, + ) + if (res.ok) { + const data = await res.json() + const list = Array.isArray(data.results) + ? data.results + : Array.isArray(data.steps) + ? data.steps + : [] + const existing = list.find((s: any) => s.step_key === key || s.stepKey === key) + if (existing) { + const updated = { + ...existing, + status, + trace_id: traceId, + span_id: spanId, + references: {...((existing as any)?.references || {}), ...references}, + } + await fetch(`${apiUrl}/preview/evaluations/results/?project_id=${projectId}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + // API expects bulk-style body: { results: [ { id, ...fields } ] } + body: JSON.stringify({results: [updated]}), + }) + return + } + } + } catch { + /* fallthrough to creation */ + } + + const body = { + results: [ + { + status, + step_key: key, + trace_id: traceId, + span_id: spanId, + scenario_id: scenarioId, + run_id: runId, + references, + }, + ], + } + try { + await fetch(`${apiUrl}/preview/evaluations/results/?project_id=${projectId}`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify(body), + }) + } catch { + /* ignore */ + } +} diff --git a/web/ee/src/services/human-evaluations/api/index.ts b/web/ee/src/services/human-evaluations/api/index.ts new file mode 100644 index 0000000000..20828c1b52 --- /dev/null +++ b/web/ee/src/services/human-evaluations/api/index.ts @@ -0,0 +1,353 @@ +import axios from "@/oss/lib/api/assets/axiosConfig" +import {EvaluationFlow, EvaluationType} from "@/oss/lib/enums" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {assertValidId} from "@/oss/lib/helpers/serviceValidations" +import { + abTestingEvaluationTransformer, + fromEvaluationResponseToEvaluation, + fromEvaluationScenarioResponseToEvaluationScenario, + singleModelTestEvaluationTransformer, +} from "@/oss/lib/transformers" +import { + EvaluationResponseType, + Evaluation, + GenericObject, + CreateCustomEvaluation, + ExecuteCustomEvalCode, + AICritiqueCreate, +} from "@/oss/lib/Types" +import {getProjectValues} from "@/oss/state/project" + +//Prefix convention: +// - fetch: GET single entity from server +// - fetchAll: GET all entities from server +// - create: POST data to server +// - update: PUT data to server +// - delete: DELETE data from server + +export const fetchAllLoadEvaluations = async ( + appId: string, + projectId: string, + ignoreAxiosError = false, +) => { + const app = assertValidId(appId, "appId") + const project = assertValidId(projectId, "projectId") + + const response = await axios.get(`${getAgentaApiUrl()}/human-evaluations`, { + params: {project_id: project, app_id: app}, + _ignoreError: ignoreAxiosError, + } as any) + return response.data +} + +export const fetchLoadEvaluation = async (evaluationId: string) => { + const {projectId} = getProjectValues() + const id = assertValidId(evaluationId, "evaluationId") + const project = assertValidId(projectId, "projectId") + try { + return await axios + .get(`${getAgentaApiUrl()}/human-evaluations/${encodeURIComponent(id)}`, { + params: {project_id: project}, + }) + .then((responseData) => { + return fromEvaluationResponseToEvaluation(responseData.data) + }) + } catch (error) { + if (axios.isCancel?.(error) || (error as any)?.code === "ERR_CANCELED") { + return null + } + console.error(`Error fetching evaluation ${id}:`, error) + return null + } +} + +export const deleteEvaluations = async (ids: string[]) => { + const {projectId} = getProjectValues() + const project = assertValidId(projectId, "projectId") + + const response = await axios({ + method: "delete", + url: `${getAgentaApiUrl()}/human-evaluations`, + params: {project_id: project}, + data: {evaluations_ids: ids}, + }) + return response.data +} + +export const fetchAllLoadEvaluationsScenarios = async ( + evaluationTableId: string, + evaluation: Evaluation, +) => { + const {projectId} = getProjectValues() + const tableId = assertValidId(evaluationTableId, "evaluationTableId") + const project = assertValidId(projectId, "projectId") + + return await axios + .get( + `${getAgentaApiUrl()}/human-evaluations/${encodeURIComponent( + tableId, + )}/evaluation_scenarios`, + {params: {project_id: project}}, + ) + .then((responseData) => { + const evaluationsRows = responseData.data.map((item: any) => { + return fromEvaluationScenarioResponseToEvaluationScenario(item, evaluation) + }) + + return evaluationsRows + }) +} + +export const createNewEvaluation = async ( + { + appId, + variant_ids, + evaluationType, + evaluationTypeSettings, + inputs, + llmAppPromptTemplate, + selectedCustomEvaluationID, + testsetId, + }: { + appId: string + variant_ids: string[] + evaluationType: string + evaluationTypeSettings: Partial + inputs: string[] + llmAppPromptTemplate?: string + selectedCustomEvaluationID?: string + testsetId: string + }, + ignoreAxiosError = false, +) => { + const app = assertValidId(appId, "appId") + const testset = assertValidId(testsetId, "testsetId") + const customId = selectedCustomEvaluationID + ? assertValidId(selectedCustomEvaluationID, "customEvaluationId") + : undefined + + const data = { + variant_ids, + inputs: inputs, + app_id: app, + evaluation_type: evaluationType, + evaluation_type_settings: { + ...evaluationTypeSettings, + custom_code_evaluation_id: customId, + llm_app_prompt_template: llmAppPromptTemplate, + }, + testset_id: testset, + status: EvaluationFlow.EVALUATION_INITIALIZED, + } + + const {projectId} = getProjectValues() + const project = assertValidId(projectId, "projectId") + + const response = await axios.post(`${getAgentaApiUrl()}/human-evaluations`, data, { + params: {project_id: project}, + _ignoreError: ignoreAxiosError, + } as any) + return response.data.id +} + +export const updateEvaluation = async (evaluationId: string, data: GenericObject) => { + const {projectId} = getProjectValues() + const id = assertValidId(evaluationId, "evaluationId") + const project = assertValidId(projectId, "projectId") + + const response = await axios.put( + `${getAgentaApiUrl()}/human-evaluations/${encodeURIComponent(id)}`, + data, + {params: {project_id: project}}, + ) + return response.data +} + +export const updateEvaluationScenario = async ( + evaluationTableId: string, + evaluationScenarioId: string, + data: GenericObject, + evaluationType: EvaluationType, +) => { + const {projectId} = getProjectValues() + const tableId = assertValidId(evaluationTableId, "evaluationTableId") + const scenarioId = assertValidId(evaluationScenarioId, "evaluationScenarioId") + const project = assertValidId(projectId, "projectId") + + const response = await axios.put( + `${getAgentaApiUrl()}/human-evaluations/${encodeURIComponent( + tableId, + )}/evaluation_scenario/${encodeURIComponent(scenarioId)}/${encodeURIComponent( + evaluationType, + )}`, + data, + {params: {project_id: project}}, + ) + return response.data +} + +export const createEvaluationScenario = async (evaluationTableId: string, data: GenericObject) => { + const {projectId} = getProjectValues() + const tableId = assertValidId(evaluationTableId, "evaluationTableId") + const project = assertValidId(projectId, "projectId") + + const response = await axios.post( + `${getAgentaApiUrl()}/human-evaluations/${encodeURIComponent(tableId)}/evaluation_scenario`, + data, + {params: {project_id: project}}, + ) + return response.data +} + +export const createEvaluateAICritiqueForEvalScenario = async ( + data: AICritiqueCreate, + ignoreAxiosError = false, +) => { + const {projectId} = getProjectValues() + const project = assertValidId(projectId, "projectId") + + const response = await axios.post( + `${getAgentaApiUrl()}/human-evaluations/evaluation_scenario/ai_critique`, + data, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response +} + +export const fetchEvaluationResults = async (evaluationId: string, ignoreAxiosError = false) => { + const {projectId} = getProjectValues() + const id = assertValidId(evaluationId, "evaluationId") + const project = assertValidId(projectId, "projectId") + + const response = await axios.get( + `${getAgentaApiUrl()}/human-evaluations/${encodeURIComponent(id)}/results`, + { + params: {project_id: project}, + _ignoreError: ignoreAxiosError, + } as any, + ) + return response.data as EvaluationResponseType +} + +export const fetchEvaluationScenarioResults = async (evaluation_scenario_id: string) => { + const {projectId} = getProjectValues() + const scenarioId = assertValidId(evaluation_scenario_id, "evaluation_scenario_id") + const project = assertValidId(projectId, "projectId") + + const response = await axios.get( + `${getAgentaApiUrl()}/human-evaluations/evaluation_scenario/${encodeURIComponent( + scenarioId, + )}/score`, + {params: {project_id: project}}, + ) + return response +} + +export const createCustomCodeEvaluation = async ( + payload: CreateCustomEvaluation, + ignoreAxiosError = false, +) => { + const {projectId} = getProjectValues() + const project = assertValidId(projectId, "projectId") + + const response = await axios.post( + `${getAgentaApiUrl()}/human-evaluations/custom_evaluation`, + payload, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response +} + +export const updateCustomEvaluationDetail = async ( + id: string, + payload: CreateCustomEvaluation, + ignoreAxiosError = false, +) => { + const {projectId} = getProjectValues() + const customId = assertValidId(id, "custom_evaluation_id") + const project = assertValidId(projectId, "projectId") + + const response = await axios.put( + `${getAgentaApiUrl()}/human-evaluations/custom_evaluation/${encodeURIComponent(customId)}`, + payload, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response +} + +export const fetchCustomEvaluations = async (app_id: string, ignoreAxiosError = false) => { + const {projectId} = getProjectValues() + const appId = assertValidId(app_id, "app_id") + const project = assertValidId(projectId, "projectId") + + const response = await axios.get( + `${getAgentaApiUrl()}/human-evaluations/custom_evaluation/list/${encodeURIComponent( + appId, + )}`, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response +} + +export const fetchCustomEvaluationDetail = async (id: string, ignoreAxiosError = false) => { + const {projectId} = getProjectValues() + const customId = assertValidId(id, "custom_evaluation_id") + const project = assertValidId(projectId, "projectId") + + const response = await axios.get( + `${getAgentaApiUrl()}/human-evaluations/custom_evaluation/${encodeURIComponent(customId)}`, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response.data +} + +export const fetchCustomEvaluationNames = async (app_id: string, ignoreAxiosError = false) => { + const {projectId} = getProjectValues() + const appId = assertValidId(app_id, "app_id") + const project = assertValidId(projectId, "projectId") + + const response = await axios.get( + `${getAgentaApiUrl()}/human-evaluations/custom_evaluation/${encodeURIComponent( + appId, + )}/names`, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response +} + +export const createExecuteCustomEvaluationCode = async ( + payload: ExecuteCustomEvalCode, + ignoreAxiosError = false, +) => { + const {projectId} = getProjectValues() + const project = assertValidId(projectId, "projectId") + const evalId = assertValidId(payload.evaluation_id, "evaluation_id") + + const response = await axios.post( + `${getAgentaApiUrl()}/human-evaluations/custom_evaluation/execute/${encodeURIComponent( + evalId, + )}`, + payload, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response +} + +export const updateEvaluationScenarioScore = async ( + evaluation_scenario_id: string, + score: number, + ignoreAxiosError = false, +) => { + const {projectId} = getProjectValues() + const scenarioId = assertValidId(evaluation_scenario_id, "evaluation_scenario_id") + const project = assertValidId(projectId, "projectId") + + const response = await axios.put( + `${getAgentaApiUrl()}/human-evaluations/evaluation_scenario/${encodeURIComponent( + scenarioId, + )}/score`, + {score}, + {params: {project_id: project}, _ignoreError: ignoreAxiosError} as any, + ) + return response +} diff --git a/web/ee/src/services/human-evaluations/hooks/useEvaluationResults.ts b/web/ee/src/services/human-evaluations/hooks/useEvaluationResults.ts new file mode 100644 index 0000000000..609baeb737 --- /dev/null +++ b/web/ee/src/services/human-evaluations/hooks/useEvaluationResults.ts @@ -0,0 +1,26 @@ +import type {SWRConfiguration} from "swr" +import useSWR from "swr" + +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {getProjectValues} from "@/oss/state/project" + +interface UseEvaluationResultsOptions extends SWRConfiguration { + evaluationId?: string +} + +export const useEvaluationResults = ({evaluationId, ...rest}: UseEvaluationResultsOptions = {}) => { + const {projectId} = getProjectValues() + + const swr = useSWR( + evaluationId && projectId + ? `${getAgentaApiUrl()}/human-evaluations/${evaluationId}/results?project_id=${projectId}` + : null, + { + ...rest, + revalidateOnFocus: false, + shouldRetryOnError: false, + }, + ) + + return swr +} diff --git a/web/ee/src/services/observability/api/helper.ts b/web/ee/src/services/observability/api/helper.ts new file mode 100644 index 0000000000..11b6616843 --- /dev/null +++ b/web/ee/src/services/observability/api/helper.ts @@ -0,0 +1,61 @@ +import {GenerationDashboardData, TracingDashboardData} from "@/oss/lib/types_ee" +import dayjs from "dayjs" + +export const normalizeDurationSeconds = (d = 0) => d / 1_000 + +export const formatTick = (ts: number | string, range: string) => + dayjs(ts).format(range === "24_hours" ? "h:mm a" : range === "7_days" ? "ddd" : "D MMM") + +export function tracingToGeneration( + tracing: TracingDashboardData, + range: string, +): GenerationDashboardData { + const buckets = tracing.buckets ?? [] + + let successCount = 0 + let errorCount = 0 + let totalCost = 0 + let totalTokens = 0 + let totalSuccessDuration = 0 + + const data = buckets.map((b) => { + const succC = b.total?.count ?? 0 + const errC = b.errors?.count ?? 0 + + const succCost = b.total?.costs ?? 0 + const errCost = b.errors?.costs ?? 0 + + const succTok = b.total?.tokens ?? 0 + const errTok = b.errors?.tokens ?? 0 + + const succDurS = normalizeDurationSeconds(b.total?.duration ?? 0) + + successCount += succC + errorCount += errC + totalCost += succCost + errCost + totalTokens += succTok + errTok + totalSuccessDuration += succDurS + + return { + timestamp: formatTick(b.timestamp, range), + success_count: succC, + failure_count: errC, + cost: succCost + errCost, + latency: succC ? succDurS / Math.max(succC, 1) : 0, // avg latency per success in the bucket + total_tokens: succTok + errTok, + } + }) + + const totalCount = successCount + errorCount + + return { + data, + total_count: totalCount, + failure_rate: totalCount ? errorCount / totalCount : 0, + total_cost: totalCost, + avg_cost: totalCount ? totalCost / totalCount : 0, + avg_latency: successCount ? totalSuccessDuration / successCount : 0, + total_tokens: totalTokens, + avg_tokens: totalCount ? totalTokens / totalCount : 0, + } +} diff --git a/web/ee/src/services/observability/api/index.ts b/web/ee/src/services/observability/api/index.ts new file mode 100644 index 0000000000..2dabfe1c40 --- /dev/null +++ b/web/ee/src/services/observability/api/index.ts @@ -0,0 +1,168 @@ +import axios from "@/oss/lib/api/assets/axiosConfig" +import {delay, pickRandom} from "@/oss/lib/helpers/utils" +import {GenericObject, WithPagination} from "@/oss/lib/Types" +import {Generation, GenerationDetails, Trace, TracingDashboardData} from "@/oss/lib/types_ee" +import {getProjectValues} from "@/oss/state/project" + +import {tracingToGeneration} from "./helper" +import {ObservabilityMock} from "./mock" + +//Prefix convention: +// - fetch: GET single entity from server +// - fetchAll: GET all entities from server +// - create: POST data to server +// - update: PUT data to server +// - delete: DELETE data from server + +const mock = false + +interface TableParams { + pagination?: { + page: number + pageSize: number + } + sorters?: GenericObject + filters?: GenericObject +} + +function tableParamsToApiParams(options?: Partial) { + const {page = 1, pageSize = 20} = options?.pagination || {} + const res: GenericObject = {page, pageSize} + if (options?.sorters) { + Object.entries(options.sorters).forEach( + ([key, val]) => (res[key] = val === "ascend" ? "asc" : "desc"), + ) + } + if (options?.filters) { + Object.entries(options.filters).forEach(([key, val]) => (res[key] = val)) + } + return res +} + +const generations = pickRandom(ObservabilityMock.generations, 100).map((item, ix) => ({ + ...item, + id: ix + 1 + "", +})) + +export const fetchAllGenerations = async (appId: string, options?: Partial) => { + const {projectId} = getProjectValues() + + const params = tableParamsToApiParams(options) + if (mock) { + const {page, pageSize} = params + await delay(800) + return { + data: generations.slice((page - 1) * pageSize, page * pageSize), + total: generations.length, + page, + pageSize, + } as WithPagination + } + + const response = await axios.get(`/observability/spans?project_id=${projectId}`, { + params: {app_id: appId, type: "generation", ...params}, + }) + return response.data as WithPagination +} + +export const fetchGeneration = async (generationId: string) => { + const {projectId} = getProjectValues() + + if (mock) { + await delay(800) + const generation = generations.find((item) => item.id === generationId) + if (!generation) throw new Error("not found!") + + return { + ...generation, + ...ObservabilityMock.generationDetail, + } as GenerationDetails + } + + const response = await axios.get( + `/observability/spans/${generationId}?project_id=${projectId}`, + { + params: {type: "generation"}, + }, + ) + return response.data as GenerationDetails +} + +export const fetchGenerationsDashboardData = async ( + appId: string | null | undefined, + _options: { + range: string + environment?: string + variant?: string + projectId?: string + signal?: AbortSignal + }, +) => { + const {projectId: propsProjectId, signal, ...options} = _options + const {projectId: _projectId} = getProjectValues() + const projectId = propsProjectId || _projectId + + const {range} = options + + if (signal?.aborted) { + throw new DOMException("Aborted", "AbortError") + } + + const responseTracing = await axios.post( + `/preview/tracing/spans/analytics?project_id=${projectId}`, + { + focus: "trace", + interval: 720, + filter: { + conditions: [ + { + field: "references", + operator: "in", + value: [ + { + id: appId, + }, + ], + }, + ], + }, + }, + ) + + const valTracing = responseTracing.data as TracingDashboardData + return tracingToGeneration(valTracing, range) +} + +export const deleteGeneration = async ( + generationIds: string[], + type = "generation", + ignoreAxiosError = true, +) => { + const {projectId} = getProjectValues() + + await axios.delete(`/observability/spans?project_id=${projectId}`, { + data: generationIds, + _ignoreError: ignoreAxiosError, + } as any) + return true +} + +export const fetchAllTraces = async (appId: string, options?: Partial) => { + const {projectId} = getProjectValues() + + const params = tableParamsToApiParams(options) + if (mock) { + const {page, pageSize} = params + await delay(800) + return { + data: generations.slice((page - 1) * pageSize, page * pageSize), + total: generations.length, + page, + pageSize, + } as WithPagination + } + const response = await axios.get(`/observability/traces?project_id=${projectId}`, { + params: {app_id: appId, type: "generation", ...params}, + }) + return response.data as WithPagination +} diff --git a/web/ee/src/services/observability/api/mock.ts b/web/ee/src/services/observability/api/mock.ts new file mode 100644 index 0000000000..a13e172d46 --- /dev/null +++ b/web/ee/src/services/observability/api/mock.ts @@ -0,0 +1,148 @@ +import dayjs from "dayjs" + +import {randNum} from "@/oss/lib/helpers/utils" +import { + Generation, + GenerationKind, + GenerationDashboardData, + GenerationStatus, +} from "@/oss/lib/types_ee" + +const generations: Generation[] = [ + { + id: "1", + created_at: "2021-10-01T00:00:00Z", + variant: { + variant_id: "1", + variant_name: "default", + revision: 1, + }, + environment: "production", + status: GenerationStatus.OK, + spankind: GenerationKind.LLM, + metadata: { + cost: 0.0001, + latency: 0.32, + usage: { + total_tokens: 72, + prompt_tokens: 25, + completion_tokens: 47, + }, + }, + user_id: "u-8k3j4", + content: { + inputs: [ + {input_name: "country", input_value: "Pakistan"}, + {input_name: "criteria", input_value: "Most population"}, + ], + outputs: ["The most populous city in Pakistan is Karachi"], + internals: [], + }, + }, + { + id: "2", + created_at: "2023-10-01T00:00:00Z", + variant: { + variant_id: "2", + variant_name: "test", + revision: 1, + }, + environment: "staging", + status: GenerationStatus.ERROR, + spankind: GenerationKind.LLM, + metadata: { + cost: 0.0004, + latency: 0.845, + usage: { + total_tokens: 143, + prompt_tokens: 25, + completion_tokens: 118, + }, + }, + user_id: "u-8k3j4", + content: { + inputs: [], + outputs: [], + internals: [], + }, + }, + { + id: "3", + created_at: "2024-10-01T00:00:00Z", + variant: { + variant_id: "1", + variant_name: "default", + revision: 2, + }, + environment: "development", + status: GenerationStatus.OK, + spankind: GenerationKind.LLM, + metadata: { + cost: 0.0013, + latency: 0.205, + usage: { + total_tokens: 61, + prompt_tokens: 25, + completion_tokens: 36, + }, + }, + user_id: "u-7tij2", + content: { + inputs: [], + outputs: [], + internals: [], + }, + }, +] + +const generationDetail = { + content: { + inputs: [ + {input_name: "country", input_value: "Pakistan"}, + {input_name: "criteria", input_value: "Most population"}, + ], + outputs: ["The most populous city in Pakistan is Karachi"], + internals: [], + }, + config: { + system: "You are an expert in geography.", + user: "What is the city of {country} with the criteria {criteria}?", + variables: [ + {name: "country", type: "string"}, + {name: "criteria", type: "string"}, + ], + temperature: 0.7, + model: "gpt-3.5-turbo", + max_tokens: 100, + top_p: 0.9, + frequency_penalty: 0.5, + presence_penalty: 0, + }, +} + +const dashboardData = (count = 300): GenerationDashboardData["data"] => { + return Array(count) + .fill(true) + .map(() => { + const totalTokens = randNum(0, 600) + const promptTokens = randNum(0, 150) + return { + timestamp: randNum(dayjs().subtract(30, "days").valueOf(), dayjs().valueOf()), // b/w last 30 days + success_count: randNum(0, 20), + failure_count: randNum(0, 5), + latency: Math.random() * 1.5, + cost: Math.random() * 0.01, + total_tokens: totalTokens, + prompt_tokens: promptTokens, + completion_tokens: totalTokens - promptTokens, + enviornment: ["production", "staging", "development"][randNum(0, 2)], + variant: "default", + } + }) +} + +export const ObservabilityMock = { + generations, + generationDetail, + dashboardData, +} diff --git a/web/ee/src/services/promptVersioning/api/index.ts b/web/ee/src/services/promptVersioning/api/index.ts new file mode 100644 index 0000000000..d51cd8ac75 --- /dev/null +++ b/web/ee/src/services/promptVersioning/api/index.ts @@ -0,0 +1,41 @@ +import axios from "@/oss/lib/api/assets/axiosConfig" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {getProjectValues} from "@/oss/state/project" + +//Prefix convention: +// - fetch: GET single entity from server +// - fetchAll: GET all entities from server +// - create: POST data to server +// - update: PUT data to server +// - delete: DELETE data from server + +// versioning +export const fetchAllPromptVersioning = async (variantId: string, ignoreAxiosError = false) => { + const {projectId} = getProjectValues() + + const {data} = await axios.get( + `${getAgentaApiUrl()}/variants/${variantId}/revisions?project_id=${projectId}`, + { + _ignoreError: ignoreAxiosError, + } as any, + ) + + return data +} + +export const fetchPromptRevision = async ( + variantId: string, + revisionNumber: number, + ignoreAxiosError = false, +) => { + const {projectId} = getProjectValues() + + const {data} = await axios.get( + `${getAgentaApiUrl()}/variants/${variantId}/revisions/${revisionNumber}?project_id=${projectId}`, + { + _ignoreError: ignoreAxiosError, + } as any, + ) + + return data +} diff --git a/web/ee/src/services/runMetrics/api/assets/contants.ts b/web/ee/src/services/runMetrics/api/assets/contants.ts new file mode 100644 index 0000000000..f1d8278bd0 --- /dev/null +++ b/web/ee/src/services/runMetrics/api/assets/contants.ts @@ -0,0 +1,18 @@ +export const PERCENTILE_STOPS = [ + 0.05, 0.1, 0.5, 1, 2.5, 5, 10, 12.5, 20, 25, 30, 37.5, 40, 50, 60, 62.5, 70, 75, 80, 87.5, 90, + 95, 97.5, 99, 99.5, 99.9, 99.95, +] + +// Inter-quartile ranges aligned with backend mapping +export const iqrsLevels: Record = { + iqr25: ["p37.5", "p62.5"], + iqr50: ["p25", "p75"], + iqr60: ["p20", "p80"], + iqr75: ["p12.5", "p87.5"], + iqr80: ["p10", "p90"], + iqr90: ["p5", "p95"], + iqr95: ["p2.5", "p97.5"], + iqr98: ["p1", "p99"], + iqr99: ["p0.5", "p99.5"], + "iqr99.9": ["p0.05", "p99.95"], +} diff --git a/web/ee/src/services/runMetrics/api/index.ts b/web/ee/src/services/runMetrics/api/index.ts new file mode 100644 index 0000000000..2ba88cad4a --- /dev/null +++ b/web/ee/src/services/runMetrics/api/index.ts @@ -0,0 +1,696 @@ +import {iqrsLevels, PERCENTILE_STOPS} from "./assets/contants" +import {BasicStats} from "./types" + +export const METRICS_ENDPOINT = "/preview/evaluations/metrics/" + +const fetchJSON = async (url: string, options: RequestInit) => { + const res = await fetch(url, options) + if (!res.ok) throw new Error(res.statusText) + return res.json() +} + +// /** +// * Create a new run-level metric entry. +// * +// * @param apiUrl The URL of the API service to create the metric against. +// * @param jwt The JWT token to authenticate the request. +// * @param runId The UUID of the evaluation run to associate with the metric. +// * @param data A dictionary of string keys to numeric values representing the +// * metric data. +// * +// * @returns The newly created metric object (snake_case). +// */ +// export const createRunMetrics = async ( +// apiUrl: string, +// jwt: string, +// runId: string, +// data: Record, +// projectId: string, +// ) => { +// const payload = {metrics: [{run_id: runId, data}]} +// return fetchJSON(`${apiUrl}${METRICS_ENDPOINT}?project_id=${projectId}`, { +// method: "POST", +// headers: { +// "Content-Type": "application/json", +// Authorization: `Bearer ${jwt}`, +// }, +// body: JSON.stringify(payload), +// }) +// } + +/** + * Creates a new run-level metric or updates an existing one. + * + * This function will first attempt to fetch the existing metric associated + * with the given runId. If a metric is found, it will be updated with the + * new data. If no existing metric is found, a new metric entry will be + * created. + * + * @param apiUrl The base URL of the API service. + * @param jwt The JWT token used for authenticating the request. + * @param runId The UUID of the evaluation run to associate with the metrics. + * @param data A dictionary of string keys to numeric values representing the + * metric data. + * + * @returns The newly created or updated metric object (snake_case). + */ +// export const upsertRunMetrics = async ( +// apiUrl: string, +// jwt: string, +// runId: string, +// data: Record, +// projectId: string, +// ) => { +// try { +// const params = new URLSearchParams({ +// run_ids: runId, +// }) +// const res = await fetchJSON(`${apiUrl}${METRICS_ENDPOINT}?${params.toString()}`, { +// headers: {Authorization: `Bearer ${jwt}`}, +// }) +// const existing = Array.isArray(res.metrics) ? res.metrics[0] : undefined +// if (existing) { +// const merged = {...(existing.data || {}), ...data} +// return updateMetric(apiUrl, jwt, existing.id, { +// data: merged, +// status: existing.status || "finished", +// tags: existing.tags, +// meta: existing.meta, +// }) +// } +// } catch { +// /* ignore lookup errors and fall back to creation */ +// } +// return createRunMetrics(apiUrl, jwt, runId, data, projectId) +// } + +/** + * Create or update scenario-level metrics for a specific evaluation run. + * + * This function takes a list of scenario metric entries and attempts to + * either create new metrics or update existing ones based on the provided + * runId and scenarioId. If a metric already exists for a given scenario, + * it is updated with the new data. If no existing metric is found, a new + * metric entry is created. + * + * @param apiUrl The base URL of the API service. + * @param jwt The JWT token used for authenticating the request. + * @param runId The UUID of the evaluation run to associate with the metrics. + * @param entries An array of objects containing scenarioId and data to + * be stored as metrics. + * + * @returns A promise that resolves when all create or update operations + * have been completed. + */ +export const createScenarioMetrics = async ( + apiUrl: string, + jwt: string, + runId: string, + entries: {scenarioId: string; data: Record}[], + projectId: string, +) => { + const toCreate: {run_id: string; scenario_id: string; data: Record}[] = [] + const toUpdate: { + id: string + data: Record + status?: string + tags?: Record + meta?: Record + }[] = [] + + const queryUrl = `${apiUrl}${METRICS_ENDPOINT}query?project_id=${projectId}` + const existingByScenario: Record = {} + + try { + const payload = { + metrics: { + run_ids: [runId], + scenario_ids: entries.map((entry) => entry.scenarioId), + }, + windowing: {}, + } + + const queryResponse = await fetchJSON(queryUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify(payload), + }) + + const existingMetrics = Array.isArray(queryResponse?.metrics) ? queryResponse.metrics : [] + + existingMetrics.forEach((metric: any) => { + const scenarioId = metric?.scenario_id || metric?.scenarioId + if (scenarioId) { + existingByScenario[scenarioId] = metric + } + }) + } catch (error) { + console.warn("[createScenarioMetrics] Failed to query existing metrics", error) + } + + for (const entry of entries) { + const existing = existingByScenario[entry.scenarioId] + if (existing) { + const mergedData = { + ...(existing.data || {}), + ...entry.data, + } + if (existing.id) { + toUpdate.push({ + id: existing.id, + data: mergedData, + status: existing.status, + tags: existing.tags, + meta: existing.meta, + }) + continue + } + } + toCreate.push({run_id: runId, scenario_id: entry.scenarioId, data: entry.data}) + } + + const promises: Promise[] = [] + if (toCreate.length) { + promises.push( + fetchJSON(`${apiUrl}${METRICS_ENDPOINT}?project_id=${projectId}`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify({metrics: toCreate}), + }), + ) + } + if (toUpdate.length) { + promises.push( + fetchJSON(`${apiUrl}${METRICS_ENDPOINT}?project_id=${projectId}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify({metrics: toUpdate}), + }), + ) + } + return Promise.all(promises) +} + +/** + * Update a single metric entry. + * + * @param apiUrl The URL of the API service to create the metric against. + * @param jwt The JWT token to authenticate the request. + * @param metricId The UUID of the metric to update. + * @param changes A dictionary of changes to apply to the metric. + * + * @returns The updated metric object (snake_case). + */ +export const updateMetric = async ( + apiUrl: string, + jwt: string, + metricId: string, + changes: { + data?: Record + status?: string + tags?: Record + meta?: Record + }, + projectId: string, +) => { + const payload = {metric: {id: metricId, ...changes}} + return fetchJSON(`${apiUrl}${METRICS_ENDPOINT}${metricId}?project_id=${projectId}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify(payload), + }) +} + +/** + * Update multiple metric entries. + * + * @param apiUrl The URL of the API service to update the metrics against. + * @param jwt The JWT token to authenticate the request. + * @param metrics An array of metric objects to update. Each object should contain + * at least an 'id' property and may contain additional properties + * to update ('data', 'status', 'tags', 'meta'). + * + * @returns An array of the updated metric objects (snake_case). + */ +export const updateMetrics = async ( + apiUrl: string, + jwt: string, + metrics: { + id: string + data?: Record + status?: string + tags?: Record + meta?: Record + }[], + projectId: string, +) => { + return fetchJSON(`${apiUrl}${METRICS_ENDPOINT}?project_id=${projectId}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${jwt}`, + }, + body: JSON.stringify({metrics}), + }) +} + +// --- Statistics helpers -------------------------------------------------- + +/** + * Calculates the p-th percentile of a sorted array of numbers. + * + * @param sorted - An array of numbers sorted in ascending order. + * @param p - The percentile to calculate (between 0 and 100). + * @returns The calculated percentile value. + * If the array is empty, returns 0. + */ +function percentile(sorted: number[], p: number): number { + if (sorted.length === 0) return 0 + const idx = (p / 100) * (sorted.length - 1) + const lower = Math.floor(idx) + const upper = Math.ceil(idx) + if (lower === upper) return sorted[lower] + const weight = idx - lower + return sorted[lower] * (1 - weight) + sorted[upper] * weight +} + +// Helper: round to 'p' decimal places (default 6) and coerce back to number +// Smart rounding: for numbers < 0.001 use significant–figure precision to +// avoid long binary tails; otherwise use fixed decimal rounding. +const round = (v: number, p = 6, sig = 6): number => { + if (Number.isNaN(v)) return v + const abs = Math.abs(v) + if (abs !== 0 && abs < 1e-3) { + return Number(v.toPrecision(sig)) + } + return Number(v.toFixed(p)) +} + +/** + * Builds a histogram distribution from an array of numbers. + * + * This function calculates a histogram by determining the optimal number of bins + * based on the square root of the number of input values. It then computes the + * bin size and assigns each number to a bin. The resulting histogram is returned + * as an array of objects, each containing a bin start value and the count of + * numbers in that bin. + * + * @param values - An array of numbers to create the distribution from. + * @returns An array of objects where each object represents a bin with the + * 'value' as the bin start and 'count' as the number of elements + * in that bin. If all values are the same, returns a single bin + * with the value and the count of elements. + */ +function buildDistribution(values: number[]): {value: number; count: number}[] { + if (!values.length) return [] + + const n = values.length + const bins = Math.ceil(Math.sqrt(n)) + const min = Math.min(...values) + const max = Math.max(...values) + + if (min === max) { + return [{value: round(min, 6), count: n}] + } + + const binSize = (max - min) / bins + // precision = number of decimal places required to keep bin starts stable + const precision = binSize ? Math.max(0, -Math.floor(Math.log10(binSize))) : 0 + + const hist = new Map() + + values.forEach((v) => { + let binIndex = Math.floor((v - min) / binSize) + if (binIndex === bins) binIndex -= 1 // edge case when v === max + const binStart = Number((min + binIndex * binSize).toFixed(precision)) + hist.set(binStart, (hist.get(binStart) ?? 0) + 1) + }) + + return Array.from(hist.entries()) + .sort((a, b) => a[0] - b[0]) + .map(([value, count]) => ({value, count})) +} + +/** + * Computes various statistical measures for a given array of numbers. + * + * @param values - An array of numbers for which statistics are to be computed. + * @returns An object containing the following statistical measures: + * - count: The number of elements in the array. + * - sum: The total sum of the elements. + * - mean: The average value of the elements. + * - min: The minimum value in the array. + * - max: The maximum value in the array. + * - range: The difference between the maximum and minimum values. + * - distribution: A histogram representation of the values. + * - percentiles: An object containing percentile values for defined stops. + * - iqrs: An object containing inter-quartile ranges as per backend mapping. + */ +function computeStats(values: number[]): BasicStats { + const count = values.length + if (count === 0) { + return { + count: 0, + sum: 0, + mean: 0, + min: 0, + max: 0, + range: 0, + distribution: [], + percentiles: {}, + iqrs: {}, + } + } + + const sorted = [...values].sort((a, b) => a - b) + const sum = values.reduce((acc, v) => acc + v, 0) + const mean = sum / count + const min = sorted[0] + const max = sorted[sorted.length - 1] + const range = max - min + + // Percentiles with rounded output + const percentiles: Record = {} + PERCENTILE_STOPS.forEach((p) => { + percentiles[`p${p}`] = round(percentile(sorted, p), 4) + }) + + const iqrs: Record = {} + Object.entries(iqrsLevels).forEach(([label, [low, high]]) => { + iqrs[label] = round(percentiles[high] - percentiles[low], 4) + }) + + const distribution = buildDistribution(values) + const bins = distribution.length + const binSize = bins ? (range !== 0 ? range / bins : 1) : undefined + + return { + count, + sum: round(sum, 6), + mean: round(mean, 6), + min: round(min, 6), + max: round(max, 6), + range: round(range, 6), + distribution, + percentiles, + iqrs, + binSize: binSize !== undefined ? round(binSize, 6) : undefined, + } +} + +// --- Additional helpers for non-numeric metrics ------------------------- + +// Count of values +function count(values: unknown[]): number { + return values.length +} + +// Build frequency list [{value,count}] +function buildFrequency(values: unknown[]): {value: any; count: number}[] { + const freqMap = new Map() + values.forEach((v) => freqMap.set(v, (freqMap.get(v) ?? 0) + 1)) + return Array.from(freqMap.entries()).map(([value, count]) => ({value, count})) +} + +function buildRank(values: unknown[], topK = 10): {value: any; count: number}[] { + return buildFrequency(values) + .sort((a, b) => b.count - a.count) + .slice(0, topK) +} + +function processBinary(values: (boolean | null)[]): BasicStats { + const filtered = values.map((v) => (v === null || v === undefined ? null : v)) + return { + count: count(filtered), + frequency: buildFrequency(filtered), + unique: Array.from(new Set(filtered)), + rank: buildRank(filtered), + } +} + +function processClass(values: (string | number | boolean | null)[]): BasicStats { + return { + count: count(values), + frequency: buildFrequency(values), + unique: Array.from(new Set(values)), + rank: buildRank(values), + } +} + +function processLabels(values: ((string | number | boolean | null)[] | null)[]): BasicStats { + // Flatten labels list + const flat: (string | number | boolean | null)[] = [] + values.forEach((arr) => { + if (Array.isArray(arr)) flat.push(...arr) + else flat.push(null) + }) + // Additionally compute distribution of label counts per record + // const labelCounts = values.map((arr) => (Array.isArray(arr) ? arr.length : 0)) + // const distStats = computeStats(labelCounts) + // const labelValueDistribution = buildFrequency(flat).map((f) => ({ + // value: f.value, + // count: f.count, + // })) + const returnData = { + count: count(flat), + frequency: buildFrequency(flat), + unique: Array.from(new Set(flat)), + rank: buildRank(flat), + } + return returnData +} + +// TODO: Clean this up Ashraf +// Implemented this to handle boolean metric for auto eval +interface BoolCount { + count: number + value: boolean +} +interface ItemShape { + rank?: BoolCount[] + frequency?: BoolCount[] + count?: number // not required for aggregation + unique?: boolean[] // not required for aggregation +} + +interface Summary { + rank: BoolCount[] + count: number + unique: boolean[] + frequency: BoolCount[] +} + +export function aggregateBooleanSummaryByVote(items: ItemShape[]): Summary { + let totalItems = 0 + let votesTrue = 0 + let votesFalse = 0 + + for (const item of items) { + // Prefer rank if present, else fall back to frequency + const source = (item.rank?.length ? item.rank : item.frequency) ?? [] + + if (!source.length) continue + + // Pick the winner for THIS item: + // - If item.rank was provided, assume it's already sorted (winner is source[0]) + // - Otherwise, find the max by count from frequency + let winner: BoolCount | undefined + + if (item.rank?.length) { + winner = source[0] + } else { + winner = source.reduce((best, cur) => { + if (!best) return cur + if (cur.count > best.count) return cur + if (cur.count === best.count) { + // Tie-break: prefer the one that appears first (stable), or prefer true. + // To prefer true on ties, use the following line instead: + // return cur.value === true ? cur : best; + return best + } + return best + }, undefined) + } + + if (winner && typeof winner.value === "boolean") { + totalItems += 1 // this item contributes exactly one vote + if (winner.value) votesTrue += 1 + else votesFalse += 1 + } + } + + // Build totals; keep rank/frequency consistent and sorted by count desc (tie: true first) + const totals: BoolCount[] = [ + {value: true, count: votesTrue}, + {value: false, count: votesFalse}, + ].sort((a, b) => b.count - a.count || (a.value === true ? -1 : 1)) + + return { + rank: totals, + count: totalItems, // <= items.length + unique: [true, false], + frequency: totals, + } +} + +// ------------------------------------------------------------------------ + +/** + * Computes a map of metrics to their computed statistics, given a list of + * objects with `data` properties containing key-value pairs of metric names + * to their respective values. + * + * It will group values by metric key, and compute the following statistics + * for each key: + * + * - `count`: The number of values. + * - `sum`: The sum of all values. + * - `mean`: The mean of all values. + * - `min`: The minimum value. + * - `max`: The maximum value. + * - `range`: The difference between the maximum and minimum values. + * - `distribution`: An array of 11 values representing the distribution of + * values between the minimum and maximum. + * - `percentiles`: An object with keys `pX` where `X` is a percentile (e.g. + * `p25`, `p50`, `p75`), and values that are the corresponding percentiles + * of the values. + * - `iqrs`: An object with keys that are the names of interquartile ranges + * (e.g. `iqr25`, `iqr50`, `iqr75`), and values that are the corresponding + * interquartile ranges of the values. + * + * @param metrics An array of objects with `data` properties containing key-value pairs of metric names to their respective values. + * @returns An object with metric names as keys, and their computed statistics as values. + */ +export const computeRunMetrics = (metrics: {data: Record}[]): Record => { + if (!metrics?.length) return {} + + const result: Record = {} + const valueBuckets: Record = {} + + metrics.forEach((m) => { + Object.entries(m.data || {}).forEach(([k, v]) => { + if (v !== undefined) { + valueBuckets[k] = valueBuckets[k] || [] + valueBuckets[k].push(v) + } + }) + }) + + // Process non-special keys + Object.entries(valueBuckets).forEach(([k, values]) => { + const allNumbers = values.every((v) => typeof v === "number" && !isNaN(v)) + const allBooleans = values.every((v) => typeof v === "boolean" || v === null) + const proccesdBooleans = values.every( + (v) => v?.unique?.length && typeof v?.unique?.[0] === "boolean", + ) + const allArrays = values.every((v) => Array.isArray(v)) + const allStatsObjects = values.every( + (v) => + v && + typeof v === "object" && + !Array.isArray(v) && + ("mean" in (v as any) || + "sum" in (v as any) || + "count" in (v as any) || + "frequency" in (v as any) || + "rank" in (v as any)), + ) + + if (allNumbers) { + result[k] = computeStats(values as number[]) + } else if (allBooleans) { + result[k] = processBinary(values as (boolean | null)[]) + } else if (proccesdBooleans) { + result[k] = aggregateBooleanSummaryByVote(values) + } else if (allArrays) { + result[k] = processLabels(values as any[][]) // treat as labels metric + } else if (allStatsObjects) { + const merged = values.reduce((acc: any, current: any) => { + if (!acc) return current + const next: any = {...acc} + if (typeof current.mean === "number") next.mean = current.mean + if (typeof current.sum === "number") next.sum = current.sum + if (typeof current.count === "number") { + next.count = (next.count ?? 0) + (current.count ?? 0) + } + if (Array.isArray(current.frequency)) next.frequency = current.frequency + if (Array.isArray(current.rank)) next.rank = current.rank + if (Array.isArray(current.unique)) next.unique = current.unique + if (Array.isArray(current.distribution)) next.distribution = current.distribution + if (current.percentiles) next.percentiles = current.percentiles + if (current.iqrs) next.iqrs = current.iqrs + if (typeof current.binSize === "number") next.binSize = current.binSize + return next + }, null) + const finalStats = merged ?? values[0] + if (finalStats && Array.isArray(finalStats.frequency)) { + finalStats.frequency = finalStats.frequency.map((entry: any) => ({ + value: entry?.value, + count: entry?.count ?? entry?.frequency ?? 0, + })) + finalStats.frequency.sort( + (a: any, b: any) => b.count - a.count || (a.value === true ? -1 : 1), + ) + finalStats.rank = finalStats.frequency + if (!Array.isArray(finalStats.unique) || !finalStats.unique.length) { + finalStats.unique = finalStats.frequency.map((entry: any) => entry.value) + } + } + result[k] = finalStats + } else { + // Default to class metric for strings / mixed primitives + // result[k] = processClass(values as any[]) + } + }) + + return result +} + +export interface MetricDistribution { + distribution: {value: number; count: number}[] + mean: number + min: number + max: number + binSize: number +} + +export const computeMetricDistribution = ( + values: number[], + stats?: BasicStats, +): MetricDistribution | undefined => { + let computed = stats + if (!computed) { + if (!values.length) return undefined + const tmpKey = "__metric" + const agg = computeRunMetrics(values.map((v) => ({data: {[tmpKey]: v}}))) + computed = agg[tmpKey] + } + if (!computed?.distribution || !computed.distribution.length) { + return computed + } + let binSize = computed.binSize + if (binSize === undefined) { + const bins = computed.distribution.length + const range = computed.range ?? (computed.max ?? 0) - (computed.min ?? 0) + binSize = bins ? (range !== 0 ? range / bins : 1) : 1 + } + return { + distribution: computed.distribution, + mean: computed.mean ?? 0, + min: computed.min ?? 0, + max: computed.max ?? 0, + binSize, + } +} diff --git a/web/ee/src/services/runMetrics/api/types.ts b/web/ee/src/services/runMetrics/api/types.ts new file mode 100644 index 0000000000..97a59c2a22 --- /dev/null +++ b/web/ee/src/services/runMetrics/api/types.ts @@ -0,0 +1,22 @@ +// Aggregated statistics for a metric. +// Only a subset of these properties will be present depending on the metric type. +export interface BasicStats { + // Always present --------------------------------------------------------- + count: number + + // Numeric metrics ------------------------------------------------------- + sum?: number + mean?: number + min?: number + max?: number + range?: number + distribution?: {value: number; count: number}[] + percentiles?: Record + iqrs?: Record + binSize?: number + + // Categorical / binary metrics ----------------------------------------- + frequency?: {value: string | number | boolean | null; count: number}[] + unique?: (string | number | boolean | null)[] + rank?: {value: string | number | boolean | null; count: number}[] +} diff --git a/web/ee/src/services/variantConfigs/api/index.ts b/web/ee/src/services/variantConfigs/api/index.ts new file mode 100644 index 0000000000..1190123367 --- /dev/null +++ b/web/ee/src/services/variantConfigs/api/index.ts @@ -0,0 +1,77 @@ +import axios from "@/oss/lib/api/assets/axiosConfig" + +export interface VariantReferenceRequest { + projectId: string + application: { + id?: string + slug?: string + } + variant: { + id?: string + slug?: string + version?: number | null + } +} + +export interface VariantConfigResponse { + params?: Record + url?: string | null + application_ref?: { + id?: string + slug?: string + } + variant_ref?: { + id?: string + slug?: string + version?: number | null + } + service_ref?: { + id?: string + slug?: string + version?: number | null + } +} + +const isEmpty = (obj: Record) => + Object.values(obj).every((value) => value === undefined || value === null) + +export const fetchVariantConfig = async ({ + projectId, + application, + variant, +}: VariantReferenceRequest): Promise => { + if (!projectId) { + throw new Error("Project id is required to fetch variant config") + } + + const payload: Record = {} + + if (!isEmpty(application)) { + payload.application_ref = application + } + + if (!isEmpty(variant)) { + payload.variant_ref = variant + } + + if (!payload.variant_ref) { + throw new Error("Variant reference is required to fetch variant config") + } + + try { + const response = await axios.post( + `/variants/configs/fetch?project_id=${projectId}`, + payload, + { + _ignoreError: true, + } as any, + ) + + return (response.data as VariantConfigResponse) ?? null + } catch (error: any) { + if (error?.response?.status === 404) { + return null + } + throw error + } +} diff --git a/web/ee/src/state/billing/atoms.ts b/web/ee/src/state/billing/atoms.ts new file mode 100644 index 0000000000..a6e2fc02b8 --- /dev/null +++ b/web/ee/src/state/billing/atoms.ts @@ -0,0 +1,239 @@ +import {atom} from "jotai" +import {atomWithMutation, atomWithQuery} from "jotai-tanstack-query" + +import axios from "@/oss/lib/api/assets/axiosConfig" +import {getAgentaApiUrl} from "@/oss/lib/helpers/api" +import {User} from "@/oss/lib/Types" +import {selectedOrgIdAtom} from "@/oss/state/org" +import {profileQueryAtom} from "@/oss/state/profile/selectors/user" +import {projectIdAtom} from "@/oss/state/project" + +import {BillingPlan, DataUsageType, SubscriptionType} from "../../services/billing/types" + +/** + * Query atom for fetching billing usage data + * Only enabled when user is authenticated and project is not default + */ +export const usageQueryAtom = atomWithQuery((get) => { + const profileQuery = get(profileQueryAtom) + const user = profileQuery.data as User | undefined + const projectId = get(projectIdAtom) + + return { + queryKey: ["billing", "usage", projectId, user?.id], + queryFn: async () => { + const response = await axios.get( + `${getAgentaApiUrl()}/billing/usage?project_id=${projectId}`, + ) + return response.data as DataUsageType + }, + staleTime: 1000 * 60 * 2, // 2 minutes + refetchOnWindowFocus: false, + refetchOnReconnect: false, + refetchOnMount: true, + enabled: !!user && !!projectId, + retry: (failureCount, error) => { + // Don't retry on client errors + if ((error as any)?.response?.status >= 400 && (error as any)?.response?.status < 500) { + return false + } + return failureCount < 2 + }, + } +}) + +/** + * Query atom for fetching subscription data + * Only enabled when user is authenticated and project is not default + */ +export const subscriptionQueryAtom = atomWithQuery((get) => { + const profileQuery = get(profileQueryAtom) + const user = profileQuery.data as User | undefined + const projectId = get(projectIdAtom) + const orgId = get(selectedOrgIdAtom) + + return { + queryKey: ["billing", "subscription", projectId, user?.id, orgId], + queryFn: async () => { + const response = await axios.get( + `${getAgentaApiUrl()}/billing/subscription?project_id=${projectId}`, + ) + return response.data as SubscriptionType + }, + staleTime: 1000 * 60 * 5, // 5 minutes + refetchOnWindowFocus: false, + refetchOnReconnect: false, + refetchOnMount: true, + enabled: !!orgId && !!user && !!projectId, + retry: (failureCount, error) => { + // Don't retry on client errors + if ((error as any)?.response?.status >= 400 && (error as any)?.response?.status < 500) { + return false + } + return failureCount < 2 + }, + } +}) + +/** + * Query atom for fetching pricing plans + * Only enabled when user is authenticated and project is not default + */ +export const pricingPlansQueryAtom = atomWithQuery((get) => { + const profileQuery = get(profileQueryAtom) + const user = profileQuery.data as User | undefined + const projectId = get(projectIdAtom) + + return { + queryKey: ["billing", "plans", projectId, user?.id], + queryFn: async () => { + const response = await axios.get( + `${getAgentaApiUrl()}/billing/plans?project_id=${projectId}`, + ) + return response.data as BillingPlan[] + }, + staleTime: 1000 * 60 * 10, // 10 minutes - plans don't change often + refetchOnWindowFocus: false, + refetchOnReconnect: false, + refetchOnMount: true, + enabled: !!user && !!projectId, + retry: (failureCount, error) => { + // Don't retry on client errors + if ((error as any)?.response?.status >= 400 && (error as any)?.response?.status < 500) { + return false + } + return failureCount < 2 + }, + } +}) + +/** + * Mutation atom for switching subscription plans + */ +export const switchSubscriptionMutationAtom = atomWithMutation(() => ({ + mutationFn: async (payload: {plan: string}) => { + const store = await import("jotai").then((m) => m.getDefaultStore()) + const projectId = store.get(projectIdAtom) + + const response = await axios.post( + `${getAgentaApiUrl()}/billing/plans/switch?plan=${payload.plan}&project_id=${projectId}`, + ) + return response.data + }, + onSuccess: () => { + // Subscription data will be invalidated by the hook + }, +})) + +/** + * Mutation atom for canceling subscription + */ +export const cancelSubscriptionMutationAtom = atomWithMutation(() => ({ + mutationFn: async () => { + const store = await import("jotai").then((m) => m.getDefaultStore()) + const projectId = store.get(projectIdAtom) + + const response = await axios.post( + `${getAgentaApiUrl()}/billing/subscription/cancel?project_id=${projectId}`, + ) + return response.data + }, + onSuccess: () => { + // Subscription data will be invalidated by the hook + }, +})) + +/** + * Mutation atom for creating new subscription checkout + */ +export const checkoutSubscriptionMutationAtom = atomWithMutation(() => ({ + mutationFn: async (payload: {plan: string; success_url: string}) => { + const response = await axios.post( + `${getAgentaApiUrl()}/billing/stripe/checkouts/?plan=${payload.plan}&success_url=${payload.success_url}`, + ) + return response.data + }, +})) + +/** + * Mutation atom for editing subscription info (Stripe portal) + */ +export const editSubscriptionMutationAtom = atomWithMutation(() => ({ + mutationFn: async () => { + const response = await axios.post(`${getAgentaApiUrl()}/billing/stripe/portals/`) + return response.data + }, +})) + +/** + * Action atom for switching subscription with automatic data refresh + */ +export const switchSubscriptionAtom = atom(null, async (get, set, payload: {plan: string}) => { + const switchMutation = get(switchSubscriptionMutationAtom) + + try { + const result = await switchMutation.mutateAsync(payload) + + // Refetch subscription and usage data after successful switch + set(subscriptionQueryAtom) + set(usageQueryAtom) + + return result + } catch (error) { + console.error("Failed to switch subscription:", error) + throw error + } +}) + +/** + * Action atom for canceling subscription with automatic data refresh + */ +export const cancelSubscriptionAtom = atom(null, async (get, set) => { + const cancelMutation = get(cancelSubscriptionMutationAtom) + + try { + const result = await cancelMutation.mutateAsync() + + // Refetch subscription and usage data after successful cancellation + set(subscriptionQueryAtom) + set(usageQueryAtom) + + return result + } catch (error) { + console.error("Failed to cancel subscription:", error) + throw error + } +}) + +/** + * Action atom for checkout with no automatic refresh (redirect expected) + */ +export const checkoutSubscriptionAtom = atom( + null, + async (get, set, payload: {plan: string; success_url: string}) => { + const checkoutMutation = get(checkoutSubscriptionMutationAtom) + + try { + const result = await checkoutMutation.mutateAsync(payload) + return result + } catch (error) { + console.error("Failed to create checkout:", error) + throw error + } + }, +) + +/** + * Action atom for editing subscription info (Stripe portal) + */ +export const editSubscriptionAtom = atom(null, async (get, set) => { + const editMutation = get(editSubscriptionMutationAtom) + + try { + const result = await editMutation.mutateAsync() + return result + } catch (error) { + console.error("Failed to open subscription portal:", error) + throw error + } +}) diff --git a/web/ee/src/state/billing/hooks.ts b/web/ee/src/state/billing/hooks.ts new file mode 100644 index 0000000000..050234db73 --- /dev/null +++ b/web/ee/src/state/billing/hooks.ts @@ -0,0 +1,137 @@ +import {useCallback} from "react" + +import {useAtom, useAtomValue} from "jotai" + +import { + usageQueryAtom, + subscriptionQueryAtom, + pricingPlansQueryAtom, + switchSubscriptionAtom, + cancelSubscriptionAtom, + checkoutSubscriptionAtom, + editSubscriptionAtom, +} from "./atoms" + +/** + * Hook for managing billing usage data + * Provides the same interface as the original SWR-based useUsageData hook + */ +export const useUsageData = () => { + const usageQuery = useAtomValue(usageQueryAtom) + + return { + usage: usageQuery.data, + isUsageLoading: usageQuery.isPending, + mutateUsage: usageQuery.refetch, + error: usageQuery.error, + isError: usageQuery.isError, + isSuccess: usageQuery.isSuccess, + } +} + +/** + * Hook for managing subscription data + * Provides the same interface as the original SWR-based useSubscriptionData hook + */ +export const useSubscriptionData = () => { + const subscriptionQuery = useAtomValue(subscriptionQueryAtom) + + return { + subscription: subscriptionQuery.data, + isSubLoading: subscriptionQuery.isPending, + mutateSubscription: subscriptionQuery.refetch, + error: subscriptionQuery.error, + isError: subscriptionQuery.isError, + isSuccess: subscriptionQuery.isSuccess, + } +} + +/** + * Hook for managing pricing plans data + * Provides the same interface as the original SWR-based usePricingPlans hook + */ +export const usePricingPlans = () => { + const plansQuery = useAtomValue(pricingPlansQueryAtom) + + return { + plans: plansQuery.data, + isLoadingPlan: plansQuery.isPending, + error: plansQuery.error, + isError: plansQuery.isError, + isSuccess: plansQuery.isSuccess, + refetch: plansQuery.refetch, + } +} + +/** + * Hook for managing subscription actions + * Provides mutation functions for subscription management + */ +export const useSubscriptionActions = () => { + const [, switchSubscription] = useAtom(switchSubscriptionAtom) + const [, cancelSubscription] = useAtom(cancelSubscriptionAtom) + const [, checkoutSubscription] = useAtom(checkoutSubscriptionAtom) + const [, editSubscription] = useAtom(editSubscriptionAtom) + + const handleSwitchSubscription = useCallback( + async (payload: {plan: string}) => { + return await switchSubscription(payload) + }, + [switchSubscription], + ) + + const handleCancelSubscription = useCallback(async () => { + return await cancelSubscription() + }, [cancelSubscription]) + + const handleCheckoutSubscription = useCallback( + async (payload: {plan: string; success_url: string}) => { + return await checkoutSubscription(payload) + }, + [checkoutSubscription], + ) + + const handleEditSubscription = useCallback(async () => { + return await editSubscription() + }, [editSubscription]) + + return { + switchSubscription: handleSwitchSubscription, + cancelSubscription: handleCancelSubscription, + checkoutSubscription: handleCheckoutSubscription, + editSubscription: handleEditSubscription, + } +} + +/** + * Combined hook for all billing functionality + * Provides a comprehensive interface for billing management + */ +export const useBilling = () => { + const usage = useUsageData() + const subscription = useSubscriptionData() + const plans = usePricingPlans() + const actions = useSubscriptionActions() + + return { + // Usage data + usage: usage.usage, + isUsageLoading: usage.isUsageLoading, + mutateUsage: usage.mutateUsage, + usageError: usage.error, + + // Subscription data + subscription: subscription.subscription, + isSubLoading: subscription.isSubLoading, + mutateSubscription: subscription.mutateSubscription, + subscriptionError: subscription.error, + + // Plans data + plans: plans.plans, + isLoadingPlan: plans.isLoadingPlan, + plansError: plans.error, + + // Actions + ...actions, + } +} diff --git a/web/ee/src/state/billing/index.ts b/web/ee/src/state/billing/index.ts new file mode 100644 index 0000000000..e49d36ee3b --- /dev/null +++ b/web/ee/src/state/billing/index.ts @@ -0,0 +1,23 @@ +// Billing atoms +export { + usageQueryAtom, + subscriptionQueryAtom, + pricingPlansQueryAtom, + switchSubscriptionMutationAtom, + cancelSubscriptionMutationAtom, + checkoutSubscriptionMutationAtom, + editSubscriptionMutationAtom, + switchSubscriptionAtom, + cancelSubscriptionAtom, + checkoutSubscriptionAtom, + editSubscriptionAtom, +} from "./atoms" + +// Billing hooks +export { + useUsageData, + useSubscriptionData, + usePricingPlans, + useSubscriptionActions, + useBilling, +} from "./hooks" diff --git a/web/ee/src/state/observability/dashboard.ts b/web/ee/src/state/observability/dashboard.ts new file mode 100644 index 0000000000..a21cbd2950 --- /dev/null +++ b/web/ee/src/state/observability/dashboard.ts @@ -0,0 +1,61 @@ +import {useAtom} from "jotai" +import {eagerAtom} from "jotai-eager" +import {atomWithQuery} from "jotai-tanstack-query" + +import {GenerationDashboardData} from "@/oss/lib/types_ee" +import {fetchGenerationsDashboardData} from "@/oss/services/observability/api" +import {routerAppIdAtom} from "@/oss/state/app/atoms/fetcher" +import {projectIdAtom} from "@/oss/state/project" + +const DEFAULT_RANGE = "30_days" + +export const observabilityDashboardQueryAtom = atomWithQuery( + (get) => { + const appId = get(routerAppIdAtom) + const projectId = get(projectIdAtom) + + return { + queryKey: [ + "observability", + "dashboard", + appId ?? "__global__", + projectId ?? null, + DEFAULT_RANGE, + ], + queryFn: async ({signal}) => { + if (!projectId) return null + return fetchGenerationsDashboardData(appId, { + range: DEFAULT_RANGE, + projectId, + signal, + }) + }, + enabled: Boolean(projectId), + staleTime: 1000 * 60, + refetchOnWindowFocus: false, + } + }, +) + +export const observabilityDashboardAtom = eagerAtom((get) => { + const result = (get(observabilityDashboardQueryAtom) as any) + ?.data as GenerationDashboardData | null + return result ?? null +}) + +export const useObservabilityDashboard = () => { + const [query] = useAtom(observabilityDashboardQueryAtom) + + const {data, isPending, isFetching, isLoading, error, refetch, fetchStatus} = query as any + + const fetching = fetchStatus === "fetching" + const loading = Boolean(fetching || isPending || isLoading) + + return { + data: (data as GenerationDashboardData | null) ?? null, + loading, + isFetching: Boolean(isFetching) || fetching, + error, + refetch, + } +} diff --git a/web/ee/src/state/observability/index.ts b/web/ee/src/state/observability/index.ts new file mode 100644 index 0000000000..fec8ad0fe4 --- /dev/null +++ b/web/ee/src/state/observability/index.ts @@ -0,0 +1 @@ +export * from "./dashboard" diff --git a/web/ee/src/state/url/focusDrawer.ts b/web/ee/src/state/url/focusDrawer.ts new file mode 100644 index 0000000000..0a87387caf --- /dev/null +++ b/web/ee/src/state/url/focusDrawer.ts @@ -0,0 +1,131 @@ +import {getDefaultStore} from "jotai" +import Router from "next/router" + +import { + focusDrawerAtom, + openFocusDrawerAtom, + resetFocusDrawerAtom, + setFocusDrawerTargetAtom, +} from "@/oss/components/EvalRunDetails/state/focusScenarioAtom" +import {navigationRequestAtom, type NavigationCommand} from "@/oss/state/appState" + +const isBrowser = typeof window !== "undefined" + +export const FOCUS_SCENARIO_QUERY_KEY = "focusScenarioId" +export const FOCUS_RUN_QUERY_KEY = "focusRunId" + +const ensureCleanFocusParams = (url: URL) => { + let mutated = false + if (url.searchParams.get(FOCUS_SCENARIO_QUERY_KEY)?.trim() === "") { + url.searchParams.delete(FOCUS_SCENARIO_QUERY_KEY) + mutated = true + } + if (url.searchParams.get(FOCUS_RUN_QUERY_KEY)?.trim() === "") { + url.searchParams.delete(FOCUS_RUN_QUERY_KEY) + mutated = true + } + if (!mutated) return false + + const newPath = `${url.pathname}${url.search}${url.hash}` + void Router.replace(newPath, undefined, {shallow: true}).catch((error) => { + console.error("Failed to normalize focus drawer query params:", error) + }) + return true +} + +export const syncFocusDrawerStateFromUrl = (nextUrl?: string) => { + if (!isBrowser) return + + try { + const store = getDefaultStore() + const url = new URL(nextUrl ?? window.location.href, window.location.origin) + + const rawScenario = url.searchParams.get(FOCUS_SCENARIO_QUERY_KEY) + const rawRun = url.searchParams.get(FOCUS_RUN_QUERY_KEY) + const pendingNav = store.get(navigationRequestAtom) as NavigationCommand | null + + const scenarioId = rawScenario?.trim() || undefined + const runId = rawRun?.trim() || undefined + + const currentState = store.get(focusDrawerAtom) + + // Clean up empty params before processing + if (ensureCleanFocusParams(url)) { + // After normalising the URL we bail out; the router callback will re-run with clean params + return + } + + if (!scenarioId) { + const pendingScenarioPatch = + pendingNav?.type === "patch-query" + ? pendingNav.patch[FOCUS_SCENARIO_QUERY_KEY] + : undefined + const hasPendingScenario = + pendingScenarioPatch !== undefined && + (Array.isArray(pendingScenarioPatch) + ? pendingScenarioPatch.length > 0 + : String(pendingScenarioPatch ?? "").length > 0) + if (hasPendingScenario) { + return + } + + const hasStoredTarget = + currentState.focusScenarioId != null || currentState.focusRunId != null + const urlProvided = typeof nextUrl === "string" && nextUrl.length > 0 + // Avoid racing against local open actions (no URL yet) while still reacting to + // deliberate URL transitions that remove the focus params. + const shouldReset = + currentState.isClosing || + (!currentState.open && hasStoredTarget) || + (urlProvided && currentState.open && hasStoredTarget && !currentState.isClosing) + + if (shouldReset) { + store.set(resetFocusDrawerAtom, null) + } + return + } + + const nextTarget = { + focusScenarioId: scenarioId, + focusRunId: runId ?? currentState.focusRunId ?? null, + } + + const alreadyOpen = + currentState.open && + currentState.focusScenarioId === nextTarget.focusScenarioId && + currentState.focusRunId === nextTarget.focusRunId + + if (alreadyOpen && !currentState.isClosing) { + return + } + + // Ensure target is up to date before opening (helps preserve data during transitions) + store.set(setFocusDrawerTargetAtom, nextTarget) + store.set(openFocusDrawerAtom, nextTarget) + } catch (err) { + console.error("Failed to sync focus drawer state from URL:", nextUrl, err) + } +} + +export const clearFocusDrawerQueryParams = () => { + if (!isBrowser) return + try { + const url = new URL(window.location.href) + let mutated = false + if (url.searchParams.has(FOCUS_SCENARIO_QUERY_KEY)) { + url.searchParams.delete(FOCUS_SCENARIO_QUERY_KEY) + mutated = true + } + if (url.searchParams.has(FOCUS_RUN_QUERY_KEY)) { + url.searchParams.delete(FOCUS_RUN_QUERY_KEY) + mutated = true + } + if (!mutated) return + const newPath = `${url.pathname}${url.search}${url.hash}` + void Router.replace(newPath, undefined, {shallow: true}).catch((error) => { + console.error("Failed to clear focus drawer query params:", error) + }) + } catch (err) { + console.error("Failed to clear focus drawer query params:", err) + } +} diff --git a/web/ee/tailwind.config.ts b/web/ee/tailwind.config.ts new file mode 100644 index 0000000000..50a9d02fe3 --- /dev/null +++ b/web/ee/tailwind.config.ts @@ -0,0 +1,3 @@ +import {createConfig} from "@agenta/oss/tailwind.config" + +export default createConfig(["../oss/src/**/*.{js,ts,jsx,tsx}"]) diff --git a/web/ee/tests/1-settings/api-keys-management.spec.ts b/web/ee/tests/1-settings/api-keys-management.spec.ts new file mode 100644 index 0000000000..1395cba61f --- /dev/null +++ b/web/ee/tests/1-settings/api-keys-management.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import apiKeysTests from "@agenta/oss/tests/1-settings/api-keys" + +test.skip("Settings: API Keys Management", apiKeysTests) diff --git a/web/ee/tests/1-settings/model-hub.spec.ts b/web/ee/tests/1-settings/model-hub.spec.ts new file mode 100644 index 0000000000..186de6222c --- /dev/null +++ b/web/ee/tests/1-settings/model-hub.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import modelHubTests from "@agenta/oss/tests/1-settings/model-hub" + +test.describe("Settings: Model Hub", modelHubTests) diff --git a/web/ee/tests/2-app/create.spec.ts b/web/ee/tests/2-app/create.spec.ts new file mode 100644 index 0000000000..de0137e3cd --- /dev/null +++ b/web/ee/tests/2-app/create.spec.ts @@ -0,0 +1,5 @@ +import tests, {test} from "@agenta/oss/tests/2-app" + +test.describe(`EE App Creation Flow`, () => { + tests() +}) diff --git a/web/ee/tests/3-playground/run-variant.spec.ts b/web/ee/tests/3-playground/run-variant.spec.ts new file mode 100644 index 0000000000..5fc8618686 --- /dev/null +++ b/web/ee/tests/3-playground/run-variant.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import playgroundTests from "@agenta/oss/tests/3-playground" + +test.describe("Playground: Run Variant", playgroundTests) diff --git a/web/ee/tests/4-prompt-registry/prompt-registry-flow.spec.ts b/web/ee/tests/4-prompt-registry/prompt-registry-flow.spec.ts new file mode 100644 index 0000000000..511bd060ef --- /dev/null +++ b/web/ee/tests/4-prompt-registry/prompt-registry-flow.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import promptRegistryTests from "@agenta/oss/tests/4-prompt-registry" + +test.describe("Prompt Registry Flow", promptRegistryTests) diff --git a/web/ee/tests/5-testsset/testset.spec.ts b/web/ee/tests/5-testsset/testset.spec.ts new file mode 100644 index 0000000000..5f5ed87486 --- /dev/null +++ b/web/ee/tests/5-testsset/testset.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import testsetTests from "@agenta/oss/tests/5-testsset" + +test.describe("Testsets: Interact with testsets", testsetTests) diff --git a/web/ee/tests/6-auto-evaluation/assets/README.md b/web/ee/tests/6-auto-evaluation/assets/README.md new file mode 100644 index 0000000000..e5e43460b4 --- /dev/null +++ b/web/ee/tests/6-auto-evaluation/assets/README.md @@ -0,0 +1,67 @@ +# Auto Evaluation Test Fixtures + +This directory contains test fixtures for automating the evaluation process in the Agenta platform. These fixtures provide reusable functions to interact with the evaluation UI and perform common evaluation tasks. + +## Available Fixtures + +### 1. `navigateToEvaluation` + +Navigates to the Automatic Evaluation section for a specific application. + +**Parameters:** + +- `appId` (string): The ID of the application to evaluate + +**Usage:** + +```typescript +await test("navigate to evaluation", async ({navigateToEvaluation}) => { + await navigateToEvaluation("your-app-id") +}) +``` + +### 2. `runAutoEvaluation` + +Runs an automatic evaluation with the specified configuration. + +**Parameters (object):** + +- `evaluators` (string[]): List of evaluator names to use +- `testset` (string, optional): Name of the testset to evaluate against +- `variants` (string[]): List of variant names to evaluate + +**Usage:** + +```typescript +await test("run evaluation", async ({runAutoEvaluation}) => { + await runAutoEvaluation({ + evaluators: ["factual-accuracy", "relevance"], + testset: "my-testset", + variants: ["variant-1", "variant-2"], + }) +}) +``` + +## How It Works + +1. **Test Setup**: The fixtures extend the base test fixture with evaluation-specific functionality. +2. **UI Automation**: They handle all the necessary UI interactions, including: + - Navigating to the evaluation section + - Selecting testsets + - Choosing variants + - Configuring evaluators + - Managing the evaluation creation flow +3. **State Management**: The fixtures handle waiting for async operations and ensure the UI is in the correct state before proceeding. + +## Best Practices + +- Always wait for navigation and UI updates to complete +- Use the provided helper methods instead of direct page interactions +- Keep test data (evaluators, testsets, variants) in separate configuration files +- Combine fixtures for complex test scenarios + +## Dependencies + +- Base test fixtures from `@agenta/web-tests` +- Playwright test runner +- Agenta UI components and API helpers diff --git a/web/ee/tests/6-auto-evaluation/assets/types.ts b/web/ee/tests/6-auto-evaluation/assets/types.ts new file mode 100644 index 0000000000..9160b106d5 --- /dev/null +++ b/web/ee/tests/6-auto-evaluation/assets/types.ts @@ -0,0 +1,42 @@ +import {GenerationChatRow, GenerationInputRow} from "@/oss/components/Playground/state/types" +import {ConfigMetadata, OpenAPISpec} from "@/oss/lib/shared/variant/genericTransformer/types" +import {EnhancedVariant} from "@/oss/lib/shared/variant/transformer/types" +import {BaseFixture} from "@agenta/web-tests/tests/fixtures/base.fixture/types" + +export type InvokedVariant = { + variant: EnhancedVariant + allMetadata: Record + inputRow: GenerationInputRow + messageRow?: GenerationChatRow + rowId: string + appId: string + uri: { + runtimePrefix: string + routePath?: string + status?: boolean + } + headers: Record + projectId: string + messageId?: string + chatHistory?: any[] + spec: OpenAPISpec + runId: string +} + +export enum Role { + SYSTEM = "system", + USER = "user", + ASSISTANT = "assistant", + TOOL = "tool", + FUNCTION = "function", +} +export type RunAutoEvalFixtureType = { + evaluators: string[] + testset?: string + variants: string[] +} + +export interface EvaluationFixtures extends BaseFixture { + navigateToEvaluation: (appId: string) => Promise + runAutoEvaluation: (config: RunAutoEvalFixtureType) => Promise +} diff --git a/web/ee/tests/6-auto-evaluation/index.ts b/web/ee/tests/6-auto-evaluation/index.ts new file mode 100644 index 0000000000..420a98798e --- /dev/null +++ b/web/ee/tests/6-auto-evaluation/index.ts @@ -0,0 +1,92 @@ +import {test as baseAutoEvalTest} from "./tests" + +import {expect} from "@agenta/web-tests/utils" +import { + createTagString, + TestCoverage, + TestPath, + TestScope, +} from "@agenta/web-tests/playwright/config/testTags" + +const testAutoEval = () => { + baseAutoEvalTest( + "should run a single evaluation", + { + tag: [ + createTagString("scope", TestScope.EVALUATIONS), + createTagString("coverage", TestCoverage.SMOKE), + createTagString("coverage", TestCoverage.LIGHT), + createTagString("coverage", TestCoverage.FULL), + createTagString("path", TestPath.HAPPY), + ], + }, + async ({page, apiHelpers, runAutoEvaluation, navigateToEvaluation}) => { + // 1. Fetch apps, variants from API + const app = await apiHelpers.getApp("completion") + const appId = app.app_id + + const variants = await apiHelpers.getVariants(appId) + const variantName = variants[0].name || variants[0].variant_name + + // 2. Navigate to evaluation + await navigateToEvaluation(appId) + + // 4. Run auto evaluation + await runAutoEvaluation({ + evaluators: ["Exact Match"], + variants: [variantName], + }) + + await expect(page.locator(".ant-modal").first()).toHaveCount(0) + + // 10. Check evaluation table + const evalTable = page.getByRole("table") + await evalTable.waitFor({state: "visible"}) + + const newRow = evalTable.getByRole("row").first() + await newRow.waitFor({state: "visible"}) + // const evaLoadingState = page.getByText("Running").first() + // await expect(evaLoadingState).toBeVisible() + // await expect(evaLoadingState).not.toBeVisible() + await expect(page.getByText("Completed").first()).toBeVisible() + }, + ) + + baseAutoEvalTest( + "should show an error when attempting to create an evaluation with a mismatched test set", + { + tag: [ + createTagString("scope", TestScope.EVALUATIONS), + createTagString("coverage", TestCoverage.SMOKE), + createTagString("coverage", TestCoverage.LIGHT), + createTagString("coverage", TestCoverage.FULL), + createTagString("path", TestPath.HAPPY), + ], + }, + async ({page, apiHelpers, runAutoEvaluation, navigateToEvaluation}) => { + // 1. Fetch apps, variants from API + const app = await apiHelpers.getApp("chat") + const appId = app.app_id + + const variants = await apiHelpers.getVariants(appId) + const variantName = variants[0].name || variants[0].variant_name + + // 2. Navigate to evaluation + await navigateToEvaluation(appId) + + // 4. Run auto evaluation + await runAutoEvaluation({ + evaluators: ["Exact Match"], + variants: [variantName], + }) + + const message = page.locator(".ant-message").first() + await expect(message).toBeVisible() + await expect(message).toHaveText( + "The testset columns do not match the selected variant input parameters", + ) + }, + ) +} + +export default testAutoEval diff --git a/web/ee/tests/6-auto-evaluation/run-auto-evaluation.spec.ts b/web/ee/tests/6-auto-evaluation/run-auto-evaluation.spec.ts new file mode 100644 index 0000000000..b295d76ced --- /dev/null +++ b/web/ee/tests/6-auto-evaluation/run-auto-evaluation.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import testAutoEval from "." + +test.describe("Auto Evaluation: Run evaluation", testAutoEval) diff --git a/web/ee/tests/6-auto-evaluation/tests.ts b/web/ee/tests/6-auto-evaluation/tests.ts new file mode 100644 index 0000000000..7a29a28a3a --- /dev/null +++ b/web/ee/tests/6-auto-evaluation/tests.ts @@ -0,0 +1,97 @@ +import {test as baseTest} from "@agenta/web-tests/tests/fixtures/base.fixture" +import {expect} from "@agenta/web-tests/utils" +import {EvaluationFixtures, RunAutoEvalFixtureType} from "./assets/types" + +/** + * Evaluation-specific test fixtures extending the base test fixture. + * Provides high-level actions for evaluation tests. + */ +const testWithEvaluationFixtures = baseTest.extend({ + navigateToEvaluation: async ({page, uiHelpers}, use) => { + await use(async (appId: string) => { + await page.goto(`/apps/${appId}/evaluations`) + await uiHelpers.expectPath(`/apps/${appId}/evaluations`) + + // Move to Automatic Evaluation tab + await uiHelpers.clickTab("Automatic Evaluation") + await page.locator("span").filter({hasText: /^Evaluations$/}) + + // Wait for Evaluations to load + const spinner = page.locator(".ant-spin").first() + if (await spinner.count()) { + await spinner.waitFor({state: "hidden"}) + } + }) + }, + + runAutoEvaluation: async ({page, uiHelpers}, use) => { + await use(async ({evaluators, testset, variants}: RunAutoEvalFixtureType) => { + // 1. Open modal + await uiHelpers.clickButton("Start new Evaluation") + const modal = page.locator(".ant-modal").first() + await expect(modal).toBeVisible() + + // Helper: Select tab by name + const goToStep = async (step: string) => { + const tab = modal.getByRole("tab", {name: step}) + await tab.click() + } + + // 2. Select Testset + const selectedTestset = testset + + await goToStep("Testset") + await uiHelpers.selectTableRowInput({ + rowText: selectedTestset, + inputType: "radio", + checked: true, + }) + await expect( + page + .locator(".ant-tabs-tab", {hasText: "Testset"}) + .locator(".ant-tag", {hasText: selectedTestset}), + ).toBeVisible() + + // 3. Select Variant(s) + await goToStep("Variant") + const variantRow = page.getByRole("row").filter({ + has: page + .locator("td", {hasText: variants[0]}) + .locator(".ant-tag", {hasText: "v1"}), + }) + + await expect(variantRow).toBeVisible() + await variantRow.getByRole("radio").check() + + // 4. Select Evaluator(s) + await goToStep("Evaluator") + for (const evaluator of evaluators) { + await uiHelpers.selectTableRowInput({ + rowText: evaluator, + inputType: "checkbox", + checked: true, + }) + await expect( + page + .locator(".ant-tabs-tab", {hasText: "Evaluator"}) + .locator(".ant-tag", {hasText: evaluator}), + ).toBeVisible() + } + + await expect + .poll(async () => { + return await page.locator(".ant-tabs-nav-list .ant-tag").count() + }) + .toBe(3) + + // 5. Create Evaluation + const createButton = page.getByRole("button", {name: "Create"}).last() + await createButton.scrollIntoViewIfNeeded() + await createButton.click() + + await expect(createButton).toHaveClass(/ant-btn-loading/) + }) + }, +}) + +export {testWithEvaluationFixtures as test} diff --git a/web/ee/tests/7-observability/observability.spec.ts b/web/ee/tests/7-observability/observability.spec.ts new file mode 100644 index 0000000000..98908200a9 --- /dev/null +++ b/web/ee/tests/7-observability/observability.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import observabilityTests from "@agenta/oss/tests/7-observability" + +test.describe("Observability: test observability", observabilityTests) diff --git a/web/ee/tests/8-deployment/deploy-variant.spec.ts b/web/ee/tests/8-deployment/deploy-variant.spec.ts new file mode 100644 index 0000000000..0f613a356e --- /dev/null +++ b/web/ee/tests/8-deployment/deploy-variant.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import deploymentTests from "@agenta/oss/tests/8-deployment" + +test.describe("Deployment: test deployment", deploymentTests) diff --git a/web/ee/tests/9-human-annotation/assets/types.ts b/web/ee/tests/9-human-annotation/assets/types.ts new file mode 100644 index 0000000000..968f6d2a00 --- /dev/null +++ b/web/ee/tests/9-human-annotation/assets/types.ts @@ -0,0 +1,22 @@ +import type {BaseFixture} from "@agenta/web-tests/tests/fixtures/base.fixture/types" +import {Locator} from "@agenta/web-tests/utils" + +export type HumanEvaluationConfig = { + testset?: string + variants: string + name: string + skipEvaluatorCreation?: boolean +} + +export interface HumanEvaluationFixtures extends BaseFixture { + navigateToHumanEvaluation: (appId: string) => Promise + navigateToHumanAnnotationRun: (appId: string) => Promise + createHumanEvaluationRun: (config: HumanEvaluationConfig) => Promise + runAllScenarios: () => Promise + verifyStatusUpdate: (row: Locator) => Promise + switchToTableView: () => Promise + runScenarioFromFocusView: () => Promise + navigateBetweenScenarios: () => Promise + annotateFromFocusView: () => Promise + annotateFromTableView: () => Promise +} diff --git a/web/ee/tests/9-human-annotation/human-annotation.spec.ts b/web/ee/tests/9-human-annotation/human-annotation.spec.ts new file mode 100644 index 0000000000..6c26f40717 --- /dev/null +++ b/web/ee/tests/9-human-annotation/human-annotation.spec.ts @@ -0,0 +1,4 @@ +import {test} from "@agenta/web-tests/tests/fixtures/base.fixture" +import humanAnnotationTests from "." + +test.describe("Human Annotation", humanAnnotationTests) diff --git a/web/ee/tests/9-human-annotation/index.ts b/web/ee/tests/9-human-annotation/index.ts new file mode 100644 index 0000000000..4448434e2c --- /dev/null +++ b/web/ee/tests/9-human-annotation/index.ts @@ -0,0 +1,181 @@ +import {test as baseHumanTest, expect} from "./tests" +import { + createTagString, + TestCoverage, + TestPath, + TestScope, +} from "@agenta/web-tests/playwright/config/testTags" + +const humanAnnotationTests = () => { + baseHumanTest( + "should show an error when attempting to create an evaluation with a mismatched test set", + { + tag: [ + createTagString("scope", TestScope.EVALUATIONS), + createTagString("coverage", TestCoverage.SMOKE), + createTagString("coverage", TestCoverage.LIGHT), + createTagString("coverage", TestCoverage.FULL), + createTagString("path", TestPath.HAPPY), + ], + }, + async ({page, apiHelpers, navigateToHumanEvaluation, createHumanEvaluationRun}) => { + const app = await apiHelpers.getApp("chat") + const appId = app.app_id + + const variants = await apiHelpers.getVariants(appId) + const variantName = variants[0].name || variants[0].variant_name + + await navigateToHumanEvaluation(appId) + + await createHumanEvaluationRun({ + variants: variantName, + name: `e2e-human-${Date.now()}`, + }) + + const message = page.locator(".ant-message").first() + await expect(message).toBeVisible() + await expect(message).toHaveText( + "The testset columns do not match the selected variant input parameters", + ) + }, + ) + + baseHumanTest( + "should create human evaluation run", + { + tag: [ + createTagString("scope", TestScope.EVALUATIONS), + createTagString("coverage", TestCoverage.SMOKE), + createTagString("coverage", TestCoverage.LIGHT), + createTagString("coverage", TestCoverage.FULL), + createTagString("path", TestPath.HAPPY), + ], + }, + async ({page, apiHelpers, navigateToHumanEvaluation, createHumanEvaluationRun}) => { + const app = await apiHelpers.getApp() + const appId = app.app_id + + const variants = await apiHelpers.getVariants(appId) + const variantName = variants[0].name || variants[0].variant_name + + await navigateToHumanEvaluation(appId) + + await createHumanEvaluationRun({ + variants: variantName, + name: `e2e-human-${Date.now()}`, + skipEvaluatorCreation: true, + }) + + await expect(page.locator(".ant-modal").first()).toHaveCount(0) + + await expect(page).toHaveURL(/single_model_test\/.*scenarioId=.*/) + }, + ) + + baseHumanTest( + "should run scenarios and update status", + { + tag: [ + createTagString("scope", TestScope.EVALUATIONS), + createTagString("coverage", TestCoverage.LIGHT), + createTagString("coverage", TestCoverage.FULL), + createTagString("path", TestPath.HAPPY), + ], + }, + async ({ + navigateToHumanAnnotationRun, + page, + apiHelpers, + verifyStatusUpdate, + switchToTableView, + runScenarioFromFocusView, + }) => { + const app = await apiHelpers.getApp() + const appId = app.app_id + + await navigateToHumanAnnotationRun(appId) + + // --- Focus View: Single Scenario --- + await runScenarioFromFocusView() + + // --- Focus View: Run All --- + // await page.getByRole("button", {name: "Run All"}).click() + // await expect(page.locator("span").filter({hasText: "Running"})).toBeVisible() + // await expect(page.locator("span").filter({hasText: "Success"})).toBeVisible() + + // --- Table View --- + await switchToTableView() + + // Table Row: Run Individual + const row = page.locator(".ant-table-row").nth(1) + await row.getByRole("button", {name: "Run"}).click() + await verifyStatusUpdate(row) + + // Table View: Run All + await page.getByRole("button", {name: "Run All"}).click() + + const rows = page.locator(".ant-table-row") + const rowCount = await rows.count() + + for (let i = 0; i < rowCount; i++) { + const currentRow = rows.nth(i) + await verifyStatusUpdate(currentRow) + } + }, + ) + + baseHumanTest( + "should allow annotating scenarios", + { + tag: [ + createTagString("scope", TestScope.EVALUATIONS), + createTagString("coverage", TestCoverage.LIGHT), + createTagString("coverage", TestCoverage.FULL), + createTagString("path", TestPath.HAPPY), + ], + }, + async ({ + navigateToHumanAnnotationRun, + apiHelpers, + page, + switchToTableView, + annotateFromFocusView, + annotateFromTableView, + }) => { + const app = await apiHelpers.getApp() + const appId = app.app_id + + await navigateToHumanAnnotationRun(appId) + + await page.locator(".ant-segmented-item").nth(2).click() + + await annotateFromFocusView() + + await switchToTableView() + + // await annotateFromTableView() + }, + ) + + baseHumanTest( + "should navigate scenarios with filters", + { + tag: [ + createTagString("scope", TestScope.EVALUATIONS), + createTagString("coverage", TestCoverage.LIGHT), + createTagString("coverage", TestCoverage.FULL), + createTagString("path", TestPath.HAPPY), + ], + }, + async ({apiHelpers, navigateToHumanAnnotationRun, navigateBetweenScenarios}) => { + const app = await apiHelpers.getApp() + const appId = app.app_id + + await navigateToHumanAnnotationRun(appId) + + await navigateBetweenScenarios() + }, + ) +} + +export default humanAnnotationTests diff --git a/web/ee/tests/9-human-annotation/tests.ts b/web/ee/tests/9-human-annotation/tests.ts new file mode 100644 index 0000000000..ce017df4b7 --- /dev/null +++ b/web/ee/tests/9-human-annotation/tests.ts @@ -0,0 +1,244 @@ +import {test as baseTest} from "@agenta/web-tests/tests/fixtures/base.fixture" +import {expect, Locator} from "@agenta/web-tests/utils" + +import type {HumanEvaluationFixtures, HumanEvaluationConfig} from "./assets/types" +import {waitForApiResponse} from "tests/tests/fixtures/base.fixture/apiHelpers" +import {EvaluationRun} from "@/oss/lib/hooks/usePreviewEvaluations/types" +import {SnakeToCamelCaseKeys} from "@/oss/lib/Types" + +const testWithHumanFixtures = baseTest.extend({ + navigateToHumanEvaluation: async ({page, uiHelpers, apiHelpers}, use) => { + await use(async (appId: string) => { + await page.goto(`/apps/${appId}/evaluations?selectedEvaluation=human_annotation`) + await expect(page).toHaveURL( + `/apps/${appId}/evaluations?selectedEvaluation=human_annotation`, + ) + + const evaluationRunsResponse = await waitForApiResponse<{ + runs: SnakeToCamelCaseKeys[] + count: number + }>(page, { + route: `/api/preview/evaluations/runs/query`, + method: "POST", + }) + + const evaluationRuns = await evaluationRunsResponse + + expect(Array.isArray(evaluationRuns.runs)).toBe(true) + + await expect(page.locator("span").filter({hasText: /^Evaluations$/})).toBeVisible() + + await uiHelpers.clickTab("Human annotation") + + if (evaluationRunsResponse.runs.length > 0) { + await page.locator(".ant-checkbox").first().click() + + // click delete button + await uiHelpers.clickButton("Delete") + + // confirm delete in modal + await uiHelpers.confirmModal("Delete") + } + + await expect(evaluationRunsResponse.runs.length).toBe(0) + + await expect( + page.locator(".ant-btn-primary", {hasText: "Start new evaluation"}).first(), + ).toBeVisible() + }) + }, + + navigateToHumanAnnotationRun: async ({page, uiHelpers, apiHelpers}, use) => { + await use(async (appId: string) => { + await page.goto(`/apps/${appId}/evaluations?selectedEvaluation=human_annotation`) + await expect(page).toHaveURL( + `/apps/${appId}/evaluations?selectedEvaluation=human_annotation`, + ) + + const runs = await apiHelpers.getEvaluationRuns() + + await expect(page.locator("span").filter({hasText: /^Evaluations$/})).toBeVisible() + + await uiHelpers.clickTab("Human annotation") + + await page.locator(`tr[data-row-key="${runs[0].id}"]`).click() + + await expect(page).toHaveURL( + new RegExp(`/apps/${appId}/evaluations/single_model_test/${runs[0].id}(\\?|$)`), + ) + + await expect(page.locator("h4").filter({hasText: runs[0].name})).toBeVisible() + }) + }, + + createHumanEvaluationRun: async ({page, uiHelpers}, use) => { + await use(async (config: HumanEvaluationConfig) => { + await uiHelpers.clickButton("Start new evaluation") + const modal = page.locator(".ant-modal").first() + await expect(modal).toBeVisible() + + const goToStep = async (step: string) => { + await modal.getByRole("tab", {name: step}).click() + } + + await uiHelpers.typeWithDelay('input[placeholder="Enter a name"]', config.name) + + await goToStep("Testset") + await uiHelpers.selectTableRowInput({ + rowText: config.testset, + inputType: "radio", + checked: true, + }) + + await goToStep("Variant") + const variantRow = page.getByRole("row").filter({ + has: page + .locator("td", {hasText: config.variants}) + .locator(".ant-tag", {hasText: "v1"}), + }) + + await expect(variantRow).toBeVisible() + await variantRow.getByRole("radio").check() + + await goToStep("Evaluator") + + const evaluatorName = "evaluator_test" + + if (!config.skipEvaluatorCreation) { + await uiHelpers.clickButton("Create new") + const evalDrawer = page.locator(".ant-drawer-content") + await expect(evalDrawer).toBeVisible() + await expect(evalDrawer).toContainText("Create new evaluator") + + await uiHelpers.typeWithDelay("#evaluatorName", evaluatorName) + await expect(page.locator("#evaluatorSlug")).toHaveValue(evaluatorName) + + await uiHelpers.typeWithDelay("#metrics_0_name", "isTestWorking") + + await page.locator(".ant-select").click() + + const dropdownOption = page.locator('div[title="Boolean (True/False)"]') + await expect(dropdownOption).toBeVisible() + + await dropdownOption.click() + + await uiHelpers.clickButton("Save") + + await expect(evalDrawer).toHaveCount(0) + + const successMessage = page + .locator(".ant-message") + .getByText("Evaluator created successfully") + await expect(successMessage).toBeVisible() + } + + await uiHelpers.selectTableRowInput({ + rowText: evaluatorName, + inputType: "checkbox", + checked: true, + }) + + await expect + .poll(async () => { + return await page.locator(".ant-tabs-nav-list .ant-tag").count() + }) + .toBe(3) + + const createButton = modal.getByRole("button", {name: "Create"}).last() + await createButton.click() + await expect(createButton).toHaveClass(/ant-btn-loading/) + }) + }, + + verifyStatusUpdate: async ({page, uiHelpers}, use) => { + await use(async (row: Locator) => { + await expect(row.locator(".ant-table-cell").nth(1)).toHaveText(/Running|Incomplete/) + await expect(row.getByRole("button", {name: "Annotate"})).toBeVisible() + }) + }, + + switchToTableView: async ({page, uiHelpers}, use) => { + await use(async () => { + await page.locator(".ant-radio-button-wrapper", {hasText: "Table View"}).click() + await expect(page).toHaveURL(/view=table/) + }) + }, + + runScenarioFromFocusView: async ({page, uiHelpers}, use) => { + await use(async () => { + await expect(page.locator("span").filter({hasText: "Pending"})).toBeVisible() + await page.getByRole("button", {name: "Run Scenario"}).first().click() + await expect(page.locator("span").filter({hasText: "Running"})).toBeVisible() + await expect(page.locator("span").filter({hasText: "Incomplete"}).first()).toBeVisible() + }) + }, + + annotateFromFocusView: async ({page}, use) => { + await use(async () => { + const collapseBox = page.locator(".ant-collapse-content-box") + await expect(collapseBox.getByText("isTestWorking")).toBeVisible() + + await collapseBox.locator(".ant-radio-button-wrapper").first().click() + + const annotateBtn = page.getByRole("button", {name: "Annotate"}) + await expect(annotateBtn).toBeEnabled() + + await annotateBtn.click() + + await expect(page.locator("span", {hasText: "Annotating"}).first()).toBeVisible() + + await expect(page.locator("span", {hasText: "Success"})).toHaveCount(2) + }) + }, + + annotateFromTableView: async ({page}, use) => { + await use(async () => { + const row = page.locator(".ant-table-row").first() + + await row.getByRole("button", {name: "Annotate"}).click() + + const drawer = page.locator(".ant-drawer-content") + await expect(drawer).toBeVisible() + await expect(drawer).toContainText("Annotate scenario") + await expect(drawer.getByText("isTestWorking")).toBeVisible() + + await drawer.locator(".ant-radio-button-wrapper").first().click() + + const annotateBtn = drawer.getByRole("button", {name: "Annotate"}) + await expect(annotateBtn).toBeEnabled() + await annotateBtn.click() + + await expect(drawer).toHaveCount(0) + }) + }, + + navigateBetweenScenarios: async ({page}, use) => { + await use(async () => { + const prevBtn = page.getByRole("button", {name: "Prev"}) + const nextBtn = page.getByRole("button", {name: "Next"}) + + // Initial state + await expect(prevBtn).toBeDisabled() + await expect(nextBtn).toBeEnabled() + + // Navigate: 1 → 2 + await expect(page.locator('span[title="Test case: 1"]').first()).toBeVisible() + await nextBtn.click() + await expect(page.locator('span[title="Test case: 2"]').first()).toBeVisible() + + // Navigate: 2 → 3 + await nextBtn.click() + await expect(page.locator('span[title="Test case: 3"]').first()).toBeVisible() + + // Backward: 3 → 2 + await prevBtn.click() + await expect(page.locator('span[title="Test case: 2"]').first()).toBeVisible() + + // Backward: 2 → 1 + await prevBtn.click() + await expect(page.locator('span[title="Test case: 1"]').first()).toBeVisible() + }) + }, +}) + +export {testWithHumanFixtures as test, expect} diff --git a/web/ee/tsconfig.json b/web/ee/tsconfig.json new file mode 100644 index 0000000000..3ded438dae --- /dev/null +++ b/web/ee/tsconfig.json @@ -0,0 +1,12 @@ +{ + "extends": "../oss/tsconfig.json", + "compilerOptions": { + "baseUrl": "..", + "paths": { + "@/oss/*": ["./ee/src/*", "./oss/src/*"], + "@/agenta-oss-common/*": ["./ee/src/*", "./oss/src/*"] + } + }, + "include": ["next-env.d.ts", "**/*.d.ts", "**/*.ts", "**/*.tsx"], + "exclude": ["node_modules"] +} diff --git a/web/oss/package.json b/web/oss/package.json index 970fb92b2b..bf2e603365 100644 --- a/web/oss/package.json +++ b/web/oss/package.json @@ -1,6 +1,6 @@ { "name": "@agenta/oss", - "version": "0.57.2", + "version": "0.58.0", "private": true, "engines": { "node": ">=18" diff --git a/web/package.json b/web/package.json index 6a31901223..641cbc9e9f 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "agenta-web", - "version": "0.57.2", + "version": "0.58.0", "workspaces": [ "ee", "oss",