diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb index eb9db753..091200d7 100644 --- a/docs/colab_notebooks/1-the-basics.ipynb +++ b/docs/colab_notebooks/1-the-basics.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "39d7d274", + "id": "3599c474", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: The Basics\n", @@ -14,7 +14,7 @@ }, { "cell_type": "markdown", - "id": "60f1d002", + "id": "ee8bed13", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -25,7 +25,7 @@ { "cell_type": "code", "execution_count": null, - "id": "99c42292", + "id": "f43069d1", "metadata": {}, "outputs": [], "source": [ @@ -36,7 +36,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2c959ca9", + "id": "c136bf4f", "metadata": {}, "outputs": [], "source": [ @@ -53,7 +53,7 @@ }, { "cell_type": "markdown", - "id": "bc185897", + "id": "48739393", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -64,15 +64,15 @@ { "cell_type": "code", "execution_count": null, - "id": "dc3a2d9d", + "id": "e459cd98", "metadata": {}, "outputs": [], "source": [ "from data_designer.essentials import (\n", " CategorySamplerParams,\n", + " ChatCompletionInferenceParams,\n", " DataDesigner,\n", " DataDesignerConfigBuilder,\n", - " InferenceParameters,\n", " LLMTextColumnConfig,\n", " ModelConfig,\n", " PersonFromFakerSamplerParams,\n", @@ -85,7 +85,7 @@ }, { "cell_type": "markdown", - "id": "36c5f571", + "id": "b705d204", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -98,7 +98,7 @@ { "cell_type": "code", "execution_count": null, - "id": "61b23c70", + "id": "aee62c85", "metadata": {}, "outputs": [], "source": [ @@ -107,7 +107,7 @@ }, { "cell_type": "markdown", - "id": "3c9b7cb6", + "id": "ae65c557", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -124,7 +124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b86f6217", + "id": "1079200d", "metadata": {}, "outputs": [], "source": [ @@ -145,7 +145,7 @@ " alias=MODEL_ALIAS,\n", " model=MODEL_ID,\n", " provider=MODEL_PROVIDER,\n", - " inference_parameters=InferenceParameters(\n", + " inference_parameters=ChatCompletionInferenceParams(\n", " temperature=0.5,\n", " top_p=1.0,\n", " max_tokens=1024,\n", @@ -156,7 +156,7 @@ }, { "cell_type": "markdown", - "id": "1f089871", + "id": "9f15426a", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -171,7 +171,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3d666193", + "id": "79b8212c", "metadata": {}, "outputs": [], "source": [ @@ -180,7 +180,7 @@ }, { "cell_type": "markdown", - "id": "e88c8881", + "id": "cd1d9e09", "metadata": {}, "source": [ "## 🎲 Getting started with sampler columns\n", @@ -197,7 +197,7 @@ { "cell_type": "code", "execution_count": null, - "id": "79fb85c6", + "id": "b3f469d6", "metadata": {}, "outputs": [], "source": [ @@ -206,7 +206,7 @@ }, { "cell_type": "markdown", - "id": "5106cc10", + "id": "e44adc6c", "metadata": {}, "source": [ "Let's start designing our product review dataset by adding product category and subcategory columns.\n" @@ -215,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22b97af1", + "id": "82b32804", "metadata": {}, "outputs": [], "source": [ @@ -296,7 +296,7 @@ }, { "cell_type": "markdown", - "id": "4857b085", + "id": "bd65456c", "metadata": {}, "source": [ "Next, let's add samplers to generate data related to the customer and their review.\n" @@ -305,7 +305,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9e90b3cb", + "id": "6d6d4eef", "metadata": {}, "outputs": [], "source": [ @@ -342,7 +342,7 @@ }, { "cell_type": "markdown", - "id": "b36a153b", + "id": "eb7b415c", "metadata": {}, "source": [ "## 🦜 LLM-generated columns\n", @@ -357,7 +357,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4da88fe6", + "id": "ed811560", "metadata": {}, "outputs": [], "source": [ @@ -394,7 +394,7 @@ }, { "cell_type": "markdown", - "id": "5f1b9ac8", + "id": "fdc0a2c8", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -411,7 +411,7 @@ { "cell_type": "code", "execution_count": null, - "id": "543e2f9c", + "id": "59987c81", "metadata": {}, "outputs": [], "source": [ @@ -421,7 +421,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26136a8a", + "id": "0823ca7f", "metadata": {}, "outputs": [], "source": [ @@ -432,7 +432,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aca4360d", + "id": "eca4f0bc", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +442,7 @@ }, { "cell_type": "markdown", - "id": "35ca0470", + "id": "edd57f85", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -455,7 +455,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d55b402d", + "id": "5c681eee", "metadata": {}, "outputs": [], "source": [ @@ -465,7 +465,7 @@ }, { "cell_type": "markdown", - "id": "245b48cf", + "id": "14bf06f2", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -478,7 +478,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fc803eb0", + "id": "b7ffead1", "metadata": {}, "outputs": [], "source": [ @@ -488,7 +488,7 @@ { "cell_type": "code", "execution_count": null, - "id": "881c2043", + "id": "aa966388", "metadata": {}, "outputs": [], "source": [ @@ -501,7 +501,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d79860d4", + "id": "98e1085c", "metadata": {}, "outputs": [], "source": [ @@ -513,7 +513,7 @@ }, { "cell_type": "markdown", - "id": "b4b45176", + "id": "e0b9c65a", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb index 0cb65da6..fd9a2d69 100644 --- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb +++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "33b48b4e", + "id": "928a8d4a", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Structured Outputs and Jinja Expressions\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "c29f9af1", + "id": "ad16de35", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -27,7 +27,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3a5601fb", + "id": "cfb08d4c", "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ { "cell_type": "code", "execution_count": null, - "id": "de2f0af4", + "id": "eddeceea", "metadata": {}, "outputs": [], "source": [ @@ -55,7 +55,7 @@ }, { "cell_type": "markdown", - "id": "400795be", + "id": "08e2b3bf", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -66,16 +66,16 @@ { "cell_type": "code", "execution_count": null, - "id": "378f6853", + "id": "f057319b", "metadata": {}, "outputs": [], "source": [ "from data_designer.essentials import (\n", " CategorySamplerParams,\n", + " ChatCompletionInferenceParams,\n", " DataDesigner,\n", " DataDesignerConfigBuilder,\n", " ExpressionColumnConfig,\n", - " InferenceParameters,\n", " LLMStructuredColumnConfig,\n", " ModelConfig,\n", " PersonFromFakerSamplerParams,\n", @@ -87,7 +87,7 @@ }, { "cell_type": "markdown", - "id": "15a1ac9f", + "id": "e7d5e529", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -100,7 +100,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9d7654e5", + "id": "e30fdeb1", "metadata": {}, "outputs": [], "source": [ @@ -109,7 +109,7 @@ }, { "cell_type": "markdown", - "id": "27ba0edb", + "id": "07b8bfe7", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -126,7 +126,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c24ee00e", + "id": "4ae518af", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ " alias=MODEL_ALIAS,\n", " model=MODEL_ID,\n", " provider=MODEL_PROVIDER,\n", - " inference_parameters=InferenceParameters(\n", + " inference_parameters=ChatCompletionInferenceParams(\n", " temperature=0.5,\n", " top_p=1.0,\n", " max_tokens=1024,\n", @@ -158,7 +158,7 @@ }, { "cell_type": "markdown", - "id": "a106edc9", + "id": "a3fa2eaf", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -173,7 +173,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3a167f7c", + "id": "e1b59bd4", "metadata": {}, "outputs": [], "source": [ @@ -182,7 +182,7 @@ }, { "cell_type": "markdown", - "id": "fcf68c72", + "id": "5e220f78", "metadata": {}, "source": [ "### πŸ§‘β€πŸŽ¨ Designing our data\n", @@ -209,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8f8f034e", + "id": "e814ca47", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "9d5c722a", + "id": "303ebd6f", "metadata": {}, "source": [ "Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n" @@ -246,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "013caa3d", + "id": "361afe50", "metadata": {}, "outputs": [], "source": [ @@ -355,7 +355,7 @@ }, { "cell_type": "markdown", - "id": "ef426a65", + "id": "b18e511d", "metadata": {}, "source": [ "Next, we will use more advanced Jinja expressions to create new columns.\n", @@ -372,7 +372,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27abbd6d", + "id": "e669c009", "metadata": {}, "outputs": [], "source": [ @@ -426,7 +426,7 @@ }, { "cell_type": "markdown", - "id": "18a8461e", + "id": "9ca31fb8", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -443,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2f75eee6", + "id": "402dede2", "metadata": {}, "outputs": [], "source": [ @@ -453,7 +453,7 @@ { "cell_type": "code", "execution_count": null, - "id": "950c9596", + "id": "c028223a", "metadata": {}, "outputs": [], "source": [ @@ -464,7 +464,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5c04ca5a", + "id": "32d6fdaa", "metadata": {}, "outputs": [], "source": [ @@ -474,7 +474,7 @@ }, { "cell_type": "markdown", - "id": "b0704f47", + "id": "3b76da46", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -487,7 +487,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54a609ce", + "id": "1a9ac292", "metadata": {}, "outputs": [], "source": [ @@ -497,7 +497,7 @@ }, { "cell_type": "markdown", - "id": "6ae0a8a5", + "id": "8e4b20ed", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -510,7 +510,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e22c387a", + "id": "eed6e04b", "metadata": {}, "outputs": [], "source": [ @@ -520,7 +520,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9c6b36b3", + "id": "b6d1d270", "metadata": {}, "outputs": [], "source": [ @@ -533,7 +533,7 @@ { "cell_type": "code", "execution_count": null, - "id": "138ed487", + "id": "058d6b65", "metadata": {}, "outputs": [], "source": [ @@ -545,7 +545,7 @@ }, { "cell_type": "markdown", - "id": "fde73253", + "id": "d3affdac", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb index e623ac8b..af97af89 100644 --- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb +++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "5a6b2b3f", + "id": "ce777a5d", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "137d8273", + "id": "442b70e6", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -27,7 +27,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f6ce3dc2", + "id": "9ee97c1e", "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70d6ffc8", + "id": "18c06754", "metadata": {}, "outputs": [], "source": [ @@ -55,7 +55,7 @@ }, { "cell_type": "markdown", - "id": "ce0313c9", + "id": "9394bcd4", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -66,14 +66,14 @@ { "cell_type": "code", "execution_count": null, - "id": "2aa2cdf1", + "id": "6f1099d1", "metadata": {}, "outputs": [], "source": [ "from data_designer.essentials import (\n", + " ChatCompletionInferenceParams,\n", " DataDesigner,\n", " DataDesignerConfigBuilder,\n", - " InferenceParameters,\n", " ModelConfig,\n", " SeedConfig,\n", ")" @@ -81,7 +81,7 @@ }, { "cell_type": "markdown", - "id": "9769f392", + "id": "bc74d436", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -94,7 +94,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d79db916", + "id": "0f9c59fa", "metadata": {}, "outputs": [], "source": [ @@ -103,7 +103,7 @@ }, { "cell_type": "markdown", - "id": "08dd3894", + "id": "c92d4c3c", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -120,7 +120,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3994368e", + "id": "543805b1", "metadata": {}, "outputs": [], "source": [ @@ -141,7 +141,7 @@ " alias=MODEL_ALIAS,\n", " model=MODEL_ID,\n", " provider=MODEL_PROVIDER,\n", - " inference_parameters=InferenceParameters(\n", + " inference_parameters=ChatCompletionInferenceParams(\n", " temperature=0.5,\n", " top_p=1.0,\n", " max_tokens=1024,\n", @@ -152,7 +152,7 @@ }, { "cell_type": "markdown", - "id": "5f12d6d2", + "id": "29f69761", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -167,7 +167,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f3d7640", + "id": "7b18a399", "metadata": {}, "outputs": [], "source": [ @@ -176,7 +176,7 @@ }, { "cell_type": "markdown", - "id": "1f08df99", + "id": "5439e926", "metadata": {}, "source": [ "## πŸ₯ Prepare a seed dataset\n", @@ -201,7 +201,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f265e74c", + "id": "d4cbb09e", "metadata": {}, "outputs": [], "source": [ @@ -219,7 +219,7 @@ }, { "cell_type": "markdown", - "id": "6bffa239", + "id": "367fd06d", "metadata": {}, "source": [ "## 🎨 Designing our synthetic patient notes dataset\n", @@ -236,7 +236,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15e486a6", + "id": "1cbcf034", "metadata": {}, "outputs": [], "source": [ @@ -326,7 +326,7 @@ }, { "cell_type": "markdown", - "id": "5cfe2edd", + "id": "ef3d8dcf", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -343,7 +343,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4a59576", + "id": "6aa843f8", "metadata": {}, "outputs": [], "source": [ @@ -353,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1c5aedd4", + "id": "2134a59d", "metadata": {}, "outputs": [], "source": [ @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d17df0a5", + "id": "3c3c39bf", "metadata": {}, "outputs": [], "source": [ @@ -374,7 +374,7 @@ }, { "cell_type": "markdown", - "id": "3389a088", + "id": "bf4b07b5", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -387,7 +387,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b0443498", + "id": "e7813016", "metadata": {}, "outputs": [], "source": [ @@ -397,7 +397,7 @@ }, { "cell_type": "markdown", - "id": "0527a606", + "id": "db2d39b6", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -410,7 +410,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2118b49e", + "id": "2fcc5cb3", "metadata": {}, "outputs": [], "source": [ @@ -420,7 +420,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c4f9ad59", + "id": "036581fa", "metadata": {}, "outputs": [], "source": [ @@ -433,7 +433,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8517866d", + "id": "74e91fe3", "metadata": {}, "outputs": [], "source": [ @@ -445,7 +445,7 @@ }, { "cell_type": "markdown", - "id": "b62dd069", + "id": "6f65cadc", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index e48f2bde..ddc6488a 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "fc93d603", + "id": "560882a6", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "31146c45", + "id": "b443ea18", "metadata": {}, "source": [ "#### πŸ“š What you'll learn\n", @@ -25,7 +25,7 @@ }, { "cell_type": "markdown", - "id": "b237af81", + "id": "d59495b4", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -36,7 +36,7 @@ { "cell_type": "code", "execution_count": null, - "id": "00b316a6", + "id": "e80db6fd", "metadata": {}, "outputs": [], "source": [ @@ -47,7 +47,7 @@ { "cell_type": "code", "execution_count": null, - "id": "847a6f88", + "id": "15ac4714", "metadata": {}, "outputs": [], "source": [ @@ -64,7 +64,7 @@ }, { "cell_type": "markdown", - "id": "5a0e2a31", + "id": "01741037", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -75,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7ec632f1", + "id": "097a9059", "metadata": {}, "outputs": [], "source": [ @@ -93,11 +93,11 @@ "\n", "# Data Designer imports\n", "from data_designer.essentials import (\n", + " ChatCompletionInferenceParams,\n", " DataDesigner,\n", " DataDesignerConfigBuilder,\n", " ImageContext,\n", " ImageFormat,\n", - " InferenceParameters,\n", " LLMTextColumnConfig,\n", " ModalityDataType,\n", " ModelConfig,\n", @@ -106,7 +106,7 @@ }, { "cell_type": "markdown", - "id": "66efe0cc", + "id": "4f6ca947", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -119,7 +119,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e9059d31", + "id": "6508f2b7", "metadata": {}, "outputs": [], "source": [ @@ -128,7 +128,7 @@ }, { "cell_type": "markdown", - "id": "26d60e67", + "id": "7343463e", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -145,7 +145,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5b29b57", + "id": "31019979", "metadata": {}, "outputs": [], "source": [ @@ -157,7 +157,7 @@ " alias=\"vision\",\n", " model=\"meta/llama-4-scout-17b-16e-instruct\",\n", " provider=MODEL_PROVIDER,\n", - " inference_parameters=InferenceParameters(\n", + " inference_parameters=ChatCompletionInferenceParams(\n", " temperature=0.60,\n", " top_p=0.95,\n", " max_tokens=2048,\n", @@ -168,7 +168,7 @@ }, { "cell_type": "markdown", - "id": "c7c68fce", + "id": "186f9f98", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -183,7 +183,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ab84fb1", + "id": "80fbc002", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +192,7 @@ }, { "cell_type": "markdown", - "id": "bdc1fa29", + "id": "203663e2", "metadata": {}, "source": [ "### 🌱 Seed Dataset Creation\n", @@ -209,7 +209,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3baa7ba2", + "id": "1a2b0ce1", "metadata": {}, "outputs": [], "source": [ @@ -224,7 +224,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f3780aee", + "id": "8cc0f6ea", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +272,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5c46877e", + "id": "9fe4b02c", "metadata": {}, "outputs": [], "source": [ @@ -290,7 +290,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4ca338da", + "id": "e546cc4a", "metadata": {}, "outputs": [], "source": [ @@ -300,7 +300,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0cb41a37", + "id": "eebc9963", "metadata": {}, "outputs": [], "source": [ @@ -314,7 +314,7 @@ { "cell_type": "code", "execution_count": null, - "id": "607bb265", + "id": "99b4a3b5", "metadata": { "lines_to_next_cell": 2 }, @@ -343,7 +343,7 @@ }, { "cell_type": "markdown", - "id": "984241b9", + "id": "58ab7a96", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "eca1a5ea", + "id": "7d735a85", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -368,7 +368,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4d386a38", + "id": "6404dc34", "metadata": {}, "outputs": [], "source": [ @@ -378,7 +378,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6209d609", + "id": "44f3678c", "metadata": {}, "outputs": [], "source": [ @@ -389,7 +389,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5aedc7fd", + "id": "06009be0", "metadata": {}, "outputs": [], "source": [ @@ -399,7 +399,7 @@ }, { "cell_type": "markdown", - "id": "fc339219", + "id": "43e9af07", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -412,7 +412,7 @@ { "cell_type": "code", "execution_count": null, - "id": "87ccc372", + "id": "7bce2b2e", "metadata": {}, "outputs": [], "source": [ @@ -422,7 +422,7 @@ }, { "cell_type": "markdown", - "id": "c090f413", + "id": "dc361acc", "metadata": {}, "source": [ "### πŸ”Ž Visual Inspection\n", @@ -433,7 +433,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1e1a054c", + "id": "574bbc62", "metadata": { "lines_to_next_cell": 2 }, @@ -457,7 +457,7 @@ }, { "cell_type": "markdown", - "id": "cab83636", + "id": "cc5ebe1c", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -470,7 +470,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7fa66a2e", + "id": "305f90ec", "metadata": {}, "outputs": [], "source": [ @@ -480,7 +480,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8f92b5be", + "id": "ef40fc6b", "metadata": {}, "outputs": [], "source": [ @@ -493,7 +493,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d1bddcee", + "id": "72756cf3", "metadata": {}, "outputs": [], "source": [ @@ -505,7 +505,7 @@ }, { "cell_type": "markdown", - "id": "46f68f95", + "id": "e9bac314", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/concepts/columns.md b/docs/concepts/columns.md index cfe0aab9..17123299 100644 --- a/docs/concepts/columns.md +++ b/docs/concepts/columns.md @@ -64,6 +64,22 @@ Define scoring rubrics (relevance, accuracy, fluency, helpfulness) and the judge Use judge columns for data quality filtering (e.g., keep only 4+ rated responses), A/B testing generation strategies, and quality monitoring over time. +### 🧬 Embedding Columns + +Embedding columns generate vector embeddings (numerical representations) for text content using embedding models. These embeddings capture semantic meaning, enabling similarity search, clustering, and semantic analysis. + +Specify a `target_column` containing text, and Data Designer generates embeddings for that content. The target column can contain either a single text string or a list of text strings in stringified JSON format. In the latter case, embeddings are generated for each text string in the list. + +Common use cases: + +- **Semantic search**: Generate embeddings for documents, then find similar content by vector similarity +- **Clustering**: Group similar texts based on embedding proximity +- **Recommendation systems**: Match content by semantic similarity +- **Anomaly detection**: Identify outliers in embedding space + +!!! note "Embedding Models" + Embedding columns require an embedding model configured with `EmbeddingInferenceParams`. These models differ from chat completion modelsβ€”they output vectors rather than text. The generation type is automatically determined by the inference parameters type. + ### 🧩 Expression Columns Expression columns handle simple transformations using **Jinja2 templates**β€”concatenate first and last names, calculate numerical totals, format date strings. No LLM overhead needed. diff --git a/docs/concepts/models/configure-model-settings-with-the-cli.md b/docs/concepts/models/configure-model-settings-with-the-cli.md index fa29732e..56315cdf 100644 --- a/docs/concepts/models/configure-model-settings-with-the-cli.md +++ b/docs/concepts/models/configure-model-settings-with-the-cli.md @@ -130,6 +130,6 @@ The CLI will show which configuration files exist and ask for confirmation befor ## See Also - **[Model Providers](model-providers.md)**: Learn about the `ModelProvider` class and provider configuration -- **[Model Configurations](model-configs.md)**: Learn about `ModelConfig` and `InferenceParameters` +- **[Model Configurations](model-configs.md)**: Learn about `ModelConfig` - **[Default Model Settings](default-model-settings.md)**: Pre-configured providers and model settings included with Data Designer - **[Quick Start Guide](../../quick-start.md)**: Get started with a simple example diff --git a/docs/concepts/models/inference-parameters.md b/docs/concepts/models/inference-parameters.md new file mode 100644 index 00000000..6d7d079f --- /dev/null +++ b/docs/concepts/models/inference-parameters.md @@ -0,0 +1,147 @@ +# Inference Parameters + +Inference parameters control how models generate responses during synthetic data generation. Data Designer provides two types of inference parameters: `ChatCompletionInferenceParams` for text/code/structured generation and `EmbeddingInferenceParams` for embedding generation. + +## Overview + +When you create a `ModelConfig`, you can specify inference parameters to adjust model behavior. These parameters control aspects like randomness (temperature), diversity (top_p), context size (max_tokens), and more. Data Designer supports both static values and dynamic distribution-based sampling for certain parameters. + +## Chat Completion Inference Parameters + +The `ChatCompletionInferenceParams` class controls how models generate text completions (for text, code, and structured data generation). It provides fine-grained control over generation behavior and supports both static values and dynamic distribution-based sampling. + +!!! warning "InferenceParameters is Deprecated" + The `InferenceParameters` class is deprecated and will be removed in a future version. Use `ChatCompletionInferenceParams` instead. The old `InferenceParameters` class now shows a deprecation warning when used. + +### Fields + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `temperature` | `float` or `Distribution` | No | Controls randomness in generation (0.0 to 2.0). Higher values = more creative/random | +| `top_p` | `float` or `Distribution` | No | Nucleus sampling parameter (0.0 to 1.0). Controls diversity by filtering low-probability tokens | +| `max_tokens` | `int` | No | Maximum number of tokens for the request, including both input and output tokens (β‰₯ 1) | +| `max_parallel_requests` | `int` | No | Maximum concurrent API requests (default: 4, β‰₯ 1) | +| `timeout` | `int` | No | API request timeout in seconds (β‰₯ 1) | +| `extra_body` | `dict[str, Any]` | No | Additional parameters to include in the API request body | + +!!! note "Default Values" + If `temperature`, `top_p`, or `max_tokens` are not provided, the model provider's default values will be used. Different providers and models may have different defaults. + +!!! tip "Controlling Reasoning Effort for GPT-OSS Models" + For gpt-oss models like `gpt-oss-20b` and `gpt-oss-120b`, you can control the reasoning effort using the `extra_body` parameter: + + ```python + from data_designer.essentials import ChatCompletionInferenceParams + + # High reasoning effort (more thorough, slower) + inference_parameters = ChatCompletionInferenceParams( + extra_body={"reasoning_effort": "high"} + ) + + # Medium reasoning effort (balanced) + inference_parameters = ChatCompletionInferenceParams( + extra_body={"reasoning_effort": "medium"} + ) + + # Low reasoning effort (faster, less thorough) + inference_parameters = ChatCompletionInferenceParams( + extra_body={"reasoning_effort": "low"} + ) + ``` + +### Temperature and Top P Guidelines + +- **Temperature**: + - `0.0-0.3`: Highly deterministic, focused outputs (ideal for structured/reasoning tasks) + - `0.4-0.7`: Balanced creativity and coherence (general purpose) + - `0.8-1.0`: Creative, diverse outputs (ideal for creative writing) + - `1.0+`: Highly random and experimental + +- **Top P**: + - `0.1-0.5`: Very focused, only most likely tokens + - `0.6-0.9`: Balanced diversity + - `0.95-1.0`: Maximum diversity, including less likely tokens + +!!! tip "Adjusting Temperature and Top P Together" + When tuning both parameters simultaneously, consider these combinations: + + - **For deterministic/structured outputs**: Low temperature (`0.0-0.3`) + moderate-to-high top_p (`0.8-0.95`) + - The low temperature ensures focus, while top_p allows some token diversity + - **For balanced generation**: Moderate temperature (`0.5-0.7`) + high top_p (`0.9-0.95`) + - This is a good starting point for most use cases + - **For creative outputs**: Higher temperature (`0.8-1.0`) + high top_p (`0.95-1.0`) + - Both parameters work together to maximize diversity + + **Avoid**: Setting both very low (overly restrictive) or adjusting both dramatically at once. When experimenting, adjust one parameter at a time to understand its individual effect. + +## Distribution-Based Inference Parameters + +For `temperature` and `top_p` in `ChatCompletionInferenceParams`, you can specify distributions instead of fixed values. This allows Data Designer to sample different values for each generation request, introducing controlled variability into your synthetic data. + +### Uniform Distribution + +Samples values uniformly between a low and high bound: + +```python +from data_designer.essentials import ( + ChatCompletionInferenceParams, + UniformDistribution, + UniformDistributionParams, +) + +inference_params = ChatCompletionInferenceParams( + temperature=UniformDistribution( + params=UniformDistributionParams(low=0.7, high=1.0) + ), +) +``` + +### Manual Distribution + +Samples from a discrete set of values with optional weights: + +```python +from data_designer.essentials import ( + ChatCompletionInferenceParams, + ManualDistribution, + ManualDistributionParams, +) + +# Equal probability for each value +inference_params = ChatCompletionInferenceParams( + temperature=ManualDistribution( + params=ManualDistributionParams(values=[0.5, 0.7, 0.9]) + ), +) + +# Weighted probabilities (normalized automatically) +inference_params = ChatCompletionInferenceParams( + top_p=ManualDistribution( + params=ManualDistributionParams( + values=[0.8, 0.9, 0.95], + weights=[0.2, 0.5, 0.3] # 20%, 50%, 30% probability + ) + ), +) +``` + +## Embedding Inference Parameters + +The `EmbeddingInferenceParams` class controls how models generate embeddings. This is used when working with embedding models for tasks like semantic search or similarity analysis. + +### Fields + +| Field | Type | Required | Description | +|-------|------|----------|-------------| +| `encoding_format` | `Literal["float", "base64"]` | No | Format of the embedding encoding (default: "float") | +| `dimensions` | `int` | No | Number of dimensions for the embedding | +| `max_parallel_requests` | `int` | No | Maximum concurrent API requests (default: 4, β‰₯ 1) | +| `timeout` | `int` | No | API request timeout in seconds (β‰₯ 1) | +| `extra_body` | `dict[str, Any]` | No | Additional parameters to include in the API request body | + + +## See Also + +- **[Model Configurations](model-configs.md)**: Learn about configuring model settings +- **[Model Providers](model-providers.md)**: Learn about configuring model providers +- **[Default Model Settings](default-model-settings.md)**: Pre-configured model settings included with Data Designer diff --git a/docs/concepts/models/model-configs.md b/docs/concepts/models/model-configs.md index fd87b874..c78aa2ee 100644 --- a/docs/concepts/models/model-configs.md +++ b/docs/concepts/models/model-configs.md @@ -14,136 +14,23 @@ The `ModelConfig` class has the following fields: |-------|------|----------|-------------| | `alias` | `str` | Yes | Unique identifier for this model configuration (e.g., `"my-text-model"`, `"reasoning-model"`) | | `model` | `str` | Yes | Model identifier as recognized by the provider (e.g., `"nvidia/nvidia-nemotron-nano-9b-v2"`, `"gpt-4"`) | -| `inference_parameters` | `InferenceParameters` | No | Controls model behavior during generation (temperature, top_p, max_tokens, etc). Defaults from constructing an empty `InferenceParameters` object are picked up when not provided.| +| `inference_parameters` | `InferenceParamsT` | No | Controls model behavior during generation. Use `ChatCompletionInferenceParams` for text/code/structured generation or `EmbeddingInferenceParams` for embeddings. Defaults to `ChatCompletionInferenceParams()` if not provided. The generation type is automatically determined by the inference parameters type. See [Inference Parameters](inference_parameters.md) for details. | | `provider` | `str` | No | Reference to the name of the Provider to use (e.g., `"nvidia"`, `"openai"`). If not specified, one set as the default provider, which may resolve to the first provider if there are more than one | -## InferenceParameters - -The `InferenceParameters` class controls how the model generates responses. It provides fine-grained control over generation behavior and supports both static values and dynamic distribution-based sampling. - -### Fields - -| Field | Type | Required | Description | -|-------|------|----------|-------------| -| `temperature` | `float` or `Distribution` | No | Controls randomness in generation (0.0 to 2.0). Higher values = more creative/random | -| `top_p` | `float` or `Distribution` | No | Nucleus sampling parameter (0.0 to 1.0). Controls diversity by filtering low-probability tokens | -| `max_tokens` | `int` | No | Maximum number of tokens for the request, including both input and output tokens (β‰₯ 1) | -| `max_parallel_requests` | `int` | No | Maximum concurrent API requests (default: 4, β‰₯ 1) | -| `timeout` | `int` | No | API request timeout in seconds (β‰₯ 1) | -| `extra_body` | `dict[str, Any]` | No | Additional parameters to include in the API request body | - -!!! note "Default Values" - If `temperature`, `top_p`, or `max_tokens` are not provided, the model provider's default values will be used. Different providers and models may have different defaults. - -!!! tip "Controlling Reasoning Effort for GPT-OSS Models" - For gpt-oss models like `gpt-oss-20b` and `gpt-oss-120b`, you can control the reasoning effort using the `extra_body` parameter: - - ```python - # High reasoning effort (more thorough, slower) - inference_parameters = InferenceParameters( - extra_body={"reasoning_effort": "high"} - ) - - # Medium reasoning effort (balanced) - inference_parameters = InferenceParameters( - extra_body={"reasoning_effort": "medium"} - ) - - # Low reasoning effort (faster, less thorough) - inference_parameters = InferenceParameters( - extra_body={"reasoning_effort": "low"} - ) - ``` - -### Temperature and Top P Guidelines - -- **Temperature**: - - `0.0-0.3`: Highly deterministic, focused outputs (ideal for structured/reasoning tasks) - - `0.4-0.7`: Balanced creativity and coherence (general purpose) - - `0.8-1.0`: Creative, diverse outputs (ideal for creative writing) - - `1.0+`: Highly random and experimental - -- **Top P**: - - `0.1-0.5`: Very focused, only most likely tokens - - `0.6-0.9`: Balanced diversity - - `0.95-1.0`: Maximum diversity, including less likely tokens - -!!! tip "Adjusting Temperature and Top P Together" - When tuning both parameters simultaneously, consider these combinations: - - - **For deterministic/structured outputs**: Low temperature (`0.0-0.3`) + moderate-to-high top_p (`0.8-0.95`) - - The low temperature ensures focus, while top_p allows some token diversity - - **For balanced generation**: Moderate temperature (`0.5-0.7`) + high top_p (`0.9-0.95`) - - This is a good starting point for most use cases - - **For creative outputs**: Higher temperature (`0.8-1.0`) + high top_p (`0.95-1.0`) - - Both parameters work together to maximize diversity - - **Avoid**: Setting both very low (overly restrictive) or adjusting both dramatically at once. When experimenting, adjust one parameter at a time to understand its individual effect. - -## Distribution-Based Inference Parameters - -For `temperature` and `top_p`, you can specify distributions instead of fixed values. This allows Data Designer to sample different values for each generation request, introducing controlled variability into your synthetic data. - -### Uniform Distribution - -Samples values uniformly between a low and high bound: - -```python -from data_designer.essentials import ( - InferenceParameters, - UniformDistribution, - UniformDistributionParams, -) - -inference_params = InferenceParameters( - temperature=UniformDistribution( - params=UniformDistributionParams(low=0.7, high=1.0) - ), -) -``` - -### Manual Distribution - -Samples from a discrete set of values with optional weights: - -```python -from data_designer.essentials import ( - InferenceParameters, - ManualDistribution, - ManualDistributionParams, -) - -# Equal probability for each value -inference_params = InferenceParameters( - temperature=ManualDistribution( - params=ManualDistributionParams(values=[0.5, 0.7, 0.9]) - ), -) - -# Weighted probabilities (normalized automatically) -inference_params = InferenceParameters( - top_p=ManualDistribution( - params=ManualDistributionParams( - values=[0.8, 0.9, 0.95], - weights=[0.2, 0.5, 0.3] # 20%, 50%, 30% probability - ) - ), -) -``` ## Examples ### Basic Model Configuration ```python -from data_designer.essentials import InferenceParameters, ModelConfig +from data_designer.essentials import ChatCompletionInferenceParams, ModelConfig # Simple model configuration with fixed parameters model_config = ModelConfig( alias="my-text-model", model="nvidia/nvidia-nemotron-nano-9b-v2", provider="nvidia", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.85, top_p=0.95, max_tokens=2048, @@ -154,7 +41,12 @@ model_config = ModelConfig( ### Multiple Model Configurations for Different Tasks ```python -from data_designer.essentials import InferenceParameters, ModelConfig +from data_designer.essentials import ( + ChatCompletionInferenceParams, + EmbeddingInferenceParams, + GenerationType, + ModelConfig +) model_configs = [ # Creative tasks @@ -162,7 +54,7 @@ model_configs = [ alias="creative-model", model="nvidia/nvidia-nemotron-nano-9b-v2", provider="nvidia", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.9, top_p=0.95, max_tokens=2048, @@ -173,7 +65,7 @@ model_configs = [ alias="critic-model", model="nvidia/nvidia-nemotron-nano-9b-v2", provider="nvidia", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.25, top_p=0.95, max_tokens=2048, @@ -184,7 +76,7 @@ model_configs = [ alias="reasoning-model", model="openai/gpt-oss-20b", provider="nvidia", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.3, top_p=0.9, max_tokens=4096, @@ -195,49 +87,33 @@ model_configs = [ alias="vision-model", model="nvidia/nemotron-nano-12b-v2-vl", provider="nvidia", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.7, top_p=0.95, max_tokens=2048, ), ), + # Embedding tasks + ModelConfig( + alias="embedding_model", + model=-"nvidia/llama-3.2-nv-embedqa-1b-v2", + provider="nvidia", + inference_parameters=EmbeddingInferenceParams( + encoding_format="float" + extra_body={ + "input_type": "query" + } + ) + ) ] ``` -!!! tip "Experiment with max_tokens for Task-Specific model configurations" +!!! tip "Experiment with max_tokens for Task-Specific Model Configurations" The number of tokens required to generate a single data entry can vary significantly with use case. For example, reasoning models often need more tokens to "think through" problems before generating a response. Note that `max_tokens` includes **both input and output tokens** (the total context window used), so factor in your prompt length, any context data, and the expected response length when setting this parameter. -### Using Distribution-Based Parameters - -```python -from data_designer.essentials import ( - InferenceParameters, - ManualDistribution, - ManualDistributionParams, - ModelConfig, - UniformDistribution, - UniformDistributionParams, -) - -# Model with variable temperature and top_p -model_config = ModelConfig( - alias="variable-model", - model="nvidia/nvidia-nemotron-nano-9b-v2", - inference_parameters=InferenceParameters( - # Temperature varies uniformly between 0.7 and 1.0 - temperature=UniformDistribution( - params=UniformDistributionParams(low=0.7, high=1.0) - ), - # Top P samples from discrete values with equal probability - top_p=ManualDistribution( - params=ManualDistributionParams(values=[0.85, 0.90, 0.95]) - ), - max_tokens=2048, - ), -) -``` ## See Also +- **[Inference Parameters](inference-parameters.md)**: Detailed guide to inference parameters and how to configure them - **[Model Providers](model-providers.md)**: Learn about configuring model providers - **[Default Model Settings](default-model-settings.md)**: Pre-configured model settings included with Data Designer - **[Configure Model Settings With the CLI](configure-model-settings-with-the-cli.md)**: Use the CLI to manage model settings diff --git a/docs/concepts/models/model-providers.md b/docs/concepts/models/model-providers.md index 8c7bb7cc..f21ae4ca 100644 --- a/docs/concepts/models/model-providers.md +++ b/docs/concepts/models/model-providers.md @@ -44,7 +44,8 @@ provider = ModelProvider( ## See Also -- **[Model Configurations](model-configs.md)**: Learn about configuring models and inference parameters +- **[Model Configurations](model-configs.md)**: Learn about configuring models +- **[Inference Parameters](inference-parameters.md)**: Detailed guide to inference parameters and how to configure them - **[Default Model Settings](default-model-settings.md)**: Pre-configured providers and model settings included with Data Designer - **[Configure Model Settings With the CLI](configure-model-settings-with-the-cli.md)**: Use the CLI to manage providers and model settings - **[Quick Start Guide](../../quick-start.md)**: Get started with a simple example diff --git a/docs/notebook_source/1-the-basics.py b/docs/notebook_source/1-the-basics.py index de890fb0..b03c758f 100644 --- a/docs/notebook_source/1-the-basics.py +++ b/docs/notebook_source/1-the-basics.py @@ -29,9 +29,9 @@ # %% from data_designer.essentials import ( CategorySamplerParams, + ChatCompletionInferenceParams, DataDesigner, DataDesignerConfigBuilder, - InferenceParameters, LLMTextColumnConfig, ModelConfig, PersonFromFakerSamplerParams, @@ -82,7 +82,7 @@ alias=MODEL_ALIAS, model=MODEL_ID, provider=MODEL_PROVIDER, - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.5, top_p=1.0, max_tokens=1024, diff --git a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py index f968a416..18627a26 100644 --- a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py +++ b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py @@ -31,10 +31,10 @@ # %% from data_designer.essentials import ( CategorySamplerParams, + ChatCompletionInferenceParams, DataDesigner, DataDesignerConfigBuilder, ExpressionColumnConfig, - InferenceParameters, LLMStructuredColumnConfig, ModelConfig, PersonFromFakerSamplerParams, @@ -84,7 +84,7 @@ alias=MODEL_ALIAS, model=MODEL_ID, provider=MODEL_PROVIDER, - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.5, top_p=1.0, max_tokens=1024, diff --git a/docs/notebook_source/3-seeding-with-a-dataset.py b/docs/notebook_source/3-seeding-with-a-dataset.py index 7c0d07e1..12a972aa 100644 --- a/docs/notebook_source/3-seeding-with-a-dataset.py +++ b/docs/notebook_source/3-seeding-with-a-dataset.py @@ -30,9 +30,9 @@ # %% from data_designer.essentials import ( + ChatCompletionInferenceParams, DataDesigner, DataDesignerConfigBuilder, - InferenceParameters, ModelConfig, SeedConfig, ) @@ -78,7 +78,7 @@ alias=MODEL_ALIAS, model=MODEL_ID, provider=MODEL_PROVIDER, - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.5, top_p=1.0, max_tokens=1024, diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index 6265e01f..53759f53 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -47,11 +47,11 @@ # Data Designer imports from data_designer.essentials import ( + ChatCompletionInferenceParams, DataDesigner, DataDesignerConfigBuilder, ImageContext, ImageFormat, - InferenceParameters, LLMTextColumnConfig, ModalityDataType, ModelConfig, @@ -89,7 +89,7 @@ alias="vision", model="meta/llama-4-scout-17b-16e-instruct", provider=MODEL_PROVIDER, - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.60, top_p=0.95, max_tokens=2048, diff --git a/mkdocs.yml b/mkdocs.yml index beae5e89..84517616 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -14,6 +14,7 @@ nav: - Configure with the CLI: concepts/models/configure-model-settings-with-the-cli.md - Model Providers: concepts/models/model-providers.md - Model Configs: concepts/models/model-configs.md + - Inference Parameters: concepts/models/inference-parameters.md - Columns: concepts/columns.md - Validators: concepts/validators.md - Person Sampling: concepts/person_sampling.md diff --git a/src/data_designer/cli/README.md b/src/data_designer/cli/README.md index 751d6446..f15b752e 100644 --- a/src/data_designer/cli/README.md +++ b/src/data_designer/cli/README.md @@ -129,8 +129,10 @@ class ConfigRepository(ABC, Generic[T]): - Field-level validation - Auto-completion support - History navigation (arrow keys) -- Default value handling +- Current value display when editing (`(current value: X)` instead of `(default: X)`) +- Value clearing support (type `'clear'` to remove optional parameter values) - Back navigation support +- Empty input handling (Enter key keeps current value or skips optional fields) #### 6. **UI Utilities** (`ui.py`) - **Purpose**: User interface utilities for terminal output and input @@ -179,17 +181,29 @@ model_configs: model: meta/llama-3.1-70b-instruct provider: nvidia inference_parameters: + generation_type: chat-completion temperature: 0.7 top_p: 0.9 max_tokens: 2048 max_parallel_requests: 4 + timeout: 60 - alias: gpt-4 model: gpt-4-turbo provider: openai inference_parameters: + generation_type: chat-completion temperature: 0.8 top_p: 0.95 max_tokens: 4096 + max_parallel_requests: 4 + - alias: embedder + model: text-embedding-3-large + provider: openai + inference_parameters: + generation_type: embedding + encoding_format: float + dimensions: 1024 + max_parallel_requests: 4 ``` ## Usage Examples diff --git a/src/data_designer/cli/commands/list.py b/src/data_designer/cli/commands/list.py index e0531fcd..a93b9d31 100644 --- a/src/data_designer/cli/commands/list.py +++ b/src/data_designer/cli/commands/list.py @@ -6,6 +6,7 @@ from data_designer.cli.repositories.model_repository import ModelRepository from data_designer.cli.repositories.provider_repository import ProviderRepository from data_designer.cli.ui import console, print_error, print_header, print_info, print_warning +from data_designer.config.models import ModelConfig from data_designer.config.utils.constants import DATA_DESIGNER_HOME, NordColor @@ -75,6 +76,37 @@ def display_providers(provider_repo: ProviderRepository) -> None: console.print() +def format_inference_parameters(model_config: ModelConfig) -> str: + """Format inference parameters based on generation type. + + Args: + model_config: Model configuration + + Returns: + Formatted string of inference parameters + """ + params = model_config.inference_parameters + + # Get parameter values as dict, excluding common base parameters + params_dict = params.model_dump(exclude_none=True, mode="json") + + if not params_dict: + return "(none)" + + # Format each parameter + parts = [] + for key, value in params_dict.items(): + # Check if value is a distribution (has dict structure with distribution_type) + if isinstance(value, dict) and "distribution_type" in value: + formatted_value = "dist" + elif isinstance(value, float): + formatted_value = f"{value:.2f}" + else: + formatted_value = str(value) + parts.append(f"{key}={formatted_value}") + return ", ".join(parts) + + def display_models(model_repo: ModelRepository) -> None: """Load and display model configurations. @@ -97,30 +129,16 @@ def display_models(model_repo: ModelRepository) -> None: table.add_column("Alias", style=NordColor.NORD14.value, no_wrap=True) table.add_column("Model ID", style=NordColor.NORD4.value) table.add_column("Provider", style=NordColor.NORD9.value, no_wrap=True) - table.add_column("Temperature", style=NordColor.NORD15.value, justify="right") - table.add_column("Top P", style=NordColor.NORD15.value, justify="right") - table.add_column("Max Tokens", style=NordColor.NORD15.value, justify="right") + table.add_column("Inference Parameters", style=NordColor.NORD15.value) for mc in registry.model_configs: - # Handle distribution-based parameters - temp_display = ( - f"{mc.inference_parameters.temperature:.2f}" - if isinstance(mc.inference_parameters.temperature, (int, float)) - else "dist" - ) - top_p_display = ( - f"{mc.inference_parameters.top_p:.2f}" - if isinstance(mc.inference_parameters.top_p, (int, float)) - else "dist" - ) + params_display = format_inference_parameters(mc) table.add_row( mc.alias, mc.model, mc.provider or "(default)", - temp_display, - top_p_display, - str(mc.inference_parameters.max_tokens) if mc.inference_parameters.max_tokens else "(none)", + params_display, ) console.print(table) diff --git a/src/data_designer/cli/controllers/model_controller.py b/src/data_designer/cli/controllers/model_controller.py index dff323f1..2124405b 100644 --- a/src/data_designer/cli/controllers/model_controller.py +++ b/src/data_designer/cli/controllers/model_controller.py @@ -160,9 +160,10 @@ def _handle_update(self, available_providers: list[str]) -> None: return # Check if model has distribution-based parameters - if hasattr(model.inference_parameters.temperature, "sample") or hasattr( - model.inference_parameters.top_p, "sample" - ): + params_dict = model.inference_parameters.model_dump(mode="json", exclude_none=True) + has_distribution = any(isinstance(v, dict) and "distribution_type" in v for v in params_dict.values()) + + if has_distribution: print_warning( "This model uses distribution-based inference parameters, " "which cannot be edited via the CLI. Please edit the configuration file directly." diff --git a/src/data_designer/cli/forms/field.py b/src/data_designer/cli/forms/field.py index 27aefd88..d5eab9ba 100644 --- a/src/data_designer/cli/forms/field.py +++ b/src/data_designer/cli/forms/field.py @@ -5,6 +5,7 @@ from collections.abc import Callable from typing import Any, Generic, TypeVar +from data_designer.cli.ui import BACK, prompt_text_input, select_with_arrows from data_designer.cli.utils import validate_numeric_range T = TypeVar("T") @@ -40,8 +41,14 @@ def value(self) -> T | None: return self._value @value.setter - def value(self, val: T) -> None: - """Set and validate the field value.""" + def value(self, val: T | str) -> None: + """Set and validate the field value. Converts empty strings to None for optional fields.""" + # Handle empty string for optional fields (clearing the value) + if val == "" and not self.required: + self._value = None + return + + # Standard validation for non-empty values if self.validator: # For string validators, convert to string first if needed val_str = str(val) if not isinstance(val, str) else val @@ -50,6 +57,40 @@ def value(self, val: T) -> None: raise ValidationError(error_msg or "Invalid value") self._value = val + def _build_prompt_text(self) -> str: + """Build prompt text with current value information.""" + has_current_value = self.default is not None + + if has_current_value: + # Show as "current" instead of "default" with dimmed styling + if not self.required: + return f"{self.prompt} (current value: {self.default}, type 'clear' to remove)" + return f"{self.prompt} (current value: {self.default})" + + return self.prompt + + def _handle_prompt_result(self, result: str | None | Any) -> str | None | Any: + """Handle common prompt result logic (BACK, None, clear keywords, empty input).""" + if result is BACK: + return BACK + + if result is None: + # User cancelled (ESC) + return None + + # Check for special keywords to clear the value + if result and result.lower() in ("clear", "none", "default"): + return "" + + if not result: + # Empty input: return current value if exists + has_current_value = self.default is not None + if has_current_value: + return self.default + return "" + + return result + @abstractmethod def prompt_user(self, allow_back: bool = False) -> T | None | Any: """Prompt user for input.""" @@ -75,21 +116,19 @@ def __init__( def prompt_user(self, allow_back: bool = False) -> str | None | Any: """Prompt user for text input.""" - from data_designer.cli.ui import BACK, prompt_text_input + prompt_text = self._build_prompt_text() + # Don't pass default to prompt_text_input to avoid duplicate "(default: X)" text result = prompt_text_input( - self.prompt, - default=self.default, + prompt_text, + default=None, validator=self.validator, mask=self.mask, completions=self.completions, allow_back=allow_back, ) - if result is BACK: - return BACK - - return result + return self._handle_prompt_result(result) class SelectField(Field[str]): @@ -109,8 +148,6 @@ def __init__( def prompt_user(self, allow_back: bool = False) -> str | None | Any: """Prompt user for selection.""" - from data_designer.cli.ui import BACK, select_with_arrows - result = select_with_arrows( self.options, self.prompt, @@ -144,6 +181,9 @@ def __init__( def range_validator(value: str) -> tuple[bool, str | None]: if not value and not required: return True, None + # Allow special keywords to clear the value + if value and value.lower() in ("clear", "none", "default"): + return True, None if min_value is not None and max_value is not None: is_valid, parsed = validate_numeric_range(value, min_value, max_value) if not is_valid: @@ -163,18 +203,24 @@ def range_validator(value: str) -> tuple[bool, str | None]: def prompt_user(self, allow_back: bool = False) -> float | None | Any: """Prompt user for numeric input.""" - from data_designer.cli.ui import BACK, prompt_text_input - - default_str = str(self.default) if self.default is not None else None + prompt_text = self._build_prompt_text() + # Don't pass default to prompt_text_input to avoid duplicate "(default: X)" text result = prompt_text_input( - self.prompt, - default=default_str, + prompt_text, + default=None, validator=self.validator, allow_back=allow_back, ) - if result is BACK: - return BACK + result = self._handle_prompt_result(result) - return float(result) if result else None + # Return special values (BACK, None, empty string, defaults) as-is + if result is BACK or result is None or result == "": + return result + + # Convert numeric strings to float (but not if it's already a float from default) + if isinstance(result, str): + return float(result) + + return result diff --git a/src/data_designer/cli/forms/model_builder.py b/src/data_designer/cli/forms/model_builder.py index 58c96e02..ba627892 100644 --- a/src/data_designer/cli/forms/model_builder.py +++ b/src/data_designer/cli/forms/model_builder.py @@ -6,8 +6,13 @@ from data_designer.cli.forms.builder import FormBuilder from data_designer.cli.forms.field import NumericField, SelectField, TextField from data_designer.cli.forms.form import Form -from data_designer.config.models import ModelConfig -from data_designer.config.utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P +from data_designer.cli.ui import confirm_action, print_error, print_text +from data_designer.config.models import ( + ChatCompletionInferenceParams, + EmbeddingInferenceParams, + GenerationType, + ModelConfig, +) class ModelFormBuilder(FormBuilder[ModelConfig]): @@ -19,7 +24,7 @@ def __init__(self, existing_aliases: set[str] | None = None, available_providers self.available_providers = available_providers or [] def create_form(self, initial_data: dict[str, Any] | None = None) -> Form: - """Create the model configuration form.""" + """Create the model configuration form with basic fields.""" fields = [] # Model alias @@ -29,7 +34,7 @@ def create_form(self, initial_data: dict[str, Any] | None = None) -> Form: "Model alias (used in your configs)", default=initial_data.get("alias") if initial_data else None, required=True, - validator=self._validate_alias, + validator=self.validate_alias, ) ) @@ -61,46 +66,222 @@ def create_form(self, initial_data: dict[str, Any] | None = None) -> Form: # Single provider - will be set automatically pass - # Inference parameters - fields.extend( - [ + # Generation type + # Extract from inference_parameters if present (for existing models) + default_gen_type = GenerationType.CHAT_COMPLETION + if initial_data: + inference_params = initial_data.get("inference_parameters", {}) + default_gen_type = inference_params.get("generation_type", default_gen_type) + + fields.append( + SelectField( + "generation_type", + "Generation type", + options={ + GenerationType.CHAT_COMPLETION: "Chat completion", + GenerationType.EMBEDDING: "Embedding", + }, + default=default_gen_type, + ) + ) + + return Form(self.title, fields) + + def create_inference_params_form( + self, generation_type: GenerationType, initial_params: dict[str, Any] | None = None + ) -> Form: + """Create generation-type-specific inference parameters form.""" + initial_params = initial_params or {} + fields = [] + + if generation_type == GenerationType.CHAT_COMPLETION: + # Temperature + fields.append( NumericField( "temperature", - f"Temperature ({MIN_TEMPERATURE}-{MAX_TEMPERATURE})", - default=initial_data.get("inference_parameters", {}).get("temperature", 0.7) - if initial_data - else 0.7, - min_value=MIN_TEMPERATURE, - max_value=MAX_TEMPERATURE, - ), + "Temperature (0.0-2.0)", + default=initial_params.get("temperature"), + min_value=0.0, + max_value=2.0, + required=False, + help_text="Higher values make output more random, lower values more deterministic", + ) + ) + + # Top P + fields.append( NumericField( "top_p", - f"Top P ({MIN_TOP_P}-{MAX_TOP_P})", - default=initial_data.get("inference_parameters", {}).get("top_p", 0.9) if initial_data else 0.9, - min_value=MIN_TOP_P, - max_value=MAX_TOP_P, - ), + "Top P (0.0-1.0)", + default=initial_params.get("top_p"), + min_value=0.0, + max_value=1.0, + required=False, + help_text="Controls diversity via nucleus sampling", + ) + ) + + # Max tokens + fields.append( NumericField( "max_tokens", - "Max tokens", - default=initial_data.get("inference_parameters", {}).get("max_tokens", 2048) - if initial_data - else 2048, - min_value=1, - max_value=100000, - ), - ] - ) + "Max tokens (maximum total tokens including input and output)", + default=initial_params.get("max_tokens"), + min_value=1.0, + required=False, + help_text="Maximum number of tokens including both input prompt and generated response", + ) + ) - return Form(self.title, fields) + # Max parallel requests + fields.append( + NumericField( + "max_parallel_requests", + "Max parallel requests (default: 4)", + default=initial_params.get("max_parallel_requests", 4), + min_value=1.0, + required=False, + help_text="Maximum number of parallel API requests", + ) + ) - def _validate_alias(self, alias: str) -> tuple[bool, str | None]: - """Validate model alias.""" - if not alias: - return False, "Model alias is required" - if alias in self.existing_aliases: - return False, f"Model alias '{alias}' already exists" - return True, None + # Timeout + fields.append( + NumericField( + "timeout", + "Timeout in seconds (optional)", + default=initial_params.get("timeout"), + min_value=1.0, + required=False, + help_text="Timeout for each API request in seconds", + ) + ) + + else: # EMBEDDING + # Encoding format + fields.append( + TextField( + "encoding_format", + "Encoding format (float or base64)", + default=initial_params.get("encoding_format"), + required=False, + validator=self.validate_encoding_format, + ) + ) + + # Dimensions + fields.append( + NumericField( + "dimensions", + "Dimensions (number of dimensions for embeddings)", + default=initial_params.get("dimensions"), + min_value=1.0, + required=False, + help_text="Model-specific dimension size (e.g., 1024, 768)", + ) + ) + + # Max parallel requests (common field) + fields.append( + NumericField( + "max_parallel_requests", + "Max parallel requests (default: 4)", + default=initial_params.get("max_parallel_requests", 4), + min_value=1.0, + required=False, + help_text="Maximum number of parallel API requests", + ) + ) + + # Timeout (common field) + fields.append( + NumericField( + "timeout", + "Timeout in seconds (optional)", + default=initial_params.get("timeout"), + min_value=1.0, + required=False, + help_text="Timeout for each API request in seconds", + ) + ) + + return Form(f"{self.title} - Inference Parameters", fields) + + def build_inference_params(self, generation_type: GenerationType, params_data: dict[str, Any]) -> dict[str, Any]: + """Build inference parameters dictionary from form data with proper type conversions.""" + inference_params = {} + + if generation_type == GenerationType.CHAT_COMPLETION: + if params_data.get("temperature") is not None: + inference_params["temperature"] = params_data["temperature"] + if params_data.get("top_p") is not None: + inference_params["top_p"] = params_data["top_p"] + if params_data.get("max_tokens") is not None: + inference_params["max_tokens"] = int(params_data["max_tokens"]) + + else: # EMBEDDING + # Only include fields with actual values; Pydantic will use defaults for missing fields + if params_data.get("encoding_format"): + inference_params["encoding_format"] = params_data["encoding_format"] + if params_data.get("dimensions"): + inference_params["dimensions"] = int(params_data["dimensions"]) + + # Common fields for both generation types + if params_data.get("max_parallel_requests") is not None: + inference_params["max_parallel_requests"] = int(params_data["max_parallel_requests"]) + if params_data.get("timeout") is not None: + inference_params["timeout"] = int(params_data["timeout"]) + + return inference_params + + def run(self, initial_data: dict[str, Any] | None = None) -> ModelConfig | None: + """Run the interactive form with two-step process for generation-type-specific parameters.""" + # Step 1: Collect basic model configuration + basic_form = self.create_form(initial_data) + + if initial_data: + basic_form.set_values(initial_data) + + while True: + basic_result = basic_form.prompt_all(allow_back=True) + + if basic_result is None: + if confirm_action("Cancel configuration?", default=False): + return None + continue + + # Step 2: Collect generation-type-specific inference parameters + generation_type = basic_result.get("generation_type", GenerationType.CHAT_COMPLETION) + initial_params = initial_data.get("inference_parameters") if initial_data else None + + # Print message to indicate we're now configuring inference parameters + gen_type_name = "chat completion" if generation_type == GenerationType.CHAT_COMPLETION else "embedding" + print_text( + f"βš™οΈ Configuring {gen_type_name} inference parameters [dim](Press Enter to keep current value or skip)[/dim]\n" + ) + + params_form = self.create_inference_params_form(generation_type, initial_params) + + params_result = params_form.prompt_all(allow_back=True) + + if params_result is None: + if confirm_action("Cancel configuration?", default=False): + return None + continue + + # Build inference_parameters dict from individual fields + inference_params = self.build_inference_params(generation_type, params_result) + + # Merge results + full_data = {**basic_result, "inference_parameters": inference_params} + + try: + config = self.build_config(full_data) + return config + except Exception as e: + print_error(f"Configuration error: {e}") + if not confirm_action("Try again?", default=True): + return None def build_config(self, form_data: dict[str, Any]) -> ModelConfig: """Build ModelConfig from form data.""" @@ -112,14 +293,40 @@ def build_config(self, form_data: dict[str, Any]) -> ModelConfig: else: provider = None + # Get generation type (from form data, used to determine which inference params to create) + generation_type = form_data.get("generation_type", GenerationType.CHAT_COMPLETION) + + # Get inference parameters dict + inference_params_dict = form_data.get("inference_parameters", {}) + + # Create the appropriate inference parameters type based on generation_type + # The generation_type will be set automatically by the inference params class + if generation_type == GenerationType.EMBEDDING: + inference_params = EmbeddingInferenceParams(**inference_params_dict) + else: + inference_params = ChatCompletionInferenceParams(**inference_params_dict) + return ModelConfig( alias=form_data["alias"], model=form_data["model"], provider=provider, - inference_parameters={ - "temperature": form_data["temperature"], - "top_p": form_data["top_p"], - "max_tokens": int(form_data["max_tokens"]), - "max_parallel_requests": 4, - }, + inference_parameters=inference_params, ) + + def validate_alias(self, alias: str) -> tuple[bool, str | None]: + """Validate model alias.""" + if not alias: + return False, "Model alias is required" + if alias in self.existing_aliases: + return False, f"Model alias '{alias}' already exists" + return True, None + + def validate_encoding_format(self, value: str) -> tuple[bool, str | None]: + """Validate encoding format for embedding models.""" + if not value: + return True, None # Optional field + if value.lower() in ("clear", "none", "default"): + return True, None # Allow clearing keywords + if value not in ("float", "base64"): + return False, "Must be either 'float' or 'base64'" + return True, None diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py index 4f0f7675..c0aa46b6 100644 --- a/src/data_designer/config/analysis/column_statistics.py +++ b/src/data_designer/config/analysis/column_statistics.py @@ -119,30 +119,30 @@ class LLMTextColumnStatistics(GeneralColumnStatistics): Stores both prompt and completion token consumption data. Attributes: - completion_tokens_mean: Mean number of completion tokens generated per record. - completion_tokens_median: Median number of completion tokens generated per record. - completion_tokens_stddev: Standard deviation of completion tokens per record. - prompt_tokens_mean: Mean number of prompt tokens used per record. - prompt_tokens_median: Median number of prompt tokens used per record. - prompt_tokens_stddev: Standard deviation of prompt tokens per record. + output_tokens_mean: Mean number of output tokens generated per record. + output_tokens_median: Median number of output tokens generated per record. + output_tokens_stddev: Standard deviation of output tokens per record. + input_tokens_mean: Mean number of input tokens used per record. + input_tokens_median: Median number of input tokens used per record. + input_tokens_stddev: Standard deviation of input tokens per record. column_type: Discriminator field, always "llm-text" for this statistics type. """ - completion_tokens_mean: Union[float, MissingValue] - completion_tokens_median: Union[float, MissingValue] - completion_tokens_stddev: Union[float, MissingValue] - prompt_tokens_mean: Union[float, MissingValue] - prompt_tokens_median: Union[float, MissingValue] - prompt_tokens_stddev: Union[float, MissingValue] + output_tokens_mean: Union[float, MissingValue] + output_tokens_median: Union[float, MissingValue] + output_tokens_stddev: Union[float, MissingValue] + input_tokens_mean: Union[float, MissingValue] + input_tokens_median: Union[float, MissingValue] + input_tokens_stddev: Union[float, MissingValue] column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value @field_validator( - "completion_tokens_mean", - "completion_tokens_median", - "completion_tokens_stddev", - "prompt_tokens_mean", - "prompt_tokens_median", - "prompt_tokens_stddev", + "output_tokens_mean", + "output_tokens_median", + "output_tokens_stddev", + "input_tokens_mean", + "input_tokens_median", + "input_tokens_stddev", mode="before", ) def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]: @@ -150,13 +150,13 @@ def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> def create_report_row_data(self) -> dict[str, Any]: prompt_tokens_str = ( - f"{self.prompt_tokens_median:.1f} +/- {self.prompt_tokens_stddev:.1f}" - if not self._is_missing_value(self.prompt_tokens_median) + f"{self.input_tokens_median:.1f} +/- {self.input_tokens_stddev:.1f}" + if not self._is_missing_value(self.input_tokens_median) else "--" ) completion_tokens_str = ( - f"{self.completion_tokens_median:.1f} +/- {self.completion_tokens_stddev:.1f}" - if not self._is_missing_value(self.completion_tokens_median) + f"{self.output_tokens_median:.1f} +/- {self.output_tokens_stddev:.1f}" + if not self._is_missing_value(self.output_tokens_median) else "--" ) return { diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py index 3916c08d..48fe529e 100644 --- a/src/data_designer/config/column_configs.py +++ b/src/data_designer/config/column_configs.py @@ -377,3 +377,24 @@ class SeedDatasetColumnConfig(SingleColumnConfig): """ column_type: Literal["seed-dataset"] = "seed-dataset" + + +class EmbeddingColumnConfig(SingleColumnConfig): + """Configuration for embedding generation columns. + + Embedding columns generate embeddings for text input using a specified model. + + Attributes: + target_column: The column to generate embeddings for. The column could be a single text string or a list of text strings in stringified JSON format. + If it is a list of text strings in stringified JSON format, the embeddings will be generated for each text string. + model_alias: The model to use for embedding generation. + column_type: Discriminator field, always "embedding" for this configuration type. + """ + + target_column: str + model_alias: str + column_type: Literal["embedding"] = "embedding" + + @property + def required_columns(self) -> list[str]: + return [self.target_column] diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py index bb8d8cbd..cbfce4f7 100644 --- a/src/data_designer/config/column_types.py +++ b/src/data_designer/config/column_types.py @@ -6,6 +6,7 @@ from typing_extensions import TypeAlias from data_designer.config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -35,6 +36,7 @@ SamplerColumnConfig, SeedDatasetColumnConfig, ValidationColumnConfig, + EmbeddingColumnConfig, ] ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT) @@ -54,6 +56,7 @@ DataDesignerColumnType.SEED_DATASET: "🌱", DataDesignerColumnType.SAMPLER: "🎲", DataDesignerColumnType.VALIDATION: "πŸ”", + DataDesignerColumnType.EMBEDDING: "🧬", } COLUMN_TYPE_EMOJI_MAP.update( {DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()} @@ -70,27 +73,29 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.VALIDATION, + DataDesignerColumnType.EMBEDDING, } dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType)) return column_type in dag_column_types -def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: - """Return True if the column type is an LLM-generated column.""" +def column_type_is_model_generated(column_type: Union[str, DataDesignerColumnType]) -> bool: + """Return True if the column type is a model-generated column.""" column_type = resolve_string_enum(column_type, DataDesignerColumnType) - llm_generated_column_types = { + model_generated_column_types = { DataDesignerColumnType.LLM_TEXT, DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, } - llm_generated_column_types.update( + model_generated_column_types.update( plugin_manager.get_plugin_column_types( DataDesignerColumnType, required_resources=["model_registry"], ) ) - return column_type in llm_generated_column_types + return column_type in model_generated_column_types def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT: @@ -121,6 +126,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs)) if column_type == DataDesignerColumnType.SEED_DATASET: return SeedDatasetColumnConfig(name=name, **kwargs) + if column_type == DataDesignerColumnType.EMBEDDING: + return EmbeddingColumnConfig(name=name, **kwargs) if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value): return plugin.config_cls(name=name, **kwargs) raise InvalidColumnTypeError(f"πŸ›‘ {column_type} is not a valid column type.") # pragma: no cover @@ -135,6 +142,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]: DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index acbc8ef8..5f07595d 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -19,7 +19,7 @@ from data_designer.config.column_types import ( ColumnConfigT, DataDesignerColumnType, - column_type_is_llm_generated, + column_type_is_model_generated, get_column_config_from_kwargs, get_column_display_order, ) @@ -447,12 +447,21 @@ def get_constraints(self, target_column: str) -> list[ColumnConstraintT]: return [c for c in self._constraints if c.target_column == target_column] def get_llm_gen_columns(self) -> list[ColumnConfigT]: - """Get all LLM-generated column configurations. + """Get all model-generated column configurations. Returns: - A list of column configurations that use LLM generation. + A list of column configurations that use model generation. """ - return [c for c in self._column_configs.values() if column_type_is_llm_generated(c.column_type)] + logger.warning("get_llm_gen_columns is deprecated. Use get_model_gen_columns instead.") + return self.get_model_gen_columns() + + def get_model_gen_columns(self) -> list[ColumnConfigT]: + """Get all model-generated column configurations. + + Returns: + A list of column configurations that use model generation. + """ + return [c for c in self._column_configs.values() if column_type_is_model_generated(c.column_type)] def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]: """Get all column configurations of the specified type. diff --git a/src/data_designer/config/default_model_settings.py b/src/data_designer/config/default_model_settings.py index e65d0a90..5ab6bdf8 100644 --- a/src/data_designer/config/default_model_settings.py +++ b/src/data_designer/config/default_model_settings.py @@ -8,7 +8,13 @@ from pathlib import Path from typing import Any, Literal, Optional -from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider +from data_designer.config.models import ( + ChatCompletionInferenceParams, + EmbeddingInferenceParams, + InferenceParamsT, + ModelConfig, + ModelProvider, +) from data_designer.config.utils.constants import ( MANAGED_ASSETS_PATH, MODEL_CONFIGS_FILE_PATH, @@ -21,32 +27,43 @@ logger = logging.getLogger(__name__) -def get_default_text_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_text_alias_inference_parameters() -> ChatCompletionInferenceParams: + return ChatCompletionInferenceParams( temperature=0.85, top_p=0.95, ) -def get_default_reasoning_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_reasoning_alias_inference_parameters() -> ChatCompletionInferenceParams: + return ChatCompletionInferenceParams( temperature=0.35, top_p=0.95, ) -def get_default_vision_alias_inference_parameters() -> InferenceParameters: - return InferenceParameters( +def get_default_vision_alias_inference_parameters() -> ChatCompletionInferenceParams: + return ChatCompletionInferenceParams( temperature=0.85, top_p=0.95, ) -def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters: +def get_default_embedding_alias_inference_parameters(provider: str) -> EmbeddingInferenceParams: + args = dict(encoding_format="float") + if provider == "nvidia": + args["extra_body"] = {"input_type": "query"} + return EmbeddingInferenceParams(**args) + + +def get_default_inference_parameters( + model_alias: Literal["text", "reasoning", "vision", "embedding"], provider: str +) -> InferenceParamsT: if model_alias == "reasoning": return get_default_reasoning_alias_inference_parameters() elif model_alias == "vision": return get_default_vision_alias_inference_parameters() + elif model_alias == "embedding": + return get_default_embedding_alias_inference_parameters(provider) else: return get_default_text_alias_inference_parameters() @@ -60,7 +77,7 @@ def get_builtin_model_configs() -> list[ModelConfig]: alias=f"{provider}-{model_alias}", model=model_id, provider=provider, - inference_parameters=get_default_inference_parameters(model_alias), + inference_parameters=get_default_inference_parameters(model_alias, provider), ) ) return model_configs @@ -103,7 +120,8 @@ def resolve_seed_default_model_settings() -> None: f"🍾 Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}" ) save_config_file( - MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]} + MODEL_CONFIGS_FILE_PATH, + {"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]}, ) if not MODEL_PROVIDERS_FILE_PATH.exists(): @@ -111,7 +129,7 @@ def resolve_seed_default_model_settings() -> None: f"πŸͺ„ Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}" ) save_config_file( - MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]} + MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]} ) if not MANAGED_ASSETS_PATH.exists(): diff --git a/src/data_designer/config/exports.py b/src/data_designer/config/exports.py index eb976659..3839b178 100644 --- a/src/data_designer/config/exports.py +++ b/src/data_designer/config/exports.py @@ -3,6 +3,7 @@ from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig from data_designer.config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -19,6 +20,9 @@ from data_designer.config.dataset_builders import BuildStage from data_designer.config.datastore import DatastoreSettings from data_designer.config.models import ( + ChatCompletionInferenceParams, + EmbeddingInferenceParams, + GenerationType, ImageContext, ImageFormat, InferenceParameters, @@ -81,6 +85,7 @@ def get_config_exports() -> list[str]: CodeLang.__name__, CodeValidatorParams.__name__, ColumnInequalityConstraint.__name__, + ChatCompletionInferenceParams.__name__, DataDesignerColumnType.__name__, DataDesignerConfig.__name__, DataDesignerConfigBuilder.__name__, @@ -89,8 +94,11 @@ def get_config_exports() -> list[str]: DatastoreSettings.__name__, DatetimeSamplerParams.__name__, DropColumnsProcessorConfig.__name__, + EmbeddingColumnConfig.__name__, + EmbeddingInferenceParams.__name__, ExpressionColumnConfig.__name__, GaussianSamplerParams.__name__, + GenerationType.__name__, IndexRange.__name__, InfoType.__name__, ImageContext.__name__, diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py index 09454a88..9b5ac6d7 100644 --- a/src/data_designer/config/models.py +++ b/src/data_designer/config/models.py @@ -5,10 +5,10 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Generic, List, Optional, TypeVar, Union +from typing import Any, Generic, List, Literal, Optional, TypeVar, Union import numpy as np -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self, TypeAlias from data_designer.config.base import ConfigBase @@ -205,33 +205,59 @@ def sample(self) -> float: DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution] -class InferenceParameters(ConfigBase): - """Configuration for LLM inference parameters. +class GenerationType(str, Enum): + CHAT_COMPLETION = "chat-completion" + EMBEDDING = "embedding" + + +class BaseInferenceParams(ConfigBase, ABC): + """Base configuration for inference parameters. Attributes: - temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling. - top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling. - max_tokens: Maximum number of tokens (includes both input and output tokens). + generation_type: Type of generation (chat-completion or embedding). Acts as discriminator. max_parallel_requests: Maximum number of parallel requests to the model API. timeout: Timeout in seconds for each request. extra_body: Additional parameters to pass to the model API. """ - temperature: Optional[Union[float, DistributionT]] = None - top_p: Optional[Union[float, DistributionT]] = None - max_tokens: Optional[int] = Field(default=None, ge=1) + generation_type: GenerationType max_parallel_requests: int = Field(default=4, ge=1) timeout: Optional[int] = Field(default=None, ge=1) extra_body: Optional[dict[str, Any]] = None @property - def generate_kwargs(self) -> dict[str, Union[float, int]]: + def generate_kwargs(self) -> dict[str, Any]: """Get the generate kwargs for the inference parameters. Returns: A dictionary of the generate kwargs. """ result = {} + if self.timeout is not None: + result["timeout"] = self.timeout + if self.extra_body is not None and self.extra_body != {}: + result["extra_body"] = self.extra_body + return result + + +class ChatCompletionInferenceParams(BaseInferenceParams): + """Configuration for LLM inference parameters. + + Attributes: + generation_type: Type of generation, always "chat-completion" for this class. + temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling. + top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling. + max_tokens: Maximum number of tokens (includes both input and output tokens). + """ + + generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION + temperature: Optional[Union[float, DistributionT]] = None + top_p: Optional[Union[float, DistributionT]] = None + max_tokens: Optional[int] = Field(default=None, ge=1) + + @property + def generate_kwargs(self) -> dict[str, Any]: + result = super().generate_kwargs if self.temperature is not None: result["temperature"] = ( self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature @@ -240,10 +266,6 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]: result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p if self.max_tokens is not None: result["max_tokens"] = self.max_tokens - if self.timeout is not None: - result["timeout"] = self.timeout - if self.extra_body is not None and self.extra_body != {}: - result["extra_body"] = self.extra_body return result @model_validator(mode="after") @@ -290,6 +312,47 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) - return min_value <= value <= max_value +# Maintain backwards compatibility with a deprecation warning +class InferenceParameters(ChatCompletionInferenceParams): + """ + Deprecated: Use ChatCompletionInferenceParams instead. + This alias will be removed in a future version. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + logger.warning( + "InferenceParameters is deprecated and will be removed in a future version. " + "Use ChatCompletionInferenceParams instead." + ) + super().__init__(*args, **kwargs) + + +class EmbeddingInferenceParams(BaseInferenceParams): + """Configuration for embedding generation parameters. + + Attributes: + generation_type: Type of generation, always "embedding" for this class. + encoding_format: Format of the embedding encoding ("float" or "base64"). + dimensions: Number of dimensions for the embedding. + """ + + generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + + @property + def generate_kwargs(self) -> dict[str, Union[float, int]]: + result = super().generate_kwargs + if self.encoding_format is not None: + result["encoding_format"] = self.encoding_format + if self.dimensions is not None: + result["dimensions"] = self.dimensions + return result + + +InferenceParamsT: TypeAlias = Union[ChatCompletionInferenceParams, EmbeddingInferenceParams, InferenceParameters] + + class ModelConfig(ConfigBase): """Configuration for a model used for generation. @@ -297,14 +360,32 @@ class ModelConfig(ConfigBase): alias: User-defined alias to reference in column configurations. model: Model identifier (e.g., from build.nvidia.com or other providers). inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.). + The generation_type is determined by the type of inference_parameters. provider: Optional model provider name if using custom providers. """ alias: str model: str - inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters) + inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams) provider: Optional[str] = None + @property + def generation_type(self) -> GenerationType: + """Get the generation type from the inference parameters.""" + return self.inference_parameters.generation_type + + @field_validator("inference_parameters", mode="before") + @classmethod + def _convert_inference_parameters(cls, value: Any) -> Any: + """Convert raw dict to appropriate inference parameters type based on field presence.""" + if isinstance(value, dict): + # Infer type from presence of embedding-specific fields + if "encoding_format" in value or "dimensions" in value: + return EmbeddingInferenceParams(**value) + else: + return ChatCompletionInferenceParams(**value) + return value + class ModelProvider(ConfigBase): """Configuration for a custom model provider. diff --git a/src/data_designer/config/utils/constants.py b/src/data_designer/config/utils/constants.py index 0420240f..12b0d2a3 100644 --- a/src/data_designer/config/utils/constants.py +++ b/src/data_designer/config/utils/constants.py @@ -304,10 +304,12 @@ class NordColor(Enum): "text": "nvidia/nvidia-nemotron-nano-9b-v2", "reasoning": "openai/gpt-oss-20b", "vision": "nvidia/nemotron-nano-12b-v2-vl", + "embedding": "nvidia/llama-3.2-nv-embedqa-1b-v2", }, OPENAI_PROVIDER_NAME: { "text": "gpt-4.1", "reasoning": "gpt-5", "vision": "gpt-5", + "embedding": "text-embedding-3-large", }, } diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py index d30a067a..7d3654ad 100644 --- a/src/data_designer/config/utils/validation.py +++ b/src/data_designer/config/utils/validation.py @@ -15,7 +15,7 @@ from rich.padding import Padding from rich.panel import Panel -from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated +from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated from data_designer.config.processors import ProcessorConfig, ProcessorType from data_designer.config.utils.constants import RICH_CONSOLE_THEME from data_designer.config.utils.misc import ( @@ -119,7 +119,7 @@ def validate_prompt_templates( ) -> list[Violation]: env = ImmutableSandboxedEnvironment() - columns_with_prompts = [c for c in columns if column_type_is_llm_generated(c.column_type)] + columns_with_prompts = [c for c in columns if column_type_is_model_generated(c.column_type)] violations = [] for column in columns_with_prompts: diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py index 1f843c86..c2cca08b 100644 --- a/src/data_designer/config/utils/visualization.py +++ b/src/data_designer/config/utils/visualization.py @@ -8,7 +8,7 @@ from collections import OrderedDict from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import pandas as pd @@ -194,6 +194,7 @@ def display_sample_record( + config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION) + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT) + config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED) + + config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING) ) if len(non_code_columns) > 0: table = Table(title="Generated Columns", **table_kws) @@ -201,6 +202,10 @@ def display_sample_record( table.add_column("Value") for col in non_code_columns: if not col.drop: + if col.column_type == DataDesignerColumnType.EMBEDDING: + record[col.name]["embeddings"] = [ + get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings") + ] table.add_row(col.name, convert_to_row_element(record[col.name])) render_list.append(pad_console_element(table)) @@ -269,6 +274,16 @@ def display_sample_record( console.print(Group(*render_list), markup=False) +def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str: + if max_items <= 0: + raise ValueError("max_items must be greater than 0") + if len(long_list) > max_items: + truncated_part = long_list[:max_items] + return f"[{', '.join(str(x) for x in truncated_part)}, ...]" + else: + return str(long_list) + + def display_sampler_table( sampler_params: dict[SamplerType, ConfigBase], title: Optional[str] = None, diff --git a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py index 0e0a7c0e..3b2e075b 100644 --- a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py +++ b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py @@ -18,8 +18,10 @@ MissingValue, NumericalDistribution, ) -from data_designer.config.column_configs import LLMTextColumnConfig -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.config.column_configs import ( + LLMTextColumnConfig, +) +from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, RecordBasedPromptRenderer, create_response_recipe, @@ -92,7 +94,7 @@ def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[st } -def calculate_prompt_token_stats( +def calculate_input_token_stats( column_config: LLMTextColumnConfig, df: pd.DataFrame ) -> dict[str, float | MissingValue]: try: @@ -109,44 +111,44 @@ def calculate_prompt_token_stats( concatenated_prompt = str(system_prompt + "\n\n" + prompt) num_tokens.append(len(TOKENIZER.encode(concatenated_prompt, disallowed_special=()))) except Exception as e: - logger.warning( - f"{WARNING_PREFIX} failed to calculate prompt token stats for column {column_config.name!r}: {e}" - ) + logger.warning(f"{WARNING_PREFIX} failed to calculate input token stats for column {column_config.name!r}: {e}") return { - "prompt_tokens_mean": MissingValue.CALCULATION_FAILED, - "prompt_tokens_median": MissingValue.CALCULATION_FAILED, - "prompt_tokens_stddev": MissingValue.CALCULATION_FAILED, + "input_tokens_mean": MissingValue.CALCULATION_FAILED, + "input_tokens_median": MissingValue.CALCULATION_FAILED, + "input_tokens_stddev": MissingValue.CALCULATION_FAILED, } return { - "prompt_tokens_mean": np.mean(num_tokens), - "prompt_tokens_median": np.median(num_tokens), - "prompt_tokens_stddev": np.std(num_tokens), + "input_tokens_mean": np.mean(num_tokens), + "input_tokens_median": np.median(num_tokens), + "input_tokens_stddev": np.std(num_tokens), } -def calculate_completion_token_stats(column_name: str, df: pd.DataFrame) -> dict[str, float | MissingValue]: +def calculate_output_token_stats( + column_config: LLMTextColumnConfig, df: pd.DataFrame +) -> dict[str, float | MissingValue]: try: - tokens_per_record = df[column_name].apply( + tokens_per_record = df[column_config.name].apply( lambda value: len(TOKENIZER.encode(str(value), disallowed_special=())) ) return { - "completion_tokens_mean": tokens_per_record.mean(), - "completion_tokens_median": tokens_per_record.median(), - "completion_tokens_stddev": tokens_per_record.std(), + "output_tokens_mean": tokens_per_record.mean(), + "output_tokens_median": tokens_per_record.median(), + "output_tokens_stddev": tokens_per_record.std(), } except Exception as e: - logger.warning(f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_name}: {e}") + logger.warning(f"{WARNING_PREFIX} failed to calculate output token stats for column {column_config.name}: {e}") return { - "completion_tokens_mean": MissingValue.CALCULATION_FAILED, - "completion_tokens_median": MissingValue.CALCULATION_FAILED, - "completion_tokens_stddev": MissingValue.CALCULATION_FAILED, + "output_tokens_mean": MissingValue.CALCULATION_FAILED, + "output_tokens_median": MissingValue.CALCULATION_FAILED, + "output_tokens_stddev": MissingValue.CALCULATION_FAILED, } def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]: return { - **calculate_prompt_token_stats(column_config, df), - **calculate_completion_token_stats(column_config.name, df), + **calculate_input_token_stats(column_config, df), + **calculate_output_token_stats(column_config, df), } diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py index f4ddb60c..b7ad4c6d 100644 --- a/src/data_designer/engine/column_generators/generators/base.py +++ b/src/data_designer/engine/column_generators/generators/base.py @@ -1,13 +1,20 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import functools +import logging from abc import ABC, abstractmethod from typing import overload import pandas as pd +from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP +from data_designer.config.models import BaseInferenceParams, ModelConfig from data_designer.config.utils.type_helpers import StrEnum from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT +from data_designer.engine.models.facade import ModelFacade + +logger = logging.getLogger(__name__) class GenerationStrategy(StrEnum): @@ -59,3 +66,30 @@ def can_generate_from_scratch(self) -> bool: @abstractmethod def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ... + + +class WithModelGeneration: + @functools.cached_property + def model(self) -> ModelFacade: + return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) + + @functools.cached_property + def model_config(self) -> ModelConfig: + return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) + + @functools.cached_property + def inference_parameters(self) -> BaseInferenceParams: + return self.model_config.inference_parameters + + def log_pre_generation(self) -> None: + emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] + logger.info(f"{emoji} Preparing {self.config.column_type} column generation") + logger.info(f" |-- column name: {self.config.name!r}") + logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") + if self.model_config.provider is None: + logger.info(f" |-- default model provider: {self._get_provider_name()!r}") + + def _get_provider_name(self) -> str: + model_alias = self.model_config.alias + provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) + return provider.name diff --git a/src/data_designer/engine/column_generators/generators/embedding.py b/src/data_designer/engine/column_generators/generators/embedding.py new file mode 100644 index 00000000..a8623db5 --- /dev/null +++ b/src/data_designer/engine/column_generators/generators/embedding.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from pydantic import BaseModel, computed_field + +from data_designer.config.column_configs import EmbeddingColumnConfig +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + GeneratorMetadata, + WithModelGeneration, +) +from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string +from data_designer.engine.resources.resource_provider import ResourceType + + +class EmbeddingGenerationResult(BaseModel): + embeddings: list[list[float]] + + @computed_field + def num_embeddings(self) -> int: + return len(self.embeddings) + + @computed_field + def dimension(self) -> int: + return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0 + + +class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]): + @staticmethod + def metadata() -> GeneratorMetadata: + return GeneratorMetadata( + name="embedding_cell_generator", + description="Generate embeddings for a text column.", + generation_strategy=GenerationStrategy.CELL_BY_CELL, + required_resources=[ResourceType.MODEL_REGISTRY], + ) + + def generate(self, data: dict) -> dict: + deserialized_record = deserialize_json_values(data) + input_texts = parse_list_string(deserialized_record[self.config.target_column]) + embeddings = self.model.generate_text_embeddings(input_texts=input_texts) + data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json") + return data diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/llm_completion.py similarity index 71% rename from src/data_designer/engine/column_generators/generators/llm_generators.py rename to src/data_designer/engine/column_generators/generators/llm_completion.py index ee0ab58a..ef47ed01 100644 --- a/src/data_designer/engine/column_generators/generators/llm_generators.py +++ b/src/data_designer/engine/column_generators/generators/llm_completion.py @@ -10,43 +10,41 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) -from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP -from data_designer.config.models import InferenceParameters, ModelConfig from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, GenerationStrategy, GeneratorMetadata, + WithModelGeneration, ) from data_designer.engine.column_generators.utils.prompt_renderer import ( PromptType, RecordBasedPromptRenderer, create_response_recipe, ) -from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.recipes.base import ResponseRecipe from data_designer.engine.processing.utils import deserialize_json_values from data_designer.engine.resources.resource_provider import ResourceType -DEFAULT_MAX_CONVERSATION_RESTARTS = 5 -DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +DEFAULT_MAX_CONVERSATION_RESTARTS = 5 +DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0 -class WithLLMGeneration: +class WithChatCompletionGeneration(WithModelGeneration): @functools.cached_property - def model(self) -> ModelFacade: - return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias) + def response_recipe(self) -> ResponseRecipe: + return create_response_recipe(self.config, self.model_config) - @functools.cached_property - def model_config(self) -> ModelConfig: - return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias) + @property + def max_conversation_correction_steps(self) -> int: + return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - @functools.cached_property - def inference_parameters(self) -> InferenceParameters: - return self.model_config.inference_parameters + @property + def max_conversation_restarts(self) -> int: + return DEFAULT_MAX_CONVERSATION_RESTARTS @functools.cached_property def prompt_renderer(self) -> RecordBasedPromptRenderer: @@ -59,18 +57,6 @@ def prompt_renderer(self) -> RecordBasedPromptRenderer: }, ) - @functools.cached_property - def response_recipe(self) -> ResponseRecipe: - return create_response_recipe(self.config, self.model_config) - - @property - def max_conversation_correction_steps(self) -> int: - return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - - @property - def max_conversation_restarts(self) -> int: - return DEFAULT_MAX_CONVERSATION_RESTARTS - def generate(self, data: dict) -> dict: deserialized_record = deserialize_json_values(data) @@ -96,7 +82,6 @@ def generate(self, data: dict) -> dict: max_correction_steps=self.max_conversation_correction_steps, max_conversation_restarts=self.max_conversation_restarts, purpose=f"running generation for column '{self.config.name}'", - **self.inference_parameters.generate_kwargs, ) data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response)) @@ -106,21 +91,8 @@ def generate(self, data: dict) -> dict: return data - def log_pre_generation(self) -> None: - emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] - logger.info(f"{emoji} Preparing {self.config.column_type} column generation") - logger.info(f" |-- column name: {self.config.name!r}") - logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") - if self.model_config.provider is None: - logger.info(f" |-- default model provider: {self._get_provider_name()!r}") - - def _get_provider_name(self) -> str: - model_alias = self.model_config.alias - provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) - return provider.name - -class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]): +class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -131,7 +103,7 @@ def metadata() -> GeneratorMetadata: ) -class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]): +class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -142,7 +114,7 @@ def metadata() -> GeneratorMetadata: ) -class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]): +class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -153,7 +125,7 @@ def metadata() -> GeneratorMetadata: ) -class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]): +class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]): @staticmethod def metadata() -> GeneratorMetadata: return GeneratorMetadata( @@ -163,10 +135,6 @@ def metadata() -> GeneratorMetadata: required_resources=[ResourceType.MODEL_REGISTRY], ) - @property - def max_conversation_correction_steps(self) -> int: - return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS - @property def max_conversation_restarts(self) -> int: return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py index 61b43753..6da40269 100644 --- a/src/data_designer/engine/column_generators/registry.py +++ b/src/data_designer/engine/column_generators/registry.py @@ -3,6 +3,7 @@ from data_designer.config.base import ConfigBase from data_designer.config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -12,8 +13,9 @@ ) from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.base import ColumnGenerator +from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, @@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig) registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig) registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig) + registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig) registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig) registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig) registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig) registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig) - if with_plugins: for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR): registry.register( diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py index beeedd6f..063aa15c 100644 --- a/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -10,15 +10,18 @@ import pandas as pd -from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated +from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated from data_designer.config.dataset_builders import BuildStage from data_designer.config.processors import ( DropColumnsProcessorConfig, ProcessorConfig, ProcessorType, ) -from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy -from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + GenerationStrategy, + WithModelGeneration, +) from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError from data_designer.engine.dataset_builders.multi_column_configs import ( @@ -72,7 +75,7 @@ def single_column_configs(self) -> list[ColumnConfigT]: @functools.cached_property def llm_generated_column_configs(self) -> list[ColumnConfigT]: - return [config for config in self.single_column_configs if column_type_is_llm_generated(config.column_type)] + return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)] def build( self, @@ -169,7 +172,7 @@ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None: max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR - if isinstance(generator, WithLLMGeneration): + if isinstance(generator, WithModelGeneration): max_workers = generator.inference_parameters.max_parallel_requests self._fan_out_with_threads(generator, max_workers=max_workers) @@ -178,12 +181,12 @@ def _run_full_column_generator(self, generator: ColumnGenerator) -> None: self.batch_manager.update_records(df.to_dict(orient="records")) def _run_model_health_check_if_needed(self) -> bool: - if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs): + if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs): self._resource_provider.model_registry.run_health_check( - set(config.model_alias for config in self.llm_generated_column_configs) + list(set(config.model_alias for config in self.llm_generated_column_configs)) ) - def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None: + def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None: if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} " diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py index e6b83e71..8d969809 100644 --- a/src/data_designer/engine/models/facade.py +++ b/src/data_designer/engine/models/facade.py @@ -9,9 +9,9 @@ from typing import Any from litellm.types.router import DeploymentTypedDict, LiteLLM_Params -from litellm.types.utils import ModelResponse +from litellm.types.utils import EmbeddingResponse, ModelResponse -from data_designer.config.models import ModelConfig, ModelProvider +from data_designer.config.models import GenerationType, ModelConfig, ModelProvider from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.errors import ( GenerationValidationFailureError, @@ -49,6 +49,10 @@ def model_name(self) -> str: def model_provider(self) -> ModelProvider: return self._model_provider_registry.get_provider(self._model_config.provider) + @property + def model_generation_type(self) -> GenerationType: + return self._model_config.generation_type + @property def model_provider_name(self) -> str: return self.model_provider.name @@ -64,13 +68,12 @@ def usage_stats(self) -> ModelUsageStats: def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse: logger.debug( f"Prompting model {self.model_name!r}...", - extra={"model": self.model_name, "messages": messages, "sensitive": True}, + extra={"model": self.model_name, "messages": messages}, ) response = None - if self.model_provider.extra_body: - kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + kwargs = self.consolidate_kwargs(**kwargs) try: - response = self._router.completion(self.model_name, messages, **kwargs) + response = self._router.completion(model=self.model_name, messages=messages, **kwargs) logger.debug( f"Received completion from model {self.model_name!r}", extra={ @@ -84,9 +87,50 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = except Exception as e: raise e finally: - if not skip_usage_tracking: + if not skip_usage_tracking and response is not None: self._track_usage(response) + def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: + # Remove purpose from kwargs to avoid passing it to the model + kwargs.pop("purpose", None) + kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} + if self.model_provider.extra_body: + kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + return kwargs + + @catch_llm_exceptions + def generate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response = None + try: + response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.data) if response.data else 0, + "usage": self._usage_stats.model_dump(), + }, + ) + if response.data and len(response.data) == len(input_texts): + return [data["embedding"] for data in response.data] + else: + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") + except Exception as e: + raise e + finally: + if not skip_usage_tracking and response is not None: + self._track_usage_from_embedding(response) + @catch_llm_exceptions def generate( self, @@ -218,8 +262,21 @@ def _track_usage(self, response: ModelResponse | None) -> None: ): self._usage_stats.extend( token_usage=TokenUsageStats( - prompt_tokens=response.usage.prompt_tokens, - completion_tokens=response.usage.completion_tokens, + input_tokens=response.usage.prompt_tokens, + output_tokens=response.usage.completion_tokens, + ), + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) + + def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None: + if response is None: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return + if response.usage is not None and response.usage.prompt_tokens is not None: + self._usage_stats.extend( + token_usage=TokenUsageStats( + input_tokens=response.usage.prompt_tokens, + output_tokens=0, ), request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) diff --git a/src/data_designer/engine/models/registry.py b/src/data_designer/engine/models/registry.py index aafd8c80..8fe61d59 100644 --- a/src/data_designer/engine/models/registry.py +++ b/src/data_designer/engine/models/registry.py @@ -5,7 +5,7 @@ import logging -from data_designer.config.models import ModelConfig +from data_designer.config.models import GenerationType, ModelConfig from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.litellm_overrides import apply_litellm_patches @@ -73,7 +73,7 @@ def get_model_provider(self, *, model_alias: str) -> ModelProvider: model_config = self.get_model_config(model_alias=model_alias) return self._model_provider_registry.get_provider(model_config.provider) - def run_health_check(self, model_aliases: set[str]) -> None: + def run_health_check(self, model_aliases: list[str]) -> None: logger.info("🩺 Running health checks for models...") for model_alias in model_aliases: model = self.get_model(model_alias=model_alias) @@ -81,15 +81,24 @@ def run_health_check(self, model_aliases: set[str]) -> None: f" |-- πŸ‘€ Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..." ) try: - model.generate( - prompt="Hello!", - parser=lambda x: x, - system_prompt="You are a helpful assistant.", - max_correction_steps=0, - max_conversation_restarts=0, - skip_usage_tracking=True, - purpose="running health checks", - ) + if model.model_generation_type == GenerationType.EMBEDDING: + model.generate_text_embeddings( + input_texts=["Hello!"], + skip_usage_tracking=True, + purpose="running health checks", + ) + elif model.model_generation_type == GenerationType.CHAT_COMPLETION: + model.generate( + prompt="Hello!", + parser=lambda x: x, + system_prompt="You are a helpful assistant.", + max_correction_steps=0, + max_conversation_restarts=0, + skip_usage_tracking=True, + purpose="running health checks", + ) + else: + raise ValueError(f"Unsupported generation type: {model.model_generation_type}") logger.info(" |-- βœ… Passed!") except Exception as e: logger.error(" |-- ❌ Failed!") diff --git a/src/data_designer/engine/models/usage.py b/src/data_designer/engine/models/usage.py index afbea9df..a5c23062 100644 --- a/src/data_designer/engine/models/usage.py +++ b/src/data_designer/engine/models/usage.py @@ -11,20 +11,20 @@ class TokenUsageStats(BaseModel): - prompt_tokens: int = 0 - completion_tokens: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 @computed_field def total_tokens(self) -> int: - return self.prompt_tokens + self.completion_tokens + return self.input_tokens + self.output_tokens @property def has_usage(self) -> bool: return self.total_tokens > 0 - def extend(self, *, prompt_tokens: int, completion_tokens: int) -> None: - self.prompt_tokens += prompt_tokens - self.completion_tokens += completion_tokens + def extend(self, *, input_tokens: int, output_tokens: int) -> None: + self.input_tokens += input_tokens + self.output_tokens += output_tokens class RequestUsageStats(BaseModel): @@ -56,9 +56,7 @@ def extend( self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None ) -> None: if token_usage is not None: - self.token_usage.extend( - prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens - ) + self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens) if request_usage is not None: self.request_usage.extend( successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests diff --git a/src/data_designer/engine/processing/utils.py b/src/data_designer/engine/processing/utils.py index 3579b3bd..5d42c40e 100644 --- a/src/data_designer/engine/processing/utils.py +++ b/src/data_designer/engine/processing/utils.py @@ -1,8 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import ast import json import logging +import re from typing import Any, TypeVar, Union, overload import pandas as pd @@ -100,6 +102,42 @@ def deserialize_json_values(data): return data +def parse_list_string(text: str) -> list[str]: + """Parse a list from a string, handling JSON arrays, Python lists, and trailing commas.""" + text = text.strip() + + # Try JSON first + try: + list_obj = json.loads(text) + if isinstance(list_obj, list): + return _clean_whitespace(list_obj) + except json.JSONDecodeError: + pass + + # Remove trailing commas before closing brackets (common in JSON-like strings) + text_cleaned = re.sub(r",\s*]", "]", text) + text_cleaned = re.sub(r",\s*}", "}", text_cleaned) + + # Try JSON again with cleaned text + try: + return _clean_whitespace(json.loads(text_cleaned)) + except json.JSONDecodeError: + pass + + # Try Python literal eval (handles single quotes) + try: + return _clean_whitespace(ast.literal_eval(text_cleaned)) + except (ValueError, SyntaxError): + pass + + # If all else fails, return the original text + return [text.strip()] + + +def _clean_whitespace(texts: list[str]) -> list[str]: + return [text.strip() for text in texts] + + def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None: joined_columns = set() for df in datasets: diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 758e837e..0597a504 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -9,16 +9,16 @@ from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository from data_designer.cli.services.model_service import ModelService from data_designer.cli.services.provider_service import ProviderService -from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider @pytest.fixture -def stub_inference_parameters() -> InferenceParameters: - return InferenceParameters(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4) +def stub_inference_parameters() -> ChatCompletionInferenceParams: + return ChatCompletionInferenceParams(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4) @pytest.fixture -def stub_model_configs(stub_inference_parameters: InferenceParameters) -> list[ModelConfig]: +def stub_model_configs(stub_inference_parameters: ChatCompletionInferenceParams) -> list[ModelConfig]: return [ ModelConfig( alias="test-alias-1", @@ -41,7 +41,7 @@ def stub_new_model_config() -> ModelConfig: alias="test-alias-3", model="test-model-3", provider="test-provider-1", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.7, top_p=0.9, max_tokens=2048, diff --git a/tests/cli/controllers/test_model_controller.py b/tests/cli/controllers/test_model_controller.py index b630b04a..5dfd91fe 100644 --- a/tests/cli/controllers/test_model_controller.py +++ b/tests/cli/controllers/test_model_controller.py @@ -9,7 +9,7 @@ from data_designer.cli.controllers.model_controller import ModelController from data_designer.cli.repositories.model_repository import ModelConfigRegistry from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig @pytest.fixture @@ -141,7 +141,7 @@ def test_run_updates_model( alias="test-alias-1-updated", model="test-model-1-updated", provider="test-provider-1", - inference_parameters=InferenceParameters(temperature=0.8, top_p=0.95, max_tokens=1024), + inference_parameters=ChatCompletionInferenceParams(temperature=0.8, top_p=0.95, max_tokens=1024), ) mock_builder = MagicMock() diff --git a/tests/cli/forms/test_field.py b/tests/cli/forms/test_field.py index 31db43a1..4f9ba82c 100644 --- a/tests/cli/forms/test_field.py +++ b/tests/cli/forms/test_field.py @@ -55,7 +55,7 @@ def test_text_field_validator_receives_string() -> None: validator.assert_called_with("text") -@patch("data_designer.cli.ui.prompt_text_input") +@patch("data_designer.cli.forms.field.prompt_text_input") def test_text_field_prompt_user_returns_input(mock_prompt: Mock) -> None: """Test TextField prompt_user returns user input.""" mock_prompt.return_value = "user input" @@ -64,8 +64,8 @@ def test_text_field_prompt_user_returns_input(mock_prompt: Mock) -> None: assert field.prompt_user() == "user input" -@patch("data_designer.cli.ui.BACK", "BACK_SENTINEL") -@patch("data_designer.cli.ui.prompt_text_input") +@patch("data_designer.cli.forms.field.BACK", "BACK_SENTINEL") +@patch("data_designer.cli.forms.field.prompt_text_input") def test_text_field_prompt_user_handles_back_navigation(mock_prompt: Mock) -> None: """Test TextField prompt_user properly returns BACK sentinel.""" mock_prompt.return_value = "BACK_SENTINEL" @@ -87,7 +87,7 @@ def test_select_field_value_setter() -> None: assert field.value == "1" -@patch("data_designer.cli.ui.select_with_arrows") +@patch("data_designer.cli.forms.field.select_with_arrows") def test_select_field_prompt_user_returns_selection(mock_select: Mock) -> None: """Test SelectField prompt_user returns user selection.""" mock_select.return_value = "opt1" @@ -97,8 +97,8 @@ def test_select_field_prompt_user_returns_selection(mock_select: Mock) -> None: assert field.prompt_user() == "opt1" -@patch("data_designer.cli.ui.BACK", "BACK_SENTINEL") -@patch("data_designer.cli.ui.select_with_arrows") +@patch("data_designer.cli.forms.field.BACK", "BACK_SENTINEL") +@patch("data_designer.cli.forms.field.select_with_arrows") def test_select_field_prompt_user_handles_back_navigation(mock_select: Mock) -> None: """Test SelectField prompt_user properly returns BACK sentinel.""" mock_select.return_value = "BACK_SENTINEL" @@ -253,7 +253,7 @@ def test_numeric_field_value_setter_rejects_invalid() -> None: # NumericField prompt_user tests -@patch("data_designer.cli.ui.prompt_text_input") +@patch("data_designer.cli.forms.field.prompt_text_input") def test_numeric_field_prompt_user_returns_float(mock_prompt: Mock) -> None: """Test NumericField prompt_user converts string to float.""" mock_prompt.return_value = "42" @@ -265,19 +265,19 @@ def test_numeric_field_prompt_user_returns_float(mock_prompt: Mock) -> None: assert isinstance(result, float) -@patch("data_designer.cli.ui.prompt_text_input") +@patch("data_designer.cli.forms.field.prompt_text_input") def test_numeric_field_prompt_user_returns_none_for_empty(mock_prompt: Mock) -> None: - """Test NumericField prompt_user returns None for empty input.""" + """Test NumericField prompt_user returns empty string for empty input on optional field.""" mock_prompt.return_value = "" field = NumericField(name="optional", prompt="Enter value", required=False) result = field.prompt_user() - assert result is None + assert result == "" -@patch("data_designer.cli.ui.BACK", "BACK_SENTINEL") -@patch("data_designer.cli.ui.prompt_text_input") +@patch("data_designer.cli.forms.field.BACK", "BACK_SENTINEL") +@patch("data_designer.cli.forms.field.prompt_text_input") def test_numeric_field_prompt_user_handles_back_navigation(mock_prompt: Mock) -> None: """Test NumericField prompt_user properly returns BACK sentinel.""" mock_prompt.return_value = "BACK_SENTINEL" @@ -318,3 +318,133 @@ def test_validator_converts_non_string_values() -> None: # Validator should be called with string representation validator.assert_called_once_with("42.5") + + +# Tests for clearing values with 'clear' keyword +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_numeric_field_accepts_clear_keyword(mock_prompt: Mock) -> None: + """Test NumericField accepts 'clear' keyword to remove value.""" + mock_prompt.return_value = "clear" + field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False) + + result = field.prompt_user() + + assert result == "" + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_numeric_field_accepts_none_keyword(mock_prompt: Mock) -> None: + """Test NumericField accepts 'none' keyword to remove value.""" + mock_prompt.return_value = "none" + field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False) + + result = field.prompt_user() + + assert result == "" + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_numeric_field_accepts_default_keyword(mock_prompt: Mock) -> None: + """Test NumericField accepts 'default' keyword to remove value.""" + mock_prompt.return_value = "default" + field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False) + + result = field.prompt_user() + + assert result == "" + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_numeric_field_returns_default_for_empty_when_has_default(mock_prompt: Mock) -> None: + """Test NumericField returns default value when user enters nothing and default exists.""" + mock_prompt.return_value = "" + field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False) + + result = field.prompt_user() + + assert result == 42.0 + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_numeric_field_shows_current_label_with_default(mock_prompt: Mock) -> None: + """Test NumericField shows '(current: X)' instead of '(default: X)' when default exists.""" + mock_prompt.return_value = "" + field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False) + + field.prompt_user() + + # Check that prompt_text_input was called with current value info in the prompt + call_args = mock_prompt.call_args + prompt_arg = call_args[0][0] + assert "current" in prompt_arg.lower() + assert "42.0" in prompt_arg + assert "clear" in prompt_arg.lower() + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_text_field_shows_current_label_with_default(mock_prompt: Mock) -> None: + """Test TextField shows '(current: X)' instead of '(default: X)' when default exists.""" + mock_prompt.return_value = "" + field = TextField(name="name", prompt="Enter name", default="test", required=False) + + field.prompt_user() + + # Check that prompt_text_input was called with current value info in the prompt + call_args = mock_prompt.call_args + prompt_arg = call_args[0][0] + assert "current" in prompt_arg.lower() + assert "test" in prompt_arg + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_text_field_returns_default_for_empty_when_has_default(mock_prompt: Mock) -> None: + """Test TextField returns default value when user enters nothing and default exists.""" + mock_prompt.return_value = "" + field = TextField(name="name", prompt="Enter name", default="test", required=False) + + result = field.prompt_user() + + assert result == "test" + + +def test_numeric_field_value_setter_converts_empty_string_to_none() -> None: + """Test NumericField value setter converts empty string to None for optional fields.""" + field = NumericField(name="optional", prompt="Enter value", required=False) + + field.value = "" + + assert field.value is None + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_text_field_accepts_clear_keyword(mock_prompt: Mock) -> None: + """Test TextField accepts 'clear' keyword to remove value.""" + mock_prompt.return_value = "clear" + field = TextField(name="optional", prompt="Enter value", default="test", required=False) + + result = field.prompt_user() + + assert result == "" + + +@patch("data_designer.cli.forms.field.prompt_text_input") +def test_text_field_shows_clear_instruction_for_optional_with_default(mock_prompt: Mock) -> None: + """Test TextField shows clear instruction for optional fields with default values.""" + mock_prompt.return_value = "" + field = TextField(name="optional", prompt="Enter value", default="test", required=False) + + field.prompt_user() + + # Check that prompt includes 'clear' instruction + call_args = mock_prompt.call_args + prompt_arg = call_args[0][0] + assert "clear" in prompt_arg.lower() + + +def test_text_field_value_setter_converts_empty_string_to_none() -> None: + """Test TextField value setter converts empty string to None for optional fields.""" + field = TextField(name="optional", prompt="Enter value", required=False) + + field.value = "" + + assert field.value is None diff --git a/tests/cli/forms/test_model_builder.py b/tests/cli/forms/test_model_builder.py index 806f2dcf..e75eb745 100644 --- a/tests/cli/forms/test_model_builder.py +++ b/tests/cli/forms/test_model_builder.py @@ -5,7 +5,7 @@ from data_designer.cli.forms.field import ValidationError from data_designer.cli.forms.model_builder import ModelFormBuilder -from data_designer.config.models import ModelConfig +from data_designer.config.models import GenerationType, ModelConfig # Alias validation tests - test through public form interface @@ -102,17 +102,16 @@ def test_form_omits_provider_field_with_no_providers() -> None: def test_form_has_all_required_fields() -> None: - """Test form includes all essential configuration fields.""" + """Test basic form includes essential configuration fields (inference params are in separate form).""" builder = ModelFormBuilder() form = builder.create_form() - # All required fields must be present + # All required fields must be present in basic form assert form.get_field("alias") is not None assert form.get_field("model") is not None - assert form.get_field("temperature") is not None - assert form.get_field("top_p") is not None - assert form.get_field("max_tokens") is not None + assert form.get_field("generation_type") is not None + # inference_parameters are now collected in a separate form via _create_inference_params_form # Initial data handling tests @@ -122,6 +121,7 @@ def test_form_uses_initial_data_for_field_defaults() -> None: "alias": "my-model", "model": "gpt-4", "inference_parameters": { + "generation_type": GenerationType.CHAT_COMPLETION, "temperature": 0.5, "top_p": 0.8, "max_tokens": 1024, @@ -133,9 +133,28 @@ def test_form_uses_initial_data_for_field_defaults() -> None: assert form.get_field("alias").default == "my-model" assert form.get_field("model").default == "gpt-4" - assert form.get_field("temperature").default == 0.5 - assert form.get_field("top_p").default == 0.8 - assert form.get_field("max_tokens").default == 1024 + assert form.get_field("generation_type").default == GenerationType.CHAT_COMPLETION + + +def test_form_extracts_generation_type_from_inference_parameters() -> None: + """Test form correctly extracts generation_type from nested inference_parameters for embedding models.""" + initial_data = { + "alias": "embedding-model", + "model": "text-embedding-3", + "inference_parameters": { + "generation_type": GenerationType.EMBEDDING, + "encoding_format": "base64", + "dimensions": 512, + }, + "provider": "openai", + } + builder = ModelFormBuilder() + + form = builder.create_form(initial_data) + + assert form.get_field("alias").default == "embedding-model" + assert form.get_field("model").default == "text-embedding-3" + assert form.get_field("generation_type").default == GenerationType.EMBEDDING def test_form_uses_standard_defaults_without_initial_data() -> None: @@ -146,9 +165,7 @@ def test_form_uses_standard_defaults_without_initial_data() -> None: assert form.get_field("alias").default is None assert form.get_field("model").default is None - assert form.get_field("temperature").default == 0.7 - assert form.get_field("top_p").default == 0.9 - assert form.get_field("max_tokens").default == 2048 + assert form.get_field("generation_type").default == GenerationType.CHAT_COMPLETION def test_form_handles_partial_initial_data() -> None: @@ -164,10 +181,8 @@ def test_form_handles_partial_initial_data() -> None: # Should use provided values assert form.get_field("alias").default == "my-model" assert form.get_field("model").default == "gpt-4" - # Should fall back to standard defaults for missing values - assert form.get_field("temperature").default == 0.7 - assert form.get_field("top_p").default == 0.9 - assert form.get_field("max_tokens").default == 2048 + # Should fall back to standard defaults for missing generation_type + assert form.get_field("generation_type").default == GenerationType.CHAT_COMPLETION def test_form_provider_defaults_to_first_when_multiple_available() -> None: @@ -199,9 +214,11 @@ def test_build_config_uses_provider_from_form_data() -> None: "alias": "my-model", "model": "gpt-4", "provider": "anthropic", - "temperature": 0.7, - "top_p": 0.9, - "max_tokens": 2048, + "inference_parameters": { + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 2048, + }, } config = builder.build_config(form_data) @@ -215,9 +232,11 @@ def test_build_config_infers_single_available_provider() -> None: form_data = { "alias": "my-model", "model": "gpt-4", - "temperature": 0.7, - "top_p": 0.9, - "max_tokens": 2048, + "inference_parameters": { + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 2048, + }, } config = builder.build_config(form_data) @@ -231,9 +250,11 @@ def test_build_config_sets_provider_none_when_unavailable() -> None: form_data = { "alias": "my-model", "model": "gpt-4", - "temperature": 0.7, - "top_p": 0.9, - "max_tokens": 2048, + "inference_parameters": { + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 2048, + }, } config = builder.build_config(form_data) @@ -247,9 +268,11 @@ def test_build_config_creates_valid_model_config() -> None: form_data = { "alias": "test-model", "model": "gpt-4-turbo", - "temperature": 0.5, - "top_p": 0.8, - "max_tokens": 1024, + "inference_parameters": { + "temperature": 0.5, + "top_p": 0.8, + "max_tokens": 1024, + }, } config = builder.build_config(form_data) @@ -265,19 +288,20 @@ def test_build_config_creates_valid_model_config() -> None: def test_build_config_converts_max_tokens_to_int() -> None: - """Test build_config converts max_tokens from float to int.""" + """Test build_config handles numeric values in inference parameters.""" builder = ModelFormBuilder() form_data = { "alias": "my-model", "model": "gpt-4", - "temperature": 0.7, - "top_p": 0.9, - "max_tokens": 2048.0, # NumericField returns float + "inference_parameters": { + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 2048, + }, } config = builder.build_config(form_data) - assert isinstance(config.inference_parameters.max_tokens, int) assert config.inference_parameters.max_tokens == 2048 @@ -289,9 +313,11 @@ def test_build_config_prefers_explicit_provider_over_inference() -> None: "alias": "my-model", "model": "gpt-4", "provider": "custom", # Explicitly overridden - "temperature": 0.7, - "top_p": 0.9, - "max_tokens": 2048, + "inference_parameters": { + "temperature": 0.7, + "top_p": 0.9, + "max_tokens": 2048, + }, } config = builder.build_config(form_data) @@ -322,14 +348,17 @@ def test_full_workflow_creates_valid_config() -> None: # Simulate user accepting defaults (get_values would return these) form.set_values(initial_data) - # Flatten inference_parameters for form data + + # Form data now includes inference_parameters as a dict form_data = { "alias": "new-model", "model": "claude-3-opus", "provider": "anthropic", - "temperature": 0.6, - "top_p": 0.95, - "max_tokens": 4096, + "inference_parameters": { + "temperature": 0.6, + "top_p": 0.95, + "max_tokens": 4096, + }, } # Build config @@ -342,3 +371,140 @@ def test_full_workflow_creates_valid_config() -> None: assert config.inference_parameters.temperature == 0.6 assert config.inference_parameters.top_p == 0.95 assert config.inference_parameters.max_tokens == 4096 + + +# Tests for new two-step form process +def test_create_inference_params_form_for_chat_completion() -> None: + """Test creating inference parameters form for chat completion models.""" + builder = ModelFormBuilder() + + params_form = builder.create_inference_params_form(GenerationType.CHAT_COMPLETION) + + # Should have chat completion specific fields + assert params_form.get_field("temperature") is not None + assert params_form.get_field("top_p") is not None + assert params_form.get_field("max_tokens") is not None + # Should not have embedding fields + assert params_form.get_field("encoding_format") is None + assert params_form.get_field("dimensions") is None + + +def test_create_inference_params_form_for_embedding() -> None: + """Test creating inference parameters form for embedding models.""" + builder = ModelFormBuilder() + + params_form = builder.create_inference_params_form(GenerationType.EMBEDDING) + + # Should have embedding specific fields + assert params_form.get_field("encoding_format") is not None + assert params_form.get_field("dimensions") is not None + # Should not have chat completion fields + assert params_form.get_field("temperature") is None + assert params_form.get_field("top_p") is None + assert params_form.get_field("max_tokens") is None + + +def test_create_inference_params_form_uses_initial_params() -> None: + """Test inference parameters form uses initial values from existing config.""" + builder = ModelFormBuilder() + initial_params = {"temperature": 0.8, "top_p": 0.95, "max_tokens": 2048} + + params_form = builder.create_inference_params_form(GenerationType.CHAT_COMPLETION, initial_params) + + assert params_form.get_field("temperature").default == 0.8 + assert params_form.get_field("top_p").default == 0.95 + assert params_form.get_field("max_tokens").default == 2048 + + +def test_build_inference_params_chat_completion_with_all_values() -> None: + """Test building inference params dict from chat completion form data.""" + builder = ModelFormBuilder() + params_data = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024.0} + + result = builder.build_inference_params(GenerationType.CHAT_COMPLETION, params_data) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + +def test_build_inference_params_chat_completion_with_partial_values() -> None: + """Test building inference params dict with only some values provided.""" + builder = ModelFormBuilder() + params_data = {"temperature": 0.7, "top_p": None, "max_tokens": None} + + result = builder.build_inference_params(GenerationType.CHAT_COMPLETION, params_data) + + # Only provided values should be included + assert result == {"temperature": 0.7} + + +def test_build_inference_params_embedding_with_all_values() -> None: + """Test building inference params dict from embedding form data.""" + builder = ModelFormBuilder() + params_data = {"encoding_format": "float", "dimensions": 1024.0} + + result = builder.build_inference_params(GenerationType.EMBEDDING, params_data) + + assert result == {"encoding_format": "float", "dimensions": 1024} + + +def test_build_inference_params_embedding_with_partial_values() -> None: + """Test building embedding inference params with only some values provided.""" + builder = ModelFormBuilder() + params_data = {"encoding_format": "float", "dimensions": None} + + result = builder.build_inference_params(GenerationType.EMBEDDING, params_data) + + # encoding_format always included, dimensions omitted if not provided + assert result == {"encoding_format": "float"} + + +def test_build_inference_params_embedding_all_cleared() -> None: + """Test building embedding inference params when both are cleared.""" + builder = ModelFormBuilder() + params_data = {"encoding_format": None, "dimensions": None} + + result = builder.build_inference_params(GenerationType.EMBEDDING, params_data) + + # Empty dict; Pydantic will use defaults (encoding_format="float", dimensions=None) + assert result == {} + + +def test_validate_encoding_format_accepts_valid_values() -> None: + """Test encoding format validation accepts 'float' and 'base64'.""" + builder = ModelFormBuilder() + + is_valid, error = builder.validate_encoding_format("float") + assert is_valid is True + assert error is None + + is_valid, error = builder.validate_encoding_format("base64") + assert is_valid is True + assert error is None + + +def test_validate_encoding_format_rejects_invalid_values() -> None: + """Test encoding format validation rejects invalid values.""" + builder = ModelFormBuilder() + + is_valid, error = builder.validate_encoding_format("invalid") + assert is_valid is False + assert "float" in error and "base64" in error + + +def test_validate_encoding_format_accepts_empty_string() -> None: + """Test encoding format validation accepts empty string (optional field).""" + builder = ModelFormBuilder() + + is_valid, error = builder.validate_encoding_format("") + assert is_valid is True + assert error is None + + +def test_validate_encoding_format_accepts_clear_keywords() -> None: + """Test encoding format validation accepts clearing keywords.""" + builder = ModelFormBuilder() + + for keyword in ("clear", "none", "default", "CLEAR", "None"): + is_valid, error = builder.validate_encoding_format(keyword) + assert is_valid is True, f"Failed for keyword: {keyword}" + assert error is None diff --git a/tests/cli/repositories/test_model_repository.py b/tests/cli/repositories/test_model_repository.py index 01884b5c..624cd360 100644 --- a/tests/cli/repositories/test_model_repository.py +++ b/tests/cli/repositories/test_model_repository.py @@ -21,7 +21,9 @@ def test_load_does_not_exist(): def test_load_exists(tmp_path: Path, stub_model_configs: list[ModelConfig]): model_configs_file_path = tmp_path / MODEL_CONFIGS_FILE_NAME - save_config_file(model_configs_file_path, {"model_configs": [mc.model_dump() for mc in stub_model_configs]}) + save_config_file( + model_configs_file_path, {"model_configs": [mc.model_dump(mode="json") for mc in stub_model_configs]} + ) repository = ModelRepository(tmp_path) assert repository.load() is not None assert repository.load().model_configs == stub_model_configs diff --git a/tests/cli/services/test_model_service.py b/tests/cli/services/test_model_service.py index 1d9bf5aa..2551a376 100644 --- a/tests/cli/services/test_model_service.py +++ b/tests/cli/services/test_model_service.py @@ -7,7 +7,7 @@ from data_designer.cli.repositories.model_repository import ModelRepository from data_designer.cli.services.model_service import ModelService -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig def test_list_all(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]): @@ -30,7 +30,9 @@ def test_add( assert stub_model_service.list_all() == stub_model_configs + [stub_new_model_config] -def test_add_duplicate_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters): +def test_add_duplicate_alias( + stub_model_service: ModelService, stub_inference_parameters: ChatCompletionInferenceParams +): """Test adding a model with an alias that already exists.""" duplicate_model = ModelConfig( alias="test-alias-1", @@ -61,7 +63,9 @@ def test_update_nonexistent_model(stub_model_service: ModelService, stub_new_mod stub_model_service.update("nonexistent", stub_new_model_config) -def test_update_to_existing_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters): +def test_update_to_existing_alias( + stub_model_service: ModelService, stub_inference_parameters: ChatCompletionInferenceParams +): """Test updating a model to an alias that already exists.""" updated_model = ModelConfig( alias="test-alias-2", # Already exists diff --git a/tests/config/analysis/conftest.py b/tests/config/analysis/conftest.py index 42e6fb96..2f683718 100644 --- a/tests/config/analysis/conftest.py +++ b/tests/config/analysis/conftest.py @@ -38,12 +38,12 @@ def sample_llm_text_column_stats(): num_unique=490, pyarrow_dtype="string", simple_dtype="str", - completion_tokens_mean=150.5, - completion_tokens_median=150.0, - completion_tokens_stddev=25.2, - prompt_tokens_mean=50.0, - prompt_tokens_median=50.0, - prompt_tokens_stddev=10.0, + output_tokens_mean=150.5, + output_tokens_median=150.0, + output_tokens_stddev=25.2, + input_tokens_mean=50.0, + input_tokens_median=50.0, + input_tokens_stddev=10.0, ) diff --git a/tests/config/analysis/test_column_statistics.py b/tests/config/analysis/test_column_statistics.py index 34a00a39..a3a262b2 100644 --- a/tests/config/analysis/test_column_statistics.py +++ b/tests/config/analysis/test_column_statistics.py @@ -124,12 +124,12 @@ def test_llm_text_column_statistics_with_missing_values( ): llm_text_column_statistics = llm_text_column_statistics_based_class( **stub_general_stats_args_with_missing_values, - completion_tokens_mean=MissingValue.CALCULATION_FAILED, - completion_tokens_median=MissingValue.CALCULATION_FAILED, - completion_tokens_stddev=MissingValue.CALCULATION_FAILED, - prompt_tokens_mean=MissingValue.CALCULATION_FAILED, - prompt_tokens_median=MissingValue.CALCULATION_FAILED, - prompt_tokens_stddev=MissingValue.CALCULATION_FAILED, + output_tokens_mean=MissingValue.CALCULATION_FAILED, + output_tokens_median=MissingValue.CALCULATION_FAILED, + output_tokens_stddev=MissingValue.CALCULATION_FAILED, + input_tokens_mean=MissingValue.CALCULATION_FAILED, + input_tokens_median=MissingValue.CALCULATION_FAILED, + input_tokens_stddev=MissingValue.CALCULATION_FAILED, ) assert llm_text_column_statistics.column_type == column_type assert llm_text_column_statistics.create_report_row_data() == { @@ -155,12 +155,12 @@ def test_llm_text_column_statistics_with_valid_values( ): llm_text_column_statistics = llm_text_column_statistics_based_class( **stub_general_stats_args_with_valid_values, - completion_tokens_mean=150.0, - completion_tokens_median=150.0, - completion_tokens_stddev=25.2, - prompt_tokens_mean=50.0, - prompt_tokens_median=50.0, - prompt_tokens_stddev=10.0, + output_tokens_mean=150.0, + output_tokens_median=150.0, + output_tokens_stddev=25.2, + input_tokens_mean=50.0, + input_tokens_median=50.0, + input_tokens_stddev=10.0, ) assert llm_text_column_statistics.column_type == column_type assert llm_text_column_statistics.create_report_row_data() == { diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py index ef5e63f3..cbc95756 100644 --- a/tests/config/test_columns.py +++ b/tests/config/test_columns.py @@ -5,6 +5,7 @@ from pydantic import ValidationError from data_designer.config.column_configs import ( + EmbeddingColumnConfig, ExpressionColumnConfig, LLMCodeColumnConfig, LLMJudgeColumnConfig, @@ -17,7 +18,7 @@ ) from data_designer.config.column_types import ( DataDesignerColumnType, - column_type_is_llm_generated, + column_type_is_model_generated, column_type_used_in_execution_dag, get_column_config_from_kwargs, get_column_display_order, @@ -49,20 +50,22 @@ def test_data_designer_column_type_get_display_order(): DataDesignerColumnType.LLM_CODE, DataDesignerColumnType.LLM_STRUCTURED, DataDesignerColumnType.LLM_JUDGE, + DataDesignerColumnType.EMBEDDING, DataDesignerColumnType.VALIDATION, DataDesignerColumnType.EXPRESSION, ] def test_data_designer_column_type_is_llm_generated(): - assert column_type_is_llm_generated(DataDesignerColumnType.LLM_TEXT) - assert column_type_is_llm_generated(DataDesignerColumnType.LLM_CODE) - assert column_type_is_llm_generated(DataDesignerColumnType.LLM_STRUCTURED) - assert column_type_is_llm_generated(DataDesignerColumnType.LLM_JUDGE) - assert not column_type_is_llm_generated(DataDesignerColumnType.SAMPLER) - assert not column_type_is_llm_generated(DataDesignerColumnType.VALIDATION) - assert not column_type_is_llm_generated(DataDesignerColumnType.EXPRESSION) - assert not column_type_is_llm_generated(DataDesignerColumnType.SEED_DATASET) + assert column_type_is_model_generated(DataDesignerColumnType.LLM_TEXT) + assert column_type_is_model_generated(DataDesignerColumnType.LLM_CODE) + assert column_type_is_model_generated(DataDesignerColumnType.LLM_STRUCTURED) + assert column_type_is_model_generated(DataDesignerColumnType.LLM_JUDGE) + assert column_type_is_model_generated(DataDesignerColumnType.EMBEDDING) + assert not column_type_is_model_generated(DataDesignerColumnType.SAMPLER) + assert not column_type_is_model_generated(DataDesignerColumnType.VALIDATION) + assert not column_type_is_model_generated(DataDesignerColumnType.EXPRESSION) + assert not column_type_is_model_generated(DataDesignerColumnType.SEED_DATASET) def test_data_designer_column_type_is_in_dag(): @@ -72,6 +75,7 @@ def test_data_designer_column_type_is_in_dag(): assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED) assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT) assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION) + assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER) assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET) @@ -212,6 +216,20 @@ def test_validation_column_config(): assert validation_column_config.batch_size == 5 +def test_embedding_column_config(): + embedding_column_config = EmbeddingColumnConfig( + name="test_embedding", + target_column="test_column", + model_alias=stub_model_alias, + ) + + assert embedding_column_config.column_type == DataDesignerColumnType.EMBEDDING + assert embedding_column_config.target_column == "test_column" + assert embedding_column_config.model_alias == stub_model_alias + assert embedding_column_config.required_columns == ["test_column"] + assert embedding_column_config.side_effect_columns == [] + + def test_get_column_config_from_kwargs(): assert isinstance( get_column_config_from_kwargs( @@ -278,6 +296,16 @@ def test_get_column_config_from_kwargs(): ExpressionColumnConfig, ) + assert isinstance( + get_column_config_from_kwargs( + name="test_embedding", + column_type=DataDesignerColumnType.EMBEDDING, + target_column="test_column", + model_alias=stub_model_alias, + ), + EmbeddingColumnConfig, + ) + # sampler params is a dictionary assert isinstance( get_column_config_from_kwargs( diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 8a90cc1a..f30af9e7 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -26,7 +26,7 @@ from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams from data_designer.config.seed import DatastoreSeedDatasetReference, SamplingStrategy @@ -54,7 +54,7 @@ def stub_data_designer_builder(stub_data_designer_builder_config_str): def test_loading_model_configs_in_constructor(stub_model_configs): - stub_model_configs_dict = [mc.model_dump() for mc in stub_model_configs] + stub_model_configs_dict = [mc.model_dump(mode="json") for mc in stub_model_configs] # test loading model configs from a list builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) assert builder.model_configs == stub_model_configs @@ -670,7 +670,7 @@ def test_add_model_config(stub_empty_builder): new_model_config = ModelConfig( alias="new-model", model="openai/gpt-4", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.7, top_p=0.95, max_tokens=1024, @@ -691,7 +691,7 @@ def test_add_model_config(stub_empty_builder): alias="provider-model", model="anthropic/claude-3", provider="anthropic", - inference_parameters=InferenceParameters(temperature=0.8), + inference_parameters=ChatCompletionInferenceParams(temperature=0.8), ) stub_empty_builder.add_model_config(provider_model_config) @@ -717,7 +717,7 @@ def test_add_model_config_duplicate_alias(stub_empty_builder): duplicate_model_config = ModelConfig( alias="stub-model", model="different/model", - inference_parameters=InferenceParameters(temperature=0.5), + inference_parameters=ChatCompletionInferenceParams(temperature=0.5), ) with pytest.raises( @@ -733,12 +733,12 @@ def test_delete_model_config(stub_empty_builder): model_config_1 = ModelConfig( alias="model-to-delete", model="model/delete", - inference_parameters=InferenceParameters(temperature=0.5), + inference_parameters=ChatCompletionInferenceParams(temperature=0.5), ) model_config_2 = ModelConfig( alias="model-to-keep", model="model/keep", - inference_parameters=InferenceParameters(temperature=0.6), + inference_parameters=ChatCompletionInferenceParams(temperature=0.6), ) stub_empty_builder.add_model_config(model_config_1) stub_empty_builder.add_model_config(model_config_2) diff --git a/tests/config/test_default_model_settings.py b/tests/config/test_default_model_settings.py index 222bb410..9619af02 100644 --- a/tests/config/test_default_model_settings.py +++ b/tests/config/test_default_model_settings.py @@ -18,28 +18,35 @@ get_default_providers, resolve_seed_default_model_settings, ) -from data_designer.config.models import InferenceParameters +from data_designer.config.models import ChatCompletionInferenceParams, EmbeddingInferenceParams from data_designer.config.utils.visualization import get_nvidia_api_key, get_openai_api_key def test_get_default_inference_parameters(): - assert get_default_inference_parameters("text") == InferenceParameters( + assert get_default_inference_parameters("text", "nvidia") == ChatCompletionInferenceParams( temperature=0.85, top_p=0.95, ) - assert get_default_inference_parameters("reasoning") == InferenceParameters( + assert get_default_inference_parameters("reasoning", "nvidia") == ChatCompletionInferenceParams( temperature=0.35, top_p=0.95, ) - assert get_default_inference_parameters("vision") == InferenceParameters( + assert get_default_inference_parameters("vision", "nvidia") == ChatCompletionInferenceParams( temperature=0.85, top_p=0.95, ) + assert get_default_inference_parameters("embedding", "nvidia") == EmbeddingInferenceParams( + encoding_format="float", + extra_body={"input_type": "query"}, + ) + assert get_default_inference_parameters("embedding", "openai") == EmbeddingInferenceParams( + encoding_format="float", + ) def test_get_builtin_model_configs(): builtin_model_configs = get_builtin_model_configs() - assert len(builtin_model_configs) == 6 + assert len(builtin_model_configs) == 8 assert builtin_model_configs[0].alias == "nvidia-text" assert builtin_model_configs[0].model == "nvidia/nvidia-nemotron-nano-9b-v2" assert builtin_model_configs[0].provider == "nvidia" @@ -49,11 +56,21 @@ def test_get_builtin_model_configs(): assert builtin_model_configs[2].alias == "nvidia-vision" assert builtin_model_configs[2].model == "nvidia/nemotron-nano-12b-v2-vl" assert builtin_model_configs[2].provider == "nvidia" - assert builtin_model_configs[3].alias == "openai-text" - assert builtin_model_configs[3].model == "gpt-4.1" - assert builtin_model_configs[3].provider == "openai" - assert builtin_model_configs[4].alias == "openai-reasoning" - assert builtin_model_configs[4].model == "gpt-5" + assert builtin_model_configs[3].alias == "nvidia-embedding" + assert builtin_model_configs[3].model == "nvidia/llama-3.2-nv-embedqa-1b-v2" + assert builtin_model_configs[3].provider == "nvidia" + assert builtin_model_configs[4].alias == "openai-text" + assert builtin_model_configs[4].model == "gpt-4.1" + assert builtin_model_configs[4].provider == "openai" + assert builtin_model_configs[5].alias == "openai-reasoning" + assert builtin_model_configs[5].model == "gpt-5" + assert builtin_model_configs[5].provider == "openai" + assert builtin_model_configs[6].alias == "openai-vision" + assert builtin_model_configs[6].model == "gpt-5" + assert builtin_model_configs[6].provider == "openai" + assert builtin_model_configs[7].alias == "openai-embedding" + assert builtin_model_configs[7].model == "text-embedding-3-large" + assert builtin_model_configs[7].provider == "openai" def test_get_builtin_model_providers(): diff --git a/tests/config/test_models.py b/tests/config/test_models.py index 186a4f19..6f6618c2 100644 --- a/tests/config/test_models.py +++ b/tests/config/test_models.py @@ -11,9 +11,11 @@ from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ( + ChatCompletionInferenceParams, + EmbeddingInferenceParams, + GenerationType, ImageContext, ImageFormat, - InferenceParameters, ManualDistribution, ManualDistributionParams, ModalityDataType, @@ -46,13 +48,13 @@ def test_image_context_validate_image_format(): def test_inference_parameters_default_construction(): - empty_inference_parameters = InferenceParameters() + empty_inference_parameters = ChatCompletionInferenceParams() assert empty_inference_parameters.generate_kwargs == {} assert empty_inference_parameters.max_parallel_requests == 4 def test_inference_parameters_generate_kwargs(): - assert InferenceParameters( + assert ChatCompletionInferenceParams( temperature=0.95, top_p=0.95, max_tokens=100, @@ -67,9 +69,9 @@ def test_inference_parameters_generate_kwargs(): "extra_body": {"reasoning_effort": "high"}, } - assert InferenceParameters().generate_kwargs == {} + assert ChatCompletionInferenceParams().generate_kwargs == {} - inference_parameters_kwargs = InferenceParameters( + inference_parameters_kwargs = ChatCompletionInferenceParams( temperature=UniformDistribution(params=UniformDistributionParams(low=0.0, high=1.0)), top_p=ManualDistribution(params=ManualDistributionParams(values=[0.0, 1.0], weights=[0.5, 0.5])), ).generate_kwargs @@ -131,32 +133,38 @@ def test_inference_parameters_temperature_validation(): # All temp values provide in a manual destribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters( + ChatCompletionInferenceParams( temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.5], weights=[0.5, 0.5])) ) # High and low values of uniform distribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5))) + ChatCompletionInferenceParams( + temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5)) + ) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0))) + ChatCompletionInferenceParams( + temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0)) + ) # Static values should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=3.0) + ChatCompletionInferenceParams(temperature=3.0) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(temperature=-1.0) + ChatCompletionInferenceParams(temperature=-1.0) # Valid temperature values shouldn't raise validation errors try: - InferenceParameters(temperature=0.1) - InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0))) - InferenceParameters( + ChatCompletionInferenceParams(temperature=0.1) + ChatCompletionInferenceParams( + temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0)) + ) + ChatCompletionInferenceParams( temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.0], weights=[0.5, 0.5])) ) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters temperature validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters temperature validation") def test_generation_parameters_top_p_validation(): @@ -164,31 +172,31 @@ def test_generation_parameters_top_p_validation(): # All top_p values provide in a manual destribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters( + ChatCompletionInferenceParams( top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.5], weights=[0.5, 0.5])) ) # High and low values of uniform distribution should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5))) + ChatCompletionInferenceParams(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5))) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0))) + ChatCompletionInferenceParams(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0))) # Static values should be valid with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=1.5) + ChatCompletionInferenceParams(top_p=1.5) with pytest.raises(ValidationError, match=expected_error_msg): - InferenceParameters(top_p=-0.1) + ChatCompletionInferenceParams(top_p=-0.1) # Valid top_p values shouldn't raise validation errors try: - InferenceParameters(top_p=0.1) - InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0))) - InferenceParameters( + ChatCompletionInferenceParams(top_p=0.1) + ChatCompletionInferenceParams(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0))) + ChatCompletionInferenceParams( top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.0], weights=[0.5, 0.5])) ) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters top_p validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters top_p validation") def test_generation_parameters_max_tokens_validation(): @@ -196,15 +204,15 @@ def test_generation_parameters_max_tokens_validation(): ValidationError, match="Input should be greater than or equal to 1", ): - InferenceParameters(max_tokens=0) + ChatCompletionInferenceParams(max_tokens=0) # Valid max_tokens values shouldn't raise validation errors try: - InferenceParameters(max_tokens=128_000) - InferenceParameters(max_tokens=4096) - InferenceParameters(max_tokens=1) + ChatCompletionInferenceParams(max_tokens=128_000) + ChatCompletionInferenceParams(max_tokens=4096) + ChatCompletionInferenceParams(max_tokens=1) except Exception: - pytest.fail("Unexpected exception raised during InferenceParameters max_tokens validation") + pytest.fail("Unexpected exception raised during CompletionInferenceParameters max_tokens validation") def test_load_model_configs(): @@ -212,7 +220,7 @@ def test_load_model_configs(): ModelConfig(alias="test", model="test"), ModelConfig(alias="test2", model="test2"), ] - stub_model_configs_dict_list = [mc.model_dump() for mc in stub_model_configs] + stub_model_configs_dict_list = [mc.model_dump(mode="json") for mc in stub_model_configs] assert load_model_configs([]) == [] assert load_model_configs(stub_model_configs) == stub_model_configs @@ -248,6 +256,43 @@ def test_load_model_configs(): load_model_configs(tmp_file.name) -def test_model_config_default_construction(): +def test_model_config_construction(): + # test default construction model_config = ModelConfig(alias="test", model="test") - assert model_config.inference_parameters == InferenceParameters() + assert model_config.inference_parameters == ChatCompletionInferenceParams() + assert model_config.generation_type == GenerationType.CHAT_COMPLETION + + # test construction with completion inference parameters + completion_params = ChatCompletionInferenceParams(temperature=0.5, top_p=0.5, max_tokens=100) + model_config = ModelConfig(alias="test", model="test", inference_parameters=completion_params) + assert model_config.inference_parameters == completion_params + assert model_config.generation_type == GenerationType.CHAT_COMPLETION + + # test construction with embedding inference parameters + embedding_params = EmbeddingInferenceParams(dimensions=100) + model_config = ModelConfig(alias="test", model="test", inference_parameters=embedding_params) + assert model_config.inference_parameters == embedding_params + assert model_config.generation_type == GenerationType.EMBEDDING + + +def test_model_config_generation_type_from_dict(): + # Test that generation_type in dict is used to create the right inference params type + model_config = ModelConfig.model_validate( + { + "alias": "test", + "model": "test", + "inference_parameters": {"generation_type": "embedding", "dimensions": 100}, + } + ) + assert isinstance(model_config.inference_parameters, EmbeddingInferenceParams) + assert model_config.generation_type == GenerationType.EMBEDDING + + model_config = ModelConfig.model_validate( + { + "alias": "test", + "model": "test", + "inference_parameters": {"generation_type": "chat-completion", "temperature": 0.5}, + } + ) + assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams) + assert model_config.generation_type == GenerationType.CHAT_COMPLETION diff --git a/tests/config/utils/test_visualization.py b/tests/config/utils/test_visualization.py index ab0eebc5..fa55824e 100644 --- a/tests/config/utils/test_visualization.py +++ b/tests/config/utils/test_visualization.py @@ -8,7 +8,7 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.utils.code_lang import CodeLang -from data_designer.config.utils.visualization import display_sample_record, mask_api_key +from data_designer.config.utils.visualization import display_sample_record, get_truncated_list_as_string, mask_api_key from data_designer.config.validator_params import CodeValidatorParams @@ -75,3 +75,14 @@ def test_mask_api_key(): # None or empty returns "(not set)" assert mask_api_key(None) == "(not set)" assert mask_api_key("") == "(not set)" + + +def test_get_truncated_list_as_string(): + assert get_truncated_list_as_string([1, 2, 3, 4, 5]) == "[1, 2, ...]" + assert get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=1) == "[1, ...]" + assert get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=3) == "[1, 2, 3, ...]" + assert get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=10) == "[1, 2, 3, 4, 5]" + with pytest.raises(ValueError): + get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=-1) + with pytest.raises(ValueError): + get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=0) diff --git a/tests/conftest.py b/tests/conftest.py index 31dc0057..5f0205e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.datastore import DatastoreSettings -from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider @pytest.fixture @@ -135,7 +135,7 @@ def stub_model_configs() -> list[ModelConfig]: ModelConfig( alias="stub-model", model="stub-model", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.9, top_p=0.9, max_tokens=2048, diff --git a/tests/engine/analysis/test_column_statistics_calculator.py b/tests/engine/analysis/test_column_statistics_calculator.py index d21ae179..203e03f1 100644 --- a/tests/engine/analysis/test_column_statistics_calculator.py +++ b/tests/engine/analysis/test_column_statistics_calculator.py @@ -43,10 +43,10 @@ def test_llm_generated_column_statistics(stub_df, column_configs): assert stats.column_name == column_config.name assert stats.column_type == column_config.column_type assert stats.num_records == len(stub_df) - assert isinstance(stats.completion_tokens_mean, float) - assert isinstance(stats.completion_tokens_stddev, float) - assert isinstance(stats.prompt_tokens_mean, float) - assert isinstance(stats.prompt_tokens_stddev, float) + assert isinstance(stats.output_tokens_mean, float) + assert isinstance(stats.output_tokens_stddev, float) + assert isinstance(stats.input_tokens_mean, float) + assert isinstance(stats.input_tokens_stddev, float) def test_sampler_column_statistics(stub_df, column_configs): diff --git a/tests/engine/analysis/utils/test_column_statistics_calculations.py b/tests/engine/analysis/utils/test_column_statistics_calculations.py index 6458aa49..65afefd2 100644 --- a/tests/engine/analysis/utils/test_column_statistics_calculations.py +++ b/tests/engine/analysis/utils/test_column_statistics_calculations.py @@ -19,9 +19,9 @@ from data_designer.config.utils.numerical_helpers import prepare_number_for_reporting from data_designer.engine.analysis.utils.column_statistics_calculations import ( calculate_column_distribution, - calculate_completion_token_stats, calculate_general_column_info, - calculate_prompt_token_stats, + calculate_input_token_stats, + calculate_output_token_stats, calculate_validation_column_info, convert_pyarrow_dtype_to_simple_dtype, ensure_boolean, @@ -169,38 +169,39 @@ def test_calculate_general_column_info(stub_df_with_mixed_column_types): assert result["num_unique"] == MissingValue.CALCULATION_FAILED -def test_calculate_prompt_token_stats(mock_prompt_renderer_render, stub_column_config, stub_df_responses): +def test_calculate_input_token_stats(mock_prompt_renderer_render, stub_column_config, stub_df_responses): prompt_cycle = cycle(["System prompt", "Test prompt"]) mock_prompt_renderer_render.side_effect = lambda *args, **kwargs: next(prompt_cycle) - result = calculate_prompt_token_stats(stub_column_config, stub_df_responses) - assert "prompt_tokens_mean" in result - assert "prompt_tokens_stddev" in result - assert "prompt_tokens_median" in result - assert isinstance(result["prompt_tokens_mean"], float) - assert isinstance(result["prompt_tokens_stddev"], float) - assert isinstance(result["prompt_tokens_median"], float) + result = calculate_input_token_stats(stub_column_config, stub_df_responses) + assert "input_tokens_mean" in result + assert "input_tokens_stddev" in result + assert "input_tokens_median" in result + assert isinstance(result["input_tokens_mean"], float) + assert isinstance(result["input_tokens_stddev"], float) + assert isinstance(result["input_tokens_median"], float) mock_prompt_renderer_render.side_effect = Exception("Test error") - result = calculate_prompt_token_stats(stub_column_config, stub_df_responses) - assert result["prompt_tokens_mean"] == MissingValue.CALCULATION_FAILED - assert result["prompt_tokens_stddev"] == MissingValue.CALCULATION_FAILED - assert result["prompt_tokens_median"] == MissingValue.CALCULATION_FAILED - - -def test_calculate_completion_token_stats(stub_column_config, stub_df_responses): - result = calculate_completion_token_stats(stub_column_config.name, stub_df_responses) - assert "completion_tokens_mean" in result - assert "completion_tokens_stddev" in result - assert "completion_tokens_median" in result - assert isinstance(result["completion_tokens_mean"], float) - assert isinstance(result["completion_tokens_stddev"], float) - assert isinstance(result["completion_tokens_median"], float) - - result = calculate_completion_token_stats("nonexistent_column", stub_df_responses) - assert result["completion_tokens_mean"] == MissingValue.CALCULATION_FAILED - assert result["completion_tokens_stddev"] == MissingValue.CALCULATION_FAILED - assert result["completion_tokens_median"] == MissingValue.CALCULATION_FAILED + result = calculate_input_token_stats(stub_column_config, stub_df_responses) + assert result["input_tokens_mean"] == MissingValue.CALCULATION_FAILED + assert result["input_tokens_stddev"] == MissingValue.CALCULATION_FAILED + assert result["input_tokens_median"] == MissingValue.CALCULATION_FAILED + + +def test_calculate_output_token_stats(stub_column_config, stub_df_responses): + result = calculate_output_token_stats(stub_column_config, stub_df_responses) + assert "output_tokens_mean" in result + assert "output_tokens_stddev" in result + assert "output_tokens_median" in result + assert isinstance(result["output_tokens_mean"], float) + assert isinstance(result["output_tokens_stddev"], float) + assert isinstance(result["output_tokens_median"], float) + + stub_column_config.name = "nonexistent_column" + result = calculate_output_token_stats(stub_column_config, stub_df_responses) + assert result["output_tokens_mean"] == MissingValue.CALCULATION_FAILED + assert result["output_tokens_stddev"] == MissingValue.CALCULATION_FAILED + assert result["output_tokens_median"] == MissingValue.CALCULATION_FAILED def test_calculate_validation_column_info(stub_column_config, stub_df_code_validation): diff --git a/tests/engine/column_generators/generators/test_embedding.py b/tests/engine/column_generators/generators/test_embedding.py new file mode 100644 index 00000000..75efdfc6 --- /dev/null +++ b/tests/engine/column_generators/generators/test_embedding.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest + +from data_designer.config.column_configs import EmbeddingColumnConfig +from data_designer.engine.column_generators.generators.base import GenerationStrategy +from data_designer.engine.column_generators.generators.embedding import ( + EmbeddingCellGenerator, + EmbeddingGenerationResult, +) + + +@pytest.fixture +def stub_embedding_column_config(): + return EmbeddingColumnConfig(name="test_embedding", target_column="test_column", model_alias="test_model") + + +@pytest.fixture +def stub_embeddings() -> list[list[float]]: + return [[0.1, 0.2], [0.3, 0.4]] + + +def test_embedding_cell_generator_metadata(stub_embedding_column_config, stub_resource_provider): + metadata = EmbeddingCellGenerator( + config=stub_embedding_column_config, resource_provider=stub_resource_provider + ).metadata() + assert metadata.name == "embedding_cell_generator" + assert metadata.description == "Generate embeddings for a text column." + assert metadata.generation_strategy == GenerationStrategy.CELL_BY_CELL + + +def test_embedding_cell_generator_generate(stub_embedding_column_config, stub_resource_provider, stub_embeddings): + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_text_embeddings", + return_value=stub_embeddings, + ) as mock_generate: + embedding_cell_generator = EmbeddingCellGenerator( + config=stub_embedding_column_config, resource_provider=stub_resource_provider + ) + data = embedding_cell_generator.generate(data={"test_column": "['test1', 'test2']"}) + assert stub_embedding_column_config.name in data + assert data[stub_embedding_column_config.name] == EmbeddingGenerationResult( + embeddings=stub_embeddings + ).model_dump(mode="json") + mock_generate.assert_called_once_with(input_texts=["test1", "test2"]) diff --git a/tests/engine/column_generators/generators/test_llm_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py similarity index 92% rename from tests/engine/column_generators/generators/test_llm_generators.py rename to tests/engine/column_generators/generators/test_llm_completion_generators.py index 259f3a08..0b787b7e 100644 --- a/tests/engine/column_generators/generators/test_llm_generators.py +++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py @@ -11,7 +11,7 @@ LLMStructuredColumnConfig, LLMTextColumnConfig, ) -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS, DEFAULT_MAX_CONVERSATION_RESTARTS, REASONING_TRACE_COLUMN_POSTFIX, @@ -94,7 +94,7 @@ def test_generate_method(): assert call_args[1]["multi_modal_context"] is None -@patch("data_designer.engine.column_generators.generators.llm_generators.logger", autospec=True) +@patch("data_designer.engine.column_generators.generators.base.logger", autospec=True) def test_log_pre_generation(mock_logger): generator, mock_resource_provider, _, mock_model_config, _, _, _ = _create_generator_with_mocks() mock_model_config.model_dump_json.return_value = '{"test": "config"}' @@ -259,20 +259,3 @@ def test_generate_with_json_deserialization(): result = generator.generate(data) assert result["test_column"] == {"result": "json_output"} - - -def test_generate_with_inference_parameters(): - generator, _, mock_model, _, mock_inference_params, mock_prompt_renderer, mock_response_recipe = ( - _create_generator_with_mocks() - ) - - mock_inference_params.generate_kwargs = {"temperature": 0.7, "max_tokens": 100} - _setup_generate_mocks(mock_prompt_renderer, mock_response_recipe, mock_model) - - data = {"input": "test_input"} - generator.generate(data) - - call_args = mock_model.generate.call_args - assert call_args[1]["temperature"] == 0.7 - assert call_args[1]["max_tokens"] == 100 - assert call_args[1]["purpose"] == "running generation for column 'test_column'" diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py index f70b0d90..0d325937 100644 --- a/tests/engine/column_generators/test_registry.py +++ b/tests/engine/column_generators/test_registry.py @@ -3,7 +3,7 @@ from data_designer.config.column_types import DataDesignerColumnType from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator -from data_designer.engine.column_generators.generators.llm_generators import ( +from data_designer.engine.column_generators.generators.llm_completion import ( LLMCodeCellGenerator, LLMJudgeCellGenerator, LLMStructuredCellGenerator, diff --git a/tests/engine/models/conftest.py b/tests/engine/models/conftest.py index 95e6941f..4e1dd059 100644 --- a/tests/engine/models/conftest.py +++ b/tests/engine/models/conftest.py @@ -5,7 +5,11 @@ import pytest -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import ( + ChatCompletionInferenceParams, + EmbeddingInferenceParams, + ModelConfig, +) from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry from data_designer.engine.models.registry import ModelRegistry, create_model_registry from data_designer.engine.secret_resolver import SecretsFileResolver @@ -38,7 +42,7 @@ def stub_model_configs() -> list[ModelConfig]: alias="stub-text", model="stub-model-text", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ), @@ -46,10 +50,18 @@ def stub_model_configs() -> list[ModelConfig]: alias="stub-reasoning", model="stub-model-reasoning", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ), + ModelConfig( + alias="stub-embedding", + model="stub-model-embedding", + provider="stub-model-provider", + inference_parameters=EmbeddingInferenceParams( + dimensions=100, + ), + ), ] diff --git a/tests/engine/models/test_facade.py b/tests/engine/models/test_facade.py index aad035cd..97d8cacf 100644 --- a/tests/engine/models/test_facade.py +++ b/tests/engine/models/test_facade.py @@ -5,7 +5,7 @@ from unittest.mock import patch import pytest -from litellm.types.utils import Choices, Message, ModelResponse +from litellm.types.utils import Choices, EmbeddingResponse, Message, ModelResponse from data_designer.engine.models.errors import ModelGenerationValidationFailureError from data_designer.engine.models.facade import ModelFacade @@ -30,10 +30,20 @@ def stub_model_facade(stub_model_configs, stub_secrets_resolver, stub_model_prov @pytest.fixture -def stub_expected_response(): +def stub_completion_messages(): + return [{"role": "user", "content": "test"}] + + +@pytest.fixture +def stub_expected_completion_response(): return ModelResponse(choices=Choices(message=Message(content="Test response"))) +@pytest.fixture +def stub_expected_embedding_response(): + return EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) + + @pytest.mark.parametrize( "max_correction_steps,max_conversation_restarts,total_calls", [ @@ -105,6 +115,24 @@ def test_usage_stats_property(stub_model_facade): assert hasattr(stub_model_facade.usage_stats, "model_dump") +def test_consolidate_kwargs(stub_model_configs, stub_model_facade): + # Model config generate kwargs are used as base, and purpose is removed + result = stub_model_facade.consolidate_kwargs(purpose="test") + assert result == stub_model_configs[0].inference_parameters.generate_kwargs + + # kwargs overrides model config generate kwargs + result = stub_model_facade.consolidate_kwargs(temperature=0.01, purpose="test") + assert result == {**stub_model_configs[0].inference_parameters.generate_kwargs, "temperature": 0.01} + + # Provider extra_body overrides all other kwargs + stub_model_facade.model_provider.extra_body = {"foo_provider": "bar_provider"} + result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}, purpose="test") + assert result == { + **stub_model_configs[0].inference_parameters.generate_kwargs, + "extra_body": {"foo_provider": "bar_provider", "foo": "bar"}, + } + + @pytest.mark.parametrize( "skip_usage_tracking", [ @@ -112,63 +140,85 @@ def test_usage_stats_property(stub_model_facade): True, ], ) -def test_completion_success(stub_model_facade, stub_expected_response, skip_usage_tracking): - stub_model_facade._router.completion = lambda model_name, messages, **kwargs: stub_expected_response - - messages = [{"role": "user", "content": "test"}] - result = stub_model_facade.completion(messages, skip_usage_tracking=skip_usage_tracking) - - assert result == stub_expected_response - - -def test_completion_with_exception(stub_model_facade): - def raise_exception(*args, **kwargs): - raise Exception("Router error") +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_success( + mock_router_completion, + stub_completion_messages, + stub_model_configs, + stub_model_facade, + stub_expected_completion_response, + skip_usage_tracking, +): + mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response + result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) + assert result == stub_expected_completion_response + assert mock_router_completion.call_count == 1 + assert mock_router_completion.call_args[1] == { + "model": "stub-model-text", + "messages": stub_completion_messages, + **stub_model_configs[0].inference_parameters.generate_kwargs, + } - stub_model_facade._router.completion = raise_exception - messages = [{"role": "user", "content": "test"}] +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_with_exception(mock_router_completion, stub_completion_messages, stub_model_facade): + mock_router_completion.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): - stub_model_facade.completion(messages) + stub_model_facade.completion(stub_completion_messages) -def test_completion_with_kwargs(stub_model_facade, stub_expected_response): +@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) +def test_completion_with_kwargs( + mock_router_completion, + stub_completion_messages, + stub_model_configs, + stub_model_facade, + stub_expected_completion_response, +): captured_kwargs = {} - def mock_completion(model_name, messages, **kwargs): + def mock_completion(self, model, messages, **kwargs): captured_kwargs.update(kwargs) - return stub_expected_response + return stub_expected_completion_response - stub_model_facade._router.completion = mock_completion + mock_router_completion.side_effect = mock_completion - messages = [{"role": "user", "content": "test"}] kwargs = {"temperature": 0.7, "max_tokens": 100} - result = stub_model_facade.completion(messages, **kwargs) + result = stub_model_facade.completion(stub_completion_messages, **kwargs) - assert result == stub_expected_response - assert captured_kwargs == kwargs + assert result == stub_expected_completion_response + # completion kwargs overrides model config generate kwargs + assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} -@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True) -def test_completion_with_extra_body(mock_router_completion, stub_model_facade): - messages = [{"role": "user", "content": "test"}] - - # completion call has no extra body argument and provider has no extra body - _ = stub_model_facade.completion(messages) - assert len(mock_router_completion.call_args) == 2 - assert mock_router_completion.call_args[0][1] == "stub-model-text" - assert mock_router_completion.call_args[0][2] == messages - - # completion call has no extra body argument and provider has extra body. - # Should pull extra body from model provider - custom_extra_body = {"some_custom_key": "some_custom_value"} - stub_model_facade.model_provider.extra_body = custom_extra_body - _ = stub_model_facade.completion(messages) - assert mock_router_completion.call_args[1] == {"extra_body": custom_extra_body} - - # completion call has extra body argument and provider has extra body. - # Should merge the two with provider extra body taking precedence - completion_extra_body = {"some_completion_key": "some_completion_value", "some_custom_key": "some_different_value"} - _ = stub_model_facade.completion(messages, extra_body=completion_extra_body) - assert mock_router_completion.call_args[1] == {"extra_body": {**completion_extra_body, **custom_extra_body}} +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_success(mock_router_embedding, stub_model_facade, stub_expected_embedding_response): + mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response + input_texts = ["test1", "test2"] + result = stub_model_facade.generate_text_embeddings(input_texts) + assert result == [data["embedding"] for data in stub_expected_embedding_response.data] + + +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_with_exception(mock_router_embedding, stub_model_facade): + mock_router_embedding.side_effect = Exception("Router error") + + with pytest.raises(Exception, match="Router error"): + stub_model_facade.generate_text_embeddings(["test1", "test2"]) + + +@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True) +def test_generate_text_embeddings_with_kwargs( + mock_router_embedding, stub_model_configs, stub_model_facade, stub_expected_embedding_response +): + captured_kwargs = {} + + def mock_embedding(self, model, input, **kwargs): + captured_kwargs.update(kwargs) + return stub_expected_embedding_response + + mock_router_embedding.side_effect = mock_embedding + kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"} + _ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs) + assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} diff --git a/tests/engine/models/test_model_registry.py b/tests/engine/models/test_model_registry.py index 89b95a50..4ea5a447 100644 --- a/tests/engine/models/test_model_registry.py +++ b/tests/engine/models/test_model_registry.py @@ -4,9 +4,8 @@ from unittest.mock import patch import pytest -from litellm import AuthenticationError -from data_designer.config.models import InferenceParameters, ModelConfig +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig from data_designer.engine.models.errors import ModelAuthenticationError from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.registry import ModelRegistry, create_model_registry @@ -24,7 +23,7 @@ def stub_new_model_config(): alias="stub-vision", model="stub-model-vision", provider="stub-model-provider", - inference_parameters=InferenceParameters( + inference_parameters=ChatCompletionInferenceParams( temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100 ), ) @@ -36,7 +35,7 @@ def stub_no_usage_config(): alias="no-usage", model="no-usage-model", provider="stub-model-provider", - inference_parameters=InferenceParameters(), + inference_parameters=ChatCompletionInferenceParams(), ) @@ -72,14 +71,15 @@ def test_register_model_configs(stub_model_registry, stub_new_model_config): stub_model_registry.register_model_configs([stub_new_model_config]) # Verify configs are registered - assert len(stub_model_registry.model_configs) == 3 + assert len(stub_model_registry.model_configs) == 4 # Trigger lazy initialization by requesting models assert stub_model_registry.get_model(model_alias="stub-text").model_name == "stub-model-text" assert stub_model_registry.get_model(model_alias="stub-reasoning").model_name == "stub-model-reasoning" assert stub_model_registry.get_model(model_alias="stub-vision").model_name == "stub-model-vision" + assert stub_model_registry.get_model(model_alias="stub-embedding").model_name == "stub-model-embedding" - assert len(stub_model_registry.models) == 3 + assert len(stub_model_registry.models) == 4 assert all(isinstance(model, ModelFacade) for model in stub_model_registry.models.values()) @@ -126,19 +126,19 @@ def test_get_model_usage_stats( reasoning_model = stub_model_registry.get_model(model_alias="stub-reasoning") text_model.usage_stats.extend( - token_usage=TokenUsageStats(prompt_tokens=10, completion_tokens=100), + token_usage=TokenUsageStats(input_tokens=10, output_tokens=100), request_usage=RequestUsageStats(successful_requests=10, failed_requests=0), ) reasoning_model.usage_stats.extend( - token_usage=TokenUsageStats(prompt_tokens=5, completion_tokens=200), + token_usage=TokenUsageStats(input_tokens=5, output_tokens=200), request_usage=RequestUsageStats(successful_requests=100, failed_requests=10), ) usage_stats = stub_model_registry.get_model_usage_stats(total_time_elapsed=10) assert set(usage_stats.keys()) == set(expected_keys) if "stub-model-text" in usage_stats: - assert usage_stats["stub-model-text"]["token_usage"]["prompt_tokens"] == 10 - assert usage_stats["stub-model-text"]["token_usage"]["completion_tokens"] == 100 + assert usage_stats["stub-model-text"]["token_usage"]["input_tokens"] == 10 + assert usage_stats["stub-model-text"]["token_usage"]["output_tokens"] == 100 assert usage_stats["stub-model-text"]["token_usage"]["total_tokens"] == 110 assert usage_stats["stub-model-text"]["request_usage"]["successful_requests"] == 10 assert usage_stats["stub-model-text"]["request_usage"]["failed_requests"] == 0 @@ -150,42 +150,52 @@ def test_get_model_usage_stats( # Trigger lazy initialization text_model = stub_model_registry.get_model(model_alias="stub-text") text_model.usage_stats.extend( - token_usage=TokenUsageStats(prompt_tokens=10, completion_tokens=100), + token_usage=TokenUsageStats(input_tokens=10, output_tokens=100), request_usage=RequestUsageStats(successful_requests=10, failed_requests=0), ) usage_stats = stub_model_registry.get_model_usage_stats(total_time_elapsed=10) assert set(usage_stats.keys()) == set(expected_keys) -@pytest.mark.parametrize( - "test_case,mock_side_effect,expected_exception,expected_call_count", - [ - ("success", None, None, 2), - ( - "authentication_error", - AuthenticationError("Invalid API key", llm_provider="openai", model="stub-model-text"), - ModelAuthenticationError, - 1, - ), - ], -) +@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True) +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_run_health_check_success(mock_completion, mock_generate_text_embeddings, stub_model_registry): + model_aliases = {"stub-text", "stub-reasoning", "stub-embedding"} + stub_model_registry.run_health_check(model_aliases) + assert mock_completion.call_count == 2 + assert mock_generate_text_embeddings.call_count == 1 + + +@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True) @patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_run_health_check( - mock_completion, stub_model_registry, test_case, mock_side_effect, expected_exception, expected_call_count +def test_run_health_check_completion_authentication_error( + mock_completion, mock_generate_text_embeddings, stub_model_registry ): - if mock_side_effect: - mock_completion.side_effect = mock_side_effect + auth_error = ModelAuthenticationError("Invalid API key for completion model") + mock_completion.side_effect = auth_error + model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"] - # Pass model aliases for health check - model_aliases = {"stub-text", "stub-reasoning"} + with pytest.raises(ModelAuthenticationError): + stub_model_registry.run_health_check(model_aliases) - if expected_exception: - with pytest.raises(expected_exception): - stub_model_registry.run_health_check(model_aliases) - else: + mock_completion.assert_called_once() + mock_generate_text_embeddings.assert_not_called() + + +@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True) +@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) +def test_run_health_check_embedding_authentication_error( + mock_completion, mock_generate_text_embeddings, stub_model_registry +): + auth_error = ModelAuthenticationError("Invalid API key for embedding model") + mock_generate_text_embeddings.side_effect = auth_error + model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"] + + with pytest.raises(ModelAuthenticationError): stub_model_registry.run_health_check(model_aliases) - assert mock_completion.call_count == expected_call_count + mock_completion.call_count == 2 + mock_generate_text_embeddings.assert_called_once() @pytest.mark.parametrize( diff --git a/tests/engine/models/test_usage.py b/tests/engine/models/test_usage.py index 3e92a433..3ddb09f3 100644 --- a/tests/engine/models/test_usage.py +++ b/tests/engine/models/test_usage.py @@ -6,14 +6,14 @@ def test_token_usage_stats(): token_usage_stats = TokenUsageStats() - assert token_usage_stats.prompt_tokens == 0 - assert token_usage_stats.completion_tokens == 0 + assert token_usage_stats.input_tokens == 0 + assert token_usage_stats.output_tokens == 0 assert token_usage_stats.total_tokens == 0 assert token_usage_stats.has_usage is False - token_usage_stats.extend(prompt_tokens=10, completion_tokens=20) - assert token_usage_stats.prompt_tokens == 10 - assert token_usage_stats.completion_tokens == 20 + token_usage_stats.extend(input_tokens=10, output_tokens=20) + assert token_usage_stats.input_tokens == 10 + assert token_usage_stats.output_tokens == 20 assert token_usage_stats.total_tokens == 30 assert token_usage_stats.has_usage is True @@ -34,31 +34,31 @@ def test_request_usage_stats(): def test_model_usage_stats(): model_usage_stats = ModelUsageStats() - assert model_usage_stats.token_usage.prompt_tokens == 0 - assert model_usage_stats.token_usage.completion_tokens == 0 + assert model_usage_stats.token_usage.input_tokens == 0 + assert model_usage_stats.token_usage.output_tokens == 0 assert model_usage_stats.request_usage.successful_requests == 0 assert model_usage_stats.request_usage.failed_requests == 0 assert model_usage_stats.has_usage is False assert model_usage_stats.get_usage_stats(total_time_elapsed=10) == { - "token_usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, "request_usage": {"successful_requests": 0, "failed_requests": 0, "total_requests": 0}, "tokens_per_second": 0, "requests_per_minute": 0, } model_usage_stats.extend( - token_usage=TokenUsageStats(prompt_tokens=10, completion_tokens=20), + token_usage=TokenUsageStats(input_tokens=10, output_tokens=20), request_usage=RequestUsageStats(successful_requests=2, failed_requests=1), ) - assert model_usage_stats.token_usage.prompt_tokens == 10 - assert model_usage_stats.token_usage.completion_tokens == 20 + assert model_usage_stats.token_usage.input_tokens == 10 + assert model_usage_stats.token_usage.output_tokens == 20 assert model_usage_stats.request_usage.successful_requests == 2 assert model_usage_stats.request_usage.failed_requests == 1 assert model_usage_stats.has_usage is True assert model_usage_stats.get_usage_stats(total_time_elapsed=2) == { - "token_usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + "token_usage": {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}, "request_usage": {"successful_requests": 2, "failed_requests": 1, "total_requests": 3}, "tokens_per_second": 15, "requests_per_minute": 90, diff --git a/tests/engine/processing/test_utils.py b/tests/engine/processing/test_utils.py index a41e0ec2..dec0fe6a 100644 --- a/tests/engine/processing/test_utils.py +++ b/tests/engine/processing/test_utils.py @@ -9,6 +9,7 @@ from data_designer.engine.processing.utils import ( concat_datasets, deserialize_json_values, + parse_list_string, ) @@ -116,3 +117,19 @@ def test_concat_datasets_logging(mock_logger, stub_sample_dataframes): def test_deserialize_json_values_scenarios(test_case, input_data, expected_result): result = deserialize_json_values(input_data) assert result == expected_result + + +@pytest.mark.parametrize( + "input_string,expected_result", + [ + ('["a", "b", "c"]', ["a", "b", "c"]), # valid stringified json array + ('[" a ", " b", "c "]', ["a", "b", "c"]), # valid stringified json array with whitespace + ('["a", "b", "c",]', ["a", "b", "c"]), # valid stringified json array with trailing comma + ("['a', 'b', 'c']", ["a", "b", "c"]), # valid python-style list with single quotes + ("['a', 'b', 'c', ]", ["a", "b", "c"]), # valid python-style list with trailing comma + ("simple string ", ["simple string"]), # simple string with whitespace + ], +) +def test_parse_list_string_scenarios(input_string, expected_result): + result = parse_list_string(input_string) + assert result == expected_result diff --git a/tests/essentials/test_init.py b/tests/essentials/test_init.py index 35cad0f3..e7b6288a 100644 --- a/tests/essentials/test_init.py +++ b/tests/essentials/test_init.py @@ -14,6 +14,7 @@ BernoulliSamplerParams, BinomialSamplerParams, CategorySamplerParams, + ChatCompletionInferenceParams, CodeLang, CodeValidatorParams, ColumnInequalityConstraint, @@ -23,8 +24,10 @@ DatastoreSeedDatasetReference, DatastoreSettings, DatetimeSamplerParams, + EmbeddingInferenceParams, ExpressionColumnConfig, GaussianSamplerParams, + GenerationType, ImageContext, ImageFormat, InferenceParameters, @@ -109,6 +112,9 @@ def test_model_config_imports(): assert ImageContext is not None assert ImageFormat is not None assert InferenceParameters is not None + assert ChatCompletionInferenceParams is not None + assert EmbeddingInferenceParams is not None + assert GenerationType is not None assert ManualDistribution is not None assert ManualDistributionParams is not None assert Modality is not None @@ -232,6 +238,7 @@ def test_all_contains_column_configs(): assert "Score" in __all__ assert "SeedDatasetColumnConfig" in __all__ assert "ValidationColumnConfig" in __all__ + assert "EmbeddingColumnConfig" in __all__ def test_all_contains_sampler_params(): @@ -250,6 +257,8 @@ def test_all_contains_sampler_params(): assert "TimeDeltaSamplerParams" in __all__ assert "UniformSamplerParams" in __all__ assert "UUIDSamplerParams" in __all__ + assert "PersonFromFakerSamplerParams" in __all__ + assert "ProcessorType" in __all__ def test_all_contains_constraints(): @@ -263,6 +272,9 @@ def test_all_contains_model_configs(): assert "ImageContext" in __all__ assert "ImageFormat" in __all__ assert "InferenceParameters" in __all__ + assert "ChatCompletionInferenceParams" in __all__ + assert "EmbeddingInferenceParams" in __all__ + assert "GenerationType" in __all__ assert "ManualDistribution" in __all__ assert "ManualDistributionParams" in __all__ assert "Modality" in __all__