Skip to content

Commit e8b0872

Browse files
committed
Add tests for workflow frameworks
1 parent 4285d06 commit e8b0872

File tree

4 files changed

+113
-0
lines changed

4 files changed

+113
-0
lines changed

.github/workflows/pipeline.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ jobs:
216216
shell: bash -l {0}
217217
timeout-minutes: 30
218218
run: |
219+
verdi presto --profile-name pwd
219220
pip install . --no-deps --no-build-isolation
220221
coverage run
221222
coverage xml

tests/test_aiida.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import unittest
2+
import os
3+
from aiida_workgraph import WorkGraph, task
4+
from aiida import orm, load_profile
5+
load_profile()
6+
7+
from python_workflow_definition.aiida import load_workflow_json, write_workflow_json
8+
9+
10+
def get_prod_and_div(x, y):
11+
return {"prod": x * y, "div": x / y}
12+
13+
14+
def get_sum(x, y):
15+
return x + y
16+
17+
18+
def get_square(x):
19+
return x ** 2
20+
21+
22+
class TestAiiDA(unittest.TestCase):
23+
def test_aiida(self):
24+
workflow_json_filename = "aiida_simple.json"
25+
wg = WorkGraph("arithmetic")
26+
get_prod_and_div_task = wg.add_task(
27+
task(outputs=['prod', 'div'])(get_prod_and_div),
28+
x=orm.Float(1),
29+
y=orm.Float(2),
30+
)
31+
get_sum_task = wg.add_task(
32+
get_sum,
33+
x=get_prod_and_div_task.outputs.prod,
34+
y=get_prod_and_div_task.outputs.div,
35+
)
36+
get_square_task = wg.add_task(
37+
get_square,
38+
x=get_sum_task.outputs.result,
39+
)
40+
write_workflow_json(wg=wg, file_name=workflow_json_filename)
41+
workgraph = load_workflow_json(file_name='workflow.json')
42+
workgraph.run()
43+
44+
self.assertTrue(os.path.exists(workflow_json_filename))

tests/test_jobflow.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
import os
3+
from jobflow import job, Flow
4+
from jobflow.managers.local import run_locally
5+
from python_workflow_definition.jobflow import load_workflow_json, write_workflow_json
6+
7+
8+
def get_prod_and_div(x, y):
9+
return {"prod": x * y, "div": x / y}
10+
11+
12+
def get_sum(x, y):
13+
return x + y
14+
15+
16+
def get_square(x):
17+
return x ** 2
18+
19+
20+
class TestJobflow(unittest.TestCase):
21+
def test_jobflow(self):
22+
workflow_json_filename = "jobflow_simple.json"
23+
get_sum_job = job(get_sum)
24+
get_prod_and_div_job = job(get_prod_and_div)
25+
get_square_job = job(get_square)
26+
prod_and_div = get_prod_and_div_job(x=1, y=2)
27+
tmp_sum = get_sum_job(x=prod_and_div.output.prod, y=prod_and_div.output.div)
28+
result = get_square_job(x=tmp_sum.output)
29+
flow = Flow([prod_and_div, tmp_sum, result])
30+
write_workflow_json(flow=flow, file_name=workflow_json_filename)
31+
flow = load_workflow_json(file_name=workflow_json_filename)
32+
result = run_locally(flow)
33+
34+
self.assertTrue(os.path.exists(workflow_json_filename))
35+
self.assertEqual(result[result.keys()[-1]][1].output, 6.25)

tests/test_pyiron_base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import unittest
2+
import os
3+
from pyiron_base import job
4+
from python_workflow_definition.pyiron_base import load_workflow_json, write_workflow_json
5+
6+
7+
def get_prod_and_div(x, y):
8+
return {"prod": x * y, "div": x / y}
9+
10+
11+
def get_sum(x, y):
12+
return x + y
13+
14+
15+
def get_square(x):
16+
return x ** 2
17+
18+
19+
class TestPyironBase(unittest.TestCase):
20+
def test_pyiron_base(self):
21+
workflow_json_filename = "pyiron_arithmetic.json"
22+
get_sum_job_wrapper = job(get_sum)
23+
get_prod_and_div_job_wrapper = job(get_prod_and_div, output_key_lst=["prod", "div"])
24+
get_square_job_wrapper = job(get_square)
25+
26+
prod_and_div = get_prod_and_div_job_wrapper(x=1, y=2)
27+
tmp_sum = get_sum_job_wrapper(x=prod_and_div.output.prod, y=prod_and_div.output.div)
28+
result = get_square_job_wrapper(x=tmp_sum)
29+
write_workflow_json(delayed_object=result, file_name=workflow_json_filename)
30+
delayed_object_lst = load_workflow_json(file_name=workflow_json_filename)
31+
32+
self.assertTrue(os.path.exists(workflow_json_filename))
33+
self.assertEqual(delayed_object_lst[-1].pull(), 6.25)

0 commit comments

Comments
 (0)