Skip to content

Commit 2107177

Browse files
committed
wip
1 parent aab9872 commit 2107177

File tree

4 files changed

+390
-50
lines changed

4 files changed

+390
-50
lines changed

constructed-while-output.json

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
{
2+
"version": "1.0",
3+
"nodes": [
4+
{
5+
"id": 1,
6+
"type": "input",
7+
"name": "initial_x",
8+
"value": 1
9+
},
10+
{
11+
"id": 2,
12+
"type": "input",
13+
"name": "initial_y",
14+
"value": 1
15+
},
16+
{
17+
"id": 3,
18+
"type": "function",
19+
"value": "mymodule.add"
20+
},
21+
{
22+
"id": 4,
23+
"type": "while",
24+
"conditionFunction": null,
25+
"conditionExpression": "ctx.n < 8",
26+
"conditionWorkflow": null,
27+
"bodyFunction": null,
28+
"bodyWorkflow": {
29+
"version": "1.0",
30+
"nodes": [
31+
{
32+
"id": 1,
33+
"type": "function",
34+
"value": "mymodule.add"
35+
},
36+
{
37+
"id": 2,
38+
"type": "function",
39+
"value": "mymodule.multiply"
40+
}
41+
],
42+
"edges": [
43+
{
44+
"target": 2,
45+
"targetPort": "x",
46+
"source": 1,
47+
"sourcePort": null
48+
}
49+
]
50+
},
51+
"contextVars": [
52+
"n"
53+
],
54+
"inputPorts": {
55+
"n": null
56+
},
57+
"outputPorts": {
58+
"n": "final_n"
59+
},
60+
"maxIterations": 10,
61+
"stateMapping": null
62+
},
63+
{
64+
"id": 5,
65+
"type": "function",
66+
"value": "mymodule.add"
67+
},
68+
{
69+
"id": 6,
70+
"type": "output",
71+
"name": "result"
72+
}
73+
],
74+
"edges": [
75+
{
76+
"target": 3,
77+
"targetPort": "x",
78+
"source": 1,
79+
"sourcePort": null
80+
},
81+
{
82+
"target": 3,
83+
"targetPort": "y",
84+
"source": 2,
85+
"sourcePort": null
86+
},
87+
{
88+
"target": 4,
89+
"targetPort": "n",
90+
"source": 3,
91+
"sourcePort": null
92+
},
93+
{
94+
"target": 5,
95+
"targetPort": "x",
96+
"source": 4,
97+
"sourcePort": "final_n"
98+
},
99+
{
100+
"target": 6,
101+
"targetPort": null,
102+
"source": 5,
103+
"sourcePort": null
104+
}
105+
]
106+
}

example_workflows/while_loop/test_aiida_example.py

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,101 @@
2222
sys.path.insert(0, str(while_loop_path))
2323

2424

25-
load_profile()
25+
from typing import Callable, Any
2626

27+
def my_while(
28+
input_ports: dict,
29+
condition_f: Callable[[dict, dict], bool],
30+
body_f: Callable[[dict, dict], dict],
31+
finalizer: Callable[[dict, dict], Any]
32+
) -> Any:
33+
ctx = {}
34+
while condition_f(input_ports, ctx):
35+
ctx = body_f(input_ports, ctx)
36+
return finalizer(input_ports, ctx)
2737

28-
@task
29-
def add(x, y):
30-
return x + y
3138

32-
@task
33-
def compare(x, y):
34-
return x < y
39+
# Example: Sum numbers from 0 to limit
40+
def condition_f(input_ports: dict, ctx: dict) -> bool:
41+
limit = input_ports.get('limit', 0)
42+
current = ctx.get('current', 0)
43+
return current < limit
3544

36-
@task
37-
def multiply(x, y):
38-
return x * y
45+
def body_f(input_ports: dict, ctx: dict) -> dict:
46+
current = ctx.get('current', 0)
47+
total = ctx.get('total', 0)
48+
return {
49+
'current': current + 1,
50+
'total': total + current
51+
}
3952

40-
@task.graph
41-
def WhileLoop(n, m):
42-
if m >= n:
43-
return m
44-
m = add(x=m, y=1).result
45-
return WhileLoop(n=n, m=m)
53+
def finalizer(input_ports: dict, ctx: dict) -> dict:
54+
return {
55+
'result': ctx.get('total', 0),
56+
'iterations': ctx.get('current', 0)
57+
}
4658

4759

48-
wg = WhileLoop.build(n=4, m=0)
60+
# Run it
61+
result = my_while(
62+
input_ports={'limit': 10},
63+
condition_f=condition_f,
64+
body_f=body_f,
65+
finalizer=finalizer
66+
)
4967

50-
wg.to_html()
68+
print(result) # {'result': 45, 'iterations': 10}
5169

52-
write_workflow_json(wg=wg, file_name='write_while_loop.json')
70+
raise SystemExit()
71+
72+
# from typing import Callable
73+
#
74+
# def condition_f(x, limit):
75+
# return limit > x
76+
#
77+
# def body_f(x):
78+
# return x + 1
79+
#
80+
# # def abstract_while(x, limit):
81+
# # if not condition(x=x, limit=limit):
82+
# # return x
83+
# # x = function_body(x=x)
84+
# # return abstract_while(x=x, limit=limit)
85+
#
86+
#
87+
# def my_while(input_ports: dict, condition_f=Callable, body_f=Callable, finalizer=Callable):
88+
# ctx = {}
89+
# while condition_f(input_ports, ctx):
90+
# ctx = body_f(input_ports, ctx)
91+
# return finalizer(input_ports, ctx) # these become output ports
92+
#
93+
# my_while()
94+
95+
# load_profile()
96+
#
97+
#
98+
# @task
99+
# def add(x, y):
100+
# return x + y
101+
#
102+
# @task
103+
# def compare(x, y):
104+
# return x < y
105+
#
106+
# @task
107+
# def multiply(x, y):
108+
# return x * y
109+
#
110+
# @task.graph
111+
# def WhileLoop(n, m):
112+
# if m >= n:
113+
# return m
114+
# m = add(x=m, y=1).result
115+
# return WhileLoop(n=n, m=m)
116+
#
117+
#
118+
# wg = WhileLoop.build(n=4, m=0)
119+
#
120+
# wg.to_html()
121+
#
122+
# write_workflow_json(wg=wg, file_name='write_while_loop.json')
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from python_workflow_definition.models import *
2+
3+
workflow = PythonWorkflowDefinitionWorkflow(
4+
version="1.0",
5+
nodes=[
6+
# Input node for initial value
7+
PythonWorkflowDefinitionInputNode(id=1, type="input", name="initial_x", value=1),
8+
PythonWorkflowDefinitionInputNode(id=2, type="input", name="initial_y", value=1),
9+
10+
# Initial add to set up context
11+
PythonWorkflowDefinitionFunctionNode(id=3, type="function", value="mymodule.add"),
12+
13+
# While loop node
14+
PythonWorkflowDefinitionWhileNode(
15+
id=4,
16+
type="while",
17+
conditionExpression="ctx.n < 8",
18+
bodyWorkflow=PythonWorkflowDefinitionWorkflow(
19+
version="1.0",
20+
nodes=[
21+
PythonWorkflowDefinitionFunctionNode(id=1, type="function", value="mymodule.add"),
22+
PythonWorkflowDefinitionFunctionNode(id=2, type="function", value="mymodule.multiply"),
23+
],
24+
edges=[
25+
PythonWorkflowDefinitionEdge(source=1, target=2, targetPort="x"),
26+
]
27+
),
28+
contextVars=["n"],
29+
inputPorts={"n": None}, # Will be connected via edge
30+
outputPorts={"n": "final_n"},
31+
maxIterations=10
32+
),
33+
34+
# Final add after loop
35+
PythonWorkflowDefinitionFunctionNode(id=5, type="function", value="mymodule.add"),
36+
37+
# Output node
38+
PythonWorkflowDefinitionOutputNode(id=6, type="output", name="result"),
39+
],
40+
edges=[
41+
PythonWorkflowDefinitionEdge(source=1, target=3, targetPort="x"),
42+
PythonWorkflowDefinitionEdge(source=2, target=3, targetPort="y"),
43+
PythonWorkflowDefinitionEdge(source=3, target=4, targetPort="n"),
44+
PythonWorkflowDefinitionEdge(source=4, sourcePort="final_n", target=5, targetPort="x"),
45+
PythonWorkflowDefinitionEdge(source=5, target=6),
46+
]
47+
)
48+
49+
breakpoint()
50+
51+
pass

0 commit comments

Comments
 (0)