From 1d20321fedc50be04de76f04fd5f95ffe001c0f6 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 11 Dec 2025 21:22:49 -0600 Subject: [PATCH 1/2] Refactor seed datasets --- Makefile | 2 +- docs/colab_notebooks/1-the-basics.ipynb | 64 ++-- ...ctured-outputs-and-jinja-expressions.ipynb | 60 ++-- .../3-seeding-with-a-dataset.ipynb | 61 ++-- .../4-providing-images-as-context.ipynb | 73 +++-- docs/notebook_source/1-the-basics.py | 2 +- ...tructured-outputs-and-jinja-expressions.py | 2 +- .../3-seeding-with-a-dataset.py | 7 +- .../4-providing-images-as-context.py | 5 +- src/data_designer/config/config_builder.py | 146 ++------- src/data_designer/config/datastore.py | 187 ------------ src/data_designer/config/errors.py | 3 + src/data_designer/config/exports.py | 12 +- src/data_designer/config/seed.py | 62 +--- src/data_designer/config/seed_source.py | 73 +++++ src/data_designer/config/utils/io_helpers.py | 20 -- .../generators/seed_dataset.py | 10 +- src/data_designer/engine/compiler.py | 69 +++++ .../dataset_builders/utils/config_compiler.py | 2 +- .../engine/resources/resource_provider.py | 18 +- .../resources/seed_dataset_data_store.py | 84 ----- .../engine/resources/seed_reader.py | 149 +++++++++ .../{config/utils => engine}/validation.py | 0 src/data_designer/interface/data_designer.py | 113 ++++--- src/data_designer/temp_nmp.py | 52 ++++ tests/config/test_config_builder.py | 202 +++---------- tests/config/test_datastore.py | 286 ------------------ tests/config/test_seed.py | 47 +-- tests/config/test_seed_source.py | 61 ++++ tests/config/utils/test_visualization.py | 29 +- tests/conftest.py | 38 ++- .../artifacts/dataset/column_configs.json | 5 +- .../generators/test_seed_dataset.py | 42 +-- tests/engine/conftest.py | 2 +- .../test_multi_column_configs.py | 5 +- .../utils/test_config_compiler.py | 11 +- .../resources/test_resource_provider.py | 22 +- tests/engine/resources/test_seed_reader.py | 51 ++++ tests/engine/test_compiler.py | 80 +++++ .../utils => engine}/test_validation.py | 18 +- tests/essentials/test_init.py | 12 +- tests/interface/test_data_designer.py | 224 ++------------ 42 files changed, 965 insertions(+), 1446 deletions(-) delete mode 100644 src/data_designer/config/datastore.py create mode 100644 src/data_designer/config/seed_source.py create mode 100644 src/data_designer/engine/compiler.py delete mode 100644 src/data_designer/engine/resources/seed_dataset_data_store.py create mode 100644 src/data_designer/engine/resources/seed_reader.py rename src/data_designer/{config/utils => engine}/validation.py (100%) create mode 100644 src/data_designer/temp_nmp.py delete mode 100644 tests/config/test_datastore.py create mode 100644 tests/config/test_seed_source.py create mode 100644 tests/engine/resources/test_seed_reader.py create mode 100644 tests/engine/test_compiler.py rename tests/{config/utils => engine}/test_validation.py (94%) diff --git a/Makefile b/Makefile index 295400ce..6085da6f 100644 --- a/Makefile +++ b/Makefile @@ -98,7 +98,7 @@ convert-execute-notebooks: generate-colab-notebooks: @echo "πŸ““ Generating Colab-compatible notebooks..." - uv run --group notebooks python docs/scripts/generate_colab_notebooks.py + uv run --group docs python docs/scripts/generate_colab_notebooks.py @echo "βœ… Colab notebooks created in docs/colab_notebooks/" serve-docs-locally: diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb index 0eaef320..b548531d 100644 --- a/docs/colab_notebooks/1-the-basics.ipynb +++ b/docs/colab_notebooks/1-the-basics.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "2e6331ad", + "id": "32827bda", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: The Basics\n", @@ -14,7 +14,7 @@ }, { "cell_type": "markdown", - "id": "af3aad47", + "id": "cbe6b81e", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -25,7 +25,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4c39e2a5", + "id": "5b007eed", "metadata": {}, "outputs": [], "source": [ @@ -36,7 +36,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d8652e5e", + "id": "58dfdb07", "metadata": {}, "outputs": [], "source": [ @@ -53,7 +53,7 @@ }, { "cell_type": "markdown", - "id": "c51e6323", + "id": "03d07dff", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -64,7 +64,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2de6b279", + "id": "749a1536", "metadata": {}, "outputs": [], "source": [ @@ -85,7 +85,7 @@ }, { "cell_type": "markdown", - "id": "6a484a8d", + "id": "5c31d723", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -98,7 +98,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7554bd1a", + "id": "826f2421", "metadata": {}, "outputs": [], "source": [ @@ -107,7 +107,7 @@ }, { "cell_type": "markdown", - "id": "dc1d9f84", + "id": "b6bfc01a", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -124,7 +124,7 @@ { "cell_type": "code", "execution_count": null, - "id": "76d22674", + "id": "11aa0146", "metadata": {}, "outputs": [], "source": [ @@ -154,7 +154,7 @@ }, { "cell_type": "markdown", - "id": "187da050", + "id": "c10b93a6", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -169,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "977497d1", + "id": "02ae97ca", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +178,7 @@ }, { "cell_type": "markdown", - "id": "92c51ea0", + "id": "c4c780c8", "metadata": {}, "source": [ "## 🎲 Getting started with sampler columns\n", @@ -195,7 +195,7 @@ { "cell_type": "code", "execution_count": null, - "id": "68d7a4e6", + "id": "685417ab", "metadata": {}, "outputs": [], "source": [ @@ -204,7 +204,7 @@ }, { "cell_type": "markdown", - "id": "314c4719", + "id": "6bae388c", "metadata": {}, "source": [ "Let's start designing our product review dataset by adding product category and subcategory columns.\n" @@ -213,7 +213,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1bcad060", + "id": "6122711b", "metadata": {}, "outputs": [], "source": [ @@ -289,12 +289,12 @@ ")\n", "\n", "# Optionally validate that the columns are configured correctly.\n", - "config_builder.validate()" + "data_designer.validate(config_builder)" ] }, { "cell_type": "markdown", - "id": "aab7414d", + "id": "12d8c063", "metadata": {}, "source": [ "Next, let's add samplers to generate data related to the customer and their review.\n" @@ -303,7 +303,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f191f5bf", + "id": "c4c21fa0", "metadata": {}, "outputs": [], "source": [ @@ -340,7 +340,7 @@ }, { "cell_type": "markdown", - "id": "5d893b3d", + "id": "be6bd3c8", "metadata": {}, "source": [ "## 🦜 LLM-generated columns\n", @@ -355,7 +355,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2abadac9", + "id": "99e953d8", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +391,7 @@ }, { "cell_type": "markdown", - "id": "2c9cb423", + "id": "ec30007b", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -408,7 +408,7 @@ { "cell_type": "code", "execution_count": null, - "id": "71e3a022", + "id": "7ac1f25b", "metadata": {}, "outputs": [], "source": [ @@ -418,7 +418,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28f7913d", + "id": "67d8dff4", "metadata": {}, "outputs": [], "source": [ @@ -429,7 +429,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6621c80f", + "id": "f78e4d9f", "metadata": {}, "outputs": [], "source": [ @@ -439,7 +439,7 @@ }, { "cell_type": "markdown", - "id": "2b451ded", + "id": "ee838ed6", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -452,7 +452,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f7cb6cc", + "id": "4bea0fca", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +462,7 @@ }, { "cell_type": "markdown", - "id": "721b3c7d", + "id": "da9875d4", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -475,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1ad777d1", + "id": "8c7f20b4", "metadata": {}, "outputs": [], "source": [ @@ -485,7 +485,7 @@ { "cell_type": "code", "execution_count": null, - "id": "df089509", + "id": "2f5fdaa8", "metadata": {}, "outputs": [], "source": [ @@ -498,7 +498,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e37fa65b", + "id": "5c2d935f", "metadata": {}, "outputs": [], "source": [ @@ -510,7 +510,7 @@ }, { "cell_type": "markdown", - "id": "84d1802b", + "id": "a80515b7", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb index 78b27009..2069c47a 100644 --- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb +++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "fabaacf3", + "id": "ba3eb53c", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Structured Outputs and Jinja Expressions\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "9c43b45f", + "id": "71e3de28", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -27,7 +27,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b8ec517f", + "id": "f12bda54", "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9cda5d44", + "id": "c8e1f477", "metadata": {}, "outputs": [], "source": [ @@ -55,7 +55,7 @@ }, { "cell_type": "markdown", - "id": "5cd6ad06", + "id": "efc5ac26", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -66,7 +66,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3b68f1ed", + "id": "3ffd67ca", "metadata": {}, "outputs": [], "source": [ @@ -87,7 +87,7 @@ }, { "cell_type": "markdown", - "id": "2ccb4a3e", + "id": "782aaacd", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -100,7 +100,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bc9a754a", + "id": "1b092ccc", "metadata": {}, "outputs": [], "source": [ @@ -109,7 +109,7 @@ }, { "cell_type": "markdown", - "id": "4a1335a7", + "id": "cf4e19e2", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -126,7 +126,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6ae7a51e", + "id": "ff05156e", "metadata": {}, "outputs": [], "source": [ @@ -156,7 +156,7 @@ }, { "cell_type": "markdown", - "id": "891d5040", + "id": "bfbac905", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -171,7 +171,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d3a95b64", + "id": "d2c7a5d7", "metadata": {}, "outputs": [], "source": [ @@ -180,7 +180,7 @@ }, { "cell_type": "markdown", - "id": "7d67b2b1", + "id": "6fb54697", "metadata": {}, "source": [ "### πŸ§‘β€πŸŽ¨ Designing our data\n", @@ -207,7 +207,7 @@ { "cell_type": "code", "execution_count": null, - "id": "78fbeae4", + "id": "0ea9bb0f", "metadata": {}, "outputs": [], "source": [ @@ -235,7 +235,7 @@ }, { "cell_type": "markdown", - "id": "69463e64", + "id": "987fbf33", "metadata": {}, "source": [ "Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n" @@ -244,7 +244,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5da9004", + "id": "9960d332", "metadata": {}, "outputs": [], "source": [ @@ -348,12 +348,12 @@ ")\n", "\n", "# Optionally validate that the columns are configured correctly.\n", - "config_builder.validate()" + "data_designer.validate(config_builder)" ] }, { "cell_type": "markdown", - "id": "63196378", + "id": "81f8d04a", "metadata": {}, "source": [ "Next, we will use more advanced Jinja expressions to create new columns.\n", @@ -370,7 +370,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36fe581f", + "id": "e9583bf7", "metadata": {}, "outputs": [], "source": [ @@ -423,7 +423,7 @@ }, { "cell_type": "markdown", - "id": "998e907e", + "id": "c79829f3", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -440,7 +440,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2ec151e", + "id": "cceb2346", "metadata": {}, "outputs": [], "source": [ @@ -450,7 +450,7 @@ { "cell_type": "code", "execution_count": null, - "id": "231e4b2d", + "id": "4c35a66e", "metadata": {}, "outputs": [], "source": [ @@ -461,7 +461,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4e9b1ecd", + "id": "c5cd23a7", "metadata": {}, "outputs": [], "source": [ @@ -471,7 +471,7 @@ }, { "cell_type": "markdown", - "id": "47a8e374", + "id": "2486fc8b", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -484,7 +484,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1cb581f6", + "id": "ef90239f", "metadata": {}, "outputs": [], "source": [ @@ -494,7 +494,7 @@ }, { "cell_type": "markdown", - "id": "aff26604", + "id": "d9cbd2f6", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -507,7 +507,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2804c0bb", + "id": "f0dcf922", "metadata": {}, "outputs": [], "source": [ @@ -517,7 +517,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e44d4940", + "id": "aac62046", "metadata": {}, "outputs": [], "source": [ @@ -530,7 +530,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1da63b8c", + "id": "33875a50", "metadata": {}, "outputs": [], "source": [ @@ -542,7 +542,7 @@ }, { "cell_type": "markdown", - "id": "ec74a797", + "id": "4d532955", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb index b913e4e0..01754359 100644 --- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb +++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b126d755", + "id": "91286c3d", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n", @@ -16,7 +16,7 @@ }, { "cell_type": "markdown", - "id": "94c8f4da", + "id": "95a57a7c", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -27,7 +27,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ebf683a3", + "id": "77a9bfa2", "metadata": {}, "outputs": [], "source": [ @@ -38,7 +38,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4c49f3ba", + "id": "fd15629b", "metadata": {}, "outputs": [], "source": [ @@ -55,7 +55,7 @@ }, { "cell_type": "markdown", - "id": "fe81f478", + "id": "7e15d3a6", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -66,7 +66,7 @@ { "cell_type": "code", "execution_count": null, - "id": "701e74f7", + "id": "e364d16c", "metadata": {}, "outputs": [], "source": [ @@ -74,13 +74,14 @@ " ChatCompletionInferenceParams,\n", " DataDesigner,\n", " DataDesignerConfigBuilder,\n", + " LocalFileSeedSource,\n", " ModelConfig,\n", ")" ] }, { "cell_type": "markdown", - "id": "381d35d6", + "id": "b54e3d85", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -93,7 +94,7 @@ { "cell_type": "code", "execution_count": null, - "id": "50ee9b75", + "id": "5ff1527c", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +103,7 @@ }, { "cell_type": "markdown", - "id": "9b121890", + "id": "8998e5ec", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -119,7 +120,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c4c8ff05", + "id": "60e3fc08", "metadata": {}, "outputs": [], "source": [ @@ -149,7 +150,7 @@ }, { "cell_type": "markdown", - "id": "9397b044", + "id": "720b3858", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -164,7 +165,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fb47fe2f", + "id": "02960190", "metadata": {}, "outputs": [], "source": [ @@ -173,7 +174,7 @@ }, { "cell_type": "markdown", - "id": "3b49e4b6", + "id": "2b89b51a", "metadata": {}, "source": [ "## πŸ₯ Prepare a seed dataset\n", @@ -198,7 +199,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8a41c518", + "id": "dfc7a563", "metadata": {}, "outputs": [], "source": [ @@ -209,14 +210,14 @@ "local_filename, _ = urllib.request.urlretrieve(url, \"gretelai_symptom_to_diagnosis.csv\")\n", "\n", "# Seed datasets are passed as reference objects to the config builder.\n", - "seed_dataset_reference = data_designer.make_seed_reference_from_file(local_filename)\n", + "seed_source = LocalFileSeedSource(path=local_filename)\n", "\n", - "config_builder.with_seed_dataset(seed_dataset_reference)" + "config_builder.with_seed_dataset(seed_source)" ] }, { "cell_type": "markdown", - "id": "136e3e86", + "id": "1dc057ff", "metadata": {}, "source": [ "## 🎨 Designing our synthetic patient notes dataset\n", @@ -233,7 +234,7 @@ { "cell_type": "code", "execution_count": null, - "id": "afa64628", + "id": "fd62c1b8", "metadata": {}, "outputs": [], "source": [ @@ -318,12 +319,12 @@ " model_alias=MODEL_ALIAS,\n", ")\n", "\n", - "config_builder.validate()" + "data_designer.validate(config_builder)" ] }, { "cell_type": "markdown", - "id": "ba709fc1", + "id": "5e4681ea", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -340,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "43634cef", + "id": "db464de0", "metadata": {}, "outputs": [], "source": [ @@ -350,7 +351,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee95a73b", + "id": "89c1f822", "metadata": {}, "outputs": [], "source": [ @@ -361,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c1ad69b6", + "id": "81e5714e", "metadata": {}, "outputs": [], "source": [ @@ -371,7 +372,7 @@ }, { "cell_type": "markdown", - "id": "2998f51e", + "id": "e9013506", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -384,7 +385,7 @@ { "cell_type": "code", "execution_count": null, - "id": "78b26418", + "id": "69cc94e5", "metadata": {}, "outputs": [], "source": [ @@ -394,7 +395,7 @@ }, { "cell_type": "markdown", - "id": "b87a61fc", + "id": "b343f984", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -407,7 +408,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30c5565c", + "id": "8004b0a8", "metadata": {}, "outputs": [], "source": [ @@ -417,7 +418,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b384d9c", + "id": "6fde5cdd", "metadata": {}, "outputs": [], "source": [ @@ -430,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fa60ec0f", + "id": "d27f0d95", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +443,7 @@ }, { "cell_type": "markdown", - "id": "b7c06123", + "id": "afac9559", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index c7bf04d5..b5856522 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "c0a5ad9c", + "id": "9a23e054", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation" @@ -10,7 +10,7 @@ }, { "cell_type": "markdown", - "id": "5f390e88", + "id": "10fbb133", "metadata": {}, "source": [ "#### πŸ“š What you'll learn\n", @@ -25,7 +25,7 @@ }, { "cell_type": "markdown", - "id": "3b5d2be3", + "id": "11777ee9", "metadata": {}, "source": [ "### ⚑ Colab Setup\n", @@ -36,7 +36,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7975769", + "id": "69ee8b75", "metadata": {}, "outputs": [], "source": [ @@ -47,7 +47,7 @@ { "cell_type": "code", "execution_count": null, - "id": "75793ed8", + "id": "123d0751", "metadata": {}, "outputs": [], "source": [ @@ -64,7 +64,7 @@ }, { "cell_type": "markdown", - "id": "09868c88", + "id": "83e5d174", "metadata": {}, "source": [ "### πŸ“¦ Import the essentials\n", @@ -75,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0c505114", + "id": "eec65d47", "metadata": {}, "outputs": [], "source": [ @@ -96,6 +96,7 @@ " ChatCompletionInferenceParams,\n", " DataDesigner,\n", " DataDesignerConfigBuilder,\n", + " DataFrameSeedSource,\n", " ImageContext,\n", " ImageFormat,\n", " LLMTextColumnConfig,\n", @@ -106,7 +107,7 @@ }, { "cell_type": "markdown", - "id": "831df7e9", + "id": "51b4f7fb", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -119,7 +120,7 @@ { "cell_type": "code", "execution_count": null, - "id": "433cdf1a", + "id": "5c06aed6", "metadata": {}, "outputs": [], "source": [ @@ -128,7 +129,7 @@ }, { "cell_type": "markdown", - "id": "eddf1262", + "id": "5d9f3e5e", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -145,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7001a15d", + "id": "fc21c884", "metadata": {}, "outputs": [], "source": [ @@ -168,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "73bb672c", + "id": "8f0a5b51", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -183,7 +184,7 @@ { "cell_type": "code", "execution_count": null, - "id": "231e84e9", + "id": "3fb63c9b", "metadata": {}, "outputs": [], "source": [ @@ -192,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "dd756993", + "id": "647c040f", "metadata": {}, "source": [ "### 🌱 Seed Dataset Creation\n", @@ -209,7 +210,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0e218552", + "id": "b4891ab3", "metadata": {}, "outputs": [], "source": [ @@ -224,7 +225,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cf33a7ac", + "id": "8b9f6efd", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f7b9d525", + "id": "243347f3", "metadata": {}, "outputs": [], "source": [ @@ -290,7 +291,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7474c081", + "id": "e26a5cf9", "metadata": {}, "outputs": [], "source": [ @@ -300,21 +301,19 @@ { "cell_type": "code", "execution_count": null, - "id": "8c56d461", + "id": "d13af281", "metadata": {}, "outputs": [], "source": [ "# Add the seed dataset containing our processed images\n", "df_seed = pd.DataFrame(img_dataset)[[\"uuid\", \"image_filename\", \"base64_image\", \"page\", \"options\", \"source\"]]\n", - "config_builder.with_seed_dataset(\n", - " DataDesigner.make_seed_reference_from_dataframe(df_seed, file_path=\"colpali_train_set.csv\")\n", - ")" + "config_builder.with_seed_dataset(DataFrameSeedSource(df=df_seed))" ] }, { "cell_type": "code", "execution_count": null, - "id": "ba32ed33", + "id": "fbd517f5", "metadata": { "lines_to_next_cell": 2 }, @@ -343,7 +342,7 @@ }, { "cell_type": "markdown", - "id": "6d93b821", + "id": "980fbfbd", "metadata": { "lines_to_next_cell": 2 }, @@ -351,7 +350,7 @@ }, { "cell_type": "markdown", - "id": "516f5380", + "id": "6cffa555", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -368,7 +367,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0db8df1c", + "id": "26f980ac", "metadata": {}, "outputs": [], "source": [ @@ -378,7 +377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0064e2ca", + "id": "adf246c4", "metadata": {}, "outputs": [], "source": [ @@ -389,7 +388,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2f39f30", + "id": "8ee56c67", "metadata": {}, "outputs": [], "source": [ @@ -399,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "e1460b2d", + "id": "e543a7bf", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -412,7 +411,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22786db9", + "id": "f6bebd0b", "metadata": {}, "outputs": [], "source": [ @@ -422,7 +421,7 @@ }, { "cell_type": "markdown", - "id": "a249f2ef", + "id": "6e5566cb", "metadata": {}, "source": [ "### πŸ”Ž Visual Inspection\n", @@ -433,7 +432,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86fcc082", + "id": "1b5210b8", "metadata": { "lines_to_next_cell": 2 }, @@ -457,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "d404d91c", + "id": "64693378", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -470,7 +469,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6da5b84b", + "id": "b5e0fb82", "metadata": {}, "outputs": [], "source": [ @@ -480,7 +479,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0a9f69a5", + "id": "eb09bc3c", "metadata": {}, "outputs": [], "source": [ @@ -493,7 +492,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d980ba69", + "id": "0c4d34ba", "metadata": {}, "outputs": [], "source": [ @@ -505,7 +504,7 @@ }, { "cell_type": "markdown", - "id": "bb4f8c1c", + "id": "fce1cc65", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/notebook_source/1-the-basics.py b/docs/notebook_source/1-the-basics.py index 42d02fa8..92cd3ae1 100644 --- a/docs/notebook_source/1-the-basics.py +++ b/docs/notebook_source/1-the-basics.py @@ -193,7 +193,7 @@ ) # Optionally validate that the columns are configured correctly. -config_builder.validate() +data_designer.validate(config_builder) # %% [markdown] # Next, let's add samplers to generate data related to the customer and their review. diff --git a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py index b64a507d..a719a794 100644 --- a/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py +++ b/docs/notebook_source/2-structured-outputs-and-jinja-expressions.py @@ -253,7 +253,7 @@ class ProductReview(BaseModel): ) # Optionally validate that the columns are configured correctly. -config_builder.validate() +data_designer.validate(config_builder) # %% [markdown] # Next, we will use more advanced Jinja expressions to create new columns. diff --git a/docs/notebook_source/3-seeding-with-a-dataset.py b/docs/notebook_source/3-seeding-with-a-dataset.py index b866de50..3665fb33 100644 --- a/docs/notebook_source/3-seeding-with-a-dataset.py +++ b/docs/notebook_source/3-seeding-with-a-dataset.py @@ -33,6 +33,7 @@ ChatCompletionInferenceParams, DataDesigner, DataDesignerConfigBuilder, + LocalFileSeedSource, ModelConfig, ) @@ -124,9 +125,9 @@ local_filename, _ = urllib.request.urlretrieve(url, "gretelai_symptom_to_diagnosis.csv") # Seed datasets are passed as reference objects to the config builder. -seed_dataset_reference = data_designer.make_seed_reference_from_file(local_filename) +seed_source = LocalFileSeedSource(path=local_filename) -config_builder.with_seed_dataset(seed_dataset_reference) +config_builder.with_seed_dataset(seed_source) # %% [markdown] # ## 🎨 Designing our synthetic patient notes dataset @@ -222,7 +223,7 @@ model_alias=MODEL_ALIAS, ) -config_builder.validate() +data_designer.validate(config_builder) # %% [markdown] # ### πŸ” Iteration is key – preview the dataset! diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index 70f806ae..9bcc4b29 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -50,6 +50,7 @@ ChatCompletionInferenceParams, DataDesigner, DataDesignerConfigBuilder, + DataFrameSeedSource, ImageContext, ImageFormat, LLMTextColumnConfig, @@ -189,9 +190,7 @@ def convert_image_to_chat_format(record, height: int) -> dict: # %% # Add the seed dataset containing our processed images df_seed = pd.DataFrame(img_dataset)[["uuid", "image_filename", "base64_image", "page", "options", "source"]] -config_builder.with_seed_dataset( - DataDesigner.make_seed_reference_from_dataframe(df_seed, file_path="colpali_train_set.csv") -) +config_builder.with_seed_dataset(DataFrameSeedSource(df=df_seed)) # %% # Add a column to generate detailed document summaries diff --git a/src/data_designer/config/config_builder.py b/src/data_designer/config/config_builder.py index 8aeeeae8..1852dbc9 100644 --- a/src/data_designer/config/config_builder.py +++ b/src/data_designer/config/config_builder.py @@ -24,9 +24,8 @@ ) from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.dataset_builders import BuildStage -from data_designer.config.datastore import DatastoreSettings, fetch_seed_dataset_column_names from data_designer.config.default_model_settings import get_default_model_configs -from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError +from data_designer.config.errors import BuilderConfigurationError, BuilderSerializationError, InvalidColumnTypeError from data_designer.config.models import ModelConfig, load_model_configs from data_designer.config.processors import ProcessorConfigT, ProcessorType, get_processor_config_from_kwargs from data_designer.config.sampler_constraints import ( @@ -36,20 +35,17 @@ ScalarInequalityConstraint, ) from data_designer.config.seed import ( - DatastoreSeedDatasetReference, IndexRange, - LocalSeedDatasetReference, PartitionBlock, SamplingStrategy, SeedConfig, - SeedDatasetReference, ) +from data_designer.config.seed_source import DataFrameSeedSource, SeedSource from data_designer.config.utils.constants import DEFAULT_REPR_HTML_STYLE, REPR_HTML_TEMPLATE from data_designer.config.utils.info import ConfigBuilderInfo from data_designer.config.utils.io_helpers import serialize_data, smart_load_yaml from data_designer.config.utils.misc import can_run_data_designer_locally, json_indent_list_of_strings, kebab_to_snake from data_designer.config.utils.type_helpers import resolve_string_enum -from data_designer.config.utils.validation import ViolationLevel, rich_print_violations, validate_data_designer_config logger = logging.getLogger(__name__) @@ -63,12 +59,9 @@ class BuilderConfig(ExportableConfigBase): Attributes: data_designer: The main Data Designer configuration containing columns, constraints, profilers, and other settings. - datastore_settings: Optional datastore settings for accessing external - datasets. """ data_designer: DataDesignerConfig - datastore_settings: DatastoreSettings | None class DataDesignerConfigBuilder: @@ -101,31 +94,19 @@ def from_config(cls, config: dict | str | Path | BuilderConfig) -> Self: builder_config = BuilderConfig.model_validate(json_config) builder = cls(model_configs=builder_config.data_designer.model_configs) - config = builder_config.data_designer + data_designer_config = builder_config.data_designer - for col in config.columns: - if not isinstance(col, SeedDatasetColumnConfig): - builder.add_column(col) + for col in data_designer_config.columns: + builder.add_column(col) - for constraint in config.constraints or []: + for constraint in data_designer_config.constraints or []: builder.add_constraint(constraint=constraint) - if config.seed_config: - if builder_config.datastore_settings is None: - if can_run_data_designer_locally(): - seed_dataset_reference = LocalSeedDatasetReference(dataset=config.seed_config.dataset) - else: - raise BuilderConfigurationError("πŸ›‘ Datastore settings are required.") - else: - seed_dataset_reference = DatastoreSeedDatasetReference( - dataset=config.seed_config.dataset, - datastore_settings=builder_config.datastore_settings, - ) - builder.set_seed_datastore_settings(builder_config.datastore_settings) + if (seed_config := data_designer_config.seed_config) is not None: builder.with_seed_dataset( - seed_dataset_reference, - sampling_strategy=config.seed_config.sampling_strategy, - selection_strategy=config.seed_config.selection_strategy, + seed_config.source, + sampling_strategy=seed_config.sampling_strategy, + selection_strategy=seed_config.selection_strategy, ) return builder @@ -145,7 +126,6 @@ def __init__(self, model_configs: list[ModelConfig] | str | Path | None = None): self._seed_config: SeedConfig | None = None self._constraints: list[ColumnConstraintT] = [] self._profilers: list[ColumnProfilerConfigT] = [] - self._datastore_settings: DatastoreSettings | None = None @property def model_configs(self) -> list[ModelConfig]: @@ -244,13 +224,6 @@ def add_column( f"{', '.join([t.__name__ for t in allowed_column_configs])}" ) - existing_config = self._column_configs.get(column_config.name) - if existing_config is not None and isinstance(existing_config, SeedDatasetColumnConfig): - raise BuilderConfigurationError( - f"πŸ›‘ Column {column_config.name!r} already exists as a seed dataset column. " - "Please use a different column name or update the seed dataset." - ) - self._column_configs[column_config.name] = column_config return self @@ -372,19 +345,12 @@ def get_profilers(self) -> list[ColumnProfilerConfigT]: """ return self._profilers - def build(self, *, skip_validation: bool = False, raise_exceptions: bool = False) -> DataDesignerConfig: + def build(self) -> DataDesignerConfig: """Build a DataDesignerConfig instance based on the current builder configuration. - Args: - skip_validation: Whether to skip validation of the configuration. - raise_exceptions: Whether to raise an exception if the configuration is invalid. - Returns: The current Data Designer config object. """ - if not skip_validation: - self.validate(raise_exceptions=raise_exceptions) - return DataDesignerConfig( model_configs=self._model_configs, seed_config=self._seed_config, @@ -513,14 +479,6 @@ def get_seed_config(self) -> SeedConfig | None: """ return self._seed_config - def get_seed_datastore_settings(self) -> DatastoreSettings | None: - """Get most recent datastore settings for the current Data Designer configuration. - - Returns: - The datastore settings if configured, None otherwise. - """ - return None if not self._datastore_settings else DatastoreSettings.model_validate(self._datastore_settings) - def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int: """Get the count of columns of the specified type. @@ -532,85 +490,33 @@ def num_columns_of_type(self, column_type: DataDesignerColumnType) -> int: """ return len(self.get_columns_of_type(column_type)) - def set_seed_datastore_settings(self, datastore_settings: DatastoreSettings | None) -> Self: - """Set the datastore settings for the seed dataset. - - Args: - datastore_settings: The datastore settings to use for the seed dataset. - """ - self._datastore_settings = datastore_settings - return self - - def validate(self, *, raise_exceptions: bool = False) -> Self: - """Validate the current Data Designer configuration. - - Args: - raise_exceptions: Whether to raise an exception if the configuration is invalid. - - Returns: - The current Data Designer config builder instance. - - Raises: - InvalidConfigError: If the configuration is invalid and raise_exceptions is True. - """ - - violations = validate_data_designer_config( - columns=list(self._column_configs.values()), - processor_configs=self._processor_configs, - allowed_references=self.allowed_references, - ) - rich_print_violations(violations) - if raise_exceptions and len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0: - raise InvalidConfigError( - "πŸ›‘ Your configuration contains validation errors. Please address the indicated issues and try again." - ) - if len(violations) == 0: - logger.info("βœ… Validation passed") - return self - def with_seed_dataset( self, - dataset_reference: SeedDatasetReference, + seed_source: SeedSource, *, sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED, selection_strategy: IndexRange | PartitionBlock | None = None, ) -> Self: """Add a seed dataset to the current Data Designer configuration. - This method sets the seed dataset for the configuration and automatically creates - SeedDatasetColumnConfig objects for each column found in the dataset. The column - names are fetched from the dataset source, which can be the Hugging Face Hub, the - NeMo Microservices Datastore, or in the case of direct library usage, a local file. + This method sets the seed dataset for the configuration, but columns are not resolved until + compilation (including validation) is performed by the engine using a SeedReader. Args: - dataset_reference: Seed dataset reference for fetching from the datastore. + seed_source: The pointer to the seed dataset. sampling_strategy: The sampling strategy to use when generating data from the seed dataset. Defaults to ORDERED sampling. + selection_strategy: An optional selection strategy to use when generating data from the seed dataset. + Defaults to None. Returns: The current Data Designer config builder instance. - - Raises: - BuilderConfigurationError: If any seed dataset column name collides with an existing column. """ - seed_column_names = fetch_seed_dataset_column_names(dataset_reference) - colliding_columns = [name for name in seed_column_names if name in self._column_configs] - if colliding_columns: - raise BuilderConfigurationError( - f"πŸ›‘ Seed dataset column(s) {colliding_columns} collide with existing column(s). " - "Please remove the conflicting columns or use a seed dataset with different column names." - ) - self._seed_config = SeedConfig( - dataset=dataset_reference.dataset, + source=seed_source, sampling_strategy=sampling_strategy, selection_strategy=selection_strategy, ) - self.set_seed_datastore_settings( - dataset_reference.datastore_settings if hasattr(dataset_reference, "datastore_settings") else None - ) - for column_name in seed_column_names: - self._column_configs[column_name] = SeedDatasetColumnConfig(name=column_name) return self def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> None: @@ -623,7 +529,17 @@ def write_config(self, path: str | Path, indent: int | None = 2, **kwargs) -> No Raises: BuilderConfigurationError: If the file format is unsupported. - """ + BuilderSerializationError: If the configuration cannot be serialized. + """ + if (seed_config := self.get_seed_config()) is not None and isinstance(seed_config.source, DataFrameSeedSource): + raise BuilderSerializationError( + "This builder was configured with a DataFrame seed dataset. " + "DataFrame seeds cannot be serialized to config files. " + "To serialize this configuration, change your seed dataset to a more persistent, serializable source format. " + "For example, you could make a local file seed source from the dataframe:\n\n" + "LocalFileSeedSource.from_dataframe(my_dataframe, '/path/to/data.parquet')" + ) + cfg = self.get_builder_config() suffix = Path(path).suffix if suffix in {".yaml", ".yml"}: @@ -639,7 +555,7 @@ def get_builder_config(self) -> BuilderConfig: Returns: The builder config. """ - return BuilderConfig(data_designer=self.build(), datastore_settings=self._datastore_settings) + return BuilderConfig(data_designer=self.build()) def __repr__(self) -> str: """Generates a string representation of the DataDesignerConfigBuilder instance. @@ -651,7 +567,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" props_to_repr = { - "seed_dataset": (None if self._seed_config is None else f"'{self._seed_config.dataset}'"), + "seed_dataset": (None if self._seed_config is None else f"{self._seed_config.source.seed_type} seed"), } for column_type in get_column_display_order(): diff --git a/src/data_designer/config/datastore.py b/src/data_designer/config/datastore.py deleted file mode 100644 index ab78bae4..00000000 --- a/src/data_designer/config/datastore.py +++ /dev/null @@ -1,187 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import logging -from pathlib import Path -from typing import TYPE_CHECKING - -import pandas as pd -import pyarrow.parquet as pq -from huggingface_hub import HfApi, HfFileSystem -from pydantic import BaseModel, Field - -from data_designer.config.errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError -from data_designer.config.utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS, validate_path_contains_files_of_type - -if TYPE_CHECKING: - from data_designer.config.seed import SeedDatasetReference - -logger = logging.getLogger(__name__) - - -class DatastoreSettings(BaseModel): - """Configuration for interacting with a datastore.""" - - endpoint: str = Field( - ..., - description="Datastore endpoint. Use 'https://huggingface.co' for the Hugging Face Hub.", - ) - token: str | None = Field(default=None, description="If needed, token to use for authentication.") - - -def get_file_column_names(file_reference: str | Path | HfFileSystem, file_type: str) -> list[str]: - """Get column names from a dataset file. - - Args: - file_reference: Path to the dataset file, or an HfFileSystem object. - file_type: Type of the dataset file. Must be one of: 'parquet', 'json', 'jsonl', 'csv'. - - Raises: - InvalidFilePathError: If the file type is not supported. - - Returns: - List of column names. - """ - if file_type == "parquet": - try: - schema = pq.read_schema(file_reference) - if hasattr(schema, "names"): - return schema.names - else: - return [field.name for field in schema] - except Exception as e: - logger.warning(f"Failed to process parquet file {file_reference}: {e}") - return [] - elif file_type in ["json", "jsonl"]: - return pd.read_json(file_reference, orient="records", lines=True, nrows=1).columns.tolist() - elif file_type == "csv": - try: - df = pd.read_csv(file_reference, nrows=1) - return df.columns.tolist() - except (pd.errors.EmptyDataError, pd.errors.ParserError) as e: - logger.warning(f"Failed to process CSV file {file_reference}: {e}") - return [] - else: - raise InvalidFilePathError(f"πŸ›‘ Unsupported file type: {file_type!r}") - - -def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference) -> list[str]: - if hasattr(seed_dataset_reference, "datastore_settings"): - return fetch_seed_dataset_column_names_from_datastore( - seed_dataset_reference.repo_id, - seed_dataset_reference.filename, - seed_dataset_reference.datastore_settings, - ) - return fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset) - - -def fetch_seed_dataset_column_names_from_datastore( - repo_id: str, - filename: str, - datastore_settings: DatastoreSettings | dict | None = None, -) -> list[str]: - file_type = filename.split(".")[-1] - if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS: - raise InvalidFileFormatError(f"πŸ›‘ Unsupported file type: {filename!r}") - - datastore_settings = resolve_datastore_settings(datastore_settings) - fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True) - - file_path = _extract_single_file_path_from_glob_pattern_if_present(f"datasets/{repo_id}/{filename}", fs=fs) - - with fs.open(file_path) as f: - return get_file_column_names(f, file_type) - - -def fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]: - dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True) - dataset_path = _extract_single_file_path_from_glob_pattern_if_present(dataset_path) - return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1]) - - -def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | None) -> DatastoreSettings: - if datastore_settings is None: - raise InvalidConfigError("πŸ›‘ Datastore settings are required in order to upload datasets to the datastore.") - if isinstance(datastore_settings, DatastoreSettings): - return datastore_settings - elif isinstance(datastore_settings, dict): - return DatastoreSettings.model_validate(datastore_settings) - else: - raise InvalidConfigError( - "πŸ›‘ Invalid datastore settings format. Must be DatastoreSettings object or dictionary." - ) - - -def upload_to_hf_hub( - dataset_path: str | Path, - filename: str, - repo_id: str, - datastore_settings: DatastoreSettings, - **kwargs, -) -> str: - datastore_settings = resolve_datastore_settings(datastore_settings) - dataset_path = _validate_dataset_path(dataset_path) - filename_ext = filename.split(".")[-1].lower() - if dataset_path.suffix.lower()[1:] != filename_ext: - raise InvalidFileFormatError( - f"πŸ›‘ Dataset file extension {dataset_path.suffix!r} does not match `filename` extension .{filename_ext!r}" - ) - - hfapi = HfApi(endpoint=datastore_settings.endpoint, token=datastore_settings.token) - hfapi.create_repo(repo_id, exist_ok=True, repo_type="dataset") - hfapi.upload_file( - path_or_fileobj=dataset_path, - path_in_repo=filename, - repo_id=repo_id, - repo_type="dataset", - **kwargs, - ) - return f"{repo_id}/{filename}" - - -def _extract_single_file_path_from_glob_pattern_if_present( - file_path: str | Path, - fs: HfFileSystem | None = None, -) -> Path: - file_path = Path(file_path) - - # no glob pattern - if "*" not in str(file_path): - return file_path - - # glob pattern with HfFileSystem - if fs is not None: - file_to_check = None - file_extension = file_path.name.split(".")[-1] - for file in fs.ls(str(file_path.parent)): - filename = file["name"] - if filename.endswith(f".{file_extension}"): - file_to_check = filename - if file_to_check is None: - raise InvalidFilePathError(f"πŸ›‘ No files found matching pattern: {str(file_path)!r}") - logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset") - return Path(file_to_check) - - # glob pattern with local file system - if not (matching_files := sorted(file_path.parent.glob(file_path.name))): - raise InvalidFilePathError(f"πŸ›‘ No files found matching pattern: {str(file_path)!r}") - logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset") - return matching_files[0] - - -def _validate_dataset_path(dataset_path: str | Path, allow_glob_pattern: bool = False) -> Path: - if allow_glob_pattern and "*" in str(dataset_path): - parts = str(dataset_path).split("*.") - file_path = parts[0] - file_extension = parts[-1] - validate_path_contains_files_of_type(file_path, file_extension) - return Path(dataset_path) - if not Path(dataset_path).is_file(): - raise InvalidFilePathError("πŸ›‘ To upload a dataset to the datastore, you must provide a valid file path.") - if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)): - raise InvalidFileFormatError( - "πŸ›‘ Dataset files must be in `parquet`, `csv`, or `json` (orient='records', lines=True) format." - ) - return Path(dataset_path) diff --git a/src/data_designer/config/errors.py b/src/data_designer/config/errors.py index f70576b0..e60a9347 100644 --- a/src/data_designer/config/errors.py +++ b/src/data_designer/config/errors.py @@ -7,6 +7,9 @@ class BuilderConfigurationError(DataDesignerError): ... +class BuilderSerializationError(DataDesignerError): ... + + class InvalidColumnTypeError(DataDesignerError): ... diff --git a/src/data_designer/config/exports.py b/src/data_designer/config/exports.py index 8b17e0fd..0df38ae4 100644 --- a/src/data_designer/config/exports.py +++ b/src/data_designer/config/exports.py @@ -18,7 +18,6 @@ from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.dataset_builders import BuildStage -from data_designer.config.datastore import DatastoreSettings from data_designer.config.models import ( ChatCompletionInferenceParams, EmbeddingInferenceParams, @@ -59,12 +58,16 @@ UUIDSamplerParams, ) from data_designer.config.seed import ( - DatastoreSeedDatasetReference, IndexRange, PartitionBlock, SamplingStrategy, SeedConfig, ) +from data_designer.config.seed_source import ( + DataFrameSeedSource, + HuggingFaceSeedSource, + LocalFileSeedSource, +) from data_designer.config.utils.code_lang import CodeLang from data_designer.config.utils.info import InfoType from data_designer.config.validator_params import ( @@ -88,9 +91,8 @@ def get_config_exports() -> list[str]: DataDesignerColumnType.__name__, DataDesignerConfig.__name__, DataDesignerConfigBuilder.__name__, + DataFrameSeedSource.__name__, BuildStage.__name__, - DatastoreSeedDatasetReference.__name__, - DatastoreSettings.__name__, DatetimeSamplerParams.__name__, DropColumnsProcessorConfig.__name__, EmbeddingColumnConfig.__name__, @@ -98,6 +100,7 @@ def get_config_exports() -> list[str]: ExpressionColumnConfig.__name__, GaussianSamplerParams.__name__, GenerationType.__name__, + HuggingFaceSeedSource.__name__, IndexRange.__name__, InfoType.__name__, ImageContext.__name__, @@ -107,6 +110,7 @@ def get_config_exports() -> list[str]: LLMJudgeColumnConfig.__name__, LLMStructuredColumnConfig.__name__, LLMTextColumnConfig.__name__, + LocalFileSeedSource.__name__, ManualDistribution.__name__, ManualDistributionParams.__name__, Modality.__name__, diff --git a/src/data_designer/config/seed.py b/src/data_designer/config/seed.py index a49f73e2..86070ff7 100644 --- a/src/data_designer/config/seed.py +++ b/src/data_designer/config/seed.py @@ -1,19 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from abc import ABC from enum import Enum -from pydantic import Field, field_validator, model_validator +from pydantic import Field, model_validator from typing_extensions import Self from data_designer.config.base import ConfigBase -from data_designer.config.datastore import DatastoreSettings -from data_designer.config.utils.io_helpers import ( - VALID_DATASET_FILE_EXTENSIONS, - validate_dataset_file_path, - validate_path_contains_files_of_type, -) +from data_designer.config.seed_source import SeedSourceT class SamplingStrategy(str, Enum): @@ -62,7 +56,7 @@ class SeedConfig(ConfigBase): """Configuration for sampling data from a seed dataset. Args: - dataset: Path or identifier for the seed dataset. + source: A SeedSource defining where the seed data exists sampling_strategy: Strategy for how to sample rows from the dataset. - ORDERED: Read rows sequentially in their original order. - SHUFFLE: Randomly shuffle rows before sampling. When used with @@ -75,70 +69,46 @@ class SeedConfig(ConfigBase): Examples: Read rows sequentially from start to end: - SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.ORDERED) + SeedConfig( + source=LocalFileSeedSource(path="my_data.parquet"), + sampling_strategy=SamplingStrategy.ORDERED + ) Read rows in random order: - SeedConfig(dataset="my_data.parquet", sampling_strategy=SamplingStrategy.SHUFFLE) + SeedConfig( + source=LocalFileSeedSource(path="my_data.parquet"), + sampling_strategy=SamplingStrategy.SHUFFLE + ) Read specific index range (rows 100-199): SeedConfig( - dataset="my_data.parquet", + source=LocalFileSeedSource(path="my_data.parquet"), sampling_strategy=SamplingStrategy.ORDERED, selection_strategy=IndexRange(start=100, end=199) ) Read random rows from a specific index range (shuffles within rows 100-199): SeedConfig( - dataset="my_data.parquet", + source=LocalFileSeedSource(path="my_data.parquet"), sampling_strategy=SamplingStrategy.SHUFFLE, selection_strategy=IndexRange(start=100, end=199) ) Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset): SeedConfig( - dataset="my_data.parquet", + source=LocalFileSeedSource(path="my_data.parquet"), sampling_strategy=SamplingStrategy.ORDERED, selection_strategy=PartitionBlock(index=2, num_partitions=5) ) Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition): SeedConfig( - dataset="my_data.parquet", + source=LocalFileSeedSource(path="my_data.parquet"), sampling_strategy=SamplingStrategy.SHUFFLE, selection_strategy=PartitionBlock(index=0, num_partitions=10) ) """ - dataset: str + source: SeedSourceT sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED selection_strategy: IndexRange | PartitionBlock | None = None - - -class SeedDatasetReference(ABC, ConfigBase): - dataset: str - - -class DatastoreSeedDatasetReference(SeedDatasetReference): - datastore_settings: DatastoreSettings - - @property - def repo_id(self) -> str: - return "/".join(self.dataset.split("/")[:-1]) - - @property - def filename(self) -> str: - return self.dataset.split("/")[-1] - - -class LocalSeedDatasetReference(SeedDatasetReference): - @field_validator("dataset", mode="after") - def validate_dataset_is_file(cls, v: str) -> str: - valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS} - if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions): - parts = v.split("*.") - file_path = parts[0] - file_extension = parts[-1] - validate_path_contains_files_of_type(file_path, file_extension) - else: - validate_dataset_file_path(v) - return v diff --git a/src/data_designer/config/seed_source.py b/src/data_designer/config/seed_source.py new file mode 100644 index 00000000..6d62d773 --- /dev/null +++ b/src/data_designer/config/seed_source.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC +from typing import Annotated, Literal + +import pandas as pd +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing_extensions import Self + +from data_designer.config.utils.io_helpers import ( + VALID_DATASET_FILE_EXTENSIONS, + validate_dataset_file_path, + validate_path_contains_files_of_type, +) + + +class SeedSource(BaseModel, ABC): + """Base class for seed dataset configurations. + + All subclasses must define a `seed_type` field with a Literal value. + This serves as a discriminated union discriminator. + """ + + seed_type: str + + +class LocalFileSeedSource(SeedSource): + seed_type: Literal["local"] = "local" + + path: str + + @field_validator("path", mode="after") + def validate_path(cls, v: str) -> str: + valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS} + if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions): + parts = v.split("*.") + file_path = parts[0] + file_extension = parts[-1] + validate_path_contains_files_of_type(file_path, file_extension) + else: + validate_dataset_file_path(v) + return v + + @classmethod + def from_dataframe(cls, df: pd.DataFrame, path: str) -> Self: + df.to_parquet(path, index=False) + return cls(path=path) + + +class HuggingFaceSeedSource(SeedSource): + seed_type: Literal["hf"] = "hf" + + path: str = Field( + ..., + description="Path to the seed data in HuggingFace. Wildcards are allowed. Examples include 'datasets/my-username/my-dataset/data/000_00000.parquet', 'datasets/my-username/my-dataset/data/*.parquet', 'datasets/my-username/my-dataset/**/*.parquet'", + ) + token: str | None = None + endpoint: str = "https://huggingface.co" + + +class DataFrameSeedSource(SeedSource): + seed_type: Literal["df"] = "df" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + df: pd.DataFrame + + +SeedSourceT = Annotated[ + LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource, + Field(discriminator="seed_type"), +] diff --git a/src/data_designer/config/utils/io_helpers.py b/src/data_designer/config/utils/io_helpers.py index 57a0c9c2..d6d0da50 100644 --- a/src/data_designer/config/utils/io_helpers.py +++ b/src/data_designer/config/utils/io_helpers.py @@ -108,26 +108,6 @@ def read_parquet_dataset(path: Path) -> pd.DataFrame: raise e -def write_seed_dataset(dataframe: pd.DataFrame, file_path: Path) -> None: - """Write a seed dataset to a file in the specified format. - - Supported file extensions: .parquet, .csv, .json, .jsonl - - Args: - dataframe: The pandas DataFrame to write. - file_path: The path where the dataset should be saved. - Format is inferred from the file extension. - """ - file_path = validate_dataset_file_path(file_path, should_exist=False) - logger.info(f"πŸ’Ύ Saving seed dataset to {file_path}") - if file_path.suffix.lower() == ".parquet": - dataframe.to_parquet(file_path, index=False) - elif file_path.suffix.lower() == ".csv": - dataframe.to_csv(file_path, index=False) - elif file_path.suffix.lower() in {".json", ".jsonl"}: - dataframe.to_json(file_path, orient="records", lines=True) - - def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True) -> Path: """Validate that a dataset file path has a valid extension and optionally exists. diff --git a/src/data_designer/engine/column_generators/generators/seed_dataset.py b/src/data_designer/engine/column_generators/generators/seed_dataset.py index 35578602..08c2fed6 100644 --- a/src/data_designer/engine/column_generators/generators/seed_dataset.py +++ b/src/data_designer/engine/column_generators/generators/seed_dataset.py @@ -30,7 +30,7 @@ def metadata() -> GeneratorMetadata: name="seed_dataset_column_generator", description="Sample columns from a seed dataset.", generation_strategy=GenerationStrategy.FULL_COLUMN, - required_resources=[ResourceType.DATASTORE], + required_resources=[ResourceType.SEED_READER], ) @property @@ -39,10 +39,10 @@ def num_records_sampled(self) -> int: @functools.cached_property def duckdb_conn(self) -> duckdb.DuckDBPyConnection: - return self.resource_provider.datastore.create_duckdb_connection() + return self.resource_provider.seed_reader.create_duckdb_connection() - def generate(self, dataset: pd.DataFrame) -> pd.DataFrame: - return concat_datasets([self.generate_from_scratch(len(dataset)), dataset]) + def generate(self, data: pd.DataFrame) -> pd.DataFrame: + return concat_datasets([self.generate_from_scratch(len(data)), data]) def generate_from_scratch(self, num_records: int) -> pd.DataFrame: if num_records <= 0: @@ -57,7 +57,7 @@ def _initialize(self) -> None: self._num_records_sampled = 0 self._batch_reader = None self._df_remaining = None - self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset) + self._dataset_uri = self.resource_provider.seed_reader.get_dataset_uri() self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0] self._index_range = self._resolve_index_range() diff --git a/src/data_designer/engine/compiler.py b/src/data_designer/engine/compiler.py new file mode 100644 index 00000000..841f2f84 --- /dev/null +++ b/src/data_designer/engine/compiler.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +from data_designer.config.column_configs import SeedDatasetColumnConfig +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.data_designer_config import DataDesignerConfig +from data_designer.config.errors import InvalidConfigError +from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.resources.seed_reader import SeedReader +from data_designer.engine.validation import ViolationLevel, rich_print_violations, validate_data_designer_config + +logger = logging.getLogger(__name__) + + +def compile_data_designer_config( + config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider +) -> DataDesignerConfig: + config = config_builder.build() + _resolve_and_add_seed_columns(config, resource_provider.seed_reader) + _validate(config) + + return config + + +def _resolve_and_add_seed_columns(config: DataDesignerConfig, seed_reader: SeedReader | None) -> None: + """Fetches the seed dataset column names, ensures there are no conflicts + with other columns, and adds seed column configs to the DataDesignerConfig. + """ + + if not seed_reader: + return + + seed_col_names = seed_reader.get_column_names() + existing_columns = {column.name for column in config.columns} + colliding_columns = {name for name in seed_col_names if name in existing_columns} + if colliding_columns: + raise InvalidConfigError( + f"πŸ›‘ Seed dataset column(s) {colliding_columns} collide with existing column(s). " + "Please remove the conflicting columns or use a seed dataset with different column names." + ) + + config.columns.extend([SeedDatasetColumnConfig(name=col_name) for col_name in seed_col_names]) + + +def _validate(config: DataDesignerConfig) -> None: + allowed_references = _get_allowed_references(config) + violations = validate_data_designer_config( + columns=config.columns, + processor_configs=config.processors or [], + allowed_references=allowed_references, + ) + rich_print_violations(violations) + if len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0: + raise InvalidConfigError( + "πŸ›‘ Your configuration contains validation errors. Please address the indicated issues and try again." + ) + if len(violations) == 0: + logger.info("βœ… Validation passed") + + +def _get_allowed_references(config: DataDesignerConfig) -> list[str]: + refs = set[str]() + for column_config in config.columns: + refs.add(column_config.name) + for side_effect_column in column_config.side_effect_columns: + refs.add(side_effect_column) + return list(refs) diff --git a/src/data_designer/engine/dataset_builders/utils/config_compiler.py b/src/data_designer/engine/dataset_builders/utils/config_compiler.py index 5359bfa1..85922990 100644 --- a/src/data_designer/engine/dataset_builders/utils/config_compiler.py +++ b/src/data_designer/engine/dataset_builders/utils/config_compiler.py @@ -34,7 +34,7 @@ def compile_dataset_builder_column_configs(config: DataDesignerConfig) -> list[D compiled_column_configs.append( SeedDatasetMultiColumnConfig( columns=seed_column_configs, - dataset=config.seed_config.dataset, + source=config.seed_config.source, sampling_strategy=config.seed_config.sampling_strategy, selection_strategy=config.seed_config.selection_strategy, ) diff --git a/src/data_designer/engine/resources/resource_provider.py b/src/data_designer/engine/resources/resource_provider.py index 98c6576f..c28225d3 100644 --- a/src/data_designer/engine/resources/resource_provider.py +++ b/src/data_designer/engine/resources/resource_provider.py @@ -3,26 +3,27 @@ from data_designer.config.base import ConfigBase from data_designer.config.models import ModelConfig +from data_designer.config.seed_source import SeedSource from data_designer.config.utils.type_helpers import StrEnum from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.registry import ModelRegistry, create_model_registry from data_designer.engine.resources.managed_storage import ManagedBlobStorage, init_managed_blob_storage -from data_designer.engine.resources.seed_dataset_data_store import SeedDatasetDataStore +from data_designer.engine.resources.seed_reader import SeedReader, SeedReaderRegistry from data_designer.engine.secret_resolver import SecretResolver class ResourceType(StrEnum): BLOB_STORAGE = "blob_storage" - DATASTORE = "datastore" MODEL_REGISTRY = "model_registry" + SEED_READER = "seed_reader" class ResourceProvider(ConfigBase): artifact_storage: ArtifactStorage blob_storage: ManagedBlobStorage | None = None - datastore: SeedDatasetDataStore | None = None model_registry: ModelRegistry | None = None + seed_reader: SeedReader | None = None def create_resource_provider( @@ -31,16 +32,23 @@ def create_resource_provider( model_configs: list[ModelConfig], secret_resolver: SecretResolver, model_provider_registry: ModelProviderRegistry, - datastore: SeedDatasetDataStore | None = None, + seed_reader_registry: SeedReaderRegistry, blob_storage: ManagedBlobStorage | None = None, + seed_dataset_source: SeedSource | None = None, ) -> ResourceProvider: + seed_reader = None + if seed_dataset_source: + seed_reader = seed_reader_registry.get_reader( + seed_dataset_source, + secret_resolver, + ) return ResourceProvider( artifact_storage=artifact_storage, - datastore=datastore, model_registry=create_model_registry( model_configs=model_configs, secret_resolver=secret_resolver, model_provider_registry=model_provider_registry, ), blob_storage=blob_storage or init_managed_blob_storage(), + seed_reader=seed_reader, ) diff --git a/src/data_designer/engine/resources/seed_dataset_data_store.py b/src/data_designer/engine/resources/seed_dataset_data_store.py deleted file mode 100644 index b74076ad..00000000 --- a/src/data_designer/engine/resources/seed_dataset_data_store.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from abc import ABC, abstractmethod - -import duckdb -from huggingface_hub import HfApi, HfFileSystem - -from data_designer.logging import quiet_noisy_logger - -quiet_noisy_logger("httpx") - -_HF_DATASETS_PREFIX = "hf://datasets/" - - -class MalformedFileIdError(Exception): - """Raised when file_id format is invalid.""" - - -class SeedDatasetDataStore(ABC): - """Abstract base class for dataset storage implementations.""" - - @abstractmethod - def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ... - - @abstractmethod - def get_dataset_uri(self, file_id: str) -> str: ... - - -class LocalSeedDatasetDataStore(SeedDatasetDataStore): - """Local filesystem-based dataset storage.""" - - def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: - return duckdb.connect() - - def get_dataset_uri(self, file_id: str) -> str: - return file_id - - -class HfHubSeedDatasetDataStore(SeedDatasetDataStore): - """Hugging Face and Data Store dataset storage.""" - - def __init__(self, endpoint: str, token: str | None): - self.hfapi = HfApi(endpoint=endpoint, token=token) - self.endpoint = endpoint - self.token = token - - def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: - """Create a DuckDB connection with a fresh HfFileSystem registered. - - Creates a new HfFileSystem instance for each connection to ensure file metadata - is fetched fresh from the datastore, avoiding cache-related issues when reading - recently updated parquet files. - - Returns: - A DuckDB connection with the HfFileSystem registered for hf:// URI support. - """ - # Use skip_instance_cache to avoid fsspec-level caching - hffs = HfFileSystem(endpoint=self.endpoint, token=self.token, skip_instance_cache=True) - - # Clear all internal caches to avoid stale metadata issues - # HfFileSystem caches file metadata (size, etc.) which can become stale when files are re-uploaded - if hasattr(hffs, "dircache"): - hffs.dircache.clear() - - conn = duckdb.connect() - conn.register_filesystem(hffs) - return conn - - def get_dataset_uri(self, file_id: str) -> str: - identifier = file_id.removeprefix(_HF_DATASETS_PREFIX) - repo_id, filename = self._get_repo_id_and_filename(identifier) - return f"{_HF_DATASETS_PREFIX}{repo_id}/{filename}" - - def _get_repo_id_and_filename(self, identifier: str) -> tuple[str, str]: - """Extract repo_id and filename from identifier.""" - parts = identifier.split("/", 2) - if len(parts) < 3: - raise MalformedFileIdError( - "Could not extract repo id and filename from file_id, " - "expected 'hf://datasets/{repo-namespace}/{repo-name}/{filename}'" - ) - repo_ns, repo_name, filename = parts - return f"{repo_ns}/{repo_name}", filename diff --git a/src/data_designer/engine/resources/seed_reader.py b/src/data_designer/engine/resources/seed_reader.py new file mode 100644 index 00000000..f534e430 --- /dev/null +++ b/src/data_designer/engine/resources/seed_reader.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import Generic, TypeVar, get_args, get_origin + +import duckdb +from huggingface_hub import HfFileSystem +from typing_extensions import Self + +from data_designer.config.seed_source import ( + DataFrameSeedSource, + HuggingFaceSeedSource, + LocalFileSeedSource, + SeedSource, +) +from data_designer.engine.secret_resolver import SecretResolver +from data_designer.errors import DataDesignerError + + +class SeedReaderError(DataDesignerError): ... + + +SourceT = TypeVar("ConfigT", bound=SeedSource) + + +class SeedReader(ABC, Generic[SourceT]): + """Base class for reading a seed dataset. + + Seeds are read using duckdb. Reader implementations define duckdb connection setup details + and how to get a URI that can be queried with duckdb (i.e. "... FROM ..."). + + The Data Designer engine automatically supplies the appropriate SeedSource + and a SecretResolver to use for any secret fields in the config. + """ + + source: SourceT + secret_resolver: SecretResolver + + @abstractmethod + def get_dataset_uri(self) -> str: ... + + @abstractmethod + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: ... + + def attach(self, source: SourceT, secret_resolver: SecretResolver): + """Attach a source and secret resolver to the instance. + + This is called internally by the engine so that these objects do not + need to be provided in the reader's constructor. + """ + self.source = source + self.secret_resolver = secret_resolver + + def get_column_names(self) -> list[str]: + """Returns the seed dataset's column names""" + conn = self.create_duckdb_connection() + describe_query = f"DESCRIBE SELECT * FROM '{self.get_dataset_uri()}'" + column_descriptions = conn.execute(describe_query).fetchall() + return [col[0] for col in column_descriptions] + + def get_seed_type(self) -> str: + """Return the seed_type of the source class this reader is generic over.""" + # Get the generic type arguments from the reader class + # Check __orig_bases__ for the generic base class + for base in getattr(type(self), "__orig_bases__", []): + origin = get_origin(base) + if origin is SeedReader: + args = get_args(base) + if args: + source_cls = args[0] + # Extract seed_type from the source class + if hasattr(source_cls, "model_fields") and "seed_type" in source_cls.model_fields: + field = source_cls.model_fields["seed_type"] + default_value = field.default + if isinstance(default_value, str): + return default_value + + raise SeedReaderError("Reader does not have a valid generic source type with seed_type") + + +class LocalFileSeedReader(SeedReader[LocalFileSeedSource]): + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: + return duckdb.connect() + + def get_dataset_uri(self) -> str: + return self.source.path + + +class HuggingFaceSeedReader(SeedReader[HuggingFaceSeedSource]): + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: + token = self.secret_resolver.resolve(self.source.token) if self.source.token else None + + # Use skip_instance_cache to avoid fsspec-level caching + hffs = HfFileSystem(endpoint=self.source.endpoint, token=token, skip_instance_cache=True) + + # Clear all internal caches to avoid stale metadata issues + # HfFileSystem caches file metadata (size, etc.) which can become stale when files are re-uploaded + if hasattr(hffs, "dircache"): + hffs.dircache.clear() + + conn = duckdb.connect() + conn.register_filesystem(hffs) + return conn + + def get_dataset_uri(self) -> str: + return f"hf://{self.source.path}" + + +class DataFrameSeedReader(SeedReader[DataFrameSeedSource]): + # This is a "magic string" that gets registered in the duckdb connection to make the dataframe directly queryable. + _table_name = "df" + + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: + conn = duckdb.connect() + conn.register(self._table_name, self.source.df) + return conn + + def get_dataset_uri(self) -> str: + return self._table_name + + +class SeedReaderRegistry: + def __init__(self, readers: Sequence[SeedReader]): + self._readers: dict[str, SeedReader] = {} + for reader in readers: + self.add_reader(reader) + + def add_reader(self, reader: SeedReader) -> Self: + seed_type = reader.get_seed_type() + + if seed_type in self._readers: + raise SeedReaderError(f"A reader for seed_type {seed_type!r} already exists") + + self._readers[seed_type] = reader + return self + + def get_reader(self, seed_dataset_source: SeedSource, secret_resolver: SecretResolver) -> SeedReader: + reader = self._get_reader_for_source(seed_dataset_source) + reader.attach(seed_dataset_source, secret_resolver) + return reader + + def _get_reader_for_source(self, seed_dataset_source: SeedSource) -> SeedReader: + seed_type = seed_dataset_source.seed_type + try: + return self._readers[seed_type] + except KeyError: + raise SeedReaderError(f"No reader found for seed_type {seed_type!r}") diff --git a/src/data_designer/config/utils/validation.py b/src/data_designer/engine/validation.py similarity index 100% rename from src/data_designer/config/utils/validation.py rename to src/data_designer/engine/validation.py diff --git a/src/data_designer/interface/data_designer.py b/src/data_designer/interface/data_designer.py index 0458dc60..2dba3d21 100644 --- a/src/data_designer/interface/data_designer.py +++ b/src/data_designer/interface/data_designer.py @@ -20,7 +20,6 @@ ModelProvider, ) from data_designer.config.preview_results import PreviewResults -from data_designer.config.seed import LocalSeedDatasetReference from data_designer.config.utils.constants import ( DEFAULT_NUM_RECORDS, MANAGED_ASSETS_PATH, @@ -29,21 +28,23 @@ PREDEFINED_PROVIDERS, ) from data_designer.config.utils.info import InfoType, InterfaceInfo -from data_designer.config.utils.io_helpers import write_seed_dataset from data_designer.engine.analysis.dataset_profiler import ( DataDesignerDatasetProfiler, DatasetProfilerConfig, ) +from data_designer.engine.compiler import compile_data_designer_config from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.model_provider import resolve_model_provider_registry -from data_designer.engine.models.registry import create_model_registry from data_designer.engine.resources.managed_storage import init_managed_blob_storage -from data_designer.engine.resources.resource_provider import ResourceProvider -from data_designer.engine.resources.seed_dataset_data_store import ( - HfHubSeedDatasetDataStore, - LocalSeedDatasetDataStore, +from data_designer.engine.resources.resource_provider import ResourceProvider, create_resource_provider +from data_designer.engine.resources.seed_reader import ( + DataFrameSeedReader, + HuggingFaceSeedReader, + LocalFileSeedReader, + SeedReader, + SeedReaderRegistry, ) from data_designer.engine.secret_resolver import ( CompositeResolver, @@ -61,6 +62,14 @@ DEFAULT_BUFFER_SIZE = 1000 +DEFAULT_SECRET_RESOLVER = CompositeResolver([EnvironmentResolver(), PlaintextResolver()]) + +DEFAULT_SEED_READERS = [ + HuggingFaceSeedReader(), + LocalFileSeedReader(), + DataFrameSeedReader(), +] + logger = logging.getLogger(__name__) @@ -79,6 +88,7 @@ class DataDesigner(DataDesignerInterface[DatasetCreationResults]): uses default providers. secret_resolver: Resolver for handling secrets and credentials. Defaults to EnvironmentResolver which reads secrets from environment variables. + seed_readers: Optional list of seed readers. If None, uses default readers. managed_assets_path: Path to the managed assets directory. This is used to point to the location of managed datasets and other assets used during dataset generation. If not provided, will check for an environment variable called DATA_DESIGNER_MANAGED_ASSETS_PATH. @@ -92,9 +102,10 @@ def __init__( *, model_providers: list[ModelProvider] | None = None, secret_resolver: SecretResolver | None = None, + seed_readers: list[SeedReader] | None = None, managed_assets_path: Path | str | None = None, ): - self._secret_resolver = secret_resolver or CompositeResolver([EnvironmentResolver(), PlaintextResolver()]) + self._secret_resolver = secret_resolver or DEFAULT_SECRET_RESOLVER self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts" self._buffer_size = DEFAULT_BUFFER_SIZE self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH) @@ -102,42 +113,7 @@ def __init__( self._model_provider_registry = resolve_model_provider_registry( self._model_providers, get_default_provider_name() ) - - @staticmethod - def make_seed_reference_from_file(file_path: str | Path) -> LocalSeedDatasetReference: - """Create a seed dataset reference from an existing file. - - Supported file extensions: .parquet (recommended), .csv, .json, .jsonl - - Args: - file_path: Path to an existing dataset file. - - Returns: - A LocalSeedDatasetReference pointing to the specified file. - """ - return LocalSeedDatasetReference(dataset=str(file_path)) - - @classmethod - def make_seed_reference_from_dataframe( - cls, dataframe: pd.DataFrame, file_path: str | Path - ) -> LocalSeedDatasetReference: - """Create a seed dataset reference from a pandas DataFrame. - - This method writes the DataFrame to disk and returns a reference that can - be passed to the config builder's `with_seed_dataset` method. If the file - already exists, it will be overwritten. - - Supported file extensions: .parquet (recommended), .csv, .json, .jsonl - - Args: - dataframe: Pandas DataFrame to use as seed data. - file_path: Path where to save dataset. - - Returns: - A LocalSeedDatasetReference pointing to the written file. - """ - write_seed_dataset(dataframe, Path(file_path)) - return cls.make_seed_reference_from_file(file_path) + self._seed_reader_registry = SeedReaderRegistry(readers=seed_readers or DEFAULT_SEED_READERS) @property def info(self) -> InterfaceInfo: @@ -274,6 +250,23 @@ def preview( config_builder=config_builder, ) + def validate(self, config_builder: DataDesignerConfigBuilder) -> None: + """Validate the Data Designer configuration as defined by the DataDesignerConfigBuilder + with the configured engine components (SecretResolver, SeedReaders, etc.). + + Args: + config_builder: The DataDesignerConfigBuilder containing the dataset + configuration (columns, constraints, seed data, etc.). + + Returns: + None if the configuration is valid. + + Raises: + InvalidConfigError: If the configuration is invalid. + """ + resource_provider = self._create_resource_provider("validate-configuration", config_builder) + compile_data_designer_config(config_builder, resource_provider) + def get_default_model_configs(self) -> list[ModelConfig]: """Get the default model configurations. @@ -336,9 +329,11 @@ def _resolve_model_providers(self, model_providers: list[ModelProvider] | None) def _create_dataset_builder( self, config_builder: DataDesignerConfigBuilder, resource_provider: ResourceProvider ) -> ColumnWiseDatasetBuilder: + config = compile_data_designer_config(config_builder, resource_provider) + return ColumnWiseDatasetBuilder( - column_configs=compile_dataset_builder_column_configs(config_builder.build(raise_exceptions=True)), - processor_configs=config_builder.get_processor_configs(), + column_configs=compile_dataset_builder_column_configs(config), + processor_configs=config.processors or [], resource_provider=resource_provider, ) @@ -356,24 +351,20 @@ def _create_dataset_profiler( def _create_resource_provider( self, dataset_name: str, config_builder: DataDesignerConfigBuilder ) -> ResourceProvider: - model_configs = config_builder.model_configs ArtifactStorage.mkdir_if_needed(self._artifact_path) - return ResourceProvider( + + seed_dataset_source = None + if (seed_config := config_builder.get_seed_config()) is not None: + seed_dataset_source = seed_config.source + + return create_resource_provider( artifact_storage=ArtifactStorage(artifact_path=self._artifact_path, dataset_name=dataset_name), - model_registry=create_model_registry( - model_configs=model_configs, - model_provider_registry=self._model_provider_registry, - secret_resolver=self._secret_resolver, - ), + model_configs=config_builder.model_configs, + secret_resolver=self._secret_resolver, + model_provider_registry=self._model_provider_registry, blob_storage=init_managed_blob_storage(str(self._managed_assets_path)), - datastore=( - LocalSeedDatasetDataStore() - if (settings := config_builder.get_seed_datastore_settings()) is None - else HfHubSeedDatasetDataStore( - endpoint=settings.endpoint, - token=settings.token, - ) - ), + seed_dataset_source=seed_dataset_source, + seed_reader_registry=self._seed_reader_registry, ) def _get_interface_info(self, model_providers: list[ModelProvider]) -> InterfaceInfo: diff --git a/src/data_designer/temp_nmp.py b/src/data_designer/temp_nmp.py new file mode 100644 index 00000000..b5330cc1 --- /dev/null +++ b/src/data_designer/temp_nmp.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +This module will actually exist in NMP as a plugin, but providing it here temporarily as a sketch. + +We need to make a few updates to plugins: +1. Dependency management (the "engine" problem); separate PR with an approach already up for this +2. Add support for PluginType.SEED_DATASET +""" + +from typing import Literal + +import duckdb + +from data_designer.config.seed_source import SeedSource +from data_designer.engine.resources.seed_reader import SeedReader + + +class NMPFileSeedConfig(SeedSource): + seed_type: Literal["nmp"] = "nmp" # or "fileset" since that's the fsspec client protocol? + + # Check with MG: what do we expect to be more common in scenarios like this? + # 1. Just "fileset", with optional workspace prefix, e.g. "myworkspace/myfileset", "myfileset" (implicit default workspace), etc. + # 2. Separate "fileset" and "workspace" fields + fileset: str + path: str + + +class NMPFileSeedReader(SeedReader[NMPFileSeedConfig]): + def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: + # NMP helper function + sdk = get_platform_sdk() # noqa:F821 + # New fsspec client for Files service + fs = FilesetFileSystem(sdk) # noqa:F821 + + conn = duckdb.connect() + conn.register_filesystem(fs) + return conn + + def get_dataset_uri(self) -> str: + workspace, fileset_name = self._get_workspace_and_fileset_name() + return f"fileset://{workspace}/{fileset_name}/{self.config.path}" + + def _get_workspace_and_fileset_name(self) -> tuple[str, str]: + match self.config.fileset.split("/"): + case [fileset_name]: + return ("default", fileset_name) + case [workspace, fileset_name]: + return (workspace, fileset_name) + case _: + raise ValueError("Malformed fileset") diff --git a/tests/config/test_config_builder.py b/tests/config/test_config_builder.py index 40dc21d3..549d316b 100644 --- a/tests/config/test_config_builder.py +++ b/tests/config/test_config_builder.py @@ -6,6 +6,7 @@ from pathlib import Path from unittest.mock import patch +import pandas as pd import pytest import yaml from pydantic import BaseModel, ValidationError @@ -24,12 +25,17 @@ from data_designer.config.column_types import DataDesignerColumnType, get_column_config_from_kwargs from data_designer.config.config_builder import BuilderConfig, DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig -from data_designer.config.datastore import DatastoreSettings -from data_designer.config.errors import BuilderConfigurationError, InvalidColumnTypeError, InvalidConfigError +from data_designer.config.errors import ( + BuilderConfigurationError, + BuilderSerializationError, + InvalidColumnTypeError, + InvalidConfigError, +) from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig from data_designer.config.sampler_constraints import ColumnInequalityConstraint, ScalarInequalityConstraint from data_designer.config.sampler_params import SamplerType, UUIDSamplerParams -from data_designer.config.seed import DatastoreSeedDatasetReference, SamplingStrategy +from data_designer.config.seed import SamplingStrategy +from data_designer.config.seed_source import DataFrameSeedSource, HuggingFaceSeedSource from data_designer.config.utils.code_lang import CodeLang from data_designer.config.utils.info import ConfigBuilderInfo from data_designer.config.validator_params import CodeValidatorParams @@ -39,18 +45,9 @@ class DummyStructuredModel(BaseModel): stub: str -@pytest.fixture -def mock_fetch_seed_dataset_column_names(): - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch_seed: - mock_fetch_seed.return_value = ["id", "name", "city", "country"] - yield mock_fetch_seed - - @pytest.fixture def stub_data_designer_builder(stub_data_designer_builder_config_str): - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch_seed: - mock_fetch_seed.return_value = ["id", "name", "city", "country"] - yield DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) + yield DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) def test_loading_model_configs_in_constructor(stub_model_configs): @@ -76,7 +73,7 @@ def test_loading_model_configs_in_constructor(stub_model_configs): builder = DataDesignerConfigBuilder(model_configs=tmp_file.name) -def test_from_config(stub_data_designer_builder_config_str, mock_fetch_seed_dataset_column_names): +def test_from_config(stub_data_designer_builder_config_str): builder = DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) assert isinstance(builder.get_column_config(name="code_id"), SamplerColumnConfig) @@ -362,10 +359,6 @@ def test_build(stub_data_designer_builder): ndd_config = stub_data_designer_builder.build() assert isinstance(ndd_config, DataDesignerConfig) - with patch("data_designer.config.config_builder.DataDesignerConfigBuilder.validate") as mock_validate: - stub_data_designer_builder.build(skip_validation=True) - mock_validate.assert_not_called() - def test_config_export_to_files(stub_data_designer_builder): """Test config export to JSON and YAML files via DataDesignerConfig methods.""" @@ -408,20 +401,15 @@ def test_delete_column(stub_data_designer_builder): stub_data_designer_builder.delete_column(column_name="code_id") assert len(stub_data_designer_builder.get_columns_of_type(DataDesignerColumnType.SAMPLER)) == 3 - with pytest.raises( - BuilderConfigurationError, match="Seed columns cannot be deleted. Please update the seed dataset instead." - ): - stub_data_designer_builder.delete_column(column_name="id") - def test_getters(stub_data_designer_builder): - assert len(stub_data_designer_builder.get_column_configs()) == 12 + assert len(stub_data_designer_builder.get_column_configs()) == 8 assert stub_data_designer_builder.get_column_config(name="code_id").name == "code_id" assert len(stub_data_designer_builder.get_constraints(target_column="age")) == 1 assert len(stub_data_designer_builder.get_llm_gen_columns()) == 3 assert len(stub_data_designer_builder.get_columns_of_type(DataDesignerColumnType.SAMPLER)) == 4 - assert len(stub_data_designer_builder.get_columns_excluding_type(DataDesignerColumnType.SAMPLER)) == 8 - assert stub_data_designer_builder.get_seed_config().dataset == "test-repo/testing/data.csv" + assert len(stub_data_designer_builder.get_columns_excluding_type(DataDesignerColumnType.SAMPLER)) == 4 + assert stub_data_designer_builder.get_seed_config().source.path == "datasets/test-repo/testing/data.csv" assert stub_data_designer_builder.num_columns_of_type(DataDesignerColumnType.SAMPLER) == 4 @@ -446,32 +434,6 @@ def test_write_config(stub_data_designer_builder): stub_data_designer_builder.write_config(temp_path.with_suffix(".txt")) -def test_validate(stub_empty_builder): - try: - stub_empty_builder.validate(raise_exceptions=False) - except Exception: - pytest.fail("Validate should not raise an exception if raise_exceptions is False") - - with pytest.raises( - InvalidConfigError, - match="Your configuration contains validation errors. Please address the indicated issues and try again.", - ): - stub_empty_builder.validate(raise_exceptions=True) - - with patch("data_designer.config.config_builder.logger") as mock_logger: - stub_empty_builder.add_column( - name="test_column", - column_type=DataDesignerColumnType.SAMPLER, - sampler_type=SamplerType.UUID, - params={"prefix": "test_", "short_form": True, "uppercase": True}, - ) - try: - stub_empty_builder.validate(raise_exceptions=True) - except Exception: - pytest.fail("Validate should not raise an exception for valid configuration") - mock_logger.info.assert_called_once_with("βœ… Validation passed") - - def test_get_column_config_from_kwargs(): # Test column creation and serialization @@ -629,34 +591,28 @@ def test_get_column_config_from_kwargs(): def test_seed_config(stub_complete_builder): seed_config = stub_complete_builder.get_seed_config() assert seed_config is not None - assert seed_config.dataset == "test-repo/testing/data.csv" + assert seed_config.source.path == "datasets/test-repo/testing/data.csv" assert seed_config.sampling_strategy == SamplingStrategy.SHUFFLE -def test_with_seed_dataset_basic(stub_empty_builder, mock_fetch_seed_dataset_column_names): +def test_with_seed_dataset_basic(stub_empty_builder): """Test with_seed_dataset method with basic parameters.""" - datastore_settings = DatastoreSettings(endpoint="https://huggingface.co", token="test-token") - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age", "city"] - result = stub_empty_builder.with_seed_dataset( - DatastoreSeedDatasetReference(dataset="test-repo/test-data.parquet", datastore_settings=datastore_settings) - ) + path = "datasets/test-repo/testing/data.csv" + source = HuggingFaceSeedSource(path=path) + result = stub_empty_builder.with_seed_dataset(source) assert result is stub_empty_builder - assert stub_empty_builder.get_seed_config().dataset == "test-repo/test-data.parquet" - assert len(stub_empty_builder.get_columns_of_type(DataDesignerColumnType.SEED_DATASET)) == 4 + assert stub_empty_builder.get_seed_config().source.path == path -def test_with_seed_dataset_sampling_strategy(stub_empty_builder, mock_fetch_seed_dataset_column_names): +def test_with_seed_dataset_sampling_strategy(stub_empty_builder): """Test with_seed_dataset with different sampling strategies.""" - datastore_settings = DatastoreSettings(endpoint="https://huggingface.co", token="test-token") + config = HuggingFaceSeedSource(path="datasets/test-repo/test-data.parquet", token="test-token") - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age", "city"] - stub_empty_builder.with_seed_dataset( - DatastoreSeedDatasetReference(dataset="test-repo/test-data.parquet", datastore_settings=datastore_settings), - sampling_strategy=SamplingStrategy.SHUFFLE, - ) + stub_empty_builder.with_seed_dataset( + config, + sampling_strategy=SamplingStrategy.SHUFFLE, + ) seed_config = stub_empty_builder.get_seed_config() assert seed_config.sampling_strategy == SamplingStrategy.SHUFFLE @@ -761,103 +717,21 @@ def test_delete_model_config(stub_empty_builder): assert len(stub_empty_builder.model_configs) == 2 -def test_add_column_collision_with_seed_dataset(stub_empty_builder: DataDesignerConfigBuilder) -> None: - """Test that adding a column that collides with a seed dataset column raises an error.""" - datastore_settings = DatastoreSettings(endpoint="https://huggingface.co", token="test-token") - - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age"] - stub_empty_builder.with_seed_dataset( - DatastoreSeedDatasetReference(dataset="test-repo/test-data.parquet", datastore_settings=datastore_settings) - ) - - with pytest.raises( - BuilderConfigurationError, - match="Column 'id' already exists as a seed dataset column", - ): - stub_empty_builder.add_column( - name="id", - column_type=DataDesignerColumnType.SAMPLER, - sampler_type=SamplerType.UUID, - ) - - with pytest.raises( - BuilderConfigurationError, - match="Column 'name' already exists as a seed dataset column", - ): - stub_empty_builder.add_column( - LLMTextColumnConfig( - name="name", - prompt="Write a name", - model_alias="stub-model", - ) - ) - - -def test_with_seed_dataset_collision_with_existing_columns(stub_empty_builder: DataDesignerConfigBuilder) -> None: - """Test that adding a seed dataset with columns that collide with existing columns raises an error.""" - stub_empty_builder.add_column( - name="name", - column_type=DataDesignerColumnType.LLM_TEXT, - prompt="Write a name", - model_alias="stub-model", - ) - stub_empty_builder.add_column( - name="age", - column_type=DataDesignerColumnType.SAMPLER, - sampler_type=SamplerType.UNIFORM, - params={"low": 1, "high": 100}, - ) - - datastore_settings = DatastoreSettings(endpoint="https://huggingface.co", token="test-token") - - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age", "city"] - with pytest.raises( - BuilderConfigurationError, - match=r"Seed dataset column\(s\) \['name', 'age'\] collide with existing column\(s\)", - ): - stub_empty_builder.with_seed_dataset( - DatastoreSeedDatasetReference( - dataset="test-repo/test-data.parquet", datastore_settings=datastore_settings - ) - ) - - assert stub_empty_builder.get_seed_config() is None - assert len(stub_empty_builder.get_columns_of_type(DataDesignerColumnType.SEED_DATASET)) == 0 +def test_cannot_write_config_with_dataframe_seed(stub_model_configs): + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + df = pd.DataFrame(data={"hello": [1, 2], "world": [10, 20]}) + df_seed = DataFrameSeedSource(df=df) + builder.with_seed_dataset(df_seed) -def test_with_seed_dataset_no_collision(stub_empty_builder: DataDesignerConfigBuilder) -> None: - """Test that adding a seed dataset with non-colliding columns works fine.""" - stub_empty_builder.add_column( - name="unique_column", - column_type=DataDesignerColumnType.SAMPLER, + sampler_column = SamplerColumnConfig( + name="test_id", sampler_type=SamplerType.UUID, + params=UUIDSamplerParams(prefix="code_", short_form=True, uppercase=True), ) + builder.add_column(sampler_column) - datastore_settings = DatastoreSettings(endpoint="https://huggingface.co", token="test-token") - - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "age"] - stub_empty_builder.with_seed_dataset( - DatastoreSeedDatasetReference(dataset="test-repo/test-data.parquet", datastore_settings=datastore_settings) - ) - - assert stub_empty_builder.get_seed_config() is not None - assert len(stub_empty_builder.get_columns_of_type(DataDesignerColumnType.SEED_DATASET)) == 3 - assert len(stub_empty_builder.get_columns_of_type(DataDesignerColumnType.SAMPLER)) == 1 - - -def test_from_config_does_not_duplicate_seed_dataset_columns( - stub_data_designer_builder: DataDesignerConfigBuilder, -) -> None: - """Regression test: seed dataset columns should not be duplicated during deserialization.""" - with tempfile.TemporaryDirectory() as temp_dir: - config_path = Path(temp_dir) / "config.json" - stub_data_designer_builder.write_config(config_path) - - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "city", "country"] - reloaded_builder = DataDesignerConfigBuilder.from_config(config_path) + with pytest.raises(BuilderSerializationError) as excinfo: + builder.write_config("./config.json") - assert reloaded_builder.num_columns_of_type(DataDesignerColumnType.SEED_DATASET) == 4 + assert "DataFrame seed dataset" in str(excinfo.value) diff --git a/tests/config/test_datastore.py b/tests/config/test_datastore.py deleted file mode 100644 index a2fe4d5f..00000000 --- a/tests/config/test_datastore.py +++ /dev/null @@ -1,286 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from unittest.mock import MagicMock, patch - -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq -import pytest - -from data_designer.config.datastore import ( - DatastoreSettings, - fetch_seed_dataset_column_names, - fetch_seed_dataset_column_names_from_local_file, - get_file_column_names, - resolve_datastore_settings, - upload_to_hf_hub, -) -from data_designer.config.errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError -from data_designer.config.seed import DatastoreSeedDatasetReference, LocalSeedDatasetReference - - -@pytest.fixture -def datastore_settings(): - return DatastoreSettings(endpoint="https://testing.com", token="stub-token") - - -def _write_file(df, path, file_type): - if file_type == "parquet": - df.to_parquet(path) - elif file_type in {"json", "jsonl"}: - df.to_json(path, orient="records", lines=True) - else: - df.to_csv(path, index=False) - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_basic_parquet(tmp_path, file_type): - """Test get_file_column_names with basic parquet file.""" - test_data = { - "id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - "age": [25, 30, 35], - "city": ["NYC", "LA", "Chicago"], - } - df = pd.DataFrame(test_data) - - parquet_path = tmp_path / f"test_data.{file_type}" - _write_file(df, parquet_path, file_type) - assert get_file_column_names(str(parquet_path), file_type) == df.columns.tolist() - - -def test_get_file_column_names_nested_fields(tmp_path): - """Test get_file_column_names with nested fields in parquet.""" - schema = pa.schema( - [ - pa.field( - "nested", pa.struct([pa.field("col1", pa.list_(pa.int32())), pa.field("col2", pa.list_(pa.int32()))]) - ), - ] - ) - - # For PyArrow, we need to structure the data as a list of records - nested_data = {"nested": [{"col1": [1, 2, 3], "col2": [4, 5, 6]}]} - nested_path = tmp_path / "nested_fields.parquet" - pq.write_table(pa.Table.from_pydict(nested_data, schema=schema), nested_path) - - column_names = get_file_column_names(str(nested_path), "parquet") - - assert column_names == ["nested"] - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_empty_parquet(tmp_path, file_type): - """Test get_file_column_names with empty parquet file.""" - empty_df = pd.DataFrame() - empty_path = tmp_path / f"empty.{file_type}" - _write_file(empty_df, empty_path, file_type) - - column_names = get_file_column_names(str(empty_path), file_type) - assert column_names == [] - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_large_schema(tmp_path, file_type): - """Test get_file_column_names with many columns.""" - num_columns = 50 - test_data = {f"col_{i}": np.random.randn(10) for i in range(num_columns)} - df = pd.DataFrame(test_data) - - large_path = tmp_path / f"large_schema.{file_type}" - _write_file(df, large_path, file_type) - - column_names = get_file_column_names(str(large_path), file_type) - assert len(column_names) == num_columns - assert column_names == [f"col_{i}" for i in range(num_columns)] - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_special_characters(tmp_path, file_type): - """Test get_file_column_names with special characters in column names.""" - special_data = { - "column with spaces": [1], - "column-with-dashes": [2], - "column_with_underscores": [3], - "column.with.dots": [4], - "column123": [5], - "123column": [6], - "column!@#$%^&*()": [7], - } - df_special = pd.DataFrame(special_data) - special_path = tmp_path / f"special_chars.{file_type}" - _write_file(df_special, special_path, file_type) - - assert get_file_column_names(str(special_path), file_type) == df_special.columns.tolist() - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_get_file_column_names_unicode(tmp_path, file_type): - """Test get_file_column_names with unicode column names.""" - unicode_data = {"cafΓ©": [1], "rΓ©sumΓ©": [2], "naΓ―ve": [3], "faΓ§ade": [4], "garΓ§on": [5], "ΓΌber": [6], "schΓΆn": [7]} - df_unicode = pd.DataFrame(unicode_data) - - unicode_path = tmp_path / f"unicode_columns.{file_type}" - _write_file(df_unicode, unicode_path, file_type) - assert get_file_column_names(str(unicode_path), file_type) == df_unicode.columns.tolist() - - -def test_get_file_column_names_with_filesystem_parquet(): - """Test get_file_column_names with filesystem parameter for parquet files.""" - mock_schema = MagicMock() - mock_schema.names = ["col1", "col2", "col3"] - - with patch("data_designer.config.datastore.pq.read_schema") as mock_read_schema: - mock_read_schema.return_value = mock_schema - result = get_file_column_names("datasets/test/file.parquet", "parquet") - - assert result == ["col1", "col2", "col3"] - mock_read_schema.assert_called_once_with("datasets/test/file.parquet") - - -@pytest.mark.parametrize("file_type", ["json", "jsonl", "csv"]) -def test_get_file_column_names_with_filesystem_non_parquet(tmp_path, file_type): - """Test get_file_column_names with file-like objects for non-parquet files.""" - test_data = pd.DataFrame({"col1": [1], "col2": [2], "col3": [3]}) - - # Create a real temporary file - file_path = tmp_path / f"test_file.{file_type}" - if file_type in ["json", "jsonl"]: - test_data.to_json(file_path, orient="records", lines=True) - else: - test_data.to_csv(file_path, index=False) - - result = get_file_column_names(str(file_path), file_type) - - assert result == ["col1", "col2", "col3"] - - -def test_get_file_column_names_error_handling(): - with pytest.raises(InvalidFilePathError, match="πŸ›‘ Unsupported file type: 'txt'"): - get_file_column_names("test.txt", "txt") - - with patch("data_designer.config.datastore.pq.read_schema") as mock_read_schema: - mock_read_schema.side_effect = Exception("Test error") - assert get_file_column_names("test.txt", "parquet") == [] - - with patch("data_designer.config.datastore.pq.read_schema") as mock_read_schema: - mock_col1 = MagicMock() - mock_col1.name = "col1" - mock_col2 = MagicMock() - mock_col2.name = "col2" - mock_read_schema.return_value = [mock_col1, mock_col2] - assert get_file_column_names("test.txt", "parquet") == ["col1", "col2"] - - -def test_fetch_seed_dataset_column_names_parquet_error_handling(datastore_settings): - with pytest.raises(InvalidFileFormatError, match="πŸ›‘ Unsupported file type: 'test.txt'"): - fetch_seed_dataset_column_names( - DatastoreSeedDatasetReference( - dataset="test/repo/test.txt", - datastore_settings=datastore_settings, - ) - ) - - -@patch("data_designer.config.datastore.get_file_column_names", autospec=True) -def test_fetch_seed_dataset_column_names_local_file(mock_get_file_column_names, datastore_settings): - mock_get_file_column_names.return_value = ["col1", "col2"] - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - assert fetch_seed_dataset_column_names(LocalSeedDatasetReference(dataset="test.parquet")) == ["col1", "col2"] - - -@patch("data_designer.config.datastore.HfFileSystem") -@patch("data_designer.config.datastore.get_file_column_names", autospec=True) -def test_fetch_seed_dataset_column_names_remote_file(mock_get_file_column_names, mock_hf_fs, datastore_settings): - mock_get_file_column_names.return_value = ["col1", "col2"] - mock_fs_instance = MagicMock() - mock_hf_fs.return_value = mock_fs_instance - - assert fetch_seed_dataset_column_names( - DatastoreSeedDatasetReference( - dataset="test/repo/test.parquet", - datastore_settings=datastore_settings, - ) - ) == ["col1", "col2"] - - mock_hf_fs.assert_called_once_with( - endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True - ) - - # The get_file_column_names is called with a file-like object from fs.open() - assert mock_get_file_column_names.call_count == 1 - call_args = mock_get_file_column_names.call_args - assert call_args[0][1] == "parquet" - - -def test_resolve_datastore_settings(datastore_settings): - with pytest.raises(InvalidConfigError, match="Datastore settings are required"): - resolve_datastore_settings(None) - - with pytest.raises(InvalidConfigError, match="Invalid datastore settings format"): - resolve_datastore_settings("invalid_settings") - - assert resolve_datastore_settings(datastore_settings) == datastore_settings - assert resolve_datastore_settings(datastore_settings.model_dump()) == datastore_settings - - -@patch("data_designer.config.datastore.HfApi.upload_file", autospec=True) -@patch("data_designer.config.datastore.HfApi.create_repo", autospec=True) -def test_upload_to_hf_hub(mock_create_repo, mock_upload_file, datastore_settings): - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - - assert ( - upload_to_hf_hub("test.parquet", "test.parquet", "test/repo", datastore_settings) - == "test/repo/test.parquet" - ) - mock_create_repo.assert_called_once() - mock_upload_file.assert_called_once() - - -def test_upload_to_hf_hub_error_handling(datastore_settings): - with pytest.raises( - InvalidFilePathError, match="To upload a dataset to the datastore, you must provide a valid file path." - ): - upload_to_hf_hub("test.txt", "test.txt", "test/repo", datastore_settings) - - with pytest.raises( - InvalidFileFormatError, match="Dataset file extension '.parquet' does not match `filename` extension .'csv'" - ): - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - upload_to_hf_hub("test.parquet", "test.csv", "test/repo", datastore_settings) - - with pytest.raises(InvalidFileFormatError, match="Dataset files must be in "): - with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file: - mock_is_file.return_value = True - upload_to_hf_hub("test.text", "test.txt", "test/repo", datastore_settings) - - -@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"]) -def test_fetch_seed_dataset_column_names_from_local_file_with_glob(tmp_path, file_type): - """Test fetch_seed_dataset_column_names_from_local_file with glob pattern matching multiple files.""" - test_data = pd.DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 6]}) - - # Create multiple files with the same schema - for i in range(3): - file_path = tmp_path / f"data_{i}.{file_type}" - _write_file(test_data, file_path, file_type) - - # Test glob pattern that matches all files - glob_pattern = str(tmp_path / f"*.{file_type}") - result = fetch_seed_dataset_column_names_from_local_file(glob_pattern) - - assert result == ["col1", "col2", "col3"] - - -@pytest.mark.parametrize("file_type", ["parquet", "csv"]) -def test_fetch_seed_dataset_column_names_from_local_file_with_glob_no_matches(tmp_path, file_type): - """Test fetch_seed_dataset_column_names_from_local_file with glob pattern that matches no files.""" - glob_pattern = str(tmp_path / f"nonexistent_*.{file_type}") - - with pytest.raises(InvalidFilePathError, match="does not contain files of type"): - fetch_seed_dataset_column_names_from_local_file(glob_pattern) diff --git a/tests/config/test_seed.py b/tests/config/test_seed.py index 42acf189..2e7a2e1b 100644 --- a/tests/config/test_seed.py +++ b/tests/config/test_seed.py @@ -1,29 +1,9 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path - -import pandas as pd import pytest -from data_designer.config.errors import InvalidFilePathError -from data_designer.config.seed import IndexRange, LocalSeedDatasetReference, PartitionBlock - - -def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path: - df = pd.DataFrame({"col": [1, 2, 3]}) - - for i in range(num_files): - file_path = temp_dir / f"partition_{i}.{extension}" - if extension == "parquet": - df.to_parquet(file_path) - elif extension == "csv": - df.to_csv(file_path, index=False) - elif extension == "json": - df.to_json(file_path, orient="records", lines=True) - elif extension == "jsonl": - df.to_json(file_path, orient="records", lines=True) - return temp_dir +from data_designer.config.seed import IndexRange, PartitionBlock def test_index_range_validation(): @@ -74,28 +54,3 @@ def test_partition_block_to_index_range(): assert index_range.start == 90 assert index_range.end == 104 assert index_range.size == 15 - - -def test_local_seed_dataset_reference_validation(tmp_path: Path): - with pytest.raises(InvalidFilePathError, match="πŸ›‘ Path test/dataset.parquet is not a file."): - LocalSeedDatasetReference(dataset="test/dataset.parquet") - - # Should not raise an error when referencing supported extensions with wildcard pattern. - create_partitions_in_path(tmp_path, "parquet") - create_partitions_in_path(tmp_path, "csv") - create_partitions_in_path(tmp_path, "json") - create_partitions_in_path(tmp_path, "jsonl") - - test_cases = ["parquet", "csv", "json", "jsonl"] - try: - for extension in test_cases: - reference = LocalSeedDatasetReference(dataset=f"{tmp_path}/*.{extension}") - assert reference.dataset == f"{tmp_path}/*.{extension}" - except Exception as e: - pytest.fail(f"Expected no exception, but got {e}") - - -def test_local_seed_dataset_reference_validation_error(tmp_path: Path): - create_partitions_in_path(tmp_path, "parquet") - with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"): - LocalSeedDatasetReference(dataset=f"{tmp_path}/*.csv") diff --git a/tests/config/test_seed_source.py b/tests/config/test_seed_source.py new file mode 100644 index 00000000..1f0339c7 --- /dev/null +++ b/tests/config/test_seed_source.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pandas as pd +import pytest + +from data_designer.config.errors import InvalidFilePathError +from data_designer.config.seed_source import LocalFileSeedSource + + +def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path: + df = pd.DataFrame({"col": [1, 2, 3]}) + + for i in range(num_files): + file_path = temp_dir / f"partition_{i}.{extension}" + if extension == "parquet": + df.to_parquet(file_path) + elif extension == "csv": + df.to_csv(file_path, index=False) + elif extension == "json": + df.to_json(file_path, orient="records", lines=True) + elif extension == "jsonl": + df.to_json(file_path, orient="records", lines=True) + return temp_dir + + +def test_local_seed_dataset_reference_validation(tmp_path: Path): + with pytest.raises(InvalidFilePathError, match="πŸ›‘ Path test/dataset.parquet is not a file."): + LocalFileSeedSource(path="test/dataset.parquet") + + # Should not raise an error when referencing supported extensions with wildcard pattern. + create_partitions_in_path(tmp_path, "parquet") + create_partitions_in_path(tmp_path, "csv") + create_partitions_in_path(tmp_path, "json") + create_partitions_in_path(tmp_path, "jsonl") + + test_cases = ["parquet", "csv", "json", "jsonl"] + try: + for extension in test_cases: + config = LocalFileSeedSource(path=f"{tmp_path}/*.{extension}") + assert config.path == f"{tmp_path}/*.{extension}" + except Exception as e: + pytest.fail(f"Expected no exception, but got {e}") + + +def test_local_seed_dataset_reference_validation_error(tmp_path: Path): + create_partitions_in_path(tmp_path, "parquet") + with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"): + LocalFileSeedSource(path=f"{tmp_path}/*.csv") + + +def test_local_source_from_dataframe(tmp_path: Path): + df = pd.DataFrame({"col": [1, 2, 3]}) + filepath = f"{tmp_path}/data.parquet" + + source = LocalFileSeedSource.from_dataframe(df, filepath) + + assert source.path == filepath + pd.testing.assert_frame_equal(df, pd.read_parquet(filepath)) diff --git a/tests/config/utils/test_visualization.py b/tests/config/utils/test_visualization.py index bb77e895..99caaeaa 100644 --- a/tests/config/utils/test_visualization.py +++ b/tests/config/utils/test_visualization.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import patch - import pandas as pd import pytest @@ -30,21 +28,18 @@ def validation_output(): @pytest.fixture def config_builder_with_validation(stub_model_configs): """Fixture providing a DataDesignerConfigBuilder with a validation column.""" - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["code"] - - builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) - - # Add a validation column configuration - builder.add_column( - name="code_validation_result", - column_type="validation", - target_columns=["code"], - validator_type="code", - validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON), - ) - - return builder + builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + + # Add a validation column configuration + builder.add_column( + name="code_validation_result", + column_type="validation", + target_columns=["code"], + validator_type="code", + validator_params=CodeValidatorParams(code_lang=CodeLang.PYTHON), + ) + + return builder def test_display_sample_record_twice_no_errors(validation_output, config_builder_with_validation): diff --git a/tests/conftest.py b/tests/conftest.py index 72f5a8f2..9e858690 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import tarfile import tempfile import textwrap -from unittest.mock import patch import pandas as pd import pytest @@ -16,8 +15,9 @@ from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.data_designer_config import DataDesignerConfig -from data_designer.config.datastore import DatastoreSettings from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider +from data_designer.config.seed_source import HuggingFaceSeedSource +from data_designer.engine.resources.seed_reader import SeedReader @pytest.fixture @@ -39,7 +39,9 @@ def stub_data_designer_config_str() -> str: weights: [0.3, 0.2, 0.50] seed_config: - dataset: test-repo/testing/data.csv + source: + seed_type: hf + path: datasets/test-repo/testing/data.csv sampling_strategy: shuffle columns: @@ -116,10 +118,6 @@ def stub_data_designer_builder_config_str(stub_data_designer_config_str: str) -> return f""" data_designer: {textwrap.indent(stub_data_designer_config_str, prefix=" ")} - -datastore_settings: - endpoint: http://test-endpoint:3000/v1/hf - token: stub-token """ @@ -163,15 +161,7 @@ def stub_empty_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerCon @pytest.fixture def stub_complete_builder(stub_data_designer_builder_config_str: str) -> DataDesignerConfigBuilder: - with patch("data_designer.config.config_builder.fetch_seed_dataset_column_names") as mock_fetch: - mock_fetch.return_value = ["id", "name", "city", "country"] - return DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) - - -@pytest.fixture -def stub_datastore_settings(): - """Test datastore settings with testing endpoint and token.""" - return DatastoreSettings(endpoint="https://testing.com", token="stub-token") + return DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str) @pytest.fixture @@ -310,3 +300,19 @@ def stub_sampler_only_config_builder(stub_model_configs: list[ModelConfig]) -> D SamplerColumnConfig(name="uniform", sampler_type="uniform", params={"low": 1, "high": 100}) ) return config_builder + + +class StubHuggingFaceSeedReader(SeedReader[HuggingFaceSeedSource]): + def get_column_names(self) -> list[str]: + return ["age", "city"] + + def get_dataset_uri(self) -> str: + return "unused in these tests" + + def create_duckdb_connection(self): + pass + + +@pytest.fixture +def stub_seed_reader(): + return StubHuggingFaceSeedReader() diff --git a/tests/engine/analysis/test_data/artifacts/dataset/column_configs.json b/tests/engine/analysis/test_data/artifacts/dataset/column_configs.json index 8062b1f7..449358f1 100644 --- a/tests/engine/analysis/test_data/artifacts/dataset/column_configs.json +++ b/tests/engine/analysis/test_data/artifacts/dataset/column_configs.json @@ -8,7 +8,10 @@ { "name": "uniform", "drop": false }, { "name": "poisson", "drop": false } ], - "dataset": "local_df.csv", + "source": { + "path": "datasets/test-repo/testing/data.csv", + "seed_type": "hf" + }, "sampling_strategy": "shuffle" }, { diff --git a/tests/engine/column_generators/generators/test_seed_dataset.py b/tests/engine/column_generators/generators/test_seed_dataset.py index 37e69a26..792632c2 100644 --- a/tests/engine/column_generators/generators/test_seed_dataset.py +++ b/tests/engine/column_generators/generators/test_seed_dataset.py @@ -11,6 +11,7 @@ from data_designer.config.column_configs import SeedDatasetColumnConfig from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy +from data_designer.config.seed_source import HuggingFaceSeedSource, LocalFileSeedSource from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.seed_dataset import ( MAX_ZERO_RECORD_RESPONSE_FACTOR, @@ -30,15 +31,18 @@ def stub_duckdb_conn(): @pytest.fixture def stub_seed_dataset_config(): - return SeedDatasetMultiColumnConfig(columns=[SeedDatasetColumnConfig(name="col1")], dataset="test/dataset") + return SeedDatasetMultiColumnConfig( + columns=[SeedDatasetColumnConfig(name="col1")], + source=HuggingFaceSeedSource(path="hf://datasets/test/dataset"), + ) @pytest.fixture def stub_seed_dataset_generator(stub_resource_provider, stub_duckdb_conn, stub_seed_dataset_config): mock_provider = stub_resource_provider - mock_datastore = mock_provider.datastore - mock_datastore.create_duckdb_connection.return_value = stub_duckdb_conn - mock_datastore.get_dataset_uri.return_value = "test_uri" + mock_seed_reader = mock_provider.seed_reader + mock_seed_reader.create_duckdb_connection.return_value = stub_duckdb_conn + mock_seed_reader.get_dataset_uri.return_value = "test_uri" return SeedDatasetColumnGenerator(config=stub_seed_dataset_config, resource_provider=mock_provider) @@ -107,17 +111,17 @@ def seed_dataset_jsonl(sample_dataframe): def test_seed_dataset_column_generator_metadata(): metadata = SeedDatasetColumnGenerator.metadata() assert metadata.generation_strategy == GenerationStrategy.FULL_COLUMN - assert ResourceType.DATASTORE in metadata.required_resources + assert ResourceType.SEED_READER in metadata.required_resources def test_seed_dataset_column_generator_config_structure(): config = SeedDatasetMultiColumnConfig( columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")], - dataset="test/dataset", + source=HuggingFaceSeedSource(path="hf://datasets/test/dataset"), sampling_strategy=SamplingStrategy.SHUFFLE, ) - assert config.dataset == "test/dataset" + assert config.source.path == "hf://datasets/test/dataset" assert config.sampling_strategy == SamplingStrategy.SHUFFLE assert len(config.columns) == 2 assert config.columns[0].name == "col1" @@ -129,7 +133,7 @@ def test_seed_dataset_column_generator_config_structure(): # Test PartitionBlock selection strategy config = SeedDatasetMultiColumnConfig( columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")], - dataset="test/dataset", + source=HuggingFaceSeedSource(path="hf://datasets/test/dataset"), sampling_strategy=SamplingStrategy.SHUFFLE, selection_strategy=PartitionBlock(index=1, num_partitions=3), ) @@ -140,7 +144,7 @@ def test_seed_dataset_column_generator_config_structure(): # Test IndexRange selection strategy config = SeedDatasetMultiColumnConfig( columns=[SeedDatasetColumnConfig(name="col1"), SeedDatasetColumnConfig(name="col2")], - dataset="test/dataset", + source=HuggingFaceSeedSource(path="hf://datasets/test/dataset"), sampling_strategy=SamplingStrategy.SHUFFLE, selection_strategy=IndexRange(start=0, end=1), ) @@ -373,7 +377,7 @@ def create_generator_with_real_file( SeedDatasetColumnConfig(name="city"), SeedDatasetColumnConfig(name="score"), ], - dataset=f"test/{os.path.basename(file_path)}", + source=LocalFileSeedSource(path=file_path), sampling_strategy=sampling_strategy, selection_strategy=selection_strategy, ) @@ -382,9 +386,9 @@ def create_generator_with_real_file( real_conn = duckdb.connect() mock_provider = stub_resource_provider - mock_datastore = mock_provider.datastore - mock_datastore.create_duckdb_connection.return_value = real_conn - mock_datastore.get_dataset_uri.return_value = file_path + mock_seed_reader = mock_provider.seed_reader + mock_seed_reader.create_duckdb_connection.return_value = real_conn + mock_seed_reader.get_dataset_uri.return_value = file_path generator = SeedDatasetColumnGenerator(config=config, resource_provider=mock_provider) return generator @@ -435,14 +439,14 @@ def test_seed_dataset_generator_ordered_sampling(fixture_name, stub_resource_pro config = SeedDatasetMultiColumnConfig( columns=[SeedDatasetColumnConfig(name="id"), SeedDatasetColumnConfig(name="name")], - dataset=f"test/{os.path.basename(file_path)}", + source=LocalFileSeedSource(path=file_path), sampling_strategy=SamplingStrategy.ORDERED, ) real_conn = duckdb.connect() mock_provider = stub_resource_provider - mock_provider.datastore.create_duckdb_connection.return_value = real_conn - mock_provider.datastore.get_dataset_uri.return_value = file_path + mock_provider.seed_reader.create_duckdb_connection.return_value = real_conn + mock_provider.seed_reader.get_dataset_uri.return_value = file_path generator = SeedDatasetColumnGenerator(config=config, resource_provider=mock_provider) @@ -471,14 +475,14 @@ def test_seed_dataset_generator_shuffle_sampling(fixture_name, stub_resource_pro config = SeedDatasetMultiColumnConfig( columns=[SeedDatasetColumnConfig(name="id"), SeedDatasetColumnConfig(name="name")], - dataset=f"test/{os.path.basename(file_path)}", + source=LocalFileSeedSource(path=file_path), sampling_strategy=SamplingStrategy.SHUFFLE, ) real_conn = duckdb.connect() mock_provider = stub_resource_provider - mock_provider.datastore.create_duckdb_connection.return_value = real_conn - mock_provider.datastore.get_dataset_uri.return_value = file_path + mock_provider.seed_reader.create_duckdb_connection.return_value = real_conn + mock_provider.seed_reader.get_dataset_uri.return_value = file_path generator = SeedDatasetColumnGenerator(config=config, resource_provider=mock_provider) diff --git a/tests/engine/conftest.py b/tests/engine/conftest.py index dc30ba25..1c3393f9 100644 --- a/tests/engine/conftest.py +++ b/tests/engine/conftest.py @@ -35,7 +35,7 @@ def stub_resource_provider(tmp_path, stub_model_facade): mock_provider.model_registry = mock_model_registry mock_provider.artifact_storage = ArtifactStorage(artifact_path=tmp_path) mock_provider.blob_storage = Mock(spec=ManagedBlobStorage) - mock_provider.datastore = Mock() + mock_provider.seed_reader = Mock() return mock_provider diff --git a/tests/engine/dataset_builders/test_multi_column_configs.py b/tests/engine/dataset_builders/test_multi_column_configs.py index 6dbb7b0d..c7c2873f 100644 --- a/tests/engine/dataset_builders/test_multi_column_configs.py +++ b/tests/engine/dataset_builders/test_multi_column_configs.py @@ -16,6 +16,7 @@ GaussianSamplerParams, SamplerType, ) +from data_designer.config.seed_source import HuggingFaceSeedSource from data_designer.engine.dataset_builders.multi_column_configs import ( MultiColumnConfig, SamplerMultiColumnConfig, @@ -147,11 +148,11 @@ def test_seed_dataset_multi_column_config_creation(): ] config = SeedDatasetMultiColumnConfig( - dataset="test/dataset", + source=HuggingFaceSeedSource(path="hf://datasets/test/dataset"), columns=columns, ) - assert config.dataset == "test/dataset" + assert config.source.path == "hf://datasets/test/dataset" assert len(config.columns) == 2 assert config.column_names == ["col1", "col2"] assert config.column_type == DataDesignerColumnType.SEED_DATASET diff --git a/tests/engine/dataset_builders/utils/test_config_compiler.py b/tests/engine/dataset_builders/utils/test_config_compiler.py index 00809e46..7675a7cb 100644 --- a/tests/engine/dataset_builders/utils/test_config_compiler.py +++ b/tests/engine/dataset_builders/utils/test_config_compiler.py @@ -7,6 +7,7 @@ from data_designer.config.column_types import DataDesignerColumnType from data_designer.config.data_designer_config import DataDesignerConfig from data_designer.config.seed import SamplingStrategy, SeedConfig +from data_designer.config.seed_source import HuggingFaceSeedSource from data_designer.engine.dataset_builders.utils.config_compiler import ( compile_dataset_builder_column_configs, ) @@ -16,7 +17,10 @@ def test_compile_dataset_builder_column_configs_with_seed_columns(): config = DataDesignerConfig( columns=[SeedDatasetColumnConfig(name="seed_col")], - seed_config=SeedConfig(dataset="test/dataset", sampling_strategy=SamplingStrategy.SHUFFLE), + seed_config=SeedConfig( + source=HuggingFaceSeedSource(path="hf://datasets/test/dataset"), + sampling_strategy=SamplingStrategy.SHUFFLE, + ), ) compiled_configs = compile_dataset_builder_column_configs(config) @@ -54,7 +58,10 @@ def test_compile_dataset_builder_column_configs_mixed_column_types(): LLMTextColumnConfig(name="text_col", prompt="Generate text", model_alias="test_model"), SamplerColumnConfig(name="sampler_col", sampler_type="category", params={"values": ["col3", "col4"]}), ], - seed_config=SeedConfig(dataset="test/dataset", sampling_strategy=SamplingStrategy.SHUFFLE), + seed_config=SeedConfig( + source=HuggingFaceSeedSource(path="hf://datasets/test/dataset"), + sampling_strategy=SamplingStrategy.SHUFFLE, + ), ) compiled_configs = compile_dataset_builder_column_configs(config) diff --git a/tests/engine/resources/test_resource_provider.py b/tests/engine/resources/test_resource_provider.py index 5d7b07a8..a7de76d6 100644 --- a/tests/engine/resources/test_resource_provider.py +++ b/tests/engine/resources/test_resource_provider.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import inspect from unittest.mock import Mock, patch import pytest @@ -28,6 +27,7 @@ def test_create_resource_provider_error_cases(test_case, expected_error): mock_model_configs = [Mock(), Mock()] mock_secret_resolver = Mock() mock_model_provider_registry = Mock() + mock_seed_reader_registry = Mock() with patch("data_designer.engine.resources.resource_provider.create_model_registry") as mock_create_model_registry: mock_create_model_registry.side_effect = Exception(expected_error) @@ -38,23 +38,5 @@ def test_create_resource_provider_error_cases(test_case, expected_error): model_configs=mock_model_configs, secret_resolver=mock_secret_resolver, model_provider_registry=mock_model_provider_registry, + seed_reader_registry=mock_seed_reader_registry, ) - - -def test_create_resource_provider_function_exists(): - assert callable(create_resource_provider) - - sig = inspect.signature(create_resource_provider) - params = list(sig.parameters.keys()) - - expected_params = [ - "artifact_storage", - "model_configs", - "secret_resolver", - "model_provider_registry", - "datastore", - "blob_storage", - ] - - for param in expected_params: - assert param in params diff --git a/tests/engine/resources/test_seed_reader.py b/tests/engine/resources/test_seed_reader.py new file mode 100644 index 00000000..7bd1aa00 --- /dev/null +++ b/tests/engine/resources/test_seed_reader.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pandas as pd +import pytest + +from data_designer.config.seed_source import DataFrameSeedSource +from data_designer.engine.resources.seed_reader import ( + DataFrameSeedReader, + LocalFileSeedReader, + SeedReaderError, + SeedReaderRegistry, +) +from data_designer.engine.secret_resolver import PlaintextResolver + + +def test_one_reader_per_seed_type(): + local_1 = LocalFileSeedReader() + local_2 = LocalFileSeedReader() + + with pytest.raises(SeedReaderError): + SeedReaderRegistry([local_1, local_2]) + + registry = SeedReaderRegistry([local_1]) + + with pytest.raises(SeedReaderError): + registry.add_reader(local_2) + + +def test_get_reader_basic(): + local_reader = LocalFileSeedReader() + df_reader = DataFrameSeedReader() + registry = SeedReaderRegistry([local_reader, df_reader]) + + df = pd.DataFrame(data={"a": [1, 2, 3]}) + local_seed_config = DataFrameSeedSource(df=df) + + reader = registry.get_reader(local_seed_config, PlaintextResolver()) + + assert reader == df_reader + + +def test_get_reader_missing(): + local_reader = LocalFileSeedReader() + registry = SeedReaderRegistry([local_reader]) + + df = pd.DataFrame(data={"a": [1, 2, 3]}) + local_seed_config = DataFrameSeedSource(df=df) + + with pytest.raises(SeedReaderError): + registry.get_reader(local_seed_config, PlaintextResolver()) diff --git a/tests/engine/test_compiler.py b/tests/engine/test_compiler.py new file mode 100644 index 00000000..933fc0bd --- /dev/null +++ b/tests/engine/test_compiler.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest + +from data_designer.config.column_configs import SamplerColumnConfig +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.errors import InvalidConfigError +from data_designer.config.sampler_params import CategorySamplerParams, SamplerType +from data_designer.config.seed_source import HuggingFaceSeedSource +from data_designer.engine.compiler import compile_data_designer_config +from data_designer.engine.resources.resource_provider import ResourceProvider +from data_designer.engine.resources.seed_reader import SeedReader +from data_designer.engine.validation import Violation, ViolationLevel, ViolationType + + +@pytest.fixture +def resource_provider(stub_resource_provider: ResourceProvider, stub_seed_reader: SeedReader) -> ResourceProvider: + stub_resource_provider.seed_reader = stub_seed_reader + return stub_resource_provider + + +def test_adds_seed_columns(resource_provider: ResourceProvider): + builder = DataDesignerConfigBuilder() + builder.add_column( + SamplerColumnConfig( + name="language", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["english", "french"]), + ) + ) + builder.with_seed_dataset(HuggingFaceSeedSource(path="hf://datasets/test/data.csv")) + + config = compile_data_designer_config(builder, resource_provider) + + assert len(config.columns) == 3 + + +def test_errors_on_seed_column_collisions(resource_provider: ResourceProvider): + builder = DataDesignerConfigBuilder() + builder.add_column( + SamplerColumnConfig( + name="city", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["new york", "los angeles"]), + ) + ) + builder.with_seed_dataset(HuggingFaceSeedSource(path="hf://datasets/test/data.csv")) + + with pytest.raises(InvalidConfigError) as excinfo: + compile_data_designer_config(builder, resource_provider) + + assert "city" in str(excinfo) + + +def test_validation_errors(resource_provider: ResourceProvider): + builder = DataDesignerConfigBuilder() + builder.add_column( + SamplerColumnConfig( + name="language", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["english", "french"]), + ) + ) + + with patch("data_designer.engine.compiler.validate_data_designer_config") as patched_validate: + patched_validate.return_value = [ + Violation( + type=ViolationType.INVALID_COLUMN, + message="Some error", + level=ViolationLevel.ERROR, + ) + ] + + with pytest.raises(InvalidConfigError) as excinfo: + compile_data_designer_config(builder, resource_provider) + + assert "validation errors" in str(excinfo) diff --git a/tests/config/utils/test_validation.py b/tests/engine/test_validation.py similarity index 94% rename from tests/config/utils/test_validation.py rename to tests/engine/test_validation.py index 2045e650..f5396b17 100644 --- a/tests/config/utils/test_validation.py +++ b/tests/engine/test_validation.py @@ -19,7 +19,8 @@ SchemaTransformProcessorConfig, ) from data_designer.config.utils.code_lang import CodeLang -from data_designer.config.utils.validation import ( +from data_designer.config.validator_params import CodeValidatorParams +from data_designer.engine.validation import ( Violation, ViolationLevel, ViolationType, @@ -31,7 +32,6 @@ validate_prompt_templates, validate_schema_transform_processor, ) -from data_designer.config.validator_params import CodeValidatorParams STUB_MODEL_ALIAS = "stub-alias" @@ -115,12 +115,12 @@ ALLOWED_REFERENCE = [c.name for c in COLUMNS] -@patch("data_designer.config.utils.validation.validate_prompt_templates") -@patch("data_designer.config.utils.validation.validate_code_validation") -@patch("data_designer.config.utils.validation.validate_expression_references") -@patch("data_designer.config.utils.validation.validate_columns_not_all_dropped") -@patch("data_designer.config.utils.validation.validate_drop_columns_processor") -@patch("data_designer.config.utils.validation.validate_schema_transform_processor") +@patch("data_designer.engine.validation.validate_prompt_templates") +@patch("data_designer.engine.validation.validate_code_validation") +@patch("data_designer.engine.validation.validate_expression_references") +@patch("data_designer.engine.validation.validate_columns_not_all_dropped") +@patch("data_designer.engine.validation.validate_drop_columns_processor") +@patch("data_designer.engine.validation.validate_schema_transform_processor") def test_validate_data_designer_config( mock_validate_columns_not_all_dropped, mock_validate_expression_references, @@ -282,7 +282,7 @@ def test_validate_schema_transform_processor(): assert violations[0].level == ViolationLevel.ERROR -@patch("data_designer.config.utils.validation.Console.print") +@patch("data_designer.engine.validation.Console.print") def test_rich_print_violations(mock_console_print): rich_print_violations([]) mock_console_print.assert_not_called() diff --git a/tests/essentials/test_init.py b/tests/essentials/test_init.py index 3153ac13..78392f5e 100644 --- a/tests/essentials/test_init.py +++ b/tests/essentials/test_init.py @@ -21,13 +21,13 @@ DataDesignerColumnType, DataDesignerConfig, DataDesignerConfigBuilder, - DatastoreSeedDatasetReference, - DatastoreSettings, + DataFrameSeedSource, DatetimeSamplerParams, EmbeddingInferenceParams, ExpressionColumnConfig, GaussianSamplerParams, GenerationType, + HuggingFaceSeedSource, ImageContext, ImageFormat, JudgeScoreProfilerConfig, @@ -35,6 +35,7 @@ LLMJudgeColumnConfig, LLMStructuredColumnConfig, LLMTextColumnConfig, + LocalFileSeedSource, LoggingConfig, ManualDistribution, ManualDistributionParams, @@ -83,7 +84,6 @@ def test_config_imports(): """Test config-related imports""" assert DataDesignerConfig is not None assert DataDesignerConfigBuilder is not None - assert DatastoreSettings is not None assert isinstance(can_run_data_designer_locally(), bool) @@ -149,7 +149,9 @@ def test_sampler_params_imports(): def test_seed_config_imports(): """Test seed configuration imports""" - assert DatastoreSeedDatasetReference is not None + assert DataFrameSeedSource is not None + assert HuggingFaceSeedSource is not None + assert LocalFileSeedSource is not None assert SamplingStrategy is not None assert SeedConfig is not None @@ -221,7 +223,6 @@ def test_all_contains_config_classes(): """Test __all__ contains config classes""" assert "DataDesignerConfig" in __all__ assert "DataDesignerConfigBuilder" in __all__ - assert "DatastoreSettings" in __all__ def test_all_contains_column_configs(): @@ -284,7 +285,6 @@ def test_all_contains_model_configs(): def test_all_contains_seed_configs(): """Test __all__ contains seed configuration classes""" - assert "DatastoreSeedDatasetReference" in __all__ assert "SamplingStrategy" in __all__ assert "SeedConfig" in __all__ diff --git a/tests/interface/test_data_designer.py b/tests/interface/test_data_designer.py index c8e94000..c901faf2 100644 --- a/tests/interface/test_data_designer.py +++ b/tests/interface/test_data_designer.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import tempfile from pathlib import Path from unittest.mock import MagicMock, patch @@ -11,10 +10,11 @@ from data_designer.config.column_configs import SamplerColumnConfig from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.dataset_builders import BuildStage -from data_designer.config.errors import InvalidFileFormatError +from data_designer.config.errors import InvalidConfigError +from data_designer.config.models import ModelProvider from data_designer.config.processors import DropColumnsProcessorConfig -from data_designer.config.seed import LocalSeedDatasetReference -from data_designer.engine.model_provider import ModelProvider +from data_designer.config.sampler_params import CategorySamplerParams, SamplerType +from data_designer.config.seed_source import HuggingFaceSeedSource from data_designer.engine.secret_resolver import CompositeResolver, EnvironmentResolver, PlaintextResolver from data_designer.interface.data_designer import DataDesigner from data_designer.interface.errors import ( @@ -84,171 +84,6 @@ def test_init_with_path_object(stub_artifact_path, stub_model_providers): assert designer is not None -def test_make_seed_reference_from_dataframe(stub_dataframe): - """Test creating seed reference from DataFrame.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "seed.parquet" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=file_path) - - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(file_path) - assert file_path.exists() - - # Verify the file contains the correct data - loaded_df = pd.read_parquet(file_path) - pd.testing.assert_frame_equal(loaded_df, stub_dataframe) - - -def test_make_seed_reference_from_dataframe_writes_parquet_format(stub_dataframe): - """Test that seed reference writes DataFrame as parquet.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "seed.parquet" - DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=file_path) - - # Verify we can read it back as parquet - loaded_df = pd.read_parquet(file_path) - assert len(loaded_df) == len(stub_dataframe) - assert list(loaded_df.columns) == list(stub_dataframe.columns) - - -def test_make_seed_reference_from_dataframe_writes_csv_format(stub_dataframe): - """Test that seed reference writes DataFrame as CSV.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "seed.csv" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=file_path) - - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(file_path) - assert file_path.exists() - - # Verify we can read it back as CSV - loaded_df = pd.read_csv(file_path) - assert len(loaded_df) == len(stub_dataframe) - assert list(loaded_df.columns) == list(stub_dataframe.columns) - - -def test_make_seed_reference_from_dataframe_writes_json_format(stub_dataframe): - """Test that seed reference writes DataFrame as JSON.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "seed.json" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=file_path) - - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(file_path) - assert file_path.exists() - - # Verify we can read it back as JSON - loaded_df = pd.read_json(file_path, orient="records", lines=True) - assert len(loaded_df) == len(stub_dataframe) - assert list(loaded_df.columns) == list(stub_dataframe.columns) - - -def test_make_seed_reference_from_dataframe_writes_jsonl_format(stub_dataframe): - """Test that seed reference writes DataFrame as JSONL.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "seed.jsonl" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=file_path) - - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(file_path) - assert file_path.exists() - - # Verify we can read it back as JSONL - loaded_df = pd.read_json(file_path, orient="records", lines=True) - assert len(loaded_df) == len(stub_dataframe) - assert list(loaded_df.columns) == list(stub_dataframe.columns) - - -def test_make_seed_reference_from_dataframe_raises_error_for_invalid_extension(stub_dataframe): - """Test that make_seed_reference_from_dataframe raises error for invalid file extensions.""" - with tempfile.TemporaryDirectory() as temp_dir: - # Test with .txt extension - txt_path = Path(temp_dir) / "seed.txt" - with pytest.raises(InvalidFileFormatError): - DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=txt_path) - - # Test with no extension - no_ext_path = Path(temp_dir) / "seed" - with pytest.raises(InvalidFileFormatError): - DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=no_ext_path) - - -def test_make_seed_reference_from_dataframe_accepts_uppercase_extensions(stub_dataframe): - """Test that make_seed_reference_from_dataframe accepts uppercase file extensions (case insensitive).""" - with tempfile.TemporaryDirectory() as temp_dir: - # Test .PARQUET - parquet_path = Path(temp_dir) / "seed.PARQUET" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=parquet_path) - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(parquet_path) - assert parquet_path.exists() - - # Test .CSV - csv_path = Path(temp_dir) / "seed.CSV" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=csv_path) - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(csv_path) - assert csv_path.exists() - - # Test .JSON - json_path = Path(temp_dir) / "seed.JSON" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=json_path) - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(json_path) - assert json_path.exists() - - # Test .JSONL - jsonl_path = Path(temp_dir) / "seed.JSONL" - ref = DataDesigner.make_seed_reference_from_dataframe(stub_dataframe, file_path=jsonl_path) - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(jsonl_path) - assert jsonl_path.exists() - - -def test_make_seed_reference_from_dataframe_overwrites_existing_file(stub_dataframe): - """Test that make_seed_reference_from_dataframe overwrites existing file.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "seed.parquet" - - # Create initial file with different data - initial_df = pd.DataFrame({"col": [999]}) - initial_df.to_parquet(file_path) - - # Overwrite with new data - new_df = pd.DataFrame({"col": [1, 2, 3]}) - _ = DataDesigner.make_seed_reference_from_dataframe(new_df, file_path=file_path) - - # Verify the file was overwritten - loaded_df = pd.read_parquet(file_path) - pd.testing.assert_frame_equal(loaded_df, new_df) - assert len(loaded_df) == 3 # Should have 3 rows, not 1 - - -def test_make_seed_reference_from_file_with_string_path(): - """Test creating seed reference from file path string.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "dataset.parquet" - df = pd.DataFrame({"col": [1, 2, 3]}) - df.to_parquet(file_path) - - ref = DataDesigner.make_seed_reference_from_file(str(file_path)) - - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(file_path) - - -def test_make_seed_reference_from_file_with_path_object(stub_dataframe): - """Test creating seed reference from Path object.""" - with tempfile.TemporaryDirectory() as temp_dir: - file_path = Path(temp_dir) / "dataset.parquet" - stub_dataframe.to_parquet(file_path) - - ref = DataDesigner.make_seed_reference_from_file(file_path) - - assert isinstance(ref, LocalSeedDatasetReference) - assert ref.dataset == str(file_path) - - def test_buffer_size_setting_persists(stub_artifact_path, stub_model_providers): """Test that buffer size setting persists across multiple calls.""" data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers) @@ -268,28 +103,6 @@ def test_set_buffer_size_raises_error_for_invalid_buffer_size(stub_artifact_path data_designer.set_buffer_size(0) -def test_multiple_seed_references_can_be_created(): - """Test that multiple seed references can be created from different sources.""" - with tempfile.TemporaryDirectory() as temp_dir: - # Create seed reference from DataFrame - df1 = pd.DataFrame({"col": [1, 2, 3]}) - file_path_1 = Path(temp_dir) / "seed1.parquet" - ref1 = DataDesigner.make_seed_reference_from_dataframe(df1, file_path=file_path_1) - - # Create seed reference from another DataFrame - df2 = pd.DataFrame({"col": [4, 5, 6]}) - file_path_2 = Path(temp_dir) / "seed2.parquet" - ref2 = DataDesigner.make_seed_reference_from_dataframe(df2, file_path=file_path_2) - - # Create seed reference from existing file - ref3 = DataDesigner.make_seed_reference_from_file(file_path_1) - - # Verify all references are unique and valid - assert ref1.dataset != ref2.dataset - assert ref1.dataset == ref3.dataset - assert all(isinstance(ref, LocalSeedDatasetReference) for ref in [ref1, ref2, ref3]) - - def test_create_dataset_e2e_using_only_sampler_columns( stub_sampler_only_config_builder, stub_artifact_path, stub_model_providers, stub_managed_assets_path ): @@ -479,3 +292,32 @@ def test_preview_with_dropped_columns( assert "category" in analysis.side_effect_column_names, ( "Dropped column 'category' should be tracked in side_effect_column_names" ) + + +def test_validate_raises_error_when_seed_collides( + stub_artifact_path, + stub_model_providers, + stub_model_configs, + stub_managed_assets_path, + stub_seed_reader, +): + config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs) + config_builder.with_seed_dataset(HuggingFaceSeedSource(path="hf://datasets/test/data.csv")) + config_builder.add_column( + SamplerColumnConfig( + name="city", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["new york", "los angeles"]), + ) + ) + + data_designer = DataDesigner( + artifact_path=stub_artifact_path, + model_providers=stub_model_providers, + secret_resolver=PlaintextResolver(), + managed_assets_path=stub_managed_assets_path, + seed_readers=[stub_seed_reader], + ) + + with pytest.raises(InvalidConfigError): + data_designer.validate(config_builder) From 0ff21e73cac918f12f24c6e505de8d641a91ff66 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 8 Jan 2026 11:39:01 -0600 Subject: [PATCH 2/2] Drop temp notes --- src/data_designer/temp_nmp.py | 52 ----------------------------------- 1 file changed, 52 deletions(-) delete mode 100644 src/data_designer/temp_nmp.py diff --git a/src/data_designer/temp_nmp.py b/src/data_designer/temp_nmp.py deleted file mode 100644 index b5330cc1..00000000 --- a/src/data_designer/temp_nmp.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -This module will actually exist in NMP as a plugin, but providing it here temporarily as a sketch. - -We need to make a few updates to plugins: -1. Dependency management (the "engine" problem); separate PR with an approach already up for this -2. Add support for PluginType.SEED_DATASET -""" - -from typing import Literal - -import duckdb - -from data_designer.config.seed_source import SeedSource -from data_designer.engine.resources.seed_reader import SeedReader - - -class NMPFileSeedConfig(SeedSource): - seed_type: Literal["nmp"] = "nmp" # or "fileset" since that's the fsspec client protocol? - - # Check with MG: what do we expect to be more common in scenarios like this? - # 1. Just "fileset", with optional workspace prefix, e.g. "myworkspace/myfileset", "myfileset" (implicit default workspace), etc. - # 2. Separate "fileset" and "workspace" fields - fileset: str - path: str - - -class NMPFileSeedReader(SeedReader[NMPFileSeedConfig]): - def create_duckdb_connection(self) -> duckdb.DuckDBPyConnection: - # NMP helper function - sdk = get_platform_sdk() # noqa:F821 - # New fsspec client for Files service - fs = FilesetFileSystem(sdk) # noqa:F821 - - conn = duckdb.connect() - conn.register_filesystem(fs) - return conn - - def get_dataset_uri(self) -> str: - workspace, fileset_name = self._get_workspace_and_fileset_name() - return f"fileset://{workspace}/{fileset_name}/{self.config.path}" - - def _get_workspace_and_fileset_name(self) -> tuple[str, str]: - match self.config.fileset.split("/"): - case [fileset_name]: - return ("default", fileset_name) - case [workspace, fileset_name]: - return (workspace, fileset_name) - case _: - raise ValueError("Malformed fileset")