diff --git a/src/data_designer/engine/column_generators/generators/base.py b/src/data_designer/engine/column_generators/generators/base.py index b7ad4c6d..fa477b85 100644 --- a/src/data_designer/engine/column_generators/generators/base.py +++ b/src/data_designer/engine/column_generators/generators/base.py @@ -1,23 +1,27 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import functools import logging from abc import ABC, abstractmethod -from typing import overload +from enum import Enum +from typing import TYPE_CHECKING, overload import pandas as pd -from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP -from data_designer.config.models import BaseInferenceParams, ModelConfig -from data_designer.config.utils.type_helpers import StrEnum from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT -from data_designer.engine.models.facade import ModelFacade + +if TYPE_CHECKING: + from data_designer.config.models import BaseInferenceParams, ModelConfig + from data_designer.engine.models.facade import ModelFacade + logger = logging.getLogger(__name__) -class GenerationStrategy(StrEnum): +class GenerationStrategy(str, Enum): CELL_BY_CELL = "cell_by_cell" FULL_COLUMN = "full_column" @@ -82,8 +86,7 @@ def inference_parameters(self) -> BaseInferenceParams: return self.model_config.inference_parameters def log_pre_generation(self) -> None: - emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type] - logger.info(f"{emoji} Preparing {self.config.column_type} column generation") + logger.info(f"Preparing {self.config.column_type} column generation") logger.info(f" |-- column name: {self.config.name!r}") logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}") if self.model_config.provider is None: diff --git a/tests/engine/column_generators/generators/test_llm_completion_generators.py b/tests/engine/column_generators/generators/test_llm_completion_generators.py index 0b787b7e..6cba3bdc 100644 --- a/tests/engine/column_generators/generators/test_llm_completion_generators.py +++ b/tests/engine/column_generators/generators/test_llm_completion_generators.py @@ -102,7 +102,7 @@ def test_log_pre_generation(mock_logger): generator.log_pre_generation() assert mock_logger.info.call_count == 3 - mock_logger.info.assert_any_call("📝 Preparing llm-text column generation") + mock_logger.info.assert_any_call("Preparing llm-text column generation") mock_logger.info.assert_any_call(" |-- column name: 'test_column'") mock_logger.info.assert_any_call(' |-- model config:\n{"test": "config"}')