diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb
index eb9db753..091200d7 100644
--- a/docs/colab_notebooks/1-the-basics.ipynb
+++ b/docs/colab_notebooks/1-the-basics.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "39d7d274",
+ "id": "3599c474",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: The Basics\n",
@@ -14,7 +14,7 @@
},
{
"cell_type": "markdown",
- "id": "60f1d002",
+ "id": "ee8bed13",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -25,7 +25,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "99c42292",
+ "id": "f43069d1",
"metadata": {},
"outputs": [],
"source": [
@@ -36,7 +36,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2c959ca9",
+ "id": "c136bf4f",
"metadata": {},
"outputs": [],
"source": [
@@ -53,7 +53,7 @@
},
{
"cell_type": "markdown",
- "id": "bc185897",
+ "id": "48739393",
"metadata": {},
"source": [
"### π¦ Import the essentials\n",
@@ -64,15 +64,15 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "dc3a2d9d",
+ "id": "e459cd98",
"metadata": {},
"outputs": [],
"source": [
"from data_designer.essentials import (\n",
" CategorySamplerParams,\n",
+ " ChatCompletionInferenceParams,\n",
" DataDesigner,\n",
" DataDesignerConfigBuilder,\n",
- " InferenceParameters,\n",
" LLMTextColumnConfig,\n",
" ModelConfig,\n",
" PersonFromFakerSamplerParams,\n",
@@ -85,7 +85,7 @@
},
{
"cell_type": "markdown",
- "id": "36c5f571",
+ "id": "b705d204",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -98,7 +98,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "61b23c70",
+ "id": "aee62c85",
"metadata": {},
"outputs": [],
"source": [
@@ -107,7 +107,7 @@
},
{
"cell_type": "markdown",
- "id": "3c9b7cb6",
+ "id": "ae65c557",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -124,7 +124,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b86f6217",
+ "id": "1079200d",
"metadata": {},
"outputs": [],
"source": [
@@ -145,7 +145,7 @@
" alias=MODEL_ALIAS,\n",
" model=MODEL_ID,\n",
" provider=MODEL_PROVIDER,\n",
- " inference_parameters=InferenceParameters(\n",
+ " inference_parameters=ChatCompletionInferenceParams(\n",
" temperature=0.5,\n",
" top_p=1.0,\n",
" max_tokens=1024,\n",
@@ -156,7 +156,7 @@
},
{
"cell_type": "markdown",
- "id": "1f089871",
+ "id": "9f15426a",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -171,7 +171,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3d666193",
+ "id": "79b8212c",
"metadata": {},
"outputs": [],
"source": [
@@ -180,7 +180,7 @@
},
{
"cell_type": "markdown",
- "id": "e88c8881",
+ "id": "cd1d9e09",
"metadata": {},
"source": [
"## π² Getting started with sampler columns\n",
@@ -197,7 +197,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "79fb85c6",
+ "id": "b3f469d6",
"metadata": {},
"outputs": [],
"source": [
@@ -206,7 +206,7 @@
},
{
"cell_type": "markdown",
- "id": "5106cc10",
+ "id": "e44adc6c",
"metadata": {},
"source": [
"Let's start designing our product review dataset by adding product category and subcategory columns.\n"
@@ -215,7 +215,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "22b97af1",
+ "id": "82b32804",
"metadata": {},
"outputs": [],
"source": [
@@ -296,7 +296,7 @@
},
{
"cell_type": "markdown",
- "id": "4857b085",
+ "id": "bd65456c",
"metadata": {},
"source": [
"Next, let's add samplers to generate data related to the customer and their review.\n"
@@ -305,7 +305,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9e90b3cb",
+ "id": "6d6d4eef",
"metadata": {},
"outputs": [],
"source": [
@@ -342,7 +342,7 @@
},
{
"cell_type": "markdown",
- "id": "b36a153b",
+ "id": "eb7b415c",
"metadata": {},
"source": [
"## π¦ LLM-generated columns\n",
@@ -357,7 +357,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4da88fe6",
+ "id": "ed811560",
"metadata": {},
"outputs": [],
"source": [
@@ -394,7 +394,7 @@
},
{
"cell_type": "markdown",
- "id": "5f1b9ac8",
+ "id": "fdc0a2c8",
"metadata": {},
"source": [
"### π Iteration is key βΒ preview the dataset!\n",
@@ -411,7 +411,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "543e2f9c",
+ "id": "59987c81",
"metadata": {},
"outputs": [],
"source": [
@@ -421,7 +421,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "26136a8a",
+ "id": "0823ca7f",
"metadata": {},
"outputs": [],
"source": [
@@ -432,7 +432,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "aca4360d",
+ "id": "eca4f0bc",
"metadata": {},
"outputs": [],
"source": [
@@ -442,7 +442,7 @@
},
{
"cell_type": "markdown",
- "id": "35ca0470",
+ "id": "edd57f85",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -455,7 +455,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d55b402d",
+ "id": "5c681eee",
"metadata": {},
"outputs": [],
"source": [
@@ -465,7 +465,7 @@
},
{
"cell_type": "markdown",
- "id": "245b48cf",
+ "id": "14bf06f2",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -478,7 +478,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "fc803eb0",
+ "id": "b7ffead1",
"metadata": {},
"outputs": [],
"source": [
@@ -488,7 +488,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "881c2043",
+ "id": "aa966388",
"metadata": {},
"outputs": [],
"source": [
@@ -501,7 +501,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d79860d4",
+ "id": "98e1085c",
"metadata": {},
"outputs": [],
"source": [
@@ -513,7 +513,7 @@
},
{
"cell_type": "markdown",
- "id": "b4b45176",
+ "id": "e0b9c65a",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb
index 0cb65da6..fd9a2d69 100644
--- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb
+++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "33b48b4e",
+ "id": "928a8d4a",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: Structured Outputs and Jinja Expressions\n",
@@ -16,7 +16,7 @@
},
{
"cell_type": "markdown",
- "id": "c29f9af1",
+ "id": "ad16de35",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -27,7 +27,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3a5601fb",
+ "id": "cfb08d4c",
"metadata": {},
"outputs": [],
"source": [
@@ -38,7 +38,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "de2f0af4",
+ "id": "eddeceea",
"metadata": {},
"outputs": [],
"source": [
@@ -55,7 +55,7 @@
},
{
"cell_type": "markdown",
- "id": "400795be",
+ "id": "08e2b3bf",
"metadata": {},
"source": [
"### π¦ Import the essentials\n",
@@ -66,16 +66,16 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "378f6853",
+ "id": "f057319b",
"metadata": {},
"outputs": [],
"source": [
"from data_designer.essentials import (\n",
" CategorySamplerParams,\n",
+ " ChatCompletionInferenceParams,\n",
" DataDesigner,\n",
" DataDesignerConfigBuilder,\n",
" ExpressionColumnConfig,\n",
- " InferenceParameters,\n",
" LLMStructuredColumnConfig,\n",
" ModelConfig,\n",
" PersonFromFakerSamplerParams,\n",
@@ -87,7 +87,7 @@
},
{
"cell_type": "markdown",
- "id": "15a1ac9f",
+ "id": "e7d5e529",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -100,7 +100,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9d7654e5",
+ "id": "e30fdeb1",
"metadata": {},
"outputs": [],
"source": [
@@ -109,7 +109,7 @@
},
{
"cell_type": "markdown",
- "id": "27ba0edb",
+ "id": "07b8bfe7",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -126,7 +126,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c24ee00e",
+ "id": "4ae518af",
"metadata": {},
"outputs": [],
"source": [
@@ -147,7 +147,7 @@
" alias=MODEL_ALIAS,\n",
" model=MODEL_ID,\n",
" provider=MODEL_PROVIDER,\n",
- " inference_parameters=InferenceParameters(\n",
+ " inference_parameters=ChatCompletionInferenceParams(\n",
" temperature=0.5,\n",
" top_p=1.0,\n",
" max_tokens=1024,\n",
@@ -158,7 +158,7 @@
},
{
"cell_type": "markdown",
- "id": "a106edc9",
+ "id": "a3fa2eaf",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -173,7 +173,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3a167f7c",
+ "id": "e1b59bd4",
"metadata": {},
"outputs": [],
"source": [
@@ -182,7 +182,7 @@
},
{
"cell_type": "markdown",
- "id": "fcf68c72",
+ "id": "5e220f78",
"metadata": {},
"source": [
"### π§βπ¨ Designing our data\n",
@@ -209,7 +209,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8f8f034e",
+ "id": "e814ca47",
"metadata": {},
"outputs": [],
"source": [
@@ -237,7 +237,7 @@
},
{
"cell_type": "markdown",
- "id": "9d5c722a",
+ "id": "303ebd6f",
"metadata": {},
"source": [
"Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n"
@@ -246,7 +246,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "013caa3d",
+ "id": "361afe50",
"metadata": {},
"outputs": [],
"source": [
@@ -355,7 +355,7 @@
},
{
"cell_type": "markdown",
- "id": "ef426a65",
+ "id": "b18e511d",
"metadata": {},
"source": [
"Next, we will use more advanced Jinja expressions to create new columns.\n",
@@ -372,7 +372,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "27abbd6d",
+ "id": "e669c009",
"metadata": {},
"outputs": [],
"source": [
@@ -426,7 +426,7 @@
},
{
"cell_type": "markdown",
- "id": "18a8461e",
+ "id": "9ca31fb8",
"metadata": {},
"source": [
"### π Iteration is key βΒ preview the dataset!\n",
@@ -443,7 +443,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2f75eee6",
+ "id": "402dede2",
"metadata": {},
"outputs": [],
"source": [
@@ -453,7 +453,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "950c9596",
+ "id": "c028223a",
"metadata": {},
"outputs": [],
"source": [
@@ -464,7 +464,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5c04ca5a",
+ "id": "32d6fdaa",
"metadata": {},
"outputs": [],
"source": [
@@ -474,7 +474,7 @@
},
{
"cell_type": "markdown",
- "id": "b0704f47",
+ "id": "3b76da46",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -487,7 +487,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "54a609ce",
+ "id": "1a9ac292",
"metadata": {},
"outputs": [],
"source": [
@@ -497,7 +497,7 @@
},
{
"cell_type": "markdown",
- "id": "6ae0a8a5",
+ "id": "8e4b20ed",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -510,7 +510,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e22c387a",
+ "id": "eed6e04b",
"metadata": {},
"outputs": [],
"source": [
@@ -520,7 +520,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "9c6b36b3",
+ "id": "b6d1d270",
"metadata": {},
"outputs": [],
"source": [
@@ -533,7 +533,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "138ed487",
+ "id": "058d6b65",
"metadata": {},
"outputs": [],
"source": [
@@ -545,7 +545,7 @@
},
{
"cell_type": "markdown",
- "id": "fde73253",
+ "id": "d3affdac",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb
index e623ac8b..af97af89 100644
--- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb
+++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "5a6b2b3f",
+ "id": "ce777a5d",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n",
@@ -16,7 +16,7 @@
},
{
"cell_type": "markdown",
- "id": "137d8273",
+ "id": "442b70e6",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -27,7 +27,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f6ce3dc2",
+ "id": "9ee97c1e",
"metadata": {},
"outputs": [],
"source": [
@@ -38,7 +38,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "70d6ffc8",
+ "id": "18c06754",
"metadata": {},
"outputs": [],
"source": [
@@ -55,7 +55,7 @@
},
{
"cell_type": "markdown",
- "id": "ce0313c9",
+ "id": "9394bcd4",
"metadata": {},
"source": [
"### π¦ Import the essentials\n",
@@ -66,14 +66,14 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2aa2cdf1",
+ "id": "6f1099d1",
"metadata": {},
"outputs": [],
"source": [
"from data_designer.essentials import (\n",
+ " ChatCompletionInferenceParams,\n",
" DataDesigner,\n",
" DataDesignerConfigBuilder,\n",
- " InferenceParameters,\n",
" ModelConfig,\n",
" SeedConfig,\n",
")"
@@ -81,7 +81,7 @@
},
{
"cell_type": "markdown",
- "id": "9769f392",
+ "id": "bc74d436",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -94,7 +94,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d79db916",
+ "id": "0f9c59fa",
"metadata": {},
"outputs": [],
"source": [
@@ -103,7 +103,7 @@
},
{
"cell_type": "markdown",
- "id": "08dd3894",
+ "id": "c92d4c3c",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -120,7 +120,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3994368e",
+ "id": "543805b1",
"metadata": {},
"outputs": [],
"source": [
@@ -141,7 +141,7 @@
" alias=MODEL_ALIAS,\n",
" model=MODEL_ID,\n",
" provider=MODEL_PROVIDER,\n",
- " inference_parameters=InferenceParameters(\n",
+ " inference_parameters=ChatCompletionInferenceParams(\n",
" temperature=0.5,\n",
" top_p=1.0,\n",
" max_tokens=1024,\n",
@@ -152,7 +152,7 @@
},
{
"cell_type": "markdown",
- "id": "5f12d6d2",
+ "id": "29f69761",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -167,7 +167,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0f3d7640",
+ "id": "7b18a399",
"metadata": {},
"outputs": [],
"source": [
@@ -176,7 +176,7 @@
},
{
"cell_type": "markdown",
- "id": "1f08df99",
+ "id": "5439e926",
"metadata": {},
"source": [
"## π₯ Prepare a seed dataset\n",
@@ -201,7 +201,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f265e74c",
+ "id": "d4cbb09e",
"metadata": {},
"outputs": [],
"source": [
@@ -219,7 +219,7 @@
},
{
"cell_type": "markdown",
- "id": "6bffa239",
+ "id": "367fd06d",
"metadata": {},
"source": [
"## π¨ Designing our synthetic patient notes dataset\n",
@@ -236,7 +236,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "15e486a6",
+ "id": "1cbcf034",
"metadata": {},
"outputs": [],
"source": [
@@ -326,7 +326,7 @@
},
{
"cell_type": "markdown",
- "id": "5cfe2edd",
+ "id": "ef3d8dcf",
"metadata": {},
"source": [
"### π Iteration is key βΒ preview the dataset!\n",
@@ -343,7 +343,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d4a59576",
+ "id": "6aa843f8",
"metadata": {},
"outputs": [],
"source": [
@@ -353,7 +353,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1c5aedd4",
+ "id": "2134a59d",
"metadata": {},
"outputs": [],
"source": [
@@ -364,7 +364,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d17df0a5",
+ "id": "3c3c39bf",
"metadata": {},
"outputs": [],
"source": [
@@ -374,7 +374,7 @@
},
{
"cell_type": "markdown",
- "id": "3389a088",
+ "id": "bf4b07b5",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -387,7 +387,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "b0443498",
+ "id": "e7813016",
"metadata": {},
"outputs": [],
"source": [
@@ -397,7 +397,7 @@
},
{
"cell_type": "markdown",
- "id": "0527a606",
+ "id": "db2d39b6",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -410,7 +410,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2118b49e",
+ "id": "2fcc5cb3",
"metadata": {},
"outputs": [],
"source": [
@@ -420,7 +420,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "c4f9ad59",
+ "id": "036581fa",
"metadata": {},
"outputs": [],
"source": [
@@ -433,7 +433,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8517866d",
+ "id": "74e91fe3",
"metadata": {},
"outputs": [],
"source": [
@@ -445,7 +445,7 @@
},
{
"cell_type": "markdown",
- "id": "b62dd069",
+ "id": "6f65cadc",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb
index e48f2bde..ddc6488a 100644
--- a/docs/colab_notebooks/4-providing-images-as-context.ipynb
+++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
- "id": "fc93d603",
+ "id": "560882a6",
"metadata": {},
"source": [
"# π¨ Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation"
@@ -10,7 +10,7 @@
},
{
"cell_type": "markdown",
- "id": "31146c45",
+ "id": "b443ea18",
"metadata": {},
"source": [
"#### π What you'll learn\n",
@@ -25,7 +25,7 @@
},
{
"cell_type": "markdown",
- "id": "b237af81",
+ "id": "d59495b4",
"metadata": {},
"source": [
"### β‘ Colab Setup\n",
@@ -36,7 +36,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "00b316a6",
+ "id": "e80db6fd",
"metadata": {},
"outputs": [],
"source": [
@@ -47,7 +47,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "847a6f88",
+ "id": "15ac4714",
"metadata": {},
"outputs": [],
"source": [
@@ -64,7 +64,7 @@
},
{
"cell_type": "markdown",
- "id": "5a0e2a31",
+ "id": "01741037",
"metadata": {},
"source": [
"### π¦ Import the essentials\n",
@@ -75,7 +75,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7ec632f1",
+ "id": "097a9059",
"metadata": {},
"outputs": [],
"source": [
@@ -93,11 +93,11 @@
"\n",
"# Data Designer imports\n",
"from data_designer.essentials import (\n",
+ " ChatCompletionInferenceParams,\n",
" DataDesigner,\n",
" DataDesignerConfigBuilder,\n",
" ImageContext,\n",
" ImageFormat,\n",
- " InferenceParameters,\n",
" LLMTextColumnConfig,\n",
" ModalityDataType,\n",
" ModelConfig,\n",
@@ -106,7 +106,7 @@
},
{
"cell_type": "markdown",
- "id": "66efe0cc",
+ "id": "4f6ca947",
"metadata": {},
"source": [
"### βοΈ Initialize the Data Designer interface\n",
@@ -119,7 +119,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e9059d31",
+ "id": "6508f2b7",
"metadata": {},
"outputs": [],
"source": [
@@ -128,7 +128,7 @@
},
{
"cell_type": "markdown",
- "id": "26d60e67",
+ "id": "7343463e",
"metadata": {},
"source": [
"### ποΈ Define model configurations\n",
@@ -145,7 +145,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "e5b29b57",
+ "id": "31019979",
"metadata": {},
"outputs": [],
"source": [
@@ -157,7 +157,7 @@
" alias=\"vision\",\n",
" model=\"meta/llama-4-scout-17b-16e-instruct\",\n",
" provider=MODEL_PROVIDER,\n",
- " inference_parameters=InferenceParameters(\n",
+ " inference_parameters=ChatCompletionInferenceParams(\n",
" temperature=0.60,\n",
" top_p=0.95,\n",
" max_tokens=2048,\n",
@@ -168,7 +168,7 @@
},
{
"cell_type": "markdown",
- "id": "c7c68fce",
+ "id": "186f9f98",
"metadata": {},
"source": [
"### ποΈ Initialize the Data Designer Config Builder\n",
@@ -183,7 +183,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "2ab84fb1",
+ "id": "80fbc002",
"metadata": {},
"outputs": [],
"source": [
@@ -192,7 +192,7 @@
},
{
"cell_type": "markdown",
- "id": "bdc1fa29",
+ "id": "203663e2",
"metadata": {},
"source": [
"### π± Seed Dataset Creation\n",
@@ -209,7 +209,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "3baa7ba2",
+ "id": "1a2b0ce1",
"metadata": {},
"outputs": [],
"source": [
@@ -224,7 +224,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f3780aee",
+ "id": "8cc0f6ea",
"metadata": {},
"outputs": [],
"source": [
@@ -272,7 +272,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5c46877e",
+ "id": "9fe4b02c",
"metadata": {},
"outputs": [],
"source": [
@@ -290,7 +290,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4ca338da",
+ "id": "e546cc4a",
"metadata": {},
"outputs": [],
"source": [
@@ -300,7 +300,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "0cb41a37",
+ "id": "eebc9963",
"metadata": {},
"outputs": [],
"source": [
@@ -314,7 +314,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "607bb265",
+ "id": "99b4a3b5",
"metadata": {
"lines_to_next_cell": 2
},
@@ -343,7 +343,7 @@
},
{
"cell_type": "markdown",
- "id": "984241b9",
+ "id": "58ab7a96",
"metadata": {
"lines_to_next_cell": 2
},
@@ -351,7 +351,7 @@
},
{
"cell_type": "markdown",
- "id": "eca1a5ea",
+ "id": "7d735a85",
"metadata": {},
"source": [
"### π Iteration is key β preview the dataset!\n",
@@ -368,7 +368,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "4d386a38",
+ "id": "6404dc34",
"metadata": {},
"outputs": [],
"source": [
@@ -378,7 +378,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "6209d609",
+ "id": "44f3678c",
"metadata": {},
"outputs": [],
"source": [
@@ -389,7 +389,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "5aedc7fd",
+ "id": "06009be0",
"metadata": {},
"outputs": [],
"source": [
@@ -399,7 +399,7 @@
},
{
"cell_type": "markdown",
- "id": "fc339219",
+ "id": "43e9af07",
"metadata": {},
"source": [
"### π Analyze the generated data\n",
@@ -412,7 +412,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "87ccc372",
+ "id": "7bce2b2e",
"metadata": {},
"outputs": [],
"source": [
@@ -422,7 +422,7 @@
},
{
"cell_type": "markdown",
- "id": "c090f413",
+ "id": "dc361acc",
"metadata": {},
"source": [
"### π Visual Inspection\n",
@@ -433,7 +433,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "1e1a054c",
+ "id": "574bbc62",
"metadata": {
"lines_to_next_cell": 2
},
@@ -457,7 +457,7 @@
},
{
"cell_type": "markdown",
- "id": "cab83636",
+ "id": "cc5ebe1c",
"metadata": {},
"source": [
"### π Scale up!\n",
@@ -470,7 +470,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "7fa66a2e",
+ "id": "305f90ec",
"metadata": {},
"outputs": [],
"source": [
@@ -480,7 +480,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "8f92b5be",
+ "id": "ef40fc6b",
"metadata": {},
"outputs": [],
"source": [
@@ -493,7 +493,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "d1bddcee",
+ "id": "72756cf3",
"metadata": {},
"outputs": [],
"source": [
@@ -505,7 +505,7 @@
},
{
"cell_type": "markdown",
- "id": "46f68f95",
+ "id": "e9bac314",
"metadata": {},
"source": [
"## βοΈ Next Steps\n",
diff --git a/docs/concepts/columns.md b/docs/concepts/columns.md
index cfe0aab9..17123299 100644
--- a/docs/concepts/columns.md
+++ b/docs/concepts/columns.md
@@ -64,6 +64,22 @@ Define scoring rubrics (relevance, accuracy, fluency, helpfulness) and the judge
Use judge columns for data quality filtering (e.g., keep only 4+ rated responses), A/B testing generation strategies, and quality monitoring over time.
+### 𧬠Embedding Columns
+
+Embedding columns generate vector embeddings (numerical representations) for text content using embedding models. These embeddings capture semantic meaning, enabling similarity search, clustering, and semantic analysis.
+
+Specify a `target_column` containing text, and Data Designer generates embeddings for that content. The target column can contain either a single text string or a list of text strings in stringified JSON format. In the latter case, embeddings are generated for each text string in the list.
+
+Common use cases:
+
+- **Semantic search**: Generate embeddings for documents, then find similar content by vector similarity
+- **Clustering**: Group similar texts based on embedding proximity
+- **Recommendation systems**: Match content by semantic similarity
+- **Anomaly detection**: Identify outliers in embedding space
+
+!!! note "Embedding Models"
+ Embedding columns require an embedding model configured with `EmbeddingInferenceParams`. These models differ from chat completion modelsβthey output vectors rather than text. The generation type is automatically determined by the inference parameters type.
+
### π§© Expression Columns
Expression columns handle simple transformations using **Jinja2 templates**βconcatenate first and last names, calculate numerical totals, format date strings. No LLM overhead needed.
diff --git a/docs/concepts/models/configure-model-settings-with-the-cli.md b/docs/concepts/models/configure-model-settings-with-the-cli.md
index fa29732e..56315cdf 100644
--- a/docs/concepts/models/configure-model-settings-with-the-cli.md
+++ b/docs/concepts/models/configure-model-settings-with-the-cli.md
@@ -130,6 +130,6 @@ The CLI will show which configuration files exist and ask for confirmation befor
## See Also
- **[Model Providers](model-providers.md)**: Learn about the `ModelProvider` class and provider configuration
-- **[Model Configurations](model-configs.md)**: Learn about `ModelConfig` and `InferenceParameters`
+- **[Model Configurations](model-configs.md)**: Learn about `ModelConfig`
- **[Default Model Settings](default-model-settings.md)**: Pre-configured providers and model settings included with Data Designer
- **[Quick Start Guide](../../quick-start.md)**: Get started with a simple example
diff --git a/docs/concepts/models/inference-parameters.md b/docs/concepts/models/inference-parameters.md
new file mode 100644
index 00000000..6d7d079f
--- /dev/null
+++ b/docs/concepts/models/inference-parameters.md
@@ -0,0 +1,147 @@
+# Inference Parameters
+
+Inference parameters control how models generate responses during synthetic data generation. Data Designer provides two types of inference parameters: `ChatCompletionInferenceParams` for text/code/structured generation and `EmbeddingInferenceParams` for embedding generation.
+
+## Overview
+
+When you create a `ModelConfig`, you can specify inference parameters to adjust model behavior. These parameters control aspects like randomness (temperature), diversity (top_p), context size (max_tokens), and more. Data Designer supports both static values and dynamic distribution-based sampling for certain parameters.
+
+## Chat Completion Inference Parameters
+
+The `ChatCompletionInferenceParams` class controls how models generate text completions (for text, code, and structured data generation). It provides fine-grained control over generation behavior and supports both static values and dynamic distribution-based sampling.
+
+!!! warning "InferenceParameters is Deprecated"
+ The `InferenceParameters` class is deprecated and will be removed in a future version. Use `ChatCompletionInferenceParams` instead. The old `InferenceParameters` class now shows a deprecation warning when used.
+
+### Fields
+
+| Field | Type | Required | Description |
+|-------|------|----------|-------------|
+| `temperature` | `float` or `Distribution` | No | Controls randomness in generation (0.0 to 2.0). Higher values = more creative/random |
+| `top_p` | `float` or `Distribution` | No | Nucleus sampling parameter (0.0 to 1.0). Controls diversity by filtering low-probability tokens |
+| `max_tokens` | `int` | No | Maximum number of tokens for the request, including both input and output tokens (β₯ 1) |
+| `max_parallel_requests` | `int` | No | Maximum concurrent API requests (default: 4, β₯ 1) |
+| `timeout` | `int` | No | API request timeout in seconds (β₯ 1) |
+| `extra_body` | `dict[str, Any]` | No | Additional parameters to include in the API request body |
+
+!!! note "Default Values"
+ If `temperature`, `top_p`, or `max_tokens` are not provided, the model provider's default values will be used. Different providers and models may have different defaults.
+
+!!! tip "Controlling Reasoning Effort for GPT-OSS Models"
+ For gpt-oss models like `gpt-oss-20b` and `gpt-oss-120b`, you can control the reasoning effort using the `extra_body` parameter:
+
+ ```python
+ from data_designer.essentials import ChatCompletionInferenceParams
+
+ # High reasoning effort (more thorough, slower)
+ inference_parameters = ChatCompletionInferenceParams(
+ extra_body={"reasoning_effort": "high"}
+ )
+
+ # Medium reasoning effort (balanced)
+ inference_parameters = ChatCompletionInferenceParams(
+ extra_body={"reasoning_effort": "medium"}
+ )
+
+ # Low reasoning effort (faster, less thorough)
+ inference_parameters = ChatCompletionInferenceParams(
+ extra_body={"reasoning_effort": "low"}
+ )
+ ```
+
+### Temperature and Top P Guidelines
+
+- **Temperature**:
+ - `0.0-0.3`: Highly deterministic, focused outputs (ideal for structured/reasoning tasks)
+ - `0.4-0.7`: Balanced creativity and coherence (general purpose)
+ - `0.8-1.0`: Creative, diverse outputs (ideal for creative writing)
+ - `1.0+`: Highly random and experimental
+
+- **Top P**:
+ - `0.1-0.5`: Very focused, only most likely tokens
+ - `0.6-0.9`: Balanced diversity
+ - `0.95-1.0`: Maximum diversity, including less likely tokens
+
+!!! tip "Adjusting Temperature and Top P Together"
+ When tuning both parameters simultaneously, consider these combinations:
+
+ - **For deterministic/structured outputs**: Low temperature (`0.0-0.3`) + moderate-to-high top_p (`0.8-0.95`)
+ - The low temperature ensures focus, while top_p allows some token diversity
+ - **For balanced generation**: Moderate temperature (`0.5-0.7`) + high top_p (`0.9-0.95`)
+ - This is a good starting point for most use cases
+ - **For creative outputs**: Higher temperature (`0.8-1.0`) + high top_p (`0.95-1.0`)
+ - Both parameters work together to maximize diversity
+
+ **Avoid**: Setting both very low (overly restrictive) or adjusting both dramatically at once. When experimenting, adjust one parameter at a time to understand its individual effect.
+
+## Distribution-Based Inference Parameters
+
+For `temperature` and `top_p` in `ChatCompletionInferenceParams`, you can specify distributions instead of fixed values. This allows Data Designer to sample different values for each generation request, introducing controlled variability into your synthetic data.
+
+### Uniform Distribution
+
+Samples values uniformly between a low and high bound:
+
+```python
+from data_designer.essentials import (
+ ChatCompletionInferenceParams,
+ UniformDistribution,
+ UniformDistributionParams,
+)
+
+inference_params = ChatCompletionInferenceParams(
+ temperature=UniformDistribution(
+ params=UniformDistributionParams(low=0.7, high=1.0)
+ ),
+)
+```
+
+### Manual Distribution
+
+Samples from a discrete set of values with optional weights:
+
+```python
+from data_designer.essentials import (
+ ChatCompletionInferenceParams,
+ ManualDistribution,
+ ManualDistributionParams,
+)
+
+# Equal probability for each value
+inference_params = ChatCompletionInferenceParams(
+ temperature=ManualDistribution(
+ params=ManualDistributionParams(values=[0.5, 0.7, 0.9])
+ ),
+)
+
+# Weighted probabilities (normalized automatically)
+inference_params = ChatCompletionInferenceParams(
+ top_p=ManualDistribution(
+ params=ManualDistributionParams(
+ values=[0.8, 0.9, 0.95],
+ weights=[0.2, 0.5, 0.3] # 20%, 50%, 30% probability
+ )
+ ),
+)
+```
+
+## Embedding Inference Parameters
+
+The `EmbeddingInferenceParams` class controls how models generate embeddings. This is used when working with embedding models for tasks like semantic search or similarity analysis.
+
+### Fields
+
+| Field | Type | Required | Description |
+|-------|------|----------|-------------|
+| `encoding_format` | `Literal["float", "base64"]` | No | Format of the embedding encoding (default: "float") |
+| `dimensions` | `int` | No | Number of dimensions for the embedding |
+| `max_parallel_requests` | `int` | No | Maximum concurrent API requests (default: 4, β₯ 1) |
+| `timeout` | `int` | No | API request timeout in seconds (β₯ 1) |
+| `extra_body` | `dict[str, Any]` | No | Additional parameters to include in the API request body |
+
+
+## See Also
+
+- **[Model Configurations](model-configs.md)**: Learn about configuring model settings
+- **[Model Providers](model-providers.md)**: Learn about configuring model providers
+- **[Default Model Settings](default-model-settings.md)**: Pre-configured model settings included with Data Designer
diff --git a/docs/concepts/models/model-configs.md b/docs/concepts/models/model-configs.md
index fd87b874..c78aa2ee 100644
--- a/docs/concepts/models/model-configs.md
+++ b/docs/concepts/models/model-configs.md
@@ -14,136 +14,23 @@ The `ModelConfig` class has the following fields:
|-------|------|----------|-------------|
| `alias` | `str` | Yes | Unique identifier for this model configuration (e.g., `"my-text-model"`, `"reasoning-model"`) |
| `model` | `str` | Yes | Model identifier as recognized by the provider (e.g., `"nvidia/nvidia-nemotron-nano-9b-v2"`, `"gpt-4"`) |
-| `inference_parameters` | `InferenceParameters` | No | Controls model behavior during generation (temperature, top_p, max_tokens, etc). Defaults from constructing an empty `InferenceParameters` object are picked up when not provided.|
+| `inference_parameters` | `InferenceParamsT` | No | Controls model behavior during generation. Use `ChatCompletionInferenceParams` for text/code/structured generation or `EmbeddingInferenceParams` for embeddings. Defaults to `ChatCompletionInferenceParams()` if not provided. The generation type is automatically determined by the inference parameters type. See [Inference Parameters](inference_parameters.md) for details. |
| `provider` | `str` | No | Reference to the name of the Provider to use (e.g., `"nvidia"`, `"openai"`). If not specified, one set as the default provider, which may resolve to the first provider if there are more than one |
-## InferenceParameters
-
-The `InferenceParameters` class controls how the model generates responses. It provides fine-grained control over generation behavior and supports both static values and dynamic distribution-based sampling.
-
-### Fields
-
-| Field | Type | Required | Description |
-|-------|------|----------|-------------|
-| `temperature` | `float` or `Distribution` | No | Controls randomness in generation (0.0 to 2.0). Higher values = more creative/random |
-| `top_p` | `float` or `Distribution` | No | Nucleus sampling parameter (0.0 to 1.0). Controls diversity by filtering low-probability tokens |
-| `max_tokens` | `int` | No | Maximum number of tokens for the request, including both input and output tokens (β₯ 1) |
-| `max_parallel_requests` | `int` | No | Maximum concurrent API requests (default: 4, β₯ 1) |
-| `timeout` | `int` | No | API request timeout in seconds (β₯ 1) |
-| `extra_body` | `dict[str, Any]` | No | Additional parameters to include in the API request body |
-
-!!! note "Default Values"
- If `temperature`, `top_p`, or `max_tokens` are not provided, the model provider's default values will be used. Different providers and models may have different defaults.
-
-!!! tip "Controlling Reasoning Effort for GPT-OSS Models"
- For gpt-oss models like `gpt-oss-20b` and `gpt-oss-120b`, you can control the reasoning effort using the `extra_body` parameter:
-
- ```python
- # High reasoning effort (more thorough, slower)
- inference_parameters = InferenceParameters(
- extra_body={"reasoning_effort": "high"}
- )
-
- # Medium reasoning effort (balanced)
- inference_parameters = InferenceParameters(
- extra_body={"reasoning_effort": "medium"}
- )
-
- # Low reasoning effort (faster, less thorough)
- inference_parameters = InferenceParameters(
- extra_body={"reasoning_effort": "low"}
- )
- ```
-
-### Temperature and Top P Guidelines
-
-- **Temperature**:
- - `0.0-0.3`: Highly deterministic, focused outputs (ideal for structured/reasoning tasks)
- - `0.4-0.7`: Balanced creativity and coherence (general purpose)
- - `0.8-1.0`: Creative, diverse outputs (ideal for creative writing)
- - `1.0+`: Highly random and experimental
-
-- **Top P**:
- - `0.1-0.5`: Very focused, only most likely tokens
- - `0.6-0.9`: Balanced diversity
- - `0.95-1.0`: Maximum diversity, including less likely tokens
-
-!!! tip "Adjusting Temperature and Top P Together"
- When tuning both parameters simultaneously, consider these combinations:
-
- - **For deterministic/structured outputs**: Low temperature (`0.0-0.3`) + moderate-to-high top_p (`0.8-0.95`)
- - The low temperature ensures focus, while top_p allows some token diversity
- - **For balanced generation**: Moderate temperature (`0.5-0.7`) + high top_p (`0.9-0.95`)
- - This is a good starting point for most use cases
- - **For creative outputs**: Higher temperature (`0.8-1.0`) + high top_p (`0.95-1.0`)
- - Both parameters work together to maximize diversity
-
- **Avoid**: Setting both very low (overly restrictive) or adjusting both dramatically at once. When experimenting, adjust one parameter at a time to understand its individual effect.
-
-## Distribution-Based Inference Parameters
-
-For `temperature` and `top_p`, you can specify distributions instead of fixed values. This allows Data Designer to sample different values for each generation request, introducing controlled variability into your synthetic data.
-
-### Uniform Distribution
-
-Samples values uniformly between a low and high bound:
-
-```python
-from data_designer.essentials import (
- InferenceParameters,
- UniformDistribution,
- UniformDistributionParams,
-)
-
-inference_params = InferenceParameters(
- temperature=UniformDistribution(
- params=UniformDistributionParams(low=0.7, high=1.0)
- ),
-)
-```
-
-### Manual Distribution
-
-Samples from a discrete set of values with optional weights:
-
-```python
-from data_designer.essentials import (
- InferenceParameters,
- ManualDistribution,
- ManualDistributionParams,
-)
-
-# Equal probability for each value
-inference_params = InferenceParameters(
- temperature=ManualDistribution(
- params=ManualDistributionParams(values=[0.5, 0.7, 0.9])
- ),
-)
-
-# Weighted probabilities (normalized automatically)
-inference_params = InferenceParameters(
- top_p=ManualDistribution(
- params=ManualDistributionParams(
- values=[0.8, 0.9, 0.95],
- weights=[0.2, 0.5, 0.3] # 20%, 50%, 30% probability
- )
- ),
-)
-```
## Examples
### Basic Model Configuration
```python
-from data_designer.essentials import InferenceParameters, ModelConfig
+from data_designer.essentials import ChatCompletionInferenceParams, ModelConfig
# Simple model configuration with fixed parameters
model_config = ModelConfig(
alias="my-text-model",
model="nvidia/nvidia-nemotron-nano-9b-v2",
provider="nvidia",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.85,
top_p=0.95,
max_tokens=2048,
@@ -154,7 +41,12 @@ model_config = ModelConfig(
### Multiple Model Configurations for Different Tasks
```python
-from data_designer.essentials import InferenceParameters, ModelConfig
+from data_designer.essentials import (
+ ChatCompletionInferenceParams,
+ EmbeddingInferenceParams,
+ GenerationType,
+ ModelConfig
+)
model_configs = [
# Creative tasks
@@ -162,7 +54,7 @@ model_configs = [
alias="creative-model",
model="nvidia/nvidia-nemotron-nano-9b-v2",
provider="nvidia",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.9,
top_p=0.95,
max_tokens=2048,
@@ -173,7 +65,7 @@ model_configs = [
alias="critic-model",
model="nvidia/nvidia-nemotron-nano-9b-v2",
provider="nvidia",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.25,
top_p=0.95,
max_tokens=2048,
@@ -184,7 +76,7 @@ model_configs = [
alias="reasoning-model",
model="openai/gpt-oss-20b",
provider="nvidia",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.3,
top_p=0.9,
max_tokens=4096,
@@ -195,49 +87,33 @@ model_configs = [
alias="vision-model",
model="nvidia/nemotron-nano-12b-v2-vl",
provider="nvidia",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.7,
top_p=0.95,
max_tokens=2048,
),
),
+ # Embedding tasks
+ ModelConfig(
+ alias="embedding_model",
+ model=-"nvidia/llama-3.2-nv-embedqa-1b-v2",
+ provider="nvidia",
+ inference_parameters=EmbeddingInferenceParams(
+ encoding_format="float"
+ extra_body={
+ "input_type": "query"
+ }
+ )
+ )
]
```
-!!! tip "Experiment with max_tokens for Task-Specific model configurations"
+!!! tip "Experiment with max_tokens for Task-Specific Model Configurations"
The number of tokens required to generate a single data entry can vary significantly with use case. For example, reasoning models often need more tokens to "think through" problems before generating a response. Note that `max_tokens` includes **both input and output tokens** (the total context window used), so factor in your prompt length, any context data, and the expected response length when setting this parameter.
-### Using Distribution-Based Parameters
-
-```python
-from data_designer.essentials import (
- InferenceParameters,
- ManualDistribution,
- ManualDistributionParams,
- ModelConfig,
- UniformDistribution,
- UniformDistributionParams,
-)
-
-# Model with variable temperature and top_p
-model_config = ModelConfig(
- alias="variable-model",
- model="nvidia/nvidia-nemotron-nano-9b-v2",
- inference_parameters=InferenceParameters(
- # Temperature varies uniformly between 0.7 and 1.0
- temperature=UniformDistribution(
- params=UniformDistributionParams(low=0.7, high=1.0)
- ),
- # Top P samples from discrete values with equal probability
- top_p=ManualDistribution(
- params=ManualDistributionParams(values=[0.85, 0.90, 0.95])
- ),
- max_tokens=2048,
- ),
-)
-```
## See Also
+- **[Inference Parameters](inference-parameters.md)**: Detailed guide to inference parameters and how to configure them
- **[Model Providers](model-providers.md)**: Learn about configuring model providers
- **[Default Model Settings](default-model-settings.md)**: Pre-configured model settings included with Data Designer
- **[Configure Model Settings With the CLI](configure-model-settings-with-the-cli.md)**: Use the CLI to manage model settings
diff --git a/docs/concepts/models/model-providers.md b/docs/concepts/models/model-providers.md
index 8c7bb7cc..f21ae4ca 100644
--- a/docs/concepts/models/model-providers.md
+++ b/docs/concepts/models/model-providers.md
@@ -44,7 +44,8 @@ provider = ModelProvider(
## See Also
-- **[Model Configurations](model-configs.md)**: Learn about configuring models and inference parameters
+- **[Model Configurations](model-configs.md)**: Learn about configuring models
+- **[Inference Parameters](inference-parameters.md)**: Detailed guide to inference parameters and how to configure them
- **[Default Model Settings](default-model-settings.md)**: Pre-configured providers and model settings included with Data Designer
- **[Configure Model Settings With the CLI](configure-model-settings-with-the-cli.md)**: Use the CLI to manage providers and model settings
- **[Quick Start Guide](../../quick-start.md)**: Get started with a simple example
diff --git a/docs/notebook_source/1-the-basics.py b/docs/notebook_source/1-the-basics.py
index de890fb0..b03c758f 100644
--- a/docs/notebook_source/1-the-basics.py
+++ b/docs/notebook_source/1-the-basics.py
@@ -29,9 +29,9 @@
# %%
from data_designer.essentials import (
CategorySamplerParams,
+ ChatCompletionInferenceParams,
DataDesigner,
DataDesignerConfigBuilder,
- InferenceParameters,
LLMTextColumnConfig,
ModelConfig,
PersonFromFakerSamplerParams,
@@ -82,7 +82,7 @@
alias=MODEL_ALIAS,
model=MODEL_ID,
provider=MODEL_PROVIDER,
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.5,
top_p=1.0,
max_tokens=1024,
diff --git a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py
index f968a416..18627a26 100644
--- a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py
+++ b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py
@@ -31,10 +31,10 @@
# %%
from data_designer.essentials import (
CategorySamplerParams,
+ ChatCompletionInferenceParams,
DataDesigner,
DataDesignerConfigBuilder,
ExpressionColumnConfig,
- InferenceParameters,
LLMStructuredColumnConfig,
ModelConfig,
PersonFromFakerSamplerParams,
@@ -84,7 +84,7 @@
alias=MODEL_ALIAS,
model=MODEL_ID,
provider=MODEL_PROVIDER,
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.5,
top_p=1.0,
max_tokens=1024,
diff --git a/docs/notebook_source/3-seeding-with-a-dataset.py b/docs/notebook_source/3-seeding-with-a-dataset.py
index 7c0d07e1..12a972aa 100644
--- a/docs/notebook_source/3-seeding-with-a-dataset.py
+++ b/docs/notebook_source/3-seeding-with-a-dataset.py
@@ -30,9 +30,9 @@
# %%
from data_designer.essentials import (
+ ChatCompletionInferenceParams,
DataDesigner,
DataDesignerConfigBuilder,
- InferenceParameters,
ModelConfig,
SeedConfig,
)
@@ -78,7 +78,7 @@
alias=MODEL_ALIAS,
model=MODEL_ID,
provider=MODEL_PROVIDER,
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.5,
top_p=1.0,
max_tokens=1024,
diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py
index 6265e01f..53759f53 100644
--- a/docs/notebook_source/4-providing-images-as-context.py
+++ b/docs/notebook_source/4-providing-images-as-context.py
@@ -47,11 +47,11 @@
# Data Designer imports
from data_designer.essentials import (
+ ChatCompletionInferenceParams,
DataDesigner,
DataDesignerConfigBuilder,
ImageContext,
ImageFormat,
- InferenceParameters,
LLMTextColumnConfig,
ModalityDataType,
ModelConfig,
@@ -89,7 +89,7 @@
alias="vision",
model="meta/llama-4-scout-17b-16e-instruct",
provider=MODEL_PROVIDER,
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.60,
top_p=0.95,
max_tokens=2048,
diff --git a/mkdocs.yml b/mkdocs.yml
index beae5e89..84517616 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -14,6 +14,7 @@ nav:
- Configure with the CLI: concepts/models/configure-model-settings-with-the-cli.md
- Model Providers: concepts/models/model-providers.md
- Model Configs: concepts/models/model-configs.md
+ - Inference Parameters: concepts/models/inference-parameters.md
- Columns: concepts/columns.md
- Validators: concepts/validators.md
- Person Sampling: concepts/person_sampling.md
diff --git a/src/data_designer/cli/README.md b/src/data_designer/cli/README.md
index 751d6446..f15b752e 100644
--- a/src/data_designer/cli/README.md
+++ b/src/data_designer/cli/README.md
@@ -129,8 +129,10 @@ class ConfigRepository(ABC, Generic[T]):
- Field-level validation
- Auto-completion support
- History navigation (arrow keys)
-- Default value handling
+- Current value display when editing (`(current value: X)` instead of `(default: X)`)
+- Value clearing support (type `'clear'` to remove optional parameter values)
- Back navigation support
+- Empty input handling (Enter key keeps current value or skips optional fields)
#### 6. **UI Utilities** (`ui.py`)
- **Purpose**: User interface utilities for terminal output and input
@@ -179,17 +181,29 @@ model_configs:
model: meta/llama-3.1-70b-instruct
provider: nvidia
inference_parameters:
+ generation_type: chat-completion
temperature: 0.7
top_p: 0.9
max_tokens: 2048
max_parallel_requests: 4
+ timeout: 60
- alias: gpt-4
model: gpt-4-turbo
provider: openai
inference_parameters:
+ generation_type: chat-completion
temperature: 0.8
top_p: 0.95
max_tokens: 4096
+ max_parallel_requests: 4
+ - alias: embedder
+ model: text-embedding-3-large
+ provider: openai
+ inference_parameters:
+ generation_type: embedding
+ encoding_format: float
+ dimensions: 1024
+ max_parallel_requests: 4
```
## Usage Examples
diff --git a/src/data_designer/cli/commands/list.py b/src/data_designer/cli/commands/list.py
index e0531fcd..a93b9d31 100644
--- a/src/data_designer/cli/commands/list.py
+++ b/src/data_designer/cli/commands/list.py
@@ -6,6 +6,7 @@
from data_designer.cli.repositories.model_repository import ModelRepository
from data_designer.cli.repositories.provider_repository import ProviderRepository
from data_designer.cli.ui import console, print_error, print_header, print_info, print_warning
+from data_designer.config.models import ModelConfig
from data_designer.config.utils.constants import DATA_DESIGNER_HOME, NordColor
@@ -75,6 +76,37 @@ def display_providers(provider_repo: ProviderRepository) -> None:
console.print()
+def format_inference_parameters(model_config: ModelConfig) -> str:
+ """Format inference parameters based on generation type.
+
+ Args:
+ model_config: Model configuration
+
+ Returns:
+ Formatted string of inference parameters
+ """
+ params = model_config.inference_parameters
+
+ # Get parameter values as dict, excluding common base parameters
+ params_dict = params.model_dump(exclude_none=True, mode="json")
+
+ if not params_dict:
+ return "(none)"
+
+ # Format each parameter
+ parts = []
+ for key, value in params_dict.items():
+ # Check if value is a distribution (has dict structure with distribution_type)
+ if isinstance(value, dict) and "distribution_type" in value:
+ formatted_value = "dist"
+ elif isinstance(value, float):
+ formatted_value = f"{value:.2f}"
+ else:
+ formatted_value = str(value)
+ parts.append(f"{key}={formatted_value}")
+ return ", ".join(parts)
+
+
def display_models(model_repo: ModelRepository) -> None:
"""Load and display model configurations.
@@ -97,30 +129,16 @@ def display_models(model_repo: ModelRepository) -> None:
table.add_column("Alias", style=NordColor.NORD14.value, no_wrap=True)
table.add_column("Model ID", style=NordColor.NORD4.value)
table.add_column("Provider", style=NordColor.NORD9.value, no_wrap=True)
- table.add_column("Temperature", style=NordColor.NORD15.value, justify="right")
- table.add_column("Top P", style=NordColor.NORD15.value, justify="right")
- table.add_column("Max Tokens", style=NordColor.NORD15.value, justify="right")
+ table.add_column("Inference Parameters", style=NordColor.NORD15.value)
for mc in registry.model_configs:
- # Handle distribution-based parameters
- temp_display = (
- f"{mc.inference_parameters.temperature:.2f}"
- if isinstance(mc.inference_parameters.temperature, (int, float))
- else "dist"
- )
- top_p_display = (
- f"{mc.inference_parameters.top_p:.2f}"
- if isinstance(mc.inference_parameters.top_p, (int, float))
- else "dist"
- )
+ params_display = format_inference_parameters(mc)
table.add_row(
mc.alias,
mc.model,
mc.provider or "(default)",
- temp_display,
- top_p_display,
- str(mc.inference_parameters.max_tokens) if mc.inference_parameters.max_tokens else "(none)",
+ params_display,
)
console.print(table)
diff --git a/src/data_designer/cli/controllers/model_controller.py b/src/data_designer/cli/controllers/model_controller.py
index dff323f1..2124405b 100644
--- a/src/data_designer/cli/controllers/model_controller.py
+++ b/src/data_designer/cli/controllers/model_controller.py
@@ -160,9 +160,10 @@ def _handle_update(self, available_providers: list[str]) -> None:
return
# Check if model has distribution-based parameters
- if hasattr(model.inference_parameters.temperature, "sample") or hasattr(
- model.inference_parameters.top_p, "sample"
- ):
+ params_dict = model.inference_parameters.model_dump(mode="json", exclude_none=True)
+ has_distribution = any(isinstance(v, dict) and "distribution_type" in v for v in params_dict.values())
+
+ if has_distribution:
print_warning(
"This model uses distribution-based inference parameters, "
"which cannot be edited via the CLI. Please edit the configuration file directly."
diff --git a/src/data_designer/cli/forms/field.py b/src/data_designer/cli/forms/field.py
index 27aefd88..d5eab9ba 100644
--- a/src/data_designer/cli/forms/field.py
+++ b/src/data_designer/cli/forms/field.py
@@ -5,6 +5,7 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar
+from data_designer.cli.ui import BACK, prompt_text_input, select_with_arrows
from data_designer.cli.utils import validate_numeric_range
T = TypeVar("T")
@@ -40,8 +41,14 @@ def value(self) -> T | None:
return self._value
@value.setter
- def value(self, val: T) -> None:
- """Set and validate the field value."""
+ def value(self, val: T | str) -> None:
+ """Set and validate the field value. Converts empty strings to None for optional fields."""
+ # Handle empty string for optional fields (clearing the value)
+ if val == "" and not self.required:
+ self._value = None
+ return
+
+ # Standard validation for non-empty values
if self.validator:
# For string validators, convert to string first if needed
val_str = str(val) if not isinstance(val, str) else val
@@ -50,6 +57,40 @@ def value(self, val: T) -> None:
raise ValidationError(error_msg or "Invalid value")
self._value = val
+ def _build_prompt_text(self) -> str:
+ """Build prompt text with current value information."""
+ has_current_value = self.default is not None
+
+ if has_current_value:
+ # Show as "current" instead of "default" with dimmed styling
+ if not self.required:
+ return f"{self.prompt} (current value: {self.default}, type 'clear' to remove)"
+ return f"{self.prompt} (current value: {self.default})"
+
+ return self.prompt
+
+ def _handle_prompt_result(self, result: str | None | Any) -> str | None | Any:
+ """Handle common prompt result logic (BACK, None, clear keywords, empty input)."""
+ if result is BACK:
+ return BACK
+
+ if result is None:
+ # User cancelled (ESC)
+ return None
+
+ # Check for special keywords to clear the value
+ if result and result.lower() in ("clear", "none", "default"):
+ return ""
+
+ if not result:
+ # Empty input: return current value if exists
+ has_current_value = self.default is not None
+ if has_current_value:
+ return self.default
+ return ""
+
+ return result
+
@abstractmethod
def prompt_user(self, allow_back: bool = False) -> T | None | Any:
"""Prompt user for input."""
@@ -75,21 +116,19 @@ def __init__(
def prompt_user(self, allow_back: bool = False) -> str | None | Any:
"""Prompt user for text input."""
- from data_designer.cli.ui import BACK, prompt_text_input
+ prompt_text = self._build_prompt_text()
+ # Don't pass default to prompt_text_input to avoid duplicate "(default: X)" text
result = prompt_text_input(
- self.prompt,
- default=self.default,
+ prompt_text,
+ default=None,
validator=self.validator,
mask=self.mask,
completions=self.completions,
allow_back=allow_back,
)
- if result is BACK:
- return BACK
-
- return result
+ return self._handle_prompt_result(result)
class SelectField(Field[str]):
@@ -109,8 +148,6 @@ def __init__(
def prompt_user(self, allow_back: bool = False) -> str | None | Any:
"""Prompt user for selection."""
- from data_designer.cli.ui import BACK, select_with_arrows
-
result = select_with_arrows(
self.options,
self.prompt,
@@ -144,6 +181,9 @@ def __init__(
def range_validator(value: str) -> tuple[bool, str | None]:
if not value and not required:
return True, None
+ # Allow special keywords to clear the value
+ if value and value.lower() in ("clear", "none", "default"):
+ return True, None
if min_value is not None and max_value is not None:
is_valid, parsed = validate_numeric_range(value, min_value, max_value)
if not is_valid:
@@ -163,18 +203,24 @@ def range_validator(value: str) -> tuple[bool, str | None]:
def prompt_user(self, allow_back: bool = False) -> float | None | Any:
"""Prompt user for numeric input."""
- from data_designer.cli.ui import BACK, prompt_text_input
-
- default_str = str(self.default) if self.default is not None else None
+ prompt_text = self._build_prompt_text()
+ # Don't pass default to prompt_text_input to avoid duplicate "(default: X)" text
result = prompt_text_input(
- self.prompt,
- default=default_str,
+ prompt_text,
+ default=None,
validator=self.validator,
allow_back=allow_back,
)
- if result is BACK:
- return BACK
+ result = self._handle_prompt_result(result)
- return float(result) if result else None
+ # Return special values (BACK, None, empty string, defaults) as-is
+ if result is BACK or result is None or result == "":
+ return result
+
+ # Convert numeric strings to float (but not if it's already a float from default)
+ if isinstance(result, str):
+ return float(result)
+
+ return result
diff --git a/src/data_designer/cli/forms/model_builder.py b/src/data_designer/cli/forms/model_builder.py
index 58c96e02..ba627892 100644
--- a/src/data_designer/cli/forms/model_builder.py
+++ b/src/data_designer/cli/forms/model_builder.py
@@ -6,8 +6,13 @@
from data_designer.cli.forms.builder import FormBuilder
from data_designer.cli.forms.field import NumericField, SelectField, TextField
from data_designer.cli.forms.form import Form
-from data_designer.config.models import ModelConfig
-from data_designer.config.utils.constants import MAX_TEMPERATURE, MAX_TOP_P, MIN_TEMPERATURE, MIN_TOP_P
+from data_designer.cli.ui import confirm_action, print_error, print_text
+from data_designer.config.models import (
+ ChatCompletionInferenceParams,
+ EmbeddingInferenceParams,
+ GenerationType,
+ ModelConfig,
+)
class ModelFormBuilder(FormBuilder[ModelConfig]):
@@ -19,7 +24,7 @@ def __init__(self, existing_aliases: set[str] | None = None, available_providers
self.available_providers = available_providers or []
def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
- """Create the model configuration form."""
+ """Create the model configuration form with basic fields."""
fields = []
# Model alias
@@ -29,7 +34,7 @@ def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
"Model alias (used in your configs)",
default=initial_data.get("alias") if initial_data else None,
required=True,
- validator=self._validate_alias,
+ validator=self.validate_alias,
)
)
@@ -61,46 +66,222 @@ def create_form(self, initial_data: dict[str, Any] | None = None) -> Form:
# Single provider - will be set automatically
pass
- # Inference parameters
- fields.extend(
- [
+ # Generation type
+ # Extract from inference_parameters if present (for existing models)
+ default_gen_type = GenerationType.CHAT_COMPLETION
+ if initial_data:
+ inference_params = initial_data.get("inference_parameters", {})
+ default_gen_type = inference_params.get("generation_type", default_gen_type)
+
+ fields.append(
+ SelectField(
+ "generation_type",
+ "Generation type",
+ options={
+ GenerationType.CHAT_COMPLETION: "Chat completion",
+ GenerationType.EMBEDDING: "Embedding",
+ },
+ default=default_gen_type,
+ )
+ )
+
+ return Form(self.title, fields)
+
+ def create_inference_params_form(
+ self, generation_type: GenerationType, initial_params: dict[str, Any] | None = None
+ ) -> Form:
+ """Create generation-type-specific inference parameters form."""
+ initial_params = initial_params or {}
+ fields = []
+
+ if generation_type == GenerationType.CHAT_COMPLETION:
+ # Temperature
+ fields.append(
NumericField(
"temperature",
- f"Temperature ({MIN_TEMPERATURE}-{MAX_TEMPERATURE})",
- default=initial_data.get("inference_parameters", {}).get("temperature", 0.7)
- if initial_data
- else 0.7,
- min_value=MIN_TEMPERATURE,
- max_value=MAX_TEMPERATURE,
- ),
+ "Temperature (0.0-2.0)",
+ default=initial_params.get("temperature"),
+ min_value=0.0,
+ max_value=2.0,
+ required=False,
+ help_text="Higher values make output more random, lower values more deterministic",
+ )
+ )
+
+ # Top P
+ fields.append(
NumericField(
"top_p",
- f"Top P ({MIN_TOP_P}-{MAX_TOP_P})",
- default=initial_data.get("inference_parameters", {}).get("top_p", 0.9) if initial_data else 0.9,
- min_value=MIN_TOP_P,
- max_value=MAX_TOP_P,
- ),
+ "Top P (0.0-1.0)",
+ default=initial_params.get("top_p"),
+ min_value=0.0,
+ max_value=1.0,
+ required=False,
+ help_text="Controls diversity via nucleus sampling",
+ )
+ )
+
+ # Max tokens
+ fields.append(
NumericField(
"max_tokens",
- "Max tokens",
- default=initial_data.get("inference_parameters", {}).get("max_tokens", 2048)
- if initial_data
- else 2048,
- min_value=1,
- max_value=100000,
- ),
- ]
- )
+ "Max tokens (maximum total tokens including input and output)",
+ default=initial_params.get("max_tokens"),
+ min_value=1.0,
+ required=False,
+ help_text="Maximum number of tokens including both input prompt and generated response",
+ )
+ )
- return Form(self.title, fields)
+ # Max parallel requests
+ fields.append(
+ NumericField(
+ "max_parallel_requests",
+ "Max parallel requests (default: 4)",
+ default=initial_params.get("max_parallel_requests", 4),
+ min_value=1.0,
+ required=False,
+ help_text="Maximum number of parallel API requests",
+ )
+ )
- def _validate_alias(self, alias: str) -> tuple[bool, str | None]:
- """Validate model alias."""
- if not alias:
- return False, "Model alias is required"
- if alias in self.existing_aliases:
- return False, f"Model alias '{alias}' already exists"
- return True, None
+ # Timeout
+ fields.append(
+ NumericField(
+ "timeout",
+ "Timeout in seconds (optional)",
+ default=initial_params.get("timeout"),
+ min_value=1.0,
+ required=False,
+ help_text="Timeout for each API request in seconds",
+ )
+ )
+
+ else: # EMBEDDING
+ # Encoding format
+ fields.append(
+ TextField(
+ "encoding_format",
+ "Encoding format (float or base64)",
+ default=initial_params.get("encoding_format"),
+ required=False,
+ validator=self.validate_encoding_format,
+ )
+ )
+
+ # Dimensions
+ fields.append(
+ NumericField(
+ "dimensions",
+ "Dimensions (number of dimensions for embeddings)",
+ default=initial_params.get("dimensions"),
+ min_value=1.0,
+ required=False,
+ help_text="Model-specific dimension size (e.g., 1024, 768)",
+ )
+ )
+
+ # Max parallel requests (common field)
+ fields.append(
+ NumericField(
+ "max_parallel_requests",
+ "Max parallel requests (default: 4)",
+ default=initial_params.get("max_parallel_requests", 4),
+ min_value=1.0,
+ required=False,
+ help_text="Maximum number of parallel API requests",
+ )
+ )
+
+ # Timeout (common field)
+ fields.append(
+ NumericField(
+ "timeout",
+ "Timeout in seconds (optional)",
+ default=initial_params.get("timeout"),
+ min_value=1.0,
+ required=False,
+ help_text="Timeout for each API request in seconds",
+ )
+ )
+
+ return Form(f"{self.title} - Inference Parameters", fields)
+
+ def build_inference_params(self, generation_type: GenerationType, params_data: dict[str, Any]) -> dict[str, Any]:
+ """Build inference parameters dictionary from form data with proper type conversions."""
+ inference_params = {}
+
+ if generation_type == GenerationType.CHAT_COMPLETION:
+ if params_data.get("temperature") is not None:
+ inference_params["temperature"] = params_data["temperature"]
+ if params_data.get("top_p") is not None:
+ inference_params["top_p"] = params_data["top_p"]
+ if params_data.get("max_tokens") is not None:
+ inference_params["max_tokens"] = int(params_data["max_tokens"])
+
+ else: # EMBEDDING
+ # Only include fields with actual values; Pydantic will use defaults for missing fields
+ if params_data.get("encoding_format"):
+ inference_params["encoding_format"] = params_data["encoding_format"]
+ if params_data.get("dimensions"):
+ inference_params["dimensions"] = int(params_data["dimensions"])
+
+ # Common fields for both generation types
+ if params_data.get("max_parallel_requests") is not None:
+ inference_params["max_parallel_requests"] = int(params_data["max_parallel_requests"])
+ if params_data.get("timeout") is not None:
+ inference_params["timeout"] = int(params_data["timeout"])
+
+ return inference_params
+
+ def run(self, initial_data: dict[str, Any] | None = None) -> ModelConfig | None:
+ """Run the interactive form with two-step process for generation-type-specific parameters."""
+ # Step 1: Collect basic model configuration
+ basic_form = self.create_form(initial_data)
+
+ if initial_data:
+ basic_form.set_values(initial_data)
+
+ while True:
+ basic_result = basic_form.prompt_all(allow_back=True)
+
+ if basic_result is None:
+ if confirm_action("Cancel configuration?", default=False):
+ return None
+ continue
+
+ # Step 2: Collect generation-type-specific inference parameters
+ generation_type = basic_result.get("generation_type", GenerationType.CHAT_COMPLETION)
+ initial_params = initial_data.get("inference_parameters") if initial_data else None
+
+ # Print message to indicate we're now configuring inference parameters
+ gen_type_name = "chat completion" if generation_type == GenerationType.CHAT_COMPLETION else "embedding"
+ print_text(
+ f"βοΈ Configuring {gen_type_name} inference parameters [dim](Press Enter to keep current value or skip)[/dim]\n"
+ )
+
+ params_form = self.create_inference_params_form(generation_type, initial_params)
+
+ params_result = params_form.prompt_all(allow_back=True)
+
+ if params_result is None:
+ if confirm_action("Cancel configuration?", default=False):
+ return None
+ continue
+
+ # Build inference_parameters dict from individual fields
+ inference_params = self.build_inference_params(generation_type, params_result)
+
+ # Merge results
+ full_data = {**basic_result, "inference_parameters": inference_params}
+
+ try:
+ config = self.build_config(full_data)
+ return config
+ except Exception as e:
+ print_error(f"Configuration error: {e}")
+ if not confirm_action("Try again?", default=True):
+ return None
def build_config(self, form_data: dict[str, Any]) -> ModelConfig:
"""Build ModelConfig from form data."""
@@ -112,14 +293,40 @@ def build_config(self, form_data: dict[str, Any]) -> ModelConfig:
else:
provider = None
+ # Get generation type (from form data, used to determine which inference params to create)
+ generation_type = form_data.get("generation_type", GenerationType.CHAT_COMPLETION)
+
+ # Get inference parameters dict
+ inference_params_dict = form_data.get("inference_parameters", {})
+
+ # Create the appropriate inference parameters type based on generation_type
+ # The generation_type will be set automatically by the inference params class
+ if generation_type == GenerationType.EMBEDDING:
+ inference_params = EmbeddingInferenceParams(**inference_params_dict)
+ else:
+ inference_params = ChatCompletionInferenceParams(**inference_params_dict)
+
return ModelConfig(
alias=form_data["alias"],
model=form_data["model"],
provider=provider,
- inference_parameters={
- "temperature": form_data["temperature"],
- "top_p": form_data["top_p"],
- "max_tokens": int(form_data["max_tokens"]),
- "max_parallel_requests": 4,
- },
+ inference_parameters=inference_params,
)
+
+ def validate_alias(self, alias: str) -> tuple[bool, str | None]:
+ """Validate model alias."""
+ if not alias:
+ return False, "Model alias is required"
+ if alias in self.existing_aliases:
+ return False, f"Model alias '{alias}' already exists"
+ return True, None
+
+ def validate_encoding_format(self, value: str) -> tuple[bool, str | None]:
+ """Validate encoding format for embedding models."""
+ if not value:
+ return True, None # Optional field
+ if value.lower() in ("clear", "none", "default"):
+ return True, None # Allow clearing keywords
+ if value not in ("float", "base64"):
+ return False, "Must be either 'float' or 'base64'"
+ return True, None
diff --git a/src/data_designer/config/analysis/column_statistics.py b/src/data_designer/config/analysis/column_statistics.py
index 4f0f7675..c0aa46b6 100644
--- a/src/data_designer/config/analysis/column_statistics.py
+++ b/src/data_designer/config/analysis/column_statistics.py
@@ -119,30 +119,30 @@ class LLMTextColumnStatistics(GeneralColumnStatistics):
Stores both prompt and completion token consumption data.
Attributes:
- completion_tokens_mean: Mean number of completion tokens generated per record.
- completion_tokens_median: Median number of completion tokens generated per record.
- completion_tokens_stddev: Standard deviation of completion tokens per record.
- prompt_tokens_mean: Mean number of prompt tokens used per record.
- prompt_tokens_median: Median number of prompt tokens used per record.
- prompt_tokens_stddev: Standard deviation of prompt tokens per record.
+ output_tokens_mean: Mean number of output tokens generated per record.
+ output_tokens_median: Median number of output tokens generated per record.
+ output_tokens_stddev: Standard deviation of output tokens per record.
+ input_tokens_mean: Mean number of input tokens used per record.
+ input_tokens_median: Median number of input tokens used per record.
+ input_tokens_stddev: Standard deviation of input tokens per record.
column_type: Discriminator field, always "llm-text" for this statistics type.
"""
- completion_tokens_mean: Union[float, MissingValue]
- completion_tokens_median: Union[float, MissingValue]
- completion_tokens_stddev: Union[float, MissingValue]
- prompt_tokens_mean: Union[float, MissingValue]
- prompt_tokens_median: Union[float, MissingValue]
- prompt_tokens_stddev: Union[float, MissingValue]
+ output_tokens_mean: Union[float, MissingValue]
+ output_tokens_median: Union[float, MissingValue]
+ output_tokens_stddev: Union[float, MissingValue]
+ input_tokens_mean: Union[float, MissingValue]
+ input_tokens_median: Union[float, MissingValue]
+ input_tokens_stddev: Union[float, MissingValue]
column_type: Literal[DataDesignerColumnType.LLM_TEXT.value] = DataDesignerColumnType.LLM_TEXT.value
@field_validator(
- "completion_tokens_mean",
- "completion_tokens_median",
- "completion_tokens_stddev",
- "prompt_tokens_mean",
- "prompt_tokens_median",
- "prompt_tokens_stddev",
+ "output_tokens_mean",
+ "output_tokens_median",
+ "output_tokens_stddev",
+ "input_tokens_mean",
+ "input_tokens_median",
+ "input_tokens_stddev",
mode="before",
)
def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) -> Union[float, int, MissingValue]:
@@ -150,13 +150,13 @@ def llm_column_ensure_python_floats(cls, v: Union[float, int, MissingValue]) ->
def create_report_row_data(self) -> dict[str, Any]:
prompt_tokens_str = (
- f"{self.prompt_tokens_median:.1f} +/- {self.prompt_tokens_stddev:.1f}"
- if not self._is_missing_value(self.prompt_tokens_median)
+ f"{self.input_tokens_median:.1f} +/- {self.input_tokens_stddev:.1f}"
+ if not self._is_missing_value(self.input_tokens_median)
else "--"
)
completion_tokens_str = (
- f"{self.completion_tokens_median:.1f} +/- {self.completion_tokens_stddev:.1f}"
- if not self._is_missing_value(self.completion_tokens_median)
+ f"{self.output_tokens_median:.1f} +/- {self.output_tokens_stddev:.1f}"
+ if not self._is_missing_value(self.output_tokens_median)
else "--"
)
return {
diff --git a/src/data_designer/config/column_configs.py b/src/data_designer/config/column_configs.py
index 3916c08d..48fe529e 100644
--- a/src/data_designer/config/column_configs.py
+++ b/src/data_designer/config/column_configs.py
@@ -377,3 +377,24 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
"""
column_type: Literal["seed-dataset"] = "seed-dataset"
+
+
+class EmbeddingColumnConfig(SingleColumnConfig):
+ """Configuration for embedding generation columns.
+
+ Embedding columns generate embeddings for text input using a specified model.
+
+ Attributes:
+ target_column: The column to generate embeddings for. The column could be a single text string or a list of text strings in stringified JSON format.
+ If it is a list of text strings in stringified JSON format, the embeddings will be generated for each text string.
+ model_alias: The model to use for embedding generation.
+ column_type: Discriminator field, always "embedding" for this configuration type.
+ """
+
+ target_column: str
+ model_alias: str
+ column_type: Literal["embedding"] = "embedding"
+
+ @property
+ def required_columns(self) -> list[str]:
+ return [self.target_column]
diff --git a/src/data_designer/config/column_types.py b/src/data_designer/config/column_types.py
index bb8d8cbd..cbfce4f7 100644
--- a/src/data_designer/config/column_types.py
+++ b/src/data_designer/config/column_types.py
@@ -6,6 +6,7 @@
from typing_extensions import TypeAlias
from data_designer.config.column_configs import (
+ EmbeddingColumnConfig,
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
@@ -35,6 +36,7 @@
SamplerColumnConfig,
SeedDatasetColumnConfig,
ValidationColumnConfig,
+ EmbeddingColumnConfig,
]
ColumnConfigT = plugin_manager.inject_into_column_config_type_union(ColumnConfigT)
@@ -54,6 +56,7 @@
DataDesignerColumnType.SEED_DATASET: "π±",
DataDesignerColumnType.SAMPLER: "π²",
DataDesignerColumnType.VALIDATION: "π",
+ DataDesignerColumnType.EMBEDDING: "π§¬",
}
COLUMN_TYPE_EMOJI_MAP.update(
{DataDesignerColumnType(p.name): p.emoji for p in plugin_manager.get_column_generator_plugins()}
@@ -70,27 +73,29 @@ def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumn
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.VALIDATION,
+ DataDesignerColumnType.EMBEDDING,
}
dag_column_types.update(plugin_manager.get_plugin_column_types(DataDesignerColumnType))
return column_type in dag_column_types
-def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool:
- """Return True if the column type is an LLM-generated column."""
+def column_type_is_model_generated(column_type: Union[str, DataDesignerColumnType]) -> bool:
+ """Return True if the column type is a model-generated column."""
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
- llm_generated_column_types = {
+ model_generated_column_types = {
DataDesignerColumnType.LLM_TEXT,
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
+ DataDesignerColumnType.EMBEDDING,
}
- llm_generated_column_types.update(
+ model_generated_column_types.update(
plugin_manager.get_plugin_column_types(
DataDesignerColumnType,
required_resources=["model_registry"],
)
)
- return column_type in llm_generated_column_types
+ return column_type in model_generated_column_types
def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
@@ -121,6 +126,8 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
return SamplerColumnConfig(name=name, **_resolve_sampler_kwargs(name, kwargs))
if column_type == DataDesignerColumnType.SEED_DATASET:
return SeedDatasetColumnConfig(name=name, **kwargs)
+ if column_type == DataDesignerColumnType.EMBEDDING:
+ return EmbeddingColumnConfig(name=name, **kwargs)
if plugin := plugin_manager.get_column_generator_plugin_if_exists(column_type.value):
return plugin.config_cls(name=name, **kwargs)
raise InvalidColumnTypeError(f"π {column_type} is not a valid column type.") # pragma: no cover
@@ -135,6 +142,7 @@ def get_column_display_order() -> list[DataDesignerColumnType]:
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
+ DataDesignerColumnType.EMBEDDING,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
]
diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py
index acbc8ef8..5f07595d 100644
--- a/src/data_designer/config/config_builder.py
+++ b/src/data_designer/config/config_builder.py
@@ -19,7 +19,7 @@
from data_designer.config.column_types import (
ColumnConfigT,
DataDesignerColumnType,
- column_type_is_llm_generated,
+ column_type_is_model_generated,
get_column_config_from_kwargs,
get_column_display_order,
)
@@ -447,12 +447,21 @@ def get_constraints(self, target_column: str) -> list[ColumnConstraintT]:
return [c for c in self._constraints if c.target_column == target_column]
def get_llm_gen_columns(self) -> list[ColumnConfigT]:
- """Get all LLM-generated column configurations.
+ """Get all model-generated column configurations.
Returns:
- A list of column configurations that use LLM generation.
+ A list of column configurations that use model generation.
"""
- return [c for c in self._column_configs.values() if column_type_is_llm_generated(c.column_type)]
+ logger.warning("get_llm_gen_columns is deprecated. Use get_model_gen_columns instead.")
+ return self.get_model_gen_columns()
+
+ def get_model_gen_columns(self) -> list[ColumnConfigT]:
+ """Get all model-generated column configurations.
+
+ Returns:
+ A list of column configurations that use model generation.
+ """
+ return [c for c in self._column_configs.values() if column_type_is_model_generated(c.column_type)]
def get_columns_of_type(self, column_type: DataDesignerColumnType) -> list[ColumnConfigT]:
"""Get all column configurations of the specified type.
diff --git a/src/data_designer/config/default_model_settings.py b/src/data_designer/config/default_model_settings.py
index e65d0a90..5ab6bdf8 100644
--- a/src/data_designer/config/default_model_settings.py
+++ b/src/data_designer/config/default_model_settings.py
@@ -8,7 +8,13 @@
from pathlib import Path
from typing import Any, Literal, Optional
-from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
+from data_designer.config.models import (
+ ChatCompletionInferenceParams,
+ EmbeddingInferenceParams,
+ InferenceParamsT,
+ ModelConfig,
+ ModelProvider,
+)
from data_designer.config.utils.constants import (
MANAGED_ASSETS_PATH,
MODEL_CONFIGS_FILE_PATH,
@@ -21,32 +27,43 @@
logger = logging.getLogger(__name__)
-def get_default_text_alias_inference_parameters() -> InferenceParameters:
- return InferenceParameters(
+def get_default_text_alias_inference_parameters() -> ChatCompletionInferenceParams:
+ return ChatCompletionInferenceParams(
temperature=0.85,
top_p=0.95,
)
-def get_default_reasoning_alias_inference_parameters() -> InferenceParameters:
- return InferenceParameters(
+def get_default_reasoning_alias_inference_parameters() -> ChatCompletionInferenceParams:
+ return ChatCompletionInferenceParams(
temperature=0.35,
top_p=0.95,
)
-def get_default_vision_alias_inference_parameters() -> InferenceParameters:
- return InferenceParameters(
+def get_default_vision_alias_inference_parameters() -> ChatCompletionInferenceParams:
+ return ChatCompletionInferenceParams(
temperature=0.85,
top_p=0.95,
)
-def get_default_inference_parameters(model_alias: Literal["text", "reasoning", "vision"]) -> InferenceParameters:
+def get_default_embedding_alias_inference_parameters(provider: str) -> EmbeddingInferenceParams:
+ args = dict(encoding_format="float")
+ if provider == "nvidia":
+ args["extra_body"] = {"input_type": "query"}
+ return EmbeddingInferenceParams(**args)
+
+
+def get_default_inference_parameters(
+ model_alias: Literal["text", "reasoning", "vision", "embedding"], provider: str
+) -> InferenceParamsT:
if model_alias == "reasoning":
return get_default_reasoning_alias_inference_parameters()
elif model_alias == "vision":
return get_default_vision_alias_inference_parameters()
+ elif model_alias == "embedding":
+ return get_default_embedding_alias_inference_parameters(provider)
else:
return get_default_text_alias_inference_parameters()
@@ -60,7 +77,7 @@ def get_builtin_model_configs() -> list[ModelConfig]:
alias=f"{provider}-{model_alias}",
model=model_id,
provider=provider,
- inference_parameters=get_default_inference_parameters(model_alias),
+ inference_parameters=get_default_inference_parameters(model_alias, provider),
)
)
return model_configs
@@ -103,7 +120,8 @@ def resolve_seed_default_model_settings() -> None:
f"πΎ Default model configs were not found, so writing the following to {str(MODEL_CONFIGS_FILE_PATH)!r}"
)
save_config_file(
- MODEL_CONFIGS_FILE_PATH, {"model_configs": [mc.model_dump() for mc in get_builtin_model_configs()]}
+ MODEL_CONFIGS_FILE_PATH,
+ {"model_configs": [mc.model_dump(mode="json") for mc in get_builtin_model_configs()]},
)
if not MODEL_PROVIDERS_FILE_PATH.exists():
@@ -111,7 +129,7 @@ def resolve_seed_default_model_settings() -> None:
f"πͺ Default model providers were not found, so writing the following to {str(MODEL_PROVIDERS_FILE_PATH)!r}"
)
save_config_file(
- MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump() for p in get_builtin_model_providers()]}
+ MODEL_PROVIDERS_FILE_PATH, {"providers": [p.model_dump(mode="json") for p in get_builtin_model_providers()]}
)
if not MANAGED_ASSETS_PATH.exists():
diff --git a/src/data_designer/config/exports.py b/src/data_designer/config/exports.py
index eb976659..3839b178 100644
--- a/src/data_designer/config/exports.py
+++ b/src/data_designer/config/exports.py
@@ -3,6 +3,7 @@
from data_designer.config.analysis.column_profilers import JudgeScoreProfilerConfig
from data_designer.config.column_configs import (
+ EmbeddingColumnConfig,
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
@@ -19,6 +20,9 @@
from data_designer.config.dataset_builders import BuildStage
from data_designer.config.datastore import DatastoreSettings
from data_designer.config.models import (
+ ChatCompletionInferenceParams,
+ EmbeddingInferenceParams,
+ GenerationType,
ImageContext,
ImageFormat,
InferenceParameters,
@@ -81,6 +85,7 @@ def get_config_exports() -> list[str]:
CodeLang.__name__,
CodeValidatorParams.__name__,
ColumnInequalityConstraint.__name__,
+ ChatCompletionInferenceParams.__name__,
DataDesignerColumnType.__name__,
DataDesignerConfig.__name__,
DataDesignerConfigBuilder.__name__,
@@ -89,8 +94,11 @@ def get_config_exports() -> list[str]:
DatastoreSettings.__name__,
DatetimeSamplerParams.__name__,
DropColumnsProcessorConfig.__name__,
+ EmbeddingColumnConfig.__name__,
+ EmbeddingInferenceParams.__name__,
ExpressionColumnConfig.__name__,
GaussianSamplerParams.__name__,
+ GenerationType.__name__,
IndexRange.__name__,
InfoType.__name__,
ImageContext.__name__,
diff --git a/src/data_designer/config/models.py b/src/data_designer/config/models.py
index 09454a88..9b5ac6d7 100644
--- a/src/data_designer/config/models.py
+++ b/src/data_designer/config/models.py
@@ -5,10 +5,10 @@
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
-from typing import Any, Generic, List, Optional, TypeVar, Union
+from typing import Any, Generic, List, Literal, Optional, TypeVar, Union
import numpy as np
-from pydantic import BaseModel, Field, model_validator
+from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self, TypeAlias
from data_designer.config.base import ConfigBase
@@ -205,33 +205,59 @@ def sample(self) -> float:
DistributionT: TypeAlias = Union[UniformDistribution, ManualDistribution]
-class InferenceParameters(ConfigBase):
- """Configuration for LLM inference parameters.
+class GenerationType(str, Enum):
+ CHAT_COMPLETION = "chat-completion"
+ EMBEDDING = "embedding"
+
+
+class BaseInferenceParams(ConfigBase, ABC):
+ """Base configuration for inference parameters.
Attributes:
- temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
- top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
- max_tokens: Maximum number of tokens (includes both input and output tokens).
+ generation_type: Type of generation (chat-completion or embedding). Acts as discriminator.
max_parallel_requests: Maximum number of parallel requests to the model API.
timeout: Timeout in seconds for each request.
extra_body: Additional parameters to pass to the model API.
"""
- temperature: Optional[Union[float, DistributionT]] = None
- top_p: Optional[Union[float, DistributionT]] = None
- max_tokens: Optional[int] = Field(default=None, ge=1)
+ generation_type: GenerationType
max_parallel_requests: int = Field(default=4, ge=1)
timeout: Optional[int] = Field(default=None, ge=1)
extra_body: Optional[dict[str, Any]] = None
@property
- def generate_kwargs(self) -> dict[str, Union[float, int]]:
+ def generate_kwargs(self) -> dict[str, Any]:
"""Get the generate kwargs for the inference parameters.
Returns:
A dictionary of the generate kwargs.
"""
result = {}
+ if self.timeout is not None:
+ result["timeout"] = self.timeout
+ if self.extra_body is not None and self.extra_body != {}:
+ result["extra_body"] = self.extra_body
+ return result
+
+
+class ChatCompletionInferenceParams(BaseInferenceParams):
+ """Configuration for LLM inference parameters.
+
+ Attributes:
+ generation_type: Type of generation, always "chat-completion" for this class.
+ temperature: Sampling temperature (0.0-2.0). Can be a fixed value or a distribution for dynamic sampling.
+ top_p: Nucleus sampling probability (0.0-1.0). Can be a fixed value or a distribution for dynamic sampling.
+ max_tokens: Maximum number of tokens (includes both input and output tokens).
+ """
+
+ generation_type: Literal[GenerationType.CHAT_COMPLETION] = GenerationType.CHAT_COMPLETION
+ temperature: Optional[Union[float, DistributionT]] = None
+ top_p: Optional[Union[float, DistributionT]] = None
+ max_tokens: Optional[int] = Field(default=None, ge=1)
+
+ @property
+ def generate_kwargs(self) -> dict[str, Any]:
+ result = super().generate_kwargs
if self.temperature is not None:
result["temperature"] = (
self.temperature.sample() if hasattr(self.temperature, "sample") else self.temperature
@@ -240,10 +266,6 @@ def generate_kwargs(self) -> dict[str, Union[float, int]]:
result["top_p"] = self.top_p.sample() if hasattr(self.top_p, "sample") else self.top_p
if self.max_tokens is not None:
result["max_tokens"] = self.max_tokens
- if self.timeout is not None:
- result["timeout"] = self.timeout
- if self.extra_body is not None and self.extra_body != {}:
- result["extra_body"] = self.extra_body
return result
@model_validator(mode="after")
@@ -290,6 +312,47 @@ def _is_value_in_range(self, value: float, min_value: float, max_value: float) -
return min_value <= value <= max_value
+# Maintain backwards compatibility with a deprecation warning
+class InferenceParameters(ChatCompletionInferenceParams):
+ """
+ Deprecated: Use ChatCompletionInferenceParams instead.
+ This alias will be removed in a future version.
+ """
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ logger.warning(
+ "InferenceParameters is deprecated and will be removed in a future version. "
+ "Use ChatCompletionInferenceParams instead."
+ )
+ super().__init__(*args, **kwargs)
+
+
+class EmbeddingInferenceParams(BaseInferenceParams):
+ """Configuration for embedding generation parameters.
+
+ Attributes:
+ generation_type: Type of generation, always "embedding" for this class.
+ encoding_format: Format of the embedding encoding ("float" or "base64").
+ dimensions: Number of dimensions for the embedding.
+ """
+
+ generation_type: Literal[GenerationType.EMBEDDING] = GenerationType.EMBEDDING
+ encoding_format: Literal["float", "base64"] = "float"
+ dimensions: Optional[int] = None
+
+ @property
+ def generate_kwargs(self) -> dict[str, Union[float, int]]:
+ result = super().generate_kwargs
+ if self.encoding_format is not None:
+ result["encoding_format"] = self.encoding_format
+ if self.dimensions is not None:
+ result["dimensions"] = self.dimensions
+ return result
+
+
+InferenceParamsT: TypeAlias = Union[ChatCompletionInferenceParams, EmbeddingInferenceParams, InferenceParameters]
+
+
class ModelConfig(ConfigBase):
"""Configuration for a model used for generation.
@@ -297,14 +360,32 @@ class ModelConfig(ConfigBase):
alias: User-defined alias to reference in column configurations.
model: Model identifier (e.g., from build.nvidia.com or other providers).
inference_parameters: Inference parameters for the model (temperature, top_p, max_tokens, etc.).
+ The generation_type is determined by the type of inference_parameters.
provider: Optional model provider name if using custom providers.
"""
alias: str
model: str
- inference_parameters: InferenceParameters = Field(default_factory=InferenceParameters)
+ inference_parameters: InferenceParamsT = Field(default_factory=ChatCompletionInferenceParams)
provider: Optional[str] = None
+ @property
+ def generation_type(self) -> GenerationType:
+ """Get the generation type from the inference parameters."""
+ return self.inference_parameters.generation_type
+
+ @field_validator("inference_parameters", mode="before")
+ @classmethod
+ def _convert_inference_parameters(cls, value: Any) -> Any:
+ """Convert raw dict to appropriate inference parameters type based on field presence."""
+ if isinstance(value, dict):
+ # Infer type from presence of embedding-specific fields
+ if "encoding_format" in value or "dimensions" in value:
+ return EmbeddingInferenceParams(**value)
+ else:
+ return ChatCompletionInferenceParams(**value)
+ return value
+
class ModelProvider(ConfigBase):
"""Configuration for a custom model provider.
diff --git a/src/data_designer/config/utils/constants.py b/src/data_designer/config/utils/constants.py
index 0420240f..12b0d2a3 100644
--- a/src/data_designer/config/utils/constants.py
+++ b/src/data_designer/config/utils/constants.py
@@ -304,10 +304,12 @@ class NordColor(Enum):
"text": "nvidia/nvidia-nemotron-nano-9b-v2",
"reasoning": "openai/gpt-oss-20b",
"vision": "nvidia/nemotron-nano-12b-v2-vl",
+ "embedding": "nvidia/llama-3.2-nv-embedqa-1b-v2",
},
OPENAI_PROVIDER_NAME: {
"text": "gpt-4.1",
"reasoning": "gpt-5",
"vision": "gpt-5",
+ "embedding": "text-embedding-3-large",
},
}
diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/config/utils/validation.py
index d30a067a..7d3654ad 100644
--- a/src/data_designer/config/utils/validation.py
+++ b/src/data_designer/config/utils/validation.py
@@ -15,7 +15,7 @@
from rich.padding import Padding
from rich.panel import Panel
-from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_llm_generated
+from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType, column_type_is_model_generated
from data_designer.config.processors import ProcessorConfig, ProcessorType
from data_designer.config.utils.constants import RICH_CONSOLE_THEME
from data_designer.config.utils.misc import (
@@ -119,7 +119,7 @@ def validate_prompt_templates(
) -> list[Violation]:
env = ImmutableSandboxedEnvironment()
- columns_with_prompts = [c for c in columns if column_type_is_llm_generated(c.column_type)]
+ columns_with_prompts = [c for c in columns if column_type_is_model_generated(c.column_type)]
violations = []
for column in columns_with_prompts:
diff --git a/src/data_designer/config/utils/visualization.py b/src/data_designer/config/utils/visualization.py
index 1f843c86..c2cca08b 100644
--- a/src/data_designer/config/utils/visualization.py
+++ b/src/data_designer/config/utils/visualization.py
@@ -8,7 +8,7 @@
from collections import OrderedDict
from enum import Enum
from functools import cached_property
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import pandas as pd
@@ -194,6 +194,7 @@ def display_sample_record(
+ config_builder.get_columns_of_type(DataDesignerColumnType.EXPRESSION)
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_TEXT)
+ config_builder.get_columns_of_type(DataDesignerColumnType.LLM_STRUCTURED)
+ + config_builder.get_columns_of_type(DataDesignerColumnType.EMBEDDING)
)
if len(non_code_columns) > 0:
table = Table(title="Generated Columns", **table_kws)
@@ -201,6 +202,10 @@ def display_sample_record(
table.add_column("Value")
for col in non_code_columns:
if not col.drop:
+ if col.column_type == DataDesignerColumnType.EMBEDDING:
+ record[col.name]["embeddings"] = [
+ get_truncated_list_as_string(embd) for embd in record[col.name].get("embeddings")
+ ]
table.add_row(col.name, convert_to_row_element(record[col.name]))
render_list.append(pad_console_element(table))
@@ -269,6 +274,16 @@ def display_sample_record(
console.print(Group(*render_list), markup=False)
+def get_truncated_list_as_string(long_list: list[Any], max_items: int = 2) -> str:
+ if max_items <= 0:
+ raise ValueError("max_items must be greater than 0")
+ if len(long_list) > max_items:
+ truncated_part = long_list[:max_items]
+ return f"[{', '.join(str(x) for x in truncated_part)}, ...]"
+ else:
+ return str(long_list)
+
+
def display_sampler_table(
sampler_params: dict[SamplerType, ConfigBase],
title: Optional[str] = None,
diff --git a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py
index 0e0a7c0e..3b2e075b 100644
--- a/src/data_designer/engine/analysis/utils/column_statistics_calculations.py
+++ b/src/data_designer/engine/analysis/utils/column_statistics_calculations.py
@@ -18,8 +18,10 @@
MissingValue,
NumericalDistribution,
)
-from data_designer.config.column_configs import LLMTextColumnConfig
-from data_designer.engine.column_generators.generators.llm_generators import (
+from data_designer.config.column_configs import (
+ LLMTextColumnConfig,
+)
+from data_designer.engine.column_generators.utils.prompt_renderer import (
PromptType,
RecordBasedPromptRenderer,
create_response_recipe,
@@ -92,7 +94,7 @@ def calculate_general_column_info(column_name: str, df: pd.DataFrame) -> dict[st
}
-def calculate_prompt_token_stats(
+def calculate_input_token_stats(
column_config: LLMTextColumnConfig, df: pd.DataFrame
) -> dict[str, float | MissingValue]:
try:
@@ -109,44 +111,44 @@ def calculate_prompt_token_stats(
concatenated_prompt = str(system_prompt + "\n\n" + prompt)
num_tokens.append(len(TOKENIZER.encode(concatenated_prompt, disallowed_special=())))
except Exception as e:
- logger.warning(
- f"{WARNING_PREFIX} failed to calculate prompt token stats for column {column_config.name!r}: {e}"
- )
+ logger.warning(f"{WARNING_PREFIX} failed to calculate input token stats for column {column_config.name!r}: {e}")
return {
- "prompt_tokens_mean": MissingValue.CALCULATION_FAILED,
- "prompt_tokens_median": MissingValue.CALCULATION_FAILED,
- "prompt_tokens_stddev": MissingValue.CALCULATION_FAILED,
+ "input_tokens_mean": MissingValue.CALCULATION_FAILED,
+ "input_tokens_median": MissingValue.CALCULATION_FAILED,
+ "input_tokens_stddev": MissingValue.CALCULATION_FAILED,
}
return {
- "prompt_tokens_mean": np.mean(num_tokens),
- "prompt_tokens_median": np.median(num_tokens),
- "prompt_tokens_stddev": np.std(num_tokens),
+ "input_tokens_mean": np.mean(num_tokens),
+ "input_tokens_median": np.median(num_tokens),
+ "input_tokens_stddev": np.std(num_tokens),
}
-def calculate_completion_token_stats(column_name: str, df: pd.DataFrame) -> dict[str, float | MissingValue]:
+def calculate_output_token_stats(
+ column_config: LLMTextColumnConfig, df: pd.DataFrame
+) -> dict[str, float | MissingValue]:
try:
- tokens_per_record = df[column_name].apply(
+ tokens_per_record = df[column_config.name].apply(
lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
)
return {
- "completion_tokens_mean": tokens_per_record.mean(),
- "completion_tokens_median": tokens_per_record.median(),
- "completion_tokens_stddev": tokens_per_record.std(),
+ "output_tokens_mean": tokens_per_record.mean(),
+ "output_tokens_median": tokens_per_record.median(),
+ "output_tokens_stddev": tokens_per_record.std(),
}
except Exception as e:
- logger.warning(f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_name}: {e}")
+ logger.warning(f"{WARNING_PREFIX} failed to calculate output token stats for column {column_config.name}: {e}")
return {
- "completion_tokens_mean": MissingValue.CALCULATION_FAILED,
- "completion_tokens_median": MissingValue.CALCULATION_FAILED,
- "completion_tokens_stddev": MissingValue.CALCULATION_FAILED,
+ "output_tokens_mean": MissingValue.CALCULATION_FAILED,
+ "output_tokens_median": MissingValue.CALCULATION_FAILED,
+ "output_tokens_stddev": MissingValue.CALCULATION_FAILED,
}
def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
return {
- **calculate_prompt_token_stats(column_config, df),
- **calculate_completion_token_stats(column_config.name, df),
+ **calculate_input_token_stats(column_config, df),
+ **calculate_output_token_stats(column_config, df),
}
diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py
index f4ddb60c..b7ad4c6d 100644
--- a/src/data_designer/engine/column_generators/generators/base.py
+++ b/src/data_designer/engine/column_generators/generators/base.py
@@ -1,13 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+import functools
+import logging
from abc import ABC, abstractmethod
from typing import overload
import pandas as pd
+from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
+from data_designer.config.models import BaseInferenceParams, ModelConfig
from data_designer.config.utils.type_helpers import StrEnum
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
+from data_designer.engine.models.facade import ModelFacade
+
+logger = logging.getLogger(__name__)
class GenerationStrategy(StrEnum):
@@ -59,3 +66,30 @@ def can_generate_from_scratch(self) -> bool:
@abstractmethod
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
+
+
+class WithModelGeneration:
+ @functools.cached_property
+ def model(self) -> ModelFacade:
+ return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
+
+ @functools.cached_property
+ def model_config(self) -> ModelConfig:
+ return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
+
+ @functools.cached_property
+ def inference_parameters(self) -> BaseInferenceParams:
+ return self.model_config.inference_parameters
+
+ def log_pre_generation(self) -> None:
+ emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
+ logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
+ logger.info(f" |-- column name: {self.config.name!r}")
+ logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
+ if self.model_config.provider is None:
+ logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
+
+ def _get_provider_name(self) -> str:
+ model_alias = self.model_config.alias
+ provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
+ return provider.name
diff --git a/src/data_designer/engine/column_generators/generators/embedding.py b/src/data_designer/engine/column_generators/generators/embedding.py
new file mode 100644
index 00000000..a8623db5
--- /dev/null
+++ b/src/data_designer/engine/column_generators/generators/embedding.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+from pydantic import BaseModel, computed_field
+
+from data_designer.config.column_configs import EmbeddingColumnConfig
+from data_designer.engine.column_generators.generators.base import (
+ ColumnGenerator,
+ GenerationStrategy,
+ GeneratorMetadata,
+ WithModelGeneration,
+)
+from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
+from data_designer.engine.resources.resource_provider import ResourceType
+
+
+class EmbeddingGenerationResult(BaseModel):
+ embeddings: list[list[float]]
+
+ @computed_field
+ def num_embeddings(self) -> int:
+ return len(self.embeddings)
+
+ @computed_field
+ def dimension(self) -> int:
+ return len(self.embeddings[0]) if len(self.embeddings) > 0 else 0
+
+
+class EmbeddingCellGenerator(WithModelGeneration, ColumnGenerator[EmbeddingColumnConfig]):
+ @staticmethod
+ def metadata() -> GeneratorMetadata:
+ return GeneratorMetadata(
+ name="embedding_cell_generator",
+ description="Generate embeddings for a text column.",
+ generation_strategy=GenerationStrategy.CELL_BY_CELL,
+ required_resources=[ResourceType.MODEL_REGISTRY],
+ )
+
+ def generate(self, data: dict) -> dict:
+ deserialized_record = deserialize_json_values(data)
+ input_texts = parse_list_string(deserialized_record[self.config.target_column])
+ embeddings = self.model.generate_text_embeddings(input_texts=input_texts)
+ data[self.config.name] = EmbeddingGenerationResult(embeddings=embeddings).model_dump(mode="json")
+ return data
diff --git a/src/data_designer/engine/column_generators/generators/llm_generators.py b/src/data_designer/engine/column_generators/generators/llm_completion.py
similarity index 71%
rename from src/data_designer/engine/column_generators/generators/llm_generators.py
rename to src/data_designer/engine/column_generators/generators/llm_completion.py
index ee0ab58a..ef47ed01 100644
--- a/src/data_designer/engine/column_generators/generators/llm_generators.py
+++ b/src/data_designer/engine/column_generators/generators/llm_completion.py
@@ -10,43 +10,41 @@
LLMStructuredColumnConfig,
LLMTextColumnConfig,
)
-from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
-from data_designer.config.models import InferenceParameters, ModelConfig
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
from data_designer.engine.column_generators.generators.base import (
ColumnGenerator,
GenerationStrategy,
GeneratorMetadata,
+ WithModelGeneration,
)
from data_designer.engine.column_generators.utils.prompt_renderer import (
PromptType,
RecordBasedPromptRenderer,
create_response_recipe,
)
-from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.recipes.base import ResponseRecipe
from data_designer.engine.processing.utils import deserialize_json_values
from data_designer.engine.resources.resource_provider import ResourceType
-DEFAULT_MAX_CONVERSATION_RESTARTS = 5
-DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
+logger = logging.getLogger(__name__)
-logger = logging.getLogger(__name__)
+DEFAULT_MAX_CONVERSATION_RESTARTS = 5
+DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
-class WithLLMGeneration:
+class WithChatCompletionGeneration(WithModelGeneration):
@functools.cached_property
- def model(self) -> ModelFacade:
- return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
+ def response_recipe(self) -> ResponseRecipe:
+ return create_response_recipe(self.config, self.model_config)
- @functools.cached_property
- def model_config(self) -> ModelConfig:
- return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
+ @property
+ def max_conversation_correction_steps(self) -> int:
+ return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
- @functools.cached_property
- def inference_parameters(self) -> InferenceParameters:
- return self.model_config.inference_parameters
+ @property
+ def max_conversation_restarts(self) -> int:
+ return DEFAULT_MAX_CONVERSATION_RESTARTS
@functools.cached_property
def prompt_renderer(self) -> RecordBasedPromptRenderer:
@@ -59,18 +57,6 @@ def prompt_renderer(self) -> RecordBasedPromptRenderer:
},
)
- @functools.cached_property
- def response_recipe(self) -> ResponseRecipe:
- return create_response_recipe(self.config, self.model_config)
-
- @property
- def max_conversation_correction_steps(self) -> int:
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
-
- @property
- def max_conversation_restarts(self) -> int:
- return DEFAULT_MAX_CONVERSATION_RESTARTS
-
def generate(self, data: dict) -> dict:
deserialized_record = deserialize_json_values(data)
@@ -96,7 +82,6 @@ def generate(self, data: dict) -> dict:
max_correction_steps=self.max_conversation_correction_steps,
max_conversation_restarts=self.max_conversation_restarts,
purpose=f"running generation for column '{self.config.name}'",
- **self.inference_parameters.generate_kwargs,
)
data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
@@ -106,21 +91,8 @@ def generate(self, data: dict) -> dict:
return data
- def log_pre_generation(self) -> None:
- emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
- logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
- logger.info(f" |-- column name: {self.config.name!r}")
- logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
- if self.model_config.provider is None:
- logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
-
- def _get_provider_name(self) -> str:
- model_alias = self.model_config.alias
- provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
- return provider.name
-
-class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]):
+class LLMTextCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMTextColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
@@ -131,7 +103,7 @@ def metadata() -> GeneratorMetadata:
)
-class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]):
+class LLMCodeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMCodeColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
@@ -142,7 +114,7 @@ def metadata() -> GeneratorMetadata:
)
-class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
+class LLMStructuredCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
@@ -153,7 +125,7 @@ def metadata() -> GeneratorMetadata:
)
-class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
+class LLMJudgeCellGenerator(WithChatCompletionGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
@staticmethod
def metadata() -> GeneratorMetadata:
return GeneratorMetadata(
@@ -163,10 +135,6 @@ def metadata() -> GeneratorMetadata:
required_resources=[ResourceType.MODEL_REGISTRY],
)
- @property
- def max_conversation_correction_steps(self) -> int:
- return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
-
@property
def max_conversation_restarts(self) -> int:
return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS
diff --git a/src/data_designer/engine/column_generators/registry.py b/src/data_designer/engine/column_generators/registry.py
index 61b43753..6da40269 100644
--- a/src/data_designer/engine/column_generators/registry.py
+++ b/src/data_designer/engine/column_generators/registry.py
@@ -3,6 +3,7 @@
from data_designer.config.base import ConfigBase
from data_designer.config.column_configs import (
+ EmbeddingColumnConfig,
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
@@ -12,8 +13,9 @@
)
from data_designer.config.column_types import DataDesignerColumnType
from data_designer.engine.column_generators.generators.base import ColumnGenerator
+from data_designer.engine.column_generators.generators.embedding import EmbeddingCellGenerator
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
-from data_designer.engine.column_generators.generators.llm_generators import (
+from data_designer.engine.column_generators.generators.llm_completion import (
LLMCodeCellGenerator,
LLMJudgeCellGenerator,
LLMStructuredCellGenerator,
@@ -40,11 +42,11 @@ def create_default_column_generator_registry(with_plugins: bool = True) -> Colum
registry.register(DataDesignerColumnType.LLM_CODE, LLMCodeCellGenerator, LLMCodeColumnConfig)
registry.register(DataDesignerColumnType.LLM_JUDGE, LLMJudgeCellGenerator, LLMJudgeColumnConfig)
registry.register(DataDesignerColumnType.EXPRESSION, ExpressionColumnGenerator, ExpressionColumnConfig)
+ registry.register(DataDesignerColumnType.EMBEDDING, EmbeddingCellGenerator, EmbeddingColumnConfig)
registry.register(DataDesignerColumnType.SAMPLER, SamplerColumnGenerator, SamplerMultiColumnConfig)
registry.register(DataDesignerColumnType.SEED_DATASET, SeedDatasetColumnGenerator, SeedDatasetMultiColumnConfig)
registry.register(DataDesignerColumnType.VALIDATION, ValidationColumnGenerator, ValidationColumnConfig)
registry.register(DataDesignerColumnType.LLM_STRUCTURED, LLMStructuredCellGenerator, LLMStructuredColumnConfig)
-
if with_plugins:
for plugin in PluginRegistry().get_plugins(PluginType.COLUMN_GENERATOR):
registry.register(
diff --git a/src/data_designer/engine/dataset_builders/column_wise_builder.py b/src/data_designer/engine/dataset_builders/column_wise_builder.py
index beeedd6f..063aa15c 100644
--- a/src/data_designer/engine/dataset_builders/column_wise_builder.py
+++ b/src/data_designer/engine/dataset_builders/column_wise_builder.py
@@ -10,15 +10,18 @@
import pandas as pd
-from data_designer.config.column_types import ColumnConfigT, column_type_is_llm_generated
+from data_designer.config.column_types import ColumnConfigT, column_type_is_model_generated
from data_designer.config.dataset_builders import BuildStage
from data_designer.config.processors import (
DropColumnsProcessorConfig,
ProcessorConfig,
ProcessorType,
)
-from data_designer.engine.column_generators.generators.base import ColumnGenerator, GenerationStrategy
-from data_designer.engine.column_generators.generators.llm_generators import WithLLMGeneration
+from data_designer.engine.column_generators.generators.base import (
+ ColumnGenerator,
+ GenerationStrategy,
+ WithModelGeneration,
+)
from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
from data_designer.engine.dataset_builders.errors import DatasetGenerationError, DatasetProcessingError
from data_designer.engine.dataset_builders.multi_column_configs import (
@@ -72,7 +75,7 @@ def single_column_configs(self) -> list[ColumnConfigT]:
@functools.cached_property
def llm_generated_column_configs(self) -> list[ColumnConfigT]:
- return [config for config in self.single_column_configs if column_type_is_llm_generated(config.column_type)]
+ return [config for config in self.single_column_configs if column_type_is_model_generated(config.column_type)]
def build(
self,
@@ -169,7 +172,7 @@ def _run_from_scratch_column_generator(self, generator: ColumnGenerator) -> None
def _run_cell_by_cell_generator(self, generator: ColumnGenerator) -> None:
max_workers = MAX_CONCURRENCY_PER_NON_LLM_GENERATOR
- if isinstance(generator, WithLLMGeneration):
+ if isinstance(generator, WithModelGeneration):
max_workers = generator.inference_parameters.max_parallel_requests
self._fan_out_with_threads(generator, max_workers=max_workers)
@@ -178,12 +181,12 @@ def _run_full_column_generator(self, generator: ColumnGenerator) -> None:
self.batch_manager.update_records(df.to_dict(orient="records"))
def _run_model_health_check_if_needed(self) -> bool:
- if any(column_type_is_llm_generated(config.column_type) for config in self.single_column_configs):
+ if any(column_type_is_model_generated(config.column_type) for config in self.single_column_configs):
self._resource_provider.model_registry.run_health_check(
- set(config.model_alias for config in self.llm_generated_column_configs)
+ list(set(config.model_alias for config in self.llm_generated_column_configs))
)
- def _fan_out_with_threads(self, generator: WithLLMGeneration, max_workers: int) -> None:
+ def _fan_out_with_threads(self, generator: WithModelGeneration, max_workers: int) -> None:
if generator.generation_strategy != GenerationStrategy.CELL_BY_CELL:
raise DatasetGenerationError(
f"Generator {generator.metadata().name} is not a {GenerationStrategy.CELL_BY_CELL} "
diff --git a/src/data_designer/engine/models/facade.py b/src/data_designer/engine/models/facade.py
index e6b83e71..8d969809 100644
--- a/src/data_designer/engine/models/facade.py
+++ b/src/data_designer/engine/models/facade.py
@@ -9,9 +9,9 @@
from typing import Any
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params
-from litellm.types.utils import ModelResponse
+from litellm.types.utils import EmbeddingResponse, ModelResponse
-from data_designer.config.models import ModelConfig, ModelProvider
+from data_designer.config.models import GenerationType, ModelConfig, ModelProvider
from data_designer.engine.model_provider import ModelProviderRegistry
from data_designer.engine.models.errors import (
GenerationValidationFailureError,
@@ -49,6 +49,10 @@ def model_name(self) -> str:
def model_provider(self) -> ModelProvider:
return self._model_provider_registry.get_provider(self._model_config.provider)
+ @property
+ def model_generation_type(self) -> GenerationType:
+ return self._model_config.generation_type
+
@property
def model_provider_name(self) -> str:
return self.model_provider.name
@@ -64,13 +68,12 @@ def usage_stats(self) -> ModelUsageStats:
def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool = False, **kwargs) -> ModelResponse:
logger.debug(
f"Prompting model {self.model_name!r}...",
- extra={"model": self.model_name, "messages": messages, "sensitive": True},
+ extra={"model": self.model_name, "messages": messages},
)
response = None
- if self.model_provider.extra_body:
- kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
+ kwargs = self.consolidate_kwargs(**kwargs)
try:
- response = self._router.completion(self.model_name, messages, **kwargs)
+ response = self._router.completion(model=self.model_name, messages=messages, **kwargs)
logger.debug(
f"Received completion from model {self.model_name!r}",
extra={
@@ -84,9 +87,50 @@ def completion(self, messages: list[dict[str, str]], skip_usage_tracking: bool =
except Exception as e:
raise e
finally:
- if not skip_usage_tracking:
+ if not skip_usage_tracking and response is not None:
self._track_usage(response)
+ def consolidate_kwargs(self, **kwargs) -> dict[str, Any]:
+ # Remove purpose from kwargs to avoid passing it to the model
+ kwargs.pop("purpose", None)
+ kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs}
+ if self.model_provider.extra_body:
+ kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body}
+ return kwargs
+
+ @catch_llm_exceptions
+ def generate_text_embeddings(
+ self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs
+ ) -> list[list[float]]:
+ logger.debug(
+ f"Generating embeddings with model {self.model_name!r}...",
+ extra={
+ "model": self.model_name,
+ "input_count": len(input_texts),
+ },
+ )
+ kwargs = self.consolidate_kwargs(**kwargs)
+ response = None
+ try:
+ response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs)
+ logger.debug(
+ f"Received embeddings from model {self.model_name!r}",
+ extra={
+ "model": self.model_name,
+ "embedding_count": len(response.data) if response.data else 0,
+ "usage": self._usage_stats.model_dump(),
+ },
+ )
+ if response.data and len(response.data) == len(input_texts):
+ return [data["embedding"] for data in response.data]
+ else:
+ raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}")
+ except Exception as e:
+ raise e
+ finally:
+ if not skip_usage_tracking and response is not None:
+ self._track_usage_from_embedding(response)
+
@catch_llm_exceptions
def generate(
self,
@@ -218,8 +262,21 @@ def _track_usage(self, response: ModelResponse | None) -> None:
):
self._usage_stats.extend(
token_usage=TokenUsageStats(
- prompt_tokens=response.usage.prompt_tokens,
- completion_tokens=response.usage.completion_tokens,
+ input_tokens=response.usage.prompt_tokens,
+ output_tokens=response.usage.completion_tokens,
+ ),
+ request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
+ )
+
+ def _track_usage_from_embedding(self, response: EmbeddingResponse | None) -> None:
+ if response is None:
+ self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1))
+ return
+ if response.usage is not None and response.usage.prompt_tokens is not None:
+ self._usage_stats.extend(
+ token_usage=TokenUsageStats(
+ input_tokens=response.usage.prompt_tokens,
+ output_tokens=0,
),
request_usage=RequestUsageStats(successful_requests=1, failed_requests=0),
)
diff --git a/src/data_designer/engine/models/registry.py b/src/data_designer/engine/models/registry.py
index aafd8c80..8fe61d59 100644
--- a/src/data_designer/engine/models/registry.py
+++ b/src/data_designer/engine/models/registry.py
@@ -5,7 +5,7 @@
import logging
-from data_designer.config.models import ModelConfig
+from data_designer.config.models import GenerationType, ModelConfig
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.litellm_overrides import apply_litellm_patches
@@ -73,7 +73,7 @@ def get_model_provider(self, *, model_alias: str) -> ModelProvider:
model_config = self.get_model_config(model_alias=model_alias)
return self._model_provider_registry.get_provider(model_config.provider)
- def run_health_check(self, model_aliases: set[str]) -> None:
+ def run_health_check(self, model_aliases: list[str]) -> None:
logger.info("π©Ί Running health checks for models...")
for model_alias in model_aliases:
model = self.get_model(model_alias=model_alias)
@@ -81,15 +81,24 @@ def run_health_check(self, model_aliases: set[str]) -> None:
f" |-- π Checking {model.model_name!r} in provider named {model.model_provider_name!r} for model alias {model.model_alias!r}..."
)
try:
- model.generate(
- prompt="Hello!",
- parser=lambda x: x,
- system_prompt="You are a helpful assistant.",
- max_correction_steps=0,
- max_conversation_restarts=0,
- skip_usage_tracking=True,
- purpose="running health checks",
- )
+ if model.model_generation_type == GenerationType.EMBEDDING:
+ model.generate_text_embeddings(
+ input_texts=["Hello!"],
+ skip_usage_tracking=True,
+ purpose="running health checks",
+ )
+ elif model.model_generation_type == GenerationType.CHAT_COMPLETION:
+ model.generate(
+ prompt="Hello!",
+ parser=lambda x: x,
+ system_prompt="You are a helpful assistant.",
+ max_correction_steps=0,
+ max_conversation_restarts=0,
+ skip_usage_tracking=True,
+ purpose="running health checks",
+ )
+ else:
+ raise ValueError(f"Unsupported generation type: {model.model_generation_type}")
logger.info(" |-- β
Passed!")
except Exception as e:
logger.error(" |-- β Failed!")
diff --git a/src/data_designer/engine/models/usage.py b/src/data_designer/engine/models/usage.py
index afbea9df..a5c23062 100644
--- a/src/data_designer/engine/models/usage.py
+++ b/src/data_designer/engine/models/usage.py
@@ -11,20 +11,20 @@
class TokenUsageStats(BaseModel):
- prompt_tokens: int = 0
- completion_tokens: int = 0
+ input_tokens: int = 0
+ output_tokens: int = 0
@computed_field
def total_tokens(self) -> int:
- return self.prompt_tokens + self.completion_tokens
+ return self.input_tokens + self.output_tokens
@property
def has_usage(self) -> bool:
return self.total_tokens > 0
- def extend(self, *, prompt_tokens: int, completion_tokens: int) -> None:
- self.prompt_tokens += prompt_tokens
- self.completion_tokens += completion_tokens
+ def extend(self, *, input_tokens: int, output_tokens: int) -> None:
+ self.input_tokens += input_tokens
+ self.output_tokens += output_tokens
class RequestUsageStats(BaseModel):
@@ -56,9 +56,7 @@ def extend(
self, *, token_usage: TokenUsageStats | None = None, request_usage: RequestUsageStats | None = None
) -> None:
if token_usage is not None:
- self.token_usage.extend(
- prompt_tokens=token_usage.prompt_tokens, completion_tokens=token_usage.completion_tokens
- )
+ self.token_usage.extend(input_tokens=token_usage.input_tokens, output_tokens=token_usage.output_tokens)
if request_usage is not None:
self.request_usage.extend(
successful_requests=request_usage.successful_requests, failed_requests=request_usage.failed_requests
diff --git a/src/data_designer/engine/processing/utils.py b/src/data_designer/engine/processing/utils.py
index 3579b3bd..5d42c40e 100644
--- a/src/data_designer/engine/processing/utils.py
+++ b/src/data_designer/engine/processing/utils.py
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+import ast
import json
import logging
+import re
from typing import Any, TypeVar, Union, overload
import pandas as pd
@@ -100,6 +102,42 @@ def deserialize_json_values(data):
return data
+def parse_list_string(text: str) -> list[str]:
+ """Parse a list from a string, handling JSON arrays, Python lists, and trailing commas."""
+ text = text.strip()
+
+ # Try JSON first
+ try:
+ list_obj = json.loads(text)
+ if isinstance(list_obj, list):
+ return _clean_whitespace(list_obj)
+ except json.JSONDecodeError:
+ pass
+
+ # Remove trailing commas before closing brackets (common in JSON-like strings)
+ text_cleaned = re.sub(r",\s*]", "]", text)
+ text_cleaned = re.sub(r",\s*}", "}", text_cleaned)
+
+ # Try JSON again with cleaned text
+ try:
+ return _clean_whitespace(json.loads(text_cleaned))
+ except json.JSONDecodeError:
+ pass
+
+ # Try Python literal eval (handles single quotes)
+ try:
+ return _clean_whitespace(ast.literal_eval(text_cleaned))
+ except (ValueError, SyntaxError):
+ pass
+
+ # If all else fails, return the original text
+ return [text.strip()]
+
+
+def _clean_whitespace(texts: list[str]) -> list[str]:
+ return [text.strip() for text in texts]
+
+
def _verify_columns_are_unique(datasets: list[pd.DataFrame]) -> None:
joined_columns = set()
for df in datasets:
diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py
index 758e837e..0597a504 100644
--- a/tests/cli/conftest.py
+++ b/tests/cli/conftest.py
@@ -9,16 +9,16 @@
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
from data_designer.cli.services.model_service import ModelService
from data_designer.cli.services.provider_service import ProviderService
-from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
+from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
@pytest.fixture
-def stub_inference_parameters() -> InferenceParameters:
- return InferenceParameters(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4)
+def stub_inference_parameters() -> ChatCompletionInferenceParams:
+ return ChatCompletionInferenceParams(temperature=0.7, top_p=0.9, max_tokens=2048, max_parallel_requests=4)
@pytest.fixture
-def stub_model_configs(stub_inference_parameters: InferenceParameters) -> list[ModelConfig]:
+def stub_model_configs(stub_inference_parameters: ChatCompletionInferenceParams) -> list[ModelConfig]:
return [
ModelConfig(
alias="test-alias-1",
@@ -41,7 +41,7 @@ def stub_new_model_config() -> ModelConfig:
alias="test-alias-3",
model="test-model-3",
provider="test-provider-1",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.7,
top_p=0.9,
max_tokens=2048,
diff --git a/tests/cli/controllers/test_model_controller.py b/tests/cli/controllers/test_model_controller.py
index b630b04a..5dfd91fe 100644
--- a/tests/cli/controllers/test_model_controller.py
+++ b/tests/cli/controllers/test_model_controller.py
@@ -9,7 +9,7 @@
from data_designer.cli.controllers.model_controller import ModelController
from data_designer.cli.repositories.model_repository import ModelConfigRegistry
from data_designer.cli.repositories.provider_repository import ModelProviderRegistry, ProviderRepository
-from data_designer.config.models import InferenceParameters, ModelConfig
+from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
@pytest.fixture
@@ -141,7 +141,7 @@ def test_run_updates_model(
alias="test-alias-1-updated",
model="test-model-1-updated",
provider="test-provider-1",
- inference_parameters=InferenceParameters(temperature=0.8, top_p=0.95, max_tokens=1024),
+ inference_parameters=ChatCompletionInferenceParams(temperature=0.8, top_p=0.95, max_tokens=1024),
)
mock_builder = MagicMock()
diff --git a/tests/cli/forms/test_field.py b/tests/cli/forms/test_field.py
index 31db43a1..4f9ba82c 100644
--- a/tests/cli/forms/test_field.py
+++ b/tests/cli/forms/test_field.py
@@ -55,7 +55,7 @@ def test_text_field_validator_receives_string() -> None:
validator.assert_called_with("text")
-@patch("data_designer.cli.ui.prompt_text_input")
+@patch("data_designer.cli.forms.field.prompt_text_input")
def test_text_field_prompt_user_returns_input(mock_prompt: Mock) -> None:
"""Test TextField prompt_user returns user input."""
mock_prompt.return_value = "user input"
@@ -64,8 +64,8 @@ def test_text_field_prompt_user_returns_input(mock_prompt: Mock) -> None:
assert field.prompt_user() == "user input"
-@patch("data_designer.cli.ui.BACK", "BACK_SENTINEL")
-@patch("data_designer.cli.ui.prompt_text_input")
+@patch("data_designer.cli.forms.field.BACK", "BACK_SENTINEL")
+@patch("data_designer.cli.forms.field.prompt_text_input")
def test_text_field_prompt_user_handles_back_navigation(mock_prompt: Mock) -> None:
"""Test TextField prompt_user properly returns BACK sentinel."""
mock_prompt.return_value = "BACK_SENTINEL"
@@ -87,7 +87,7 @@ def test_select_field_value_setter() -> None:
assert field.value == "1"
-@patch("data_designer.cli.ui.select_with_arrows")
+@patch("data_designer.cli.forms.field.select_with_arrows")
def test_select_field_prompt_user_returns_selection(mock_select: Mock) -> None:
"""Test SelectField prompt_user returns user selection."""
mock_select.return_value = "opt1"
@@ -97,8 +97,8 @@ def test_select_field_prompt_user_returns_selection(mock_select: Mock) -> None:
assert field.prompt_user() == "opt1"
-@patch("data_designer.cli.ui.BACK", "BACK_SENTINEL")
-@patch("data_designer.cli.ui.select_with_arrows")
+@patch("data_designer.cli.forms.field.BACK", "BACK_SENTINEL")
+@patch("data_designer.cli.forms.field.select_with_arrows")
def test_select_field_prompt_user_handles_back_navigation(mock_select: Mock) -> None:
"""Test SelectField prompt_user properly returns BACK sentinel."""
mock_select.return_value = "BACK_SENTINEL"
@@ -253,7 +253,7 @@ def test_numeric_field_value_setter_rejects_invalid() -> None:
# NumericField prompt_user tests
-@patch("data_designer.cli.ui.prompt_text_input")
+@patch("data_designer.cli.forms.field.prompt_text_input")
def test_numeric_field_prompt_user_returns_float(mock_prompt: Mock) -> None:
"""Test NumericField prompt_user converts string to float."""
mock_prompt.return_value = "42"
@@ -265,19 +265,19 @@ def test_numeric_field_prompt_user_returns_float(mock_prompt: Mock) -> None:
assert isinstance(result, float)
-@patch("data_designer.cli.ui.prompt_text_input")
+@patch("data_designer.cli.forms.field.prompt_text_input")
def test_numeric_field_prompt_user_returns_none_for_empty(mock_prompt: Mock) -> None:
- """Test NumericField prompt_user returns None for empty input."""
+ """Test NumericField prompt_user returns empty string for empty input on optional field."""
mock_prompt.return_value = ""
field = NumericField(name="optional", prompt="Enter value", required=False)
result = field.prompt_user()
- assert result is None
+ assert result == ""
-@patch("data_designer.cli.ui.BACK", "BACK_SENTINEL")
-@patch("data_designer.cli.ui.prompt_text_input")
+@patch("data_designer.cli.forms.field.BACK", "BACK_SENTINEL")
+@patch("data_designer.cli.forms.field.prompt_text_input")
def test_numeric_field_prompt_user_handles_back_navigation(mock_prompt: Mock) -> None:
"""Test NumericField prompt_user properly returns BACK sentinel."""
mock_prompt.return_value = "BACK_SENTINEL"
@@ -318,3 +318,133 @@ def test_validator_converts_non_string_values() -> None:
# Validator should be called with string representation
validator.assert_called_once_with("42.5")
+
+
+# Tests for clearing values with 'clear' keyword
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_numeric_field_accepts_clear_keyword(mock_prompt: Mock) -> None:
+ """Test NumericField accepts 'clear' keyword to remove value."""
+ mock_prompt.return_value = "clear"
+ field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False)
+
+ result = field.prompt_user()
+
+ assert result == ""
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_numeric_field_accepts_none_keyword(mock_prompt: Mock) -> None:
+ """Test NumericField accepts 'none' keyword to remove value."""
+ mock_prompt.return_value = "none"
+ field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False)
+
+ result = field.prompt_user()
+
+ assert result == ""
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_numeric_field_accepts_default_keyword(mock_prompt: Mock) -> None:
+ """Test NumericField accepts 'default' keyword to remove value."""
+ mock_prompt.return_value = "default"
+ field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False)
+
+ result = field.prompt_user()
+
+ assert result == ""
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_numeric_field_returns_default_for_empty_when_has_default(mock_prompt: Mock) -> None:
+ """Test NumericField returns default value when user enters nothing and default exists."""
+ mock_prompt.return_value = ""
+ field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False)
+
+ result = field.prompt_user()
+
+ assert result == 42.0
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_numeric_field_shows_current_label_with_default(mock_prompt: Mock) -> None:
+ """Test NumericField shows '(current: X)' instead of '(default: X)' when default exists."""
+ mock_prompt.return_value = ""
+ field = NumericField(name="optional", prompt="Enter value", default=42.0, required=False)
+
+ field.prompt_user()
+
+ # Check that prompt_text_input was called with current value info in the prompt
+ call_args = mock_prompt.call_args
+ prompt_arg = call_args[0][0]
+ assert "current" in prompt_arg.lower()
+ assert "42.0" in prompt_arg
+ assert "clear" in prompt_arg.lower()
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_text_field_shows_current_label_with_default(mock_prompt: Mock) -> None:
+ """Test TextField shows '(current: X)' instead of '(default: X)' when default exists."""
+ mock_prompt.return_value = ""
+ field = TextField(name="name", prompt="Enter name", default="test", required=False)
+
+ field.prompt_user()
+
+ # Check that prompt_text_input was called with current value info in the prompt
+ call_args = mock_prompt.call_args
+ prompt_arg = call_args[0][0]
+ assert "current" in prompt_arg.lower()
+ assert "test" in prompt_arg
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_text_field_returns_default_for_empty_when_has_default(mock_prompt: Mock) -> None:
+ """Test TextField returns default value when user enters nothing and default exists."""
+ mock_prompt.return_value = ""
+ field = TextField(name="name", prompt="Enter name", default="test", required=False)
+
+ result = field.prompt_user()
+
+ assert result == "test"
+
+
+def test_numeric_field_value_setter_converts_empty_string_to_none() -> None:
+ """Test NumericField value setter converts empty string to None for optional fields."""
+ field = NumericField(name="optional", prompt="Enter value", required=False)
+
+ field.value = ""
+
+ assert field.value is None
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_text_field_accepts_clear_keyword(mock_prompt: Mock) -> None:
+ """Test TextField accepts 'clear' keyword to remove value."""
+ mock_prompt.return_value = "clear"
+ field = TextField(name="optional", prompt="Enter value", default="test", required=False)
+
+ result = field.prompt_user()
+
+ assert result == ""
+
+
+@patch("data_designer.cli.forms.field.prompt_text_input")
+def test_text_field_shows_clear_instruction_for_optional_with_default(mock_prompt: Mock) -> None:
+ """Test TextField shows clear instruction for optional fields with default values."""
+ mock_prompt.return_value = ""
+ field = TextField(name="optional", prompt="Enter value", default="test", required=False)
+
+ field.prompt_user()
+
+ # Check that prompt includes 'clear' instruction
+ call_args = mock_prompt.call_args
+ prompt_arg = call_args[0][0]
+ assert "clear" in prompt_arg.lower()
+
+
+def test_text_field_value_setter_converts_empty_string_to_none() -> None:
+ """Test TextField value setter converts empty string to None for optional fields."""
+ field = TextField(name="optional", prompt="Enter value", required=False)
+
+ field.value = ""
+
+ assert field.value is None
diff --git a/tests/cli/forms/test_model_builder.py b/tests/cli/forms/test_model_builder.py
index 806f2dcf..e75eb745 100644
--- a/tests/cli/forms/test_model_builder.py
+++ b/tests/cli/forms/test_model_builder.py
@@ -5,7 +5,7 @@
from data_designer.cli.forms.field import ValidationError
from data_designer.cli.forms.model_builder import ModelFormBuilder
-from data_designer.config.models import ModelConfig
+from data_designer.config.models import GenerationType, ModelConfig
# Alias validation tests - test through public form interface
@@ -102,17 +102,16 @@ def test_form_omits_provider_field_with_no_providers() -> None:
def test_form_has_all_required_fields() -> None:
- """Test form includes all essential configuration fields."""
+ """Test basic form includes essential configuration fields (inference params are in separate form)."""
builder = ModelFormBuilder()
form = builder.create_form()
- # All required fields must be present
+ # All required fields must be present in basic form
assert form.get_field("alias") is not None
assert form.get_field("model") is not None
- assert form.get_field("temperature") is not None
- assert form.get_field("top_p") is not None
- assert form.get_field("max_tokens") is not None
+ assert form.get_field("generation_type") is not None
+ # inference_parameters are now collected in a separate form via _create_inference_params_form
# Initial data handling tests
@@ -122,6 +121,7 @@ def test_form_uses_initial_data_for_field_defaults() -> None:
"alias": "my-model",
"model": "gpt-4",
"inference_parameters": {
+ "generation_type": GenerationType.CHAT_COMPLETION,
"temperature": 0.5,
"top_p": 0.8,
"max_tokens": 1024,
@@ -133,9 +133,28 @@ def test_form_uses_initial_data_for_field_defaults() -> None:
assert form.get_field("alias").default == "my-model"
assert form.get_field("model").default == "gpt-4"
- assert form.get_field("temperature").default == 0.5
- assert form.get_field("top_p").default == 0.8
- assert form.get_field("max_tokens").default == 1024
+ assert form.get_field("generation_type").default == GenerationType.CHAT_COMPLETION
+
+
+def test_form_extracts_generation_type_from_inference_parameters() -> None:
+ """Test form correctly extracts generation_type from nested inference_parameters for embedding models."""
+ initial_data = {
+ "alias": "embedding-model",
+ "model": "text-embedding-3",
+ "inference_parameters": {
+ "generation_type": GenerationType.EMBEDDING,
+ "encoding_format": "base64",
+ "dimensions": 512,
+ },
+ "provider": "openai",
+ }
+ builder = ModelFormBuilder()
+
+ form = builder.create_form(initial_data)
+
+ assert form.get_field("alias").default == "embedding-model"
+ assert form.get_field("model").default == "text-embedding-3"
+ assert form.get_field("generation_type").default == GenerationType.EMBEDDING
def test_form_uses_standard_defaults_without_initial_data() -> None:
@@ -146,9 +165,7 @@ def test_form_uses_standard_defaults_without_initial_data() -> None:
assert form.get_field("alias").default is None
assert form.get_field("model").default is None
- assert form.get_field("temperature").default == 0.7
- assert form.get_field("top_p").default == 0.9
- assert form.get_field("max_tokens").default == 2048
+ assert form.get_field("generation_type").default == GenerationType.CHAT_COMPLETION
def test_form_handles_partial_initial_data() -> None:
@@ -164,10 +181,8 @@ def test_form_handles_partial_initial_data() -> None:
# Should use provided values
assert form.get_field("alias").default == "my-model"
assert form.get_field("model").default == "gpt-4"
- # Should fall back to standard defaults for missing values
- assert form.get_field("temperature").default == 0.7
- assert form.get_field("top_p").default == 0.9
- assert form.get_field("max_tokens").default == 2048
+ # Should fall back to standard defaults for missing generation_type
+ assert form.get_field("generation_type").default == GenerationType.CHAT_COMPLETION
def test_form_provider_defaults_to_first_when_multiple_available() -> None:
@@ -199,9 +214,11 @@ def test_build_config_uses_provider_from_form_data() -> None:
"alias": "my-model",
"model": "gpt-4",
"provider": "anthropic",
- "temperature": 0.7,
- "top_p": 0.9,
- "max_tokens": 2048,
+ "inference_parameters": {
+ "temperature": 0.7,
+ "top_p": 0.9,
+ "max_tokens": 2048,
+ },
}
config = builder.build_config(form_data)
@@ -215,9 +232,11 @@ def test_build_config_infers_single_available_provider() -> None:
form_data = {
"alias": "my-model",
"model": "gpt-4",
- "temperature": 0.7,
- "top_p": 0.9,
- "max_tokens": 2048,
+ "inference_parameters": {
+ "temperature": 0.7,
+ "top_p": 0.9,
+ "max_tokens": 2048,
+ },
}
config = builder.build_config(form_data)
@@ -231,9 +250,11 @@ def test_build_config_sets_provider_none_when_unavailable() -> None:
form_data = {
"alias": "my-model",
"model": "gpt-4",
- "temperature": 0.7,
- "top_p": 0.9,
- "max_tokens": 2048,
+ "inference_parameters": {
+ "temperature": 0.7,
+ "top_p": 0.9,
+ "max_tokens": 2048,
+ },
}
config = builder.build_config(form_data)
@@ -247,9 +268,11 @@ def test_build_config_creates_valid_model_config() -> None:
form_data = {
"alias": "test-model",
"model": "gpt-4-turbo",
- "temperature": 0.5,
- "top_p": 0.8,
- "max_tokens": 1024,
+ "inference_parameters": {
+ "temperature": 0.5,
+ "top_p": 0.8,
+ "max_tokens": 1024,
+ },
}
config = builder.build_config(form_data)
@@ -265,19 +288,20 @@ def test_build_config_creates_valid_model_config() -> None:
def test_build_config_converts_max_tokens_to_int() -> None:
- """Test build_config converts max_tokens from float to int."""
+ """Test build_config handles numeric values in inference parameters."""
builder = ModelFormBuilder()
form_data = {
"alias": "my-model",
"model": "gpt-4",
- "temperature": 0.7,
- "top_p": 0.9,
- "max_tokens": 2048.0, # NumericField returns float
+ "inference_parameters": {
+ "temperature": 0.7,
+ "top_p": 0.9,
+ "max_tokens": 2048,
+ },
}
config = builder.build_config(form_data)
- assert isinstance(config.inference_parameters.max_tokens, int)
assert config.inference_parameters.max_tokens == 2048
@@ -289,9 +313,11 @@ def test_build_config_prefers_explicit_provider_over_inference() -> None:
"alias": "my-model",
"model": "gpt-4",
"provider": "custom", # Explicitly overridden
- "temperature": 0.7,
- "top_p": 0.9,
- "max_tokens": 2048,
+ "inference_parameters": {
+ "temperature": 0.7,
+ "top_p": 0.9,
+ "max_tokens": 2048,
+ },
}
config = builder.build_config(form_data)
@@ -322,14 +348,17 @@ def test_full_workflow_creates_valid_config() -> None:
# Simulate user accepting defaults (get_values would return these)
form.set_values(initial_data)
- # Flatten inference_parameters for form data
+
+ # Form data now includes inference_parameters as a dict
form_data = {
"alias": "new-model",
"model": "claude-3-opus",
"provider": "anthropic",
- "temperature": 0.6,
- "top_p": 0.95,
- "max_tokens": 4096,
+ "inference_parameters": {
+ "temperature": 0.6,
+ "top_p": 0.95,
+ "max_tokens": 4096,
+ },
}
# Build config
@@ -342,3 +371,140 @@ def test_full_workflow_creates_valid_config() -> None:
assert config.inference_parameters.temperature == 0.6
assert config.inference_parameters.top_p == 0.95
assert config.inference_parameters.max_tokens == 4096
+
+
+# Tests for new two-step form process
+def test_create_inference_params_form_for_chat_completion() -> None:
+ """Test creating inference parameters form for chat completion models."""
+ builder = ModelFormBuilder()
+
+ params_form = builder.create_inference_params_form(GenerationType.CHAT_COMPLETION)
+
+ # Should have chat completion specific fields
+ assert params_form.get_field("temperature") is not None
+ assert params_form.get_field("top_p") is not None
+ assert params_form.get_field("max_tokens") is not None
+ # Should not have embedding fields
+ assert params_form.get_field("encoding_format") is None
+ assert params_form.get_field("dimensions") is None
+
+
+def test_create_inference_params_form_for_embedding() -> None:
+ """Test creating inference parameters form for embedding models."""
+ builder = ModelFormBuilder()
+
+ params_form = builder.create_inference_params_form(GenerationType.EMBEDDING)
+
+ # Should have embedding specific fields
+ assert params_form.get_field("encoding_format") is not None
+ assert params_form.get_field("dimensions") is not None
+ # Should not have chat completion fields
+ assert params_form.get_field("temperature") is None
+ assert params_form.get_field("top_p") is None
+ assert params_form.get_field("max_tokens") is None
+
+
+def test_create_inference_params_form_uses_initial_params() -> None:
+ """Test inference parameters form uses initial values from existing config."""
+ builder = ModelFormBuilder()
+ initial_params = {"temperature": 0.8, "top_p": 0.95, "max_tokens": 2048}
+
+ params_form = builder.create_inference_params_form(GenerationType.CHAT_COMPLETION, initial_params)
+
+ assert params_form.get_field("temperature").default == 0.8
+ assert params_form.get_field("top_p").default == 0.95
+ assert params_form.get_field("max_tokens").default == 2048
+
+
+def test_build_inference_params_chat_completion_with_all_values() -> None:
+ """Test building inference params dict from chat completion form data."""
+ builder = ModelFormBuilder()
+ params_data = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024.0}
+
+ result = builder.build_inference_params(GenerationType.CHAT_COMPLETION, params_data)
+
+ assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024}
+
+
+def test_build_inference_params_chat_completion_with_partial_values() -> None:
+ """Test building inference params dict with only some values provided."""
+ builder = ModelFormBuilder()
+ params_data = {"temperature": 0.7, "top_p": None, "max_tokens": None}
+
+ result = builder.build_inference_params(GenerationType.CHAT_COMPLETION, params_data)
+
+ # Only provided values should be included
+ assert result == {"temperature": 0.7}
+
+
+def test_build_inference_params_embedding_with_all_values() -> None:
+ """Test building inference params dict from embedding form data."""
+ builder = ModelFormBuilder()
+ params_data = {"encoding_format": "float", "dimensions": 1024.0}
+
+ result = builder.build_inference_params(GenerationType.EMBEDDING, params_data)
+
+ assert result == {"encoding_format": "float", "dimensions": 1024}
+
+
+def test_build_inference_params_embedding_with_partial_values() -> None:
+ """Test building embedding inference params with only some values provided."""
+ builder = ModelFormBuilder()
+ params_data = {"encoding_format": "float", "dimensions": None}
+
+ result = builder.build_inference_params(GenerationType.EMBEDDING, params_data)
+
+ # encoding_format always included, dimensions omitted if not provided
+ assert result == {"encoding_format": "float"}
+
+
+def test_build_inference_params_embedding_all_cleared() -> None:
+ """Test building embedding inference params when both are cleared."""
+ builder = ModelFormBuilder()
+ params_data = {"encoding_format": None, "dimensions": None}
+
+ result = builder.build_inference_params(GenerationType.EMBEDDING, params_data)
+
+ # Empty dict; Pydantic will use defaults (encoding_format="float", dimensions=None)
+ assert result == {}
+
+
+def test_validate_encoding_format_accepts_valid_values() -> None:
+ """Test encoding format validation accepts 'float' and 'base64'."""
+ builder = ModelFormBuilder()
+
+ is_valid, error = builder.validate_encoding_format("float")
+ assert is_valid is True
+ assert error is None
+
+ is_valid, error = builder.validate_encoding_format("base64")
+ assert is_valid is True
+ assert error is None
+
+
+def test_validate_encoding_format_rejects_invalid_values() -> None:
+ """Test encoding format validation rejects invalid values."""
+ builder = ModelFormBuilder()
+
+ is_valid, error = builder.validate_encoding_format("invalid")
+ assert is_valid is False
+ assert "float" in error and "base64" in error
+
+
+def test_validate_encoding_format_accepts_empty_string() -> None:
+ """Test encoding format validation accepts empty string (optional field)."""
+ builder = ModelFormBuilder()
+
+ is_valid, error = builder.validate_encoding_format("")
+ assert is_valid is True
+ assert error is None
+
+
+def test_validate_encoding_format_accepts_clear_keywords() -> None:
+ """Test encoding format validation accepts clearing keywords."""
+ builder = ModelFormBuilder()
+
+ for keyword in ("clear", "none", "default", "CLEAR", "None"):
+ is_valid, error = builder.validate_encoding_format(keyword)
+ assert is_valid is True, f"Failed for keyword: {keyword}"
+ assert error is None
diff --git a/tests/cli/repositories/test_model_repository.py b/tests/cli/repositories/test_model_repository.py
index 01884b5c..624cd360 100644
--- a/tests/cli/repositories/test_model_repository.py
+++ b/tests/cli/repositories/test_model_repository.py
@@ -21,7 +21,9 @@ def test_load_does_not_exist():
def test_load_exists(tmp_path: Path, stub_model_configs: list[ModelConfig]):
model_configs_file_path = tmp_path / MODEL_CONFIGS_FILE_NAME
- save_config_file(model_configs_file_path, {"model_configs": [mc.model_dump() for mc in stub_model_configs]})
+ save_config_file(
+ model_configs_file_path, {"model_configs": [mc.model_dump(mode="json") for mc in stub_model_configs]}
+ )
repository = ModelRepository(tmp_path)
assert repository.load() is not None
assert repository.load().model_configs == stub_model_configs
diff --git a/tests/cli/services/test_model_service.py b/tests/cli/services/test_model_service.py
index 1d9bf5aa..2551a376 100644
--- a/tests/cli/services/test_model_service.py
+++ b/tests/cli/services/test_model_service.py
@@ -7,7 +7,7 @@
from data_designer.cli.repositories.model_repository import ModelRepository
from data_designer.cli.services.model_service import ModelService
-from data_designer.config.models import InferenceParameters, ModelConfig
+from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
def test_list_all(stub_model_service: ModelService, stub_model_configs: list[ModelConfig]):
@@ -30,7 +30,9 @@ def test_add(
assert stub_model_service.list_all() == stub_model_configs + [stub_new_model_config]
-def test_add_duplicate_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters):
+def test_add_duplicate_alias(
+ stub_model_service: ModelService, stub_inference_parameters: ChatCompletionInferenceParams
+):
"""Test adding a model with an alias that already exists."""
duplicate_model = ModelConfig(
alias="test-alias-1",
@@ -61,7 +63,9 @@ def test_update_nonexistent_model(stub_model_service: ModelService, stub_new_mod
stub_model_service.update("nonexistent", stub_new_model_config)
-def test_update_to_existing_alias(stub_model_service: ModelService, stub_inference_parameters: InferenceParameters):
+def test_update_to_existing_alias(
+ stub_model_service: ModelService, stub_inference_parameters: ChatCompletionInferenceParams
+):
"""Test updating a model to an alias that already exists."""
updated_model = ModelConfig(
alias="test-alias-2", # Already exists
diff --git a/tests/config/analysis/conftest.py b/tests/config/analysis/conftest.py
index 42e6fb96..2f683718 100644
--- a/tests/config/analysis/conftest.py
+++ b/tests/config/analysis/conftest.py
@@ -38,12 +38,12 @@ def sample_llm_text_column_stats():
num_unique=490,
pyarrow_dtype="string",
simple_dtype="str",
- completion_tokens_mean=150.5,
- completion_tokens_median=150.0,
- completion_tokens_stddev=25.2,
- prompt_tokens_mean=50.0,
- prompt_tokens_median=50.0,
- prompt_tokens_stddev=10.0,
+ output_tokens_mean=150.5,
+ output_tokens_median=150.0,
+ output_tokens_stddev=25.2,
+ input_tokens_mean=50.0,
+ input_tokens_median=50.0,
+ input_tokens_stddev=10.0,
)
diff --git a/tests/config/analysis/test_column_statistics.py b/tests/config/analysis/test_column_statistics.py
index 34a00a39..a3a262b2 100644
--- a/tests/config/analysis/test_column_statistics.py
+++ b/tests/config/analysis/test_column_statistics.py
@@ -124,12 +124,12 @@ def test_llm_text_column_statistics_with_missing_values(
):
llm_text_column_statistics = llm_text_column_statistics_based_class(
**stub_general_stats_args_with_missing_values,
- completion_tokens_mean=MissingValue.CALCULATION_FAILED,
- completion_tokens_median=MissingValue.CALCULATION_FAILED,
- completion_tokens_stddev=MissingValue.CALCULATION_FAILED,
- prompt_tokens_mean=MissingValue.CALCULATION_FAILED,
- prompt_tokens_median=MissingValue.CALCULATION_FAILED,
- prompt_tokens_stddev=MissingValue.CALCULATION_FAILED,
+ output_tokens_mean=MissingValue.CALCULATION_FAILED,
+ output_tokens_median=MissingValue.CALCULATION_FAILED,
+ output_tokens_stddev=MissingValue.CALCULATION_FAILED,
+ input_tokens_mean=MissingValue.CALCULATION_FAILED,
+ input_tokens_median=MissingValue.CALCULATION_FAILED,
+ input_tokens_stddev=MissingValue.CALCULATION_FAILED,
)
assert llm_text_column_statistics.column_type == column_type
assert llm_text_column_statistics.create_report_row_data() == {
@@ -155,12 +155,12 @@ def test_llm_text_column_statistics_with_valid_values(
):
llm_text_column_statistics = llm_text_column_statistics_based_class(
**stub_general_stats_args_with_valid_values,
- completion_tokens_mean=150.0,
- completion_tokens_median=150.0,
- completion_tokens_stddev=25.2,
- prompt_tokens_mean=50.0,
- prompt_tokens_median=50.0,
- prompt_tokens_stddev=10.0,
+ output_tokens_mean=150.0,
+ output_tokens_median=150.0,
+ output_tokens_stddev=25.2,
+ input_tokens_mean=50.0,
+ input_tokens_median=50.0,
+ input_tokens_stddev=10.0,
)
assert llm_text_column_statistics.column_type == column_type
assert llm_text_column_statistics.create_report_row_data() == {
diff --git a/tests/config/test_columns.py b/tests/config/test_columns.py
index ef5e63f3..cbc95756 100644
--- a/tests/config/test_columns.py
+++ b/tests/config/test_columns.py
@@ -5,6 +5,7 @@
from pydantic import ValidationError
from data_designer.config.column_configs import (
+ EmbeddingColumnConfig,
ExpressionColumnConfig,
LLMCodeColumnConfig,
LLMJudgeColumnConfig,
@@ -17,7 +18,7 @@
)
from data_designer.config.column_types import (
DataDesignerColumnType,
- column_type_is_llm_generated,
+ column_type_is_model_generated,
column_type_used_in_execution_dag,
get_column_config_from_kwargs,
get_column_display_order,
@@ -49,20 +50,22 @@ def test_data_designer_column_type_get_display_order():
DataDesignerColumnType.LLM_CODE,
DataDesignerColumnType.LLM_STRUCTURED,
DataDesignerColumnType.LLM_JUDGE,
+ DataDesignerColumnType.EMBEDDING,
DataDesignerColumnType.VALIDATION,
DataDesignerColumnType.EXPRESSION,
]
def test_data_designer_column_type_is_llm_generated():
- assert column_type_is_llm_generated(DataDesignerColumnType.LLM_TEXT)
- assert column_type_is_llm_generated(DataDesignerColumnType.LLM_CODE)
- assert column_type_is_llm_generated(DataDesignerColumnType.LLM_STRUCTURED)
- assert column_type_is_llm_generated(DataDesignerColumnType.LLM_JUDGE)
- assert not column_type_is_llm_generated(DataDesignerColumnType.SAMPLER)
- assert not column_type_is_llm_generated(DataDesignerColumnType.VALIDATION)
- assert not column_type_is_llm_generated(DataDesignerColumnType.EXPRESSION)
- assert not column_type_is_llm_generated(DataDesignerColumnType.SEED_DATASET)
+ assert column_type_is_model_generated(DataDesignerColumnType.LLM_TEXT)
+ assert column_type_is_model_generated(DataDesignerColumnType.LLM_CODE)
+ assert column_type_is_model_generated(DataDesignerColumnType.LLM_STRUCTURED)
+ assert column_type_is_model_generated(DataDesignerColumnType.LLM_JUDGE)
+ assert column_type_is_model_generated(DataDesignerColumnType.EMBEDDING)
+ assert not column_type_is_model_generated(DataDesignerColumnType.SAMPLER)
+ assert not column_type_is_model_generated(DataDesignerColumnType.VALIDATION)
+ assert not column_type_is_model_generated(DataDesignerColumnType.EXPRESSION)
+ assert not column_type_is_model_generated(DataDesignerColumnType.SEED_DATASET)
def test_data_designer_column_type_is_in_dag():
@@ -72,6 +75,7 @@ def test_data_designer_column_type_is_in_dag():
assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED)
assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT)
assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION)
+ assert column_type_used_in_execution_dag(DataDesignerColumnType.EMBEDDING)
assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER)
assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET)
@@ -212,6 +216,20 @@ def test_validation_column_config():
assert validation_column_config.batch_size == 5
+def test_embedding_column_config():
+ embedding_column_config = EmbeddingColumnConfig(
+ name="test_embedding",
+ target_column="test_column",
+ model_alias=stub_model_alias,
+ )
+
+ assert embedding_column_config.column_type == DataDesignerColumnType.EMBEDDING
+ assert embedding_column_config.target_column == "test_column"
+ assert embedding_column_config.model_alias == stub_model_alias
+ assert embedding_column_config.required_columns == ["test_column"]
+ assert embedding_column_config.side_effect_columns == []
+
+
def test_get_column_config_from_kwargs():
assert isinstance(
get_column_config_from_kwargs(
@@ -278,6 +296,16 @@ def test_get_column_config_from_kwargs():
ExpressionColumnConfig,
)
+ assert isinstance(
+ get_column_config_from_kwargs(
+ name="test_embedding",
+ column_type=DataDesignerColumnType.EMBEDDING,
+ target_column="test_column",
+ model_alias=stub_model_alias,
+ ),
+ EmbeddingColumnConfig,
+ )
+
# sampler params is a dictionary
assert isinstance(
get_column_config_from_kwargs(
diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py
index 8a90cc1a..f30af9e7 100644
--- a/tests/config/test_config_builder.py
+++ b/tests/config/test_config_builder.py
@@ -26,7 +26,7 @@
from data_designer.config.data_designer_config import DataDesignerConfig
from data_designer.config.datastore import DatastoreSettings
from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError
-from data_designer.config.models import InferenceParameters, ModelConfig
+from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint
from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams
from data_designer.config.seed import DatastoreSeedDatasetReference, SamplingStrategy
@@ -54,7 +54,7 @@ def stub_data_designer_builder(stub_data_designer_builder_config_str):
def test_loading_model_configs_in_constructor(stub_model_configs):
- stub_model_configs_dict = [mc.model_dump() for mc in stub_model_configs]
+ stub_model_configs_dict = [mc.model_dump(mode="json") for mc in stub_model_configs]
# test loading model configs from a list
builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
assert builder.model_configs == stub_model_configs
@@ -670,7 +670,7 @@ def test_add_model_config(stub_empty_builder):
new_model_config = ModelConfig(
alias="new-model",
model="openai/gpt-4",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.7,
top_p=0.95,
max_tokens=1024,
@@ -691,7 +691,7 @@ def test_add_model_config(stub_empty_builder):
alias="provider-model",
model="anthropic/claude-3",
provider="anthropic",
- inference_parameters=InferenceParameters(temperature=0.8),
+ inference_parameters=ChatCompletionInferenceParams(temperature=0.8),
)
stub_empty_builder.add_model_config(provider_model_config)
@@ -717,7 +717,7 @@ def test_add_model_config_duplicate_alias(stub_empty_builder):
duplicate_model_config = ModelConfig(
alias="stub-model",
model="different/model",
- inference_parameters=InferenceParameters(temperature=0.5),
+ inference_parameters=ChatCompletionInferenceParams(temperature=0.5),
)
with pytest.raises(
@@ -733,12 +733,12 @@ def test_delete_model_config(stub_empty_builder):
model_config_1 = ModelConfig(
alias="model-to-delete",
model="model/delete",
- inference_parameters=InferenceParameters(temperature=0.5),
+ inference_parameters=ChatCompletionInferenceParams(temperature=0.5),
)
model_config_2 = ModelConfig(
alias="model-to-keep",
model="model/keep",
- inference_parameters=InferenceParameters(temperature=0.6),
+ inference_parameters=ChatCompletionInferenceParams(temperature=0.6),
)
stub_empty_builder.add_model_config(model_config_1)
stub_empty_builder.add_model_config(model_config_2)
diff --git a/tests/config/test_default_model_settings.py b/tests/config/test_default_model_settings.py
index 222bb410..9619af02 100644
--- a/tests/config/test_default_model_settings.py
+++ b/tests/config/test_default_model_settings.py
@@ -18,28 +18,35 @@
get_default_providers,
resolve_seed_default_model_settings,
)
-from data_designer.config.models import InferenceParameters
+from data_designer.config.models import ChatCompletionInferenceParams, EmbeddingInferenceParams
from data_designer.config.utils.visualization import get_nvidia_api_key, get_openai_api_key
def test_get_default_inference_parameters():
- assert get_default_inference_parameters("text") == InferenceParameters(
+ assert get_default_inference_parameters("text", "nvidia") == ChatCompletionInferenceParams(
temperature=0.85,
top_p=0.95,
)
- assert get_default_inference_parameters("reasoning") == InferenceParameters(
+ assert get_default_inference_parameters("reasoning", "nvidia") == ChatCompletionInferenceParams(
temperature=0.35,
top_p=0.95,
)
- assert get_default_inference_parameters("vision") == InferenceParameters(
+ assert get_default_inference_parameters("vision", "nvidia") == ChatCompletionInferenceParams(
temperature=0.85,
top_p=0.95,
)
+ assert get_default_inference_parameters("embedding", "nvidia") == EmbeddingInferenceParams(
+ encoding_format="float",
+ extra_body={"input_type": "query"},
+ )
+ assert get_default_inference_parameters("embedding", "openai") == EmbeddingInferenceParams(
+ encoding_format="float",
+ )
def test_get_builtin_model_configs():
builtin_model_configs = get_builtin_model_configs()
- assert len(builtin_model_configs) == 6
+ assert len(builtin_model_configs) == 8
assert builtin_model_configs[0].alias == "nvidia-text"
assert builtin_model_configs[0].model == "nvidia/nvidia-nemotron-nano-9b-v2"
assert builtin_model_configs[0].provider == "nvidia"
@@ -49,11 +56,21 @@ def test_get_builtin_model_configs():
assert builtin_model_configs[2].alias == "nvidia-vision"
assert builtin_model_configs[2].model == "nvidia/nemotron-nano-12b-v2-vl"
assert builtin_model_configs[2].provider == "nvidia"
- assert builtin_model_configs[3].alias == "openai-text"
- assert builtin_model_configs[3].model == "gpt-4.1"
- assert builtin_model_configs[3].provider == "openai"
- assert builtin_model_configs[4].alias == "openai-reasoning"
- assert builtin_model_configs[4].model == "gpt-5"
+ assert builtin_model_configs[3].alias == "nvidia-embedding"
+ assert builtin_model_configs[3].model == "nvidia/llama-3.2-nv-embedqa-1b-v2"
+ assert builtin_model_configs[3].provider == "nvidia"
+ assert builtin_model_configs[4].alias == "openai-text"
+ assert builtin_model_configs[4].model == "gpt-4.1"
+ assert builtin_model_configs[4].provider == "openai"
+ assert builtin_model_configs[5].alias == "openai-reasoning"
+ assert builtin_model_configs[5].model == "gpt-5"
+ assert builtin_model_configs[5].provider == "openai"
+ assert builtin_model_configs[6].alias == "openai-vision"
+ assert builtin_model_configs[6].model == "gpt-5"
+ assert builtin_model_configs[6].provider == "openai"
+ assert builtin_model_configs[7].alias == "openai-embedding"
+ assert builtin_model_configs[7].model == "text-embedding-3-large"
+ assert builtin_model_configs[7].provider == "openai"
def test_get_builtin_model_providers():
diff --git a/tests/config/test_models.py b/tests/config/test_models.py
index 186a4f19..6f6618c2 100644
--- a/tests/config/test_models.py
+++ b/tests/config/test_models.py
@@ -11,9 +11,11 @@
from data_designer.config.errors import InvalidConfigError
from data_designer.config.models import (
+ ChatCompletionInferenceParams,
+ EmbeddingInferenceParams,
+ GenerationType,
ImageContext,
ImageFormat,
- InferenceParameters,
ManualDistribution,
ManualDistributionParams,
ModalityDataType,
@@ -46,13 +48,13 @@ def test_image_context_validate_image_format():
def test_inference_parameters_default_construction():
- empty_inference_parameters = InferenceParameters()
+ empty_inference_parameters = ChatCompletionInferenceParams()
assert empty_inference_parameters.generate_kwargs == {}
assert empty_inference_parameters.max_parallel_requests == 4
def test_inference_parameters_generate_kwargs():
- assert InferenceParameters(
+ assert ChatCompletionInferenceParams(
temperature=0.95,
top_p=0.95,
max_tokens=100,
@@ -67,9 +69,9 @@ def test_inference_parameters_generate_kwargs():
"extra_body": {"reasoning_effort": "high"},
}
- assert InferenceParameters().generate_kwargs == {}
+ assert ChatCompletionInferenceParams().generate_kwargs == {}
- inference_parameters_kwargs = InferenceParameters(
+ inference_parameters_kwargs = ChatCompletionInferenceParams(
temperature=UniformDistribution(params=UniformDistributionParams(low=0.0, high=1.0)),
top_p=ManualDistribution(params=ManualDistributionParams(values=[0.0, 1.0], weights=[0.5, 0.5])),
).generate_kwargs
@@ -131,32 +133,38 @@ def test_inference_parameters_temperature_validation():
# All temp values provide in a manual destribution should be valid
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(
+ ChatCompletionInferenceParams(
temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.5], weights=[0.5, 0.5]))
)
# High and low values of uniform distribution should be valid
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5)))
+ ChatCompletionInferenceParams(
+ temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.5))
+ )
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0)))
+ ChatCompletionInferenceParams(
+ temperature=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=2.0))
+ )
# Static values should be valid
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(temperature=3.0)
+ ChatCompletionInferenceParams(temperature=3.0)
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(temperature=-1.0)
+ ChatCompletionInferenceParams(temperature=-1.0)
# Valid temperature values shouldn't raise validation errors
try:
- InferenceParameters(temperature=0.1)
- InferenceParameters(temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0)))
- InferenceParameters(
+ ChatCompletionInferenceParams(temperature=0.1)
+ ChatCompletionInferenceParams(
+ temperature=UniformDistribution(params=UniformDistributionParams(low=0.5, high=2.0))
+ )
+ ChatCompletionInferenceParams(
temperature=ManualDistribution(params=ManualDistributionParams(values=[0.5, 2.0], weights=[0.5, 0.5]))
)
except Exception:
- pytest.fail("Unexpected exception raised during InferenceParameters temperature validation")
+ pytest.fail("Unexpected exception raised during CompletionInferenceParameters temperature validation")
def test_generation_parameters_top_p_validation():
@@ -164,31 +172,31 @@ def test_generation_parameters_top_p_validation():
# All top_p values provide in a manual destribution should be valid
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(
+ ChatCompletionInferenceParams(
top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.5], weights=[0.5, 0.5]))
)
# High and low values of uniform distribution should be valid
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5)))
+ ChatCompletionInferenceParams(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.5)))
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0)))
+ ChatCompletionInferenceParams(top_p=UniformDistribution(params=UniformDistributionParams(low=-0.5, high=1.0)))
# Static values should be valid
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(top_p=1.5)
+ ChatCompletionInferenceParams(top_p=1.5)
with pytest.raises(ValidationError, match=expected_error_msg):
- InferenceParameters(top_p=-0.1)
+ ChatCompletionInferenceParams(top_p=-0.1)
# Valid top_p values shouldn't raise validation errors
try:
- InferenceParameters(top_p=0.1)
- InferenceParameters(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0)))
- InferenceParameters(
+ ChatCompletionInferenceParams(top_p=0.1)
+ ChatCompletionInferenceParams(top_p=UniformDistribution(params=UniformDistributionParams(low=0.5, high=1.0)))
+ ChatCompletionInferenceParams(
top_p=ManualDistribution(params=ManualDistributionParams(values=[0.5, 1.0], weights=[0.5, 0.5]))
)
except Exception:
- pytest.fail("Unexpected exception raised during InferenceParameters top_p validation")
+ pytest.fail("Unexpected exception raised during CompletionInferenceParameters top_p validation")
def test_generation_parameters_max_tokens_validation():
@@ -196,15 +204,15 @@ def test_generation_parameters_max_tokens_validation():
ValidationError,
match="Input should be greater than or equal to 1",
):
- InferenceParameters(max_tokens=0)
+ ChatCompletionInferenceParams(max_tokens=0)
# Valid max_tokens values shouldn't raise validation errors
try:
- InferenceParameters(max_tokens=128_000)
- InferenceParameters(max_tokens=4096)
- InferenceParameters(max_tokens=1)
+ ChatCompletionInferenceParams(max_tokens=128_000)
+ ChatCompletionInferenceParams(max_tokens=4096)
+ ChatCompletionInferenceParams(max_tokens=1)
except Exception:
- pytest.fail("Unexpected exception raised during InferenceParameters max_tokens validation")
+ pytest.fail("Unexpected exception raised during CompletionInferenceParameters max_tokens validation")
def test_load_model_configs():
@@ -212,7 +220,7 @@ def test_load_model_configs():
ModelConfig(alias="test", model="test"),
ModelConfig(alias="test2", model="test2"),
]
- stub_model_configs_dict_list = [mc.model_dump() for mc in stub_model_configs]
+ stub_model_configs_dict_list = [mc.model_dump(mode="json") for mc in stub_model_configs]
assert load_model_configs([]) == []
assert load_model_configs(stub_model_configs) == stub_model_configs
@@ -248,6 +256,43 @@ def test_load_model_configs():
load_model_configs(tmp_file.name)
-def test_model_config_default_construction():
+def test_model_config_construction():
+ # test default construction
model_config = ModelConfig(alias="test", model="test")
- assert model_config.inference_parameters == InferenceParameters()
+ assert model_config.inference_parameters == ChatCompletionInferenceParams()
+ assert model_config.generation_type == GenerationType.CHAT_COMPLETION
+
+ # test construction with completion inference parameters
+ completion_params = ChatCompletionInferenceParams(temperature=0.5, top_p=0.5, max_tokens=100)
+ model_config = ModelConfig(alias="test", model="test", inference_parameters=completion_params)
+ assert model_config.inference_parameters == completion_params
+ assert model_config.generation_type == GenerationType.CHAT_COMPLETION
+
+ # test construction with embedding inference parameters
+ embedding_params = EmbeddingInferenceParams(dimensions=100)
+ model_config = ModelConfig(alias="test", model="test", inference_parameters=embedding_params)
+ assert model_config.inference_parameters == embedding_params
+ assert model_config.generation_type == GenerationType.EMBEDDING
+
+
+def test_model_config_generation_type_from_dict():
+ # Test that generation_type in dict is used to create the right inference params type
+ model_config = ModelConfig.model_validate(
+ {
+ "alias": "test",
+ "model": "test",
+ "inference_parameters": {"generation_type": "embedding", "dimensions": 100},
+ }
+ )
+ assert isinstance(model_config.inference_parameters, EmbeddingInferenceParams)
+ assert model_config.generation_type == GenerationType.EMBEDDING
+
+ model_config = ModelConfig.model_validate(
+ {
+ "alias": "test",
+ "model": "test",
+ "inference_parameters": {"generation_type": "chat-completion", "temperature": 0.5},
+ }
+ )
+ assert isinstance(model_config.inference_parameters, ChatCompletionInferenceParams)
+ assert model_config.generation_type == GenerationType.CHAT_COMPLETION
diff --git a/tests/config/utils/test_visualization.py b/tests/config/utils/test_visualization.py
index ab0eebc5..fa55824e 100644
--- a/tests/config/utils/test_visualization.py
+++ b/tests/config/utils/test_visualization.py
@@ -8,7 +8,7 @@
from data_designer.config.config_builder import DataDesignerConfigBuilder
from data_designer.config.utils.code_lang import CodeLang
-from data_designer.config.utils.visualization import display_sample_record, mask_api_key
+from data_designer.config.utils.visualization import display_sample_record, get_truncated_list_as_string, mask_api_key
from data_designer.config.validator_params import CodeValidatorParams
@@ -75,3 +75,14 @@ def test_mask_api_key():
# None or empty returns "(not set)"
assert mask_api_key(None) == "(not set)"
assert mask_api_key("") == "(not set)"
+
+
+def test_get_truncated_list_as_string():
+ assert get_truncated_list_as_string([1, 2, 3, 4, 5]) == "[1, 2, ...]"
+ assert get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=1) == "[1, ...]"
+ assert get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=3) == "[1, 2, 3, ...]"
+ assert get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=10) == "[1, 2, 3, 4, 5]"
+ with pytest.raises(ValueError):
+ get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=-1)
+ with pytest.raises(ValueError):
+ get_truncated_list_as_string([1, 2, 3, 4, 5], max_items=0)
diff --git a/tests/conftest.py b/tests/conftest.py
index 31dc0057..5f0205e3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,7 +17,7 @@
from data_designer.config.config_builder import DataDesignerConfigBuilder
from data_designer.config.data_designer_config import DataDesignerConfig
from data_designer.config.datastore import DatastoreSettings
-from data_designer.config.models import InferenceParameters, ModelConfig, ModelProvider
+from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
@pytest.fixture
@@ -135,7 +135,7 @@ def stub_model_configs() -> list[ModelConfig]:
ModelConfig(
alias="stub-model",
model="stub-model",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.9,
top_p=0.9,
max_tokens=2048,
diff --git a/tests/engine/analysis/test_column_statistics_calculator.py b/tests/engine/analysis/test_column_statistics_calculator.py
index d21ae179..203e03f1 100644
--- a/tests/engine/analysis/test_column_statistics_calculator.py
+++ b/tests/engine/analysis/test_column_statistics_calculator.py
@@ -43,10 +43,10 @@ def test_llm_generated_column_statistics(stub_df, column_configs):
assert stats.column_name == column_config.name
assert stats.column_type == column_config.column_type
assert stats.num_records == len(stub_df)
- assert isinstance(stats.completion_tokens_mean, float)
- assert isinstance(stats.completion_tokens_stddev, float)
- assert isinstance(stats.prompt_tokens_mean, float)
- assert isinstance(stats.prompt_tokens_stddev, float)
+ assert isinstance(stats.output_tokens_mean, float)
+ assert isinstance(stats.output_tokens_stddev, float)
+ assert isinstance(stats.input_tokens_mean, float)
+ assert isinstance(stats.input_tokens_stddev, float)
def test_sampler_column_statistics(stub_df, column_configs):
diff --git a/tests/engine/analysis/utils/test_column_statistics_calculations.py b/tests/engine/analysis/utils/test_column_statistics_calculations.py
index 6458aa49..65afefd2 100644
--- a/tests/engine/analysis/utils/test_column_statistics_calculations.py
+++ b/tests/engine/analysis/utils/test_column_statistics_calculations.py
@@ -19,9 +19,9 @@
from data_designer.config.utils.numerical_helpers import prepare_number_for_reporting
from data_designer.engine.analysis.utils.column_statistics_calculations import (
calculate_column_distribution,
- calculate_completion_token_stats,
calculate_general_column_info,
- calculate_prompt_token_stats,
+ calculate_input_token_stats,
+ calculate_output_token_stats,
calculate_validation_column_info,
convert_pyarrow_dtype_to_simple_dtype,
ensure_boolean,
@@ -169,38 +169,39 @@ def test_calculate_general_column_info(stub_df_with_mixed_column_types):
assert result["num_unique"] == MissingValue.CALCULATION_FAILED
-def test_calculate_prompt_token_stats(mock_prompt_renderer_render, stub_column_config, stub_df_responses):
+def test_calculate_input_token_stats(mock_prompt_renderer_render, stub_column_config, stub_df_responses):
prompt_cycle = cycle(["System prompt", "Test prompt"])
mock_prompt_renderer_render.side_effect = lambda *args, **kwargs: next(prompt_cycle)
- result = calculate_prompt_token_stats(stub_column_config, stub_df_responses)
- assert "prompt_tokens_mean" in result
- assert "prompt_tokens_stddev" in result
- assert "prompt_tokens_median" in result
- assert isinstance(result["prompt_tokens_mean"], float)
- assert isinstance(result["prompt_tokens_stddev"], float)
- assert isinstance(result["prompt_tokens_median"], float)
+ result = calculate_input_token_stats(stub_column_config, stub_df_responses)
+ assert "input_tokens_mean" in result
+ assert "input_tokens_stddev" in result
+ assert "input_tokens_median" in result
+ assert isinstance(result["input_tokens_mean"], float)
+ assert isinstance(result["input_tokens_stddev"], float)
+ assert isinstance(result["input_tokens_median"], float)
mock_prompt_renderer_render.side_effect = Exception("Test error")
- result = calculate_prompt_token_stats(stub_column_config, stub_df_responses)
- assert result["prompt_tokens_mean"] == MissingValue.CALCULATION_FAILED
- assert result["prompt_tokens_stddev"] == MissingValue.CALCULATION_FAILED
- assert result["prompt_tokens_median"] == MissingValue.CALCULATION_FAILED
-
-
-def test_calculate_completion_token_stats(stub_column_config, stub_df_responses):
- result = calculate_completion_token_stats(stub_column_config.name, stub_df_responses)
- assert "completion_tokens_mean" in result
- assert "completion_tokens_stddev" in result
- assert "completion_tokens_median" in result
- assert isinstance(result["completion_tokens_mean"], float)
- assert isinstance(result["completion_tokens_stddev"], float)
- assert isinstance(result["completion_tokens_median"], float)
-
- result = calculate_completion_token_stats("nonexistent_column", stub_df_responses)
- assert result["completion_tokens_mean"] == MissingValue.CALCULATION_FAILED
- assert result["completion_tokens_stddev"] == MissingValue.CALCULATION_FAILED
- assert result["completion_tokens_median"] == MissingValue.CALCULATION_FAILED
+ result = calculate_input_token_stats(stub_column_config, stub_df_responses)
+ assert result["input_tokens_mean"] == MissingValue.CALCULATION_FAILED
+ assert result["input_tokens_stddev"] == MissingValue.CALCULATION_FAILED
+ assert result["input_tokens_median"] == MissingValue.CALCULATION_FAILED
+
+
+def test_calculate_output_token_stats(stub_column_config, stub_df_responses):
+ result = calculate_output_token_stats(stub_column_config, stub_df_responses)
+ assert "output_tokens_mean" in result
+ assert "output_tokens_stddev" in result
+ assert "output_tokens_median" in result
+ assert isinstance(result["output_tokens_mean"], float)
+ assert isinstance(result["output_tokens_stddev"], float)
+ assert isinstance(result["output_tokens_median"], float)
+
+ stub_column_config.name = "nonexistent_column"
+ result = calculate_output_token_stats(stub_column_config, stub_df_responses)
+ assert result["output_tokens_mean"] == MissingValue.CALCULATION_FAILED
+ assert result["output_tokens_stddev"] == MissingValue.CALCULATION_FAILED
+ assert result["output_tokens_median"] == MissingValue.CALCULATION_FAILED
def test_calculate_validation_column_info(stub_column_config, stub_df_code_validation):
diff --git a/tests/engine/column_generators/generators/test_embedding.py b/tests/engine/column_generators/generators/test_embedding.py
new file mode 100644
index 00000000..75efdfc6
--- /dev/null
+++ b/tests/engine/column_generators/generators/test_embedding.py
@@ -0,0 +1,49 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from unittest.mock import patch
+
+import pytest
+
+from data_designer.config.column_configs import EmbeddingColumnConfig
+from data_designer.engine.column_generators.generators.base import GenerationStrategy
+from data_designer.engine.column_generators.generators.embedding import (
+ EmbeddingCellGenerator,
+ EmbeddingGenerationResult,
+)
+
+
+@pytest.fixture
+def stub_embedding_column_config():
+ return EmbeddingColumnConfig(name="test_embedding", target_column="test_column", model_alias="test_model")
+
+
+@pytest.fixture
+def stub_embeddings() -> list[list[float]]:
+ return [[0.1, 0.2], [0.3, 0.4]]
+
+
+def test_embedding_cell_generator_metadata(stub_embedding_column_config, stub_resource_provider):
+ metadata = EmbeddingCellGenerator(
+ config=stub_embedding_column_config, resource_provider=stub_resource_provider
+ ).metadata()
+ assert metadata.name == "embedding_cell_generator"
+ assert metadata.description == "Generate embeddings for a text column."
+ assert metadata.generation_strategy == GenerationStrategy.CELL_BY_CELL
+
+
+def test_embedding_cell_generator_generate(stub_embedding_column_config, stub_resource_provider, stub_embeddings):
+ with patch.object(
+ stub_resource_provider.model_registry.get_model.return_value,
+ "generate_text_embeddings",
+ return_value=stub_embeddings,
+ ) as mock_generate:
+ embedding_cell_generator = EmbeddingCellGenerator(
+ config=stub_embedding_column_config, resource_provider=stub_resource_provider
+ )
+ data = embedding_cell_generator.generate(data={"test_column": "['test1', 'test2']"})
+ assert stub_embedding_column_config.name in data
+ assert data[stub_embedding_column_config.name] == EmbeddingGenerationResult(
+ embeddings=stub_embeddings
+ ).model_dump(mode="json")
+ mock_generate.assert_called_once_with(input_texts=["test1", "test2"])
diff --git a/tests/engine/column_generators/generators/test_llm_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py
similarity index 92%
rename from tests/engine/column_generators/generators/test_llm_generators.py
rename to tests/engine/column_generators/generators/test_llm_completion_generators.py
index 259f3a08..0b787b7e 100644
--- a/tests/engine/column_generators/generators/test_llm_generators.py
+++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py
@@ -11,7 +11,7 @@
LLMStructuredColumnConfig,
LLMTextColumnConfig,
)
-from data_designer.engine.column_generators.generators.llm_generators import (
+from data_designer.engine.column_generators.generators.llm_completion import (
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS,
DEFAULT_MAX_CONVERSATION_RESTARTS,
REASONING_TRACE_COLUMN_POSTFIX,
@@ -94,7 +94,7 @@ def test_generate_method():
assert call_args[1]["multi_modal_context"] is None
-@patch("data_designer.engine.column_generators.generators.llm_generators.logger", autospec=True)
+@patch("data_designer.engine.column_generators.generators.base.logger", autospec=True)
def test_log_pre_generation(mock_logger):
generator, mock_resource_provider, _, mock_model_config, _, _, _ = _create_generator_with_mocks()
mock_model_config.model_dump_json.return_value = '{"test": "config"}'
@@ -259,20 +259,3 @@ def test_generate_with_json_deserialization():
result = generator.generate(data)
assert result["test_column"] == {"result": "json_output"}
-
-
-def test_generate_with_inference_parameters():
- generator, _, mock_model, _, mock_inference_params, mock_prompt_renderer, mock_response_recipe = (
- _create_generator_with_mocks()
- )
-
- mock_inference_params.generate_kwargs = {"temperature": 0.7, "max_tokens": 100}
- _setup_generate_mocks(mock_prompt_renderer, mock_response_recipe, mock_model)
-
- data = {"input": "test_input"}
- generator.generate(data)
-
- call_args = mock_model.generate.call_args
- assert call_args[1]["temperature"] == 0.7
- assert call_args[1]["max_tokens"] == 100
- assert call_args[1]["purpose"] == "running generation for column 'test_column'"
diff --git a/tests/engine/column_generators/test_registry.py b/tests/engine/column_generators/test_registry.py
index f70b0d90..0d325937 100644
--- a/tests/engine/column_generators/test_registry.py
+++ b/tests/engine/column_generators/test_registry.py
@@ -3,7 +3,7 @@
from data_designer.config.column_types import DataDesignerColumnType
from data_designer.engine.column_generators.generators.expression import ExpressionColumnGenerator
-from data_designer.engine.column_generators.generators.llm_generators import (
+from data_designer.engine.column_generators.generators.llm_completion import (
LLMCodeCellGenerator,
LLMJudgeCellGenerator,
LLMStructuredCellGenerator,
diff --git a/tests/engine/models/conftest.py b/tests/engine/models/conftest.py
index 95e6941f..4e1dd059 100644
--- a/tests/engine/models/conftest.py
+++ b/tests/engine/models/conftest.py
@@ -5,7 +5,11 @@
import pytest
-from data_designer.config.models import InferenceParameters, ModelConfig
+from data_designer.config.models import (
+ ChatCompletionInferenceParams,
+ EmbeddingInferenceParams,
+ ModelConfig,
+)
from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry
from data_designer.engine.models.registry import ModelRegistry, create_model_registry
from data_designer.engine.secret_resolver import SecretsFileResolver
@@ -38,7 +42,7 @@ def stub_model_configs() -> list[ModelConfig]:
alias="stub-text",
model="stub-model-text",
provider="stub-model-provider",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100
),
),
@@ -46,10 +50,18 @@ def stub_model_configs() -> list[ModelConfig]:
alias="stub-reasoning",
model="stub-model-reasoning",
provider="stub-model-provider",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100
),
),
+ ModelConfig(
+ alias="stub-embedding",
+ model="stub-model-embedding",
+ provider="stub-model-provider",
+ inference_parameters=EmbeddingInferenceParams(
+ dimensions=100,
+ ),
+ ),
]
diff --git a/tests/engine/models/test_facade.py b/tests/engine/models/test_facade.py
index aad035cd..97d8cacf 100644
--- a/tests/engine/models/test_facade.py
+++ b/tests/engine/models/test_facade.py
@@ -5,7 +5,7 @@
from unittest.mock import patch
import pytest
-from litellm.types.utils import Choices, Message, ModelResponse
+from litellm.types.utils import Choices, EmbeddingResponse, Message, ModelResponse
from data_designer.engine.models.errors import ModelGenerationValidationFailureError
from data_designer.engine.models.facade import ModelFacade
@@ -30,10 +30,20 @@ def stub_model_facade(stub_model_configs, stub_secrets_resolver, stub_model_prov
@pytest.fixture
-def stub_expected_response():
+def stub_completion_messages():
+ return [{"role": "user", "content": "test"}]
+
+
+@pytest.fixture
+def stub_expected_completion_response():
return ModelResponse(choices=Choices(message=Message(content="Test response")))
+@pytest.fixture
+def stub_expected_embedding_response():
+ return EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2)
+
+
@pytest.mark.parametrize(
"max_correction_steps,max_conversation_restarts,total_calls",
[
@@ -105,6 +115,24 @@ def test_usage_stats_property(stub_model_facade):
assert hasattr(stub_model_facade.usage_stats, "model_dump")
+def test_consolidate_kwargs(stub_model_configs, stub_model_facade):
+ # Model config generate kwargs are used as base, and purpose is removed
+ result = stub_model_facade.consolidate_kwargs(purpose="test")
+ assert result == stub_model_configs[0].inference_parameters.generate_kwargs
+
+ # kwargs overrides model config generate kwargs
+ result = stub_model_facade.consolidate_kwargs(temperature=0.01, purpose="test")
+ assert result == {**stub_model_configs[0].inference_parameters.generate_kwargs, "temperature": 0.01}
+
+ # Provider extra_body overrides all other kwargs
+ stub_model_facade.model_provider.extra_body = {"foo_provider": "bar_provider"}
+ result = stub_model_facade.consolidate_kwargs(extra_body={"foo": "bar"}, purpose="test")
+ assert result == {
+ **stub_model_configs[0].inference_parameters.generate_kwargs,
+ "extra_body": {"foo_provider": "bar_provider", "foo": "bar"},
+ }
+
+
@pytest.mark.parametrize(
"skip_usage_tracking",
[
@@ -112,63 +140,85 @@ def test_usage_stats_property(stub_model_facade):
True,
],
)
-def test_completion_success(stub_model_facade, stub_expected_response, skip_usage_tracking):
- stub_model_facade._router.completion = lambda model_name, messages, **kwargs: stub_expected_response
-
- messages = [{"role": "user", "content": "test"}]
- result = stub_model_facade.completion(messages, skip_usage_tracking=skip_usage_tracking)
-
- assert result == stub_expected_response
-
-
-def test_completion_with_exception(stub_model_facade):
- def raise_exception(*args, **kwargs):
- raise Exception("Router error")
+@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
+def test_completion_success(
+ mock_router_completion,
+ stub_completion_messages,
+ stub_model_configs,
+ stub_model_facade,
+ stub_expected_completion_response,
+ skip_usage_tracking,
+):
+ mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response
+ result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking)
+ assert result == stub_expected_completion_response
+ assert mock_router_completion.call_count == 1
+ assert mock_router_completion.call_args[1] == {
+ "model": "stub-model-text",
+ "messages": stub_completion_messages,
+ **stub_model_configs[0].inference_parameters.generate_kwargs,
+ }
- stub_model_facade._router.completion = raise_exception
- messages = [{"role": "user", "content": "test"}]
+@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
+def test_completion_with_exception(mock_router_completion, stub_completion_messages, stub_model_facade):
+ mock_router_completion.side_effect = Exception("Router error")
with pytest.raises(Exception, match="Router error"):
- stub_model_facade.completion(messages)
+ stub_model_facade.completion(stub_completion_messages)
-def test_completion_with_kwargs(stub_model_facade, stub_expected_response):
+@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
+def test_completion_with_kwargs(
+ mock_router_completion,
+ stub_completion_messages,
+ stub_model_configs,
+ stub_model_facade,
+ stub_expected_completion_response,
+):
captured_kwargs = {}
- def mock_completion(model_name, messages, **kwargs):
+ def mock_completion(self, model, messages, **kwargs):
captured_kwargs.update(kwargs)
- return stub_expected_response
+ return stub_expected_completion_response
- stub_model_facade._router.completion = mock_completion
+ mock_router_completion.side_effect = mock_completion
- messages = [{"role": "user", "content": "test"}]
kwargs = {"temperature": 0.7, "max_tokens": 100}
- result = stub_model_facade.completion(messages, **kwargs)
+ result = stub_model_facade.completion(stub_completion_messages, **kwargs)
- assert result == stub_expected_response
- assert captured_kwargs == kwargs
+ assert result == stub_expected_completion_response
+ # completion kwargs overrides model config generate kwargs
+ assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs}
-@patch("data_designer.engine.models.facade.CustomRouter.completion", autospec=True)
-def test_completion_with_extra_body(mock_router_completion, stub_model_facade):
- messages = [{"role": "user", "content": "test"}]
-
- # completion call has no extra body argument and provider has no extra body
- _ = stub_model_facade.completion(messages)
- assert len(mock_router_completion.call_args) == 2
- assert mock_router_completion.call_args[0][1] == "stub-model-text"
- assert mock_router_completion.call_args[0][2] == messages
-
- # completion call has no extra body argument and provider has extra body.
- # Should pull extra body from model provider
- custom_extra_body = {"some_custom_key": "some_custom_value"}
- stub_model_facade.model_provider.extra_body = custom_extra_body
- _ = stub_model_facade.completion(messages)
- assert mock_router_completion.call_args[1] == {"extra_body": custom_extra_body}
-
- # completion call has extra body argument and provider has extra body.
- # Should merge the two with provider extra body taking precedence
- completion_extra_body = {"some_completion_key": "some_completion_value", "some_custom_key": "some_different_value"}
- _ = stub_model_facade.completion(messages, extra_body=completion_extra_body)
- assert mock_router_completion.call_args[1] == {"extra_body": {**completion_extra_body, **custom_extra_body}}
+@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True)
+def test_generate_text_embeddings_success(mock_router_embedding, stub_model_facade, stub_expected_embedding_response):
+ mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response
+ input_texts = ["test1", "test2"]
+ result = stub_model_facade.generate_text_embeddings(input_texts)
+ assert result == [data["embedding"] for data in stub_expected_embedding_response.data]
+
+
+@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True)
+def test_generate_text_embeddings_with_exception(mock_router_embedding, stub_model_facade):
+ mock_router_embedding.side_effect = Exception("Router error")
+
+ with pytest.raises(Exception, match="Router error"):
+ stub_model_facade.generate_text_embeddings(["test1", "test2"])
+
+
+@patch("data_designer.engine.models.facade.CustomRouter.embedding", autospec=True)
+def test_generate_text_embeddings_with_kwargs(
+ mock_router_embedding, stub_model_configs, stub_model_facade, stub_expected_embedding_response
+):
+ captured_kwargs = {}
+
+ def mock_embedding(self, model, input, **kwargs):
+ captured_kwargs.update(kwargs)
+ return stub_expected_embedding_response
+
+ mock_router_embedding.side_effect = mock_embedding
+ kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"}
+ _ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs)
+ assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs}
diff --git a/tests/engine/models/test_model_registry.py b/tests/engine/models/test_model_registry.py
index 89b95a50..4ea5a447 100644
--- a/tests/engine/models/test_model_registry.py
+++ b/tests/engine/models/test_model_registry.py
@@ -4,9 +4,8 @@
from unittest.mock import patch
import pytest
-from litellm import AuthenticationError
-from data_designer.config.models import InferenceParameters, ModelConfig
+from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig
from data_designer.engine.models.errors import ModelAuthenticationError
from data_designer.engine.models.facade import ModelFacade
from data_designer.engine.models.registry import ModelRegistry, create_model_registry
@@ -24,7 +23,7 @@ def stub_new_model_config():
alias="stub-vision",
model="stub-model-vision",
provider="stub-model-provider",
- inference_parameters=InferenceParameters(
+ inference_parameters=ChatCompletionInferenceParams(
temperature=0.80, top_p=0.95, max_tokens=100, max_parallel_requests=10, timeout=100
),
)
@@ -36,7 +35,7 @@ def stub_no_usage_config():
alias="no-usage",
model="no-usage-model",
provider="stub-model-provider",
- inference_parameters=InferenceParameters(),
+ inference_parameters=ChatCompletionInferenceParams(),
)
@@ -72,14 +71,15 @@ def test_register_model_configs(stub_model_registry, stub_new_model_config):
stub_model_registry.register_model_configs([stub_new_model_config])
# Verify configs are registered
- assert len(stub_model_registry.model_configs) == 3
+ assert len(stub_model_registry.model_configs) == 4
# Trigger lazy initialization by requesting models
assert stub_model_registry.get_model(model_alias="stub-text").model_name == "stub-model-text"
assert stub_model_registry.get_model(model_alias="stub-reasoning").model_name == "stub-model-reasoning"
assert stub_model_registry.get_model(model_alias="stub-vision").model_name == "stub-model-vision"
+ assert stub_model_registry.get_model(model_alias="stub-embedding").model_name == "stub-model-embedding"
- assert len(stub_model_registry.models) == 3
+ assert len(stub_model_registry.models) == 4
assert all(isinstance(model, ModelFacade) for model in stub_model_registry.models.values())
@@ -126,19 +126,19 @@ def test_get_model_usage_stats(
reasoning_model = stub_model_registry.get_model(model_alias="stub-reasoning")
text_model.usage_stats.extend(
- token_usage=TokenUsageStats(prompt_tokens=10, completion_tokens=100),
+ token_usage=TokenUsageStats(input_tokens=10, output_tokens=100),
request_usage=RequestUsageStats(successful_requests=10, failed_requests=0),
)
reasoning_model.usage_stats.extend(
- token_usage=TokenUsageStats(prompt_tokens=5, completion_tokens=200),
+ token_usage=TokenUsageStats(input_tokens=5, output_tokens=200),
request_usage=RequestUsageStats(successful_requests=100, failed_requests=10),
)
usage_stats = stub_model_registry.get_model_usage_stats(total_time_elapsed=10)
assert set(usage_stats.keys()) == set(expected_keys)
if "stub-model-text" in usage_stats:
- assert usage_stats["stub-model-text"]["token_usage"]["prompt_tokens"] == 10
- assert usage_stats["stub-model-text"]["token_usage"]["completion_tokens"] == 100
+ assert usage_stats["stub-model-text"]["token_usage"]["input_tokens"] == 10
+ assert usage_stats["stub-model-text"]["token_usage"]["output_tokens"] == 100
assert usage_stats["stub-model-text"]["token_usage"]["total_tokens"] == 110
assert usage_stats["stub-model-text"]["request_usage"]["successful_requests"] == 10
assert usage_stats["stub-model-text"]["request_usage"]["failed_requests"] == 0
@@ -150,42 +150,52 @@ def test_get_model_usage_stats(
# Trigger lazy initialization
text_model = stub_model_registry.get_model(model_alias="stub-text")
text_model.usage_stats.extend(
- token_usage=TokenUsageStats(prompt_tokens=10, completion_tokens=100),
+ token_usage=TokenUsageStats(input_tokens=10, output_tokens=100),
request_usage=RequestUsageStats(successful_requests=10, failed_requests=0),
)
usage_stats = stub_model_registry.get_model_usage_stats(total_time_elapsed=10)
assert set(usage_stats.keys()) == set(expected_keys)
-@pytest.mark.parametrize(
- "test_case,mock_side_effect,expected_exception,expected_call_count",
- [
- ("success", None, None, 2),
- (
- "authentication_error",
- AuthenticationError("Invalid API key", llm_provider="openai", model="stub-model-text"),
- ModelAuthenticationError,
- 1,
- ),
- ],
-)
+@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True)
+@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+def test_run_health_check_success(mock_completion, mock_generate_text_embeddings, stub_model_registry):
+ model_aliases = {"stub-text", "stub-reasoning", "stub-embedding"}
+ stub_model_registry.run_health_check(model_aliases)
+ assert mock_completion.call_count == 2
+ assert mock_generate_text_embeddings.call_count == 1
+
+
+@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True)
@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
-def test_run_health_check(
- mock_completion, stub_model_registry, test_case, mock_side_effect, expected_exception, expected_call_count
+def test_run_health_check_completion_authentication_error(
+ mock_completion, mock_generate_text_embeddings, stub_model_registry
):
- if mock_side_effect:
- mock_completion.side_effect = mock_side_effect
+ auth_error = ModelAuthenticationError("Invalid API key for completion model")
+ mock_completion.side_effect = auth_error
+ model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"]
- # Pass model aliases for health check
- model_aliases = {"stub-text", "stub-reasoning"}
+ with pytest.raises(ModelAuthenticationError):
+ stub_model_registry.run_health_check(model_aliases)
- if expected_exception:
- with pytest.raises(expected_exception):
- stub_model_registry.run_health_check(model_aliases)
- else:
+ mock_completion.assert_called_once()
+ mock_generate_text_embeddings.assert_not_called()
+
+
+@patch("data_designer.engine.models.facade.ModelFacade.generate_text_embeddings", autospec=True)
+@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True)
+def test_run_health_check_embedding_authentication_error(
+ mock_completion, mock_generate_text_embeddings, stub_model_registry
+):
+ auth_error = ModelAuthenticationError("Invalid API key for embedding model")
+ mock_generate_text_embeddings.side_effect = auth_error
+ model_aliases = ["stub-text", "stub-reasoning", "stub-embedding"]
+
+ with pytest.raises(ModelAuthenticationError):
stub_model_registry.run_health_check(model_aliases)
- assert mock_completion.call_count == expected_call_count
+ mock_completion.call_count == 2
+ mock_generate_text_embeddings.assert_called_once()
@pytest.mark.parametrize(
diff --git a/tests/engine/models/test_usage.py b/tests/engine/models/test_usage.py
index 3e92a433..3ddb09f3 100644
--- a/tests/engine/models/test_usage.py
+++ b/tests/engine/models/test_usage.py
@@ -6,14 +6,14 @@
def test_token_usage_stats():
token_usage_stats = TokenUsageStats()
- assert token_usage_stats.prompt_tokens == 0
- assert token_usage_stats.completion_tokens == 0
+ assert token_usage_stats.input_tokens == 0
+ assert token_usage_stats.output_tokens == 0
assert token_usage_stats.total_tokens == 0
assert token_usage_stats.has_usage is False
- token_usage_stats.extend(prompt_tokens=10, completion_tokens=20)
- assert token_usage_stats.prompt_tokens == 10
- assert token_usage_stats.completion_tokens == 20
+ token_usage_stats.extend(input_tokens=10, output_tokens=20)
+ assert token_usage_stats.input_tokens == 10
+ assert token_usage_stats.output_tokens == 20
assert token_usage_stats.total_tokens == 30
assert token_usage_stats.has_usage is True
@@ -34,31 +34,31 @@ def test_request_usage_stats():
def test_model_usage_stats():
model_usage_stats = ModelUsageStats()
- assert model_usage_stats.token_usage.prompt_tokens == 0
- assert model_usage_stats.token_usage.completion_tokens == 0
+ assert model_usage_stats.token_usage.input_tokens == 0
+ assert model_usage_stats.token_usage.output_tokens == 0
assert model_usage_stats.request_usage.successful_requests == 0
assert model_usage_stats.request_usage.failed_requests == 0
assert model_usage_stats.has_usage is False
assert model_usage_stats.get_usage_stats(total_time_elapsed=10) == {
- "token_usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
+ "token_usage": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0},
"request_usage": {"successful_requests": 0, "failed_requests": 0, "total_requests": 0},
"tokens_per_second": 0,
"requests_per_minute": 0,
}
model_usage_stats.extend(
- token_usage=TokenUsageStats(prompt_tokens=10, completion_tokens=20),
+ token_usage=TokenUsageStats(input_tokens=10, output_tokens=20),
request_usage=RequestUsageStats(successful_requests=2, failed_requests=1),
)
- assert model_usage_stats.token_usage.prompt_tokens == 10
- assert model_usage_stats.token_usage.completion_tokens == 20
+ assert model_usage_stats.token_usage.input_tokens == 10
+ assert model_usage_stats.token_usage.output_tokens == 20
assert model_usage_stats.request_usage.successful_requests == 2
assert model_usage_stats.request_usage.failed_requests == 1
assert model_usage_stats.has_usage is True
assert model_usage_stats.get_usage_stats(total_time_elapsed=2) == {
- "token_usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
+ "token_usage": {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30},
"request_usage": {"successful_requests": 2, "failed_requests": 1, "total_requests": 3},
"tokens_per_second": 15,
"requests_per_minute": 90,
diff --git a/tests/engine/processing/test_utils.py b/tests/engine/processing/test_utils.py
index a41e0ec2..dec0fe6a 100644
--- a/tests/engine/processing/test_utils.py
+++ b/tests/engine/processing/test_utils.py
@@ -9,6 +9,7 @@
from data_designer.engine.processing.utils import (
concat_datasets,
deserialize_json_values,
+ parse_list_string,
)
@@ -116,3 +117,19 @@ def test_concat_datasets_logging(mock_logger, stub_sample_dataframes):
def test_deserialize_json_values_scenarios(test_case, input_data, expected_result):
result = deserialize_json_values(input_data)
assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "input_string,expected_result",
+ [
+ ('["a", "b", "c"]', ["a", "b", "c"]), # valid stringified json array
+ ('[" a ", " b", "c "]', ["a", "b", "c"]), # valid stringified json array with whitespace
+ ('["a", "b", "c",]', ["a", "b", "c"]), # valid stringified json array with trailing comma
+ ("['a', 'b', 'c']", ["a", "b", "c"]), # valid python-style list with single quotes
+ ("['a', 'b', 'c', ]", ["a", "b", "c"]), # valid python-style list with trailing comma
+ ("simple string ", ["simple string"]), # simple string with whitespace
+ ],
+)
+def test_parse_list_string_scenarios(input_string, expected_result):
+ result = parse_list_string(input_string)
+ assert result == expected_result
diff --git a/tests/essentials/test_init.py b/tests/essentials/test_init.py
index 35cad0f3..e7b6288a 100644
--- a/tests/essentials/test_init.py
+++ b/tests/essentials/test_init.py
@@ -14,6 +14,7 @@
BernoulliSamplerParams,
BinomialSamplerParams,
CategorySamplerParams,
+ ChatCompletionInferenceParams,
CodeLang,
CodeValidatorParams,
ColumnInequalityConstraint,
@@ -23,8 +24,10 @@
DatastoreSeedDatasetReference,
DatastoreSettings,
DatetimeSamplerParams,
+ EmbeddingInferenceParams,
ExpressionColumnConfig,
GaussianSamplerParams,
+ GenerationType,
ImageContext,
ImageFormat,
InferenceParameters,
@@ -109,6 +112,9 @@ def test_model_config_imports():
assert ImageContext is not None
assert ImageFormat is not None
assert InferenceParameters is not None
+ assert ChatCompletionInferenceParams is not None
+ assert EmbeddingInferenceParams is not None
+ assert GenerationType is not None
assert ManualDistribution is not None
assert ManualDistributionParams is not None
assert Modality is not None
@@ -232,6 +238,7 @@ def test_all_contains_column_configs():
assert "Score" in __all__
assert "SeedDatasetColumnConfig" in __all__
assert "ValidationColumnConfig" in __all__
+ assert "EmbeddingColumnConfig" in __all__
def test_all_contains_sampler_params():
@@ -250,6 +257,8 @@ def test_all_contains_sampler_params():
assert "TimeDeltaSamplerParams" in __all__
assert "UniformSamplerParams" in __all__
assert "UUIDSamplerParams" in __all__
+ assert "PersonFromFakerSamplerParams" in __all__
+ assert "ProcessorType" in __all__
def test_all_contains_constraints():
@@ -263,6 +272,9 @@ def test_all_contains_model_configs():
assert "ImageContext" in __all__
assert "ImageFormat" in __all__
assert "InferenceParameters" in __all__
+ assert "ChatCompletionInferenceParams" in __all__
+ assert "EmbeddingInferenceParams" in __all__
+ assert "GenerationType" in __all__
assert "ManualDistribution" in __all__
assert "ManualDistributionParams" in __all__
assert "Modality" in __all__