Skip to content

Commit ab20e4c

Browse files
committed
pr feedback
1 parent ac4539b commit ab20e4c

File tree

3 files changed

+24
-17
lines changed

3 files changed

+24
-17
lines changed

src/data_designer/config/columns.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,26 +163,28 @@ class SeedDatasetColumnConfig(SingleColumnConfig):
163163
}
164164

165165

166-
def column_type_is_in_dag(column_type: Union[str, DataDesignerColumnType]) -> bool:
166+
def column_type_used_in_execution_dag(column_type: Union[str, DataDesignerColumnType]) -> bool:
167+
"""Return True if the column type is used in the workflow execution DAG."""
167168
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
168-
return column_type in [
169+
return column_type in {
169170
DataDesignerColumnType.EXPRESSION,
170171
DataDesignerColumnType.LLM_CODE,
171172
DataDesignerColumnType.LLM_JUDGE,
172173
DataDesignerColumnType.LLM_STRUCTURED,
173174
DataDesignerColumnType.LLM_TEXT,
174175
DataDesignerColumnType.VALIDATION,
175-
]
176+
}
176177

177178

178179
def column_type_is_llm_generated(column_type: Union[str, DataDesignerColumnType]) -> bool:
180+
"""Return True if the column type is an LLM-generated column."""
179181
column_type = resolve_string_enum(column_type, DataDesignerColumnType)
180-
return column_type in [
182+
return column_type in {
181183
DataDesignerColumnType.LLM_TEXT,
182184
DataDesignerColumnType.LLM_CODE,
183185
DataDesignerColumnType.LLM_STRUCTURED,
184186
DataDesignerColumnType.LLM_JUDGE,
185-
]
187+
}
186188

187189

188190
def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType, **kwargs) -> ColumnConfigT:
@@ -217,6 +219,7 @@ def get_column_config_from_kwargs(name: str, column_type: DataDesignerColumnType
217219

218220

219221
def get_column_display_order() -> list[DataDesignerColumnType]:
222+
"""Return the preferred display order of the column types."""
220223
return [
221224
DataDesignerColumnType.SEED_DATASET,
222225
DataDesignerColumnType.SAMPLER,

src/data_designer/engine/dataset_builders/utils/dag.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import networkx as nx
77

8-
from data_designer.config.columns import ColumnConfigT, column_type_is_in_dag
8+
from data_designer.config.columns import ColumnConfigT, column_type_used_in_execution_dag
99
from data_designer.engine.dataset_builders.utils.errors import DAGCircularDependencyError
1010

1111
logger = logging.getLogger(__name__)
@@ -14,8 +14,12 @@
1414
def topologically_sort_column_configs(column_configs: list[ColumnConfigT]) -> list[ColumnConfigT]:
1515
dag = nx.DiGraph()
1616

17-
non_dag_column_config_list = [col for col in column_configs if not column_type_is_in_dag(col.column_type)]
18-
dag_column_config_dict = {col.name: col for col in column_configs if column_type_is_in_dag(col.column_type)}
17+
non_dag_column_config_list = [
18+
col for col in column_configs if not column_type_used_in_execution_dag(col.column_type)
19+
]
20+
dag_column_config_dict = {
21+
col.name: col for col in column_configs if column_type_used_in_execution_dag(col.column_type)
22+
}
1923

2024
if len(dag_column_config_dict) == 0:
2125
return non_dag_column_config_list

tests/config/test_columns.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
Score,
1616
SeedDatasetColumnConfig,
1717
ValidationColumnConfig,
18-
column_type_is_in_dag,
1918
column_type_is_llm_generated,
19+
column_type_used_in_execution_dag,
2020
get_column_config_from_kwargs,
2121
get_column_display_order,
2222
)
@@ -56,14 +56,14 @@ def test_data_designer_column_type_is_llm_generated():
5656

5757

5858
def test_data_designer_column_type_is_in_dag():
59-
assert column_type_is_in_dag(DataDesignerColumnType.EXPRESSION)
60-
assert column_type_is_in_dag(DataDesignerColumnType.LLM_CODE)
61-
assert column_type_is_in_dag(DataDesignerColumnType.LLM_JUDGE)
62-
assert column_type_is_in_dag(DataDesignerColumnType.LLM_STRUCTURED)
63-
assert column_type_is_in_dag(DataDesignerColumnType.LLM_TEXT)
64-
assert column_type_is_in_dag(DataDesignerColumnType.VALIDATION)
65-
assert not column_type_is_in_dag(DataDesignerColumnType.SAMPLER)
66-
assert not column_type_is_in_dag(DataDesignerColumnType.SEED_DATASET)
59+
assert column_type_used_in_execution_dag(DataDesignerColumnType.EXPRESSION)
60+
assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_CODE)
61+
assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_JUDGE)
62+
assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_STRUCTURED)
63+
assert column_type_used_in_execution_dag(DataDesignerColumnType.LLM_TEXT)
64+
assert column_type_used_in_execution_dag(DataDesignerColumnType.VALIDATION)
65+
assert not column_type_used_in_execution_dag(DataDesignerColumnType.SAMPLER)
66+
assert not column_type_used_in_execution_dag(DataDesignerColumnType.SEED_DATASET)
6767

6868

6969
def test_sampler_column_config():

0 commit comments

Comments
 (0)