Skip to content

Commit d7e93c5

Browse files
authored
fix: limit imports in base generators module (#166)
* limit imports in base module * new line
1 parent 478358c commit d7e93c5

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/data_designer/engine/column_generators/generators/base.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from __future__ import annotations
5+
46
import functools
57
import logging
68
from abc import ABC, abstractmethod
7-
from typing import overload
9+
from enum import Enum
10+
from typing import TYPE_CHECKING, overload
811

912
import pandas as pd
1013

11-
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
12-
from data_designer.config.models import BaseInferenceParams, ModelConfig
13-
from data_designer.config.utils.type_helpers import StrEnum
1414
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
15-
from data_designer.engine.models.facade import ModelFacade
15+
16+
if TYPE_CHECKING:
17+
from data_designer.config.models import BaseInferenceParams, ModelConfig
18+
from data_designer.engine.models.facade import ModelFacade
19+
1620

1721
logger = logging.getLogger(__name__)
1822

1923

20-
class GenerationStrategy(StrEnum):
24+
class GenerationStrategy(str, Enum):
2125
CELL_BY_CELL = "cell_by_cell"
2226
FULL_COLUMN = "full_column"
2327

@@ -82,8 +86,7 @@ def inference_parameters(self) -> BaseInferenceParams:
8286
return self.model_config.inference_parameters
8387

8488
def log_pre_generation(self) -> None:
85-
emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
86-
logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
89+
logger.info(f"Preparing {self.config.column_type} column generation")
8790
logger.info(f" |-- column name: {self.config.name!r}")
8891
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
8992
if self.model_config.provider is None:

tests/engine/column_generators/generators/test_llm_completion_generators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_log_pre_generation(mock_logger):
102102
generator.log_pre_generation()
103103

104104
assert mock_logger.info.call_count == 3
105-
mock_logger.info.assert_any_call("📝 Preparing llm-text column generation")
105+
mock_logger.info.assert_any_call("Preparing llm-text column generation")
106106
mock_logger.info.assert_any_call(" |-- column name: 'test_column'")
107107
mock_logger.info.assert_any_call(' |-- model config:\n{"test": "config"}')
108108

0 commit comments

Comments
 (0)