Skip to content

Commit 9cbc00d

Browse files
authored
Revert change of removing subflow context (#225)
* Revert change of removing subflow context The extra nesting provided a nickname of the subflow in the context of the parent flow. * Use Subflow type for subflows
1 parent f3eaea7 commit 9cbc00d

File tree

4 files changed

+162
-144
lines changed

4 files changed

+162
-144
lines changed

src/routers/openml/flows.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import database.flows
88
from core.conversions import _str_to_num
99
from routers.dependencies import expdb_connection
10-
from schemas.flows import Flow, Parameter
10+
from schemas.flows import Flow, Parameter, Subflow
1111

1212
router = APIRouter(prefix="/flows", tags=["flows"])
1313

@@ -49,8 +49,14 @@ def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection
4949
]
5050

5151
tags = database.flows.get_tags(flow_id, expdb)
52-
flow_rows = database.flows.get_subflows(flow_id, expdb)
53-
subflows = [get_flow(flow_id=flow.child_id, expdb=expdb) for flow in flow_rows]
52+
subflow_rows = database.flows.get_subflows(flow_id, expdb)
53+
subflows = [
54+
Subflow(
55+
identifier=subflow.identifier,
56+
flow=get_flow(flow_id=subflow.child_id, expdb=expdb),
57+
)
58+
for subflow in subflow_rows
59+
]
5460

5561
return Flow(
5662
id_=flow.id,

src/schemas/flows.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from datetime import datetime
4-
from typing import Any, Self
4+
from typing import Any, TypedDict
55

66
from pydantic import BaseModel, ConfigDict, Field
77

@@ -25,7 +25,12 @@ class Flow(BaseModel):
2525
language: str | None = Field(max_length=128)
2626
dependencies: str | None
2727
parameter: list[Parameter]
28-
subflows: list[Self]
28+
subflows: list[Subflow]
2929
tag: list[str]
3030

3131
model_config = ConfigDict(arbitrary_types_allowed=True)
32+
33+
34+
class Subflow(TypedDict):
35+
identifier: str | None
36+
flow: Flow

tests/routers/openml/flows_test.py

Lines changed: 143 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -221,142 +221,149 @@ def test_get_flow_with_subflow(py_api: TestClient) -> None:
221221
],
222222
"subflows": [
223223
{
224-
"id": 4,
225-
"uploader": 16,
226-
"name": "weka.J48",
227-
"class_name": "weka.classifiers.trees.J48",
228-
"version": 1,
229-
"external_version": "Weka_3.9.0_11194",
230-
"description": (
231-
"Ross Quinlan (1993). C4.5: Programs for Machine Learning. "
232-
"Morgan Kaufmann Publishers, San Mateo, CA."
233-
),
234-
"upload_date": "2017-03-24T14:26:40",
235-
"language": "English",
236-
"dependencies": "Weka_3.9.0",
237-
"parameter": [
238-
{
239-
"name": "-do-not-check-capabilities",
240-
"data_type": "flag",
241-
"default_value": None,
242-
"description": (
243-
"If set, classifier capabilities are not checked"
244-
" before classifier is built\n\t(use with caution)."
245-
),
246-
},
247-
{
248-
"name": "-doNotMakeSplitPointActualValue",
249-
"data_type": "flag",
250-
"default_value": None,
251-
"description": "Do not make split point actual value.",
252-
},
253-
{
254-
"name": "A",
255-
"data_type": "flag",
256-
"default_value": None,
257-
"description": "Laplace smoothing for predicted probabilities.",
258-
},
259-
{
260-
"name": "B",
261-
"data_type": "flag",
262-
"default_value": None,
263-
"description": "Use binary splits only.",
264-
},
265-
{
266-
"name": "C",
267-
"data_type": "option",
268-
"default_value": 0.25,
269-
"description": ("Set confidence threshold for pruning.\n\t(default 0.25)"),
270-
},
271-
{
272-
"name": "J",
273-
"data_type": "flag",
274-
"default_value": None,
275-
"description": (
276-
"Do not use MDL correction for info gain on numeric attributes."
277-
),
278-
},
279-
{
280-
"name": "L",
281-
"data_type": "flag",
282-
"default_value": None,
283-
"description": "Do not clean up after the tree has been built.",
284-
},
285-
{
286-
"name": "M",
287-
"data_type": "option",
288-
"default_value": 2,
289-
"description": ("Set minimum number of instances per leaf.\n\t(default 2)"),
290-
},
291-
{
292-
"name": "N",
293-
"data_type": "option",
294-
"default_value": None,
295-
"description": (
296-
"Set number of folds for reduced error\n\t"
297-
"pruning. One fold is used as pruning set.\n\t(default 3)"
298-
),
299-
},
300-
{
301-
"name": "O",
302-
"data_type": "flag",
303-
"default_value": None,
304-
"description": "Do not collapse tree.",
305-
},
306-
{
307-
"name": "Q",
308-
"data_type": "option",
309-
"default_value": None,
310-
"description": "Seed for random data shuffling (default 1).",
311-
},
312-
{
313-
"name": "R",
314-
"data_type": "flag",
315-
"default_value": None,
316-
"description": "Use reduced error pruning.",
317-
},
318-
{
319-
"name": "S",
320-
"data_type": "flag",
321-
"default_value": None,
322-
"description": "Do not perform subtree raising.",
323-
},
324-
{
325-
"name": "U",
326-
"data_type": "flag",
327-
"default_value": None,
328-
"description": "Use unpruned tree.",
329-
},
330-
{
331-
"name": "batch-size",
332-
"data_type": "option",
333-
"default_value": None,
334-
"description": (
335-
"The desired batch size for batch prediction (default 100)."
336-
),
337-
},
338-
{
339-
"name": "num-decimal-places",
340-
"data_type": "option",
341-
"default_value": None,
342-
"description": (
343-
"The number of decimal places for the output of numbers"
344-
" in the model (default 2)."
345-
),
346-
},
347-
{
348-
"name": "output-debug-info",
349-
"data_type": "flag",
350-
"default_value": None,
351-
"description": (
352-
"If set, classifier is run in debug mode and\n\t"
353-
"may output additional info to the console"
354-
),
355-
},
356-
],
357-
"tag": ["OpenmlWeka", "weka"],
358-
"subflows": [],
359-
},
224+
"identifier": None,
225+
"flow": {
226+
"id": 4,
227+
"uploader": 16,
228+
"name": "weka.J48",
229+
"class_name": "weka.classifiers.trees.J48",
230+
"version": 1,
231+
"external_version": "Weka_3.9.0_11194",
232+
"description": (
233+
"Ross Quinlan (1993). C4.5: Programs for Machine Learning. "
234+
"Morgan Kaufmann Publishers, San Mateo, CA."
235+
),
236+
"upload_date": "2017-03-24T14:26:40",
237+
"language": "English",
238+
"dependencies": "Weka_3.9.0",
239+
"parameter": [
240+
{
241+
"name": "-do-not-check-capabilities",
242+
"data_type": "flag",
243+
"default_value": None,
244+
"description": (
245+
"If set, classifier capabilities are not checked"
246+
" before classifier is built\n\t(use with caution)."
247+
),
248+
},
249+
{
250+
"name": "-doNotMakeSplitPointActualValue",
251+
"data_type": "flag",
252+
"default_value": None,
253+
"description": "Do not make split point actual value.",
254+
},
255+
{
256+
"name": "A",
257+
"data_type": "flag",
258+
"default_value": None,
259+
"description": "Laplace smoothing for predicted probabilities.",
260+
},
261+
{
262+
"name": "B",
263+
"data_type": "flag",
264+
"default_value": None,
265+
"description": "Use binary splits only.",
266+
},
267+
{
268+
"name": "C",
269+
"data_type": "option",
270+
"default_value": 0.25,
271+
"description": (
272+
"Set confidence threshold for pruning.\n\t(default 0.25)"
273+
),
274+
},
275+
{
276+
"name": "J",
277+
"data_type": "flag",
278+
"default_value": None,
279+
"description": (
280+
"Do not use MDL correction for info gain on numeric attributes."
281+
),
282+
},
283+
{
284+
"name": "L",
285+
"data_type": "flag",
286+
"default_value": None,
287+
"description": "Do not clean up after the tree has been built.",
288+
},
289+
{
290+
"name": "M",
291+
"data_type": "option",
292+
"default_value": 2,
293+
"description": (
294+
"Set minimum number of instances per leaf.\n\t(default 2)"
295+
),
296+
},
297+
{
298+
"name": "N",
299+
"data_type": "option",
300+
"default_value": None,
301+
"description": (
302+
"Set number of folds for reduced error\n\t"
303+
"pruning. One fold is used as pruning set.\n\t(default 3)"
304+
),
305+
},
306+
{
307+
"name": "O",
308+
"data_type": "flag",
309+
"default_value": None,
310+
"description": "Do not collapse tree.",
311+
},
312+
{
313+
"name": "Q",
314+
"data_type": "option",
315+
"default_value": None,
316+
"description": "Seed for random data shuffling (default 1).",
317+
},
318+
{
319+
"name": "R",
320+
"data_type": "flag",
321+
"default_value": None,
322+
"description": "Use reduced error pruning.",
323+
},
324+
{
325+
"name": "S",
326+
"data_type": "flag",
327+
"default_value": None,
328+
"description": "Do not perform subtree raising.",
329+
},
330+
{
331+
"name": "U",
332+
"data_type": "flag",
333+
"default_value": None,
334+
"description": "Use unpruned tree.",
335+
},
336+
{
337+
"name": "batch-size",
338+
"data_type": "option",
339+
"default_value": None,
340+
"description": (
341+
"The desired batch size for batch prediction (default 100)."
342+
),
343+
},
344+
{
345+
"name": "num-decimal-places",
346+
"data_type": "option",
347+
"default_value": None,
348+
"description": (
349+
"The number of decimal places for the output of numbers"
350+
" in the model (default 2)."
351+
),
352+
},
353+
{
354+
"name": "output-debug-info",
355+
"data_type": "flag",
356+
"default_value": None,
357+
"description": (
358+
"If set, classifier is run in debug mode and\n\t"
359+
"may output additional info to the console"
360+
),
361+
},
362+
],
363+
"tag": ["OpenmlWeka", "weka"],
364+
"subflows": [],
365+
},
366+
}
360367
],
361368
"tag": ["OpenmlWeka", "weka"],
362369
}

tests/routers/openml/migration/flows_migration_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def convert_flow_naming_and_defaults(flow: dict[str, Any]) -> dict[str, Any]:
6565
if parameter["default_value"] is None:
6666
parameter["default_value"] = []
6767
for subflow in flow["subflows"]:
68-
convert_flow_naming_and_defaults(subflow)
68+
subflow["flow"] = convert_flow_naming_and_defaults(subflow["flow"])
69+
if subflow["identifier"] is None:
70+
subflow["identifier"] = []
6971
flow["component"] = flow.pop("subflows")
7072
if flow["component"] == []:
7173
flow.pop("component")
@@ -75,8 +77,6 @@ def convert_flow_naming_and_defaults(flow: dict[str, Any]) -> dict[str, Any]:
7577
new = nested_remove_single_element_list(new)
7678

7779
expected = php_api.get(f"/flow/{flow_id}").json()["flow"]
78-
if subflow := expected.get("component"):
79-
expected["component"] = subflow["flow"]
8080
# The reason we don't transform "new" to str is that it becomes harder to ignore numeric type
8181
# differences (e.g., '1.0' vs '1')
8282
expected = nested_str_to_num(expected)

0 commit comments

Comments
 (0)