Skip to content

Commit 0efcb71

Browse files
committed
Preserve integer typing for input parameter
1 parent 9faa75f commit 0efcb71

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/routers/openml/tasks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
3030
def fill_template(
3131
template: str,
3232
task: RowMapping,
33-
task_inputs: dict[str, str],
33+
task_inputs: dict[str, str | int],
3434
connection: Connection,
3535
) -> dict[str, JSON]:
3636
"""Fill in the XML template as used for task descriptions and return the result,
@@ -96,7 +96,7 @@ def fill_template(
9696
def _fill_json_template(
9797
template: JSON,
9898
task: RowMapping,
99-
task_inputs: dict[str, str],
99+
task_inputs: dict[str, str | int],
100100
fetched_data: dict[str, str],
101101
connection: Connection,
102102
) -> JSON:
@@ -120,7 +120,7 @@ def _fill_json_template(
120120
if match.string == template:
121121
# How do we know the default value? probably ttype_io table?
122122
return task_inputs.get(field, [])
123-
template = template.replace(match.group(), task_inputs[field])
123+
template = template.replace(match.group(), str(task_inputs[field]))
124124
if match := re.search(r"\[LOOKUP:(.*)]", template):
125125
(field,) = match.groups()
126126
if field not in fetched_data:
@@ -163,7 +163,7 @@ def get_task(
163163
)
164164

165165
task_inputs = {
166-
row.input: str(int(row.value)) if row.value.isdigit() else row.value
166+
row.input: int(row.value) if row.value.isdigit() else row.value
167167
for row in database.tasks.get_input_for_task(task_id, expdb)
168168
}
169169
ttios = database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb)

0 commit comments

Comments
 (0)