|
3 | 3 |
|
4 | 4 | import math |
5 | 5 | import sys |
| 6 | +from functools import partial |
6 | 7 | from operator import add, floordiv, mul, sub |
7 | 8 | from pprint import pprint |
8 | 9 |
|
@@ -56,12 +57,9 @@ def mul_op1(a, b): |
56 | 57 |
|
57 | 58 | # Pow operation |
58 | 59 | @operation( |
59 | | - name="pow_op1", |
60 | | - needs="sum_ab", |
61 | | - provides=["sum_ab_p1", "sum_ab_p2", "sum_ab_p3"], |
62 | | - params={"exponent": 3}, |
| 60 | + name="pow_op1", needs="sum_ab", provides=["sum_ab_p1", "sum_ab_p2", "sum_ab_p3"] |
63 | 61 | ) |
64 | | - def pow_op1(a, exponent=2): |
| 62 | + def pow_op1(a, exponent=3): |
65 | 63 | return [math.pow(a, y) for y in range(1, exponent + 1)] |
66 | 64 |
|
67 | 65 | assert pow_op1.compute({"sum_ab": 2}, ["sum_ab_p2"]) == {"sum_ab_p2": 4.0} |
@@ -117,6 +115,70 @@ def pow_op1(a, exponent=2): |
117 | 115 | netop({"sum_ab": 1, "b": 2}, outputs=["b", "bad_node"]) |
118 | 116 |
|
119 | 117 |
|
| 118 | +def test_network_plan_execute(): |
| 119 | + def powers_in_trange(a, exponent): |
| 120 | + outputs = [] |
| 121 | + for y in range(1, exponent + 1): |
| 122 | + p = math.pow(a, y) |
| 123 | + outputs.append(p) |
| 124 | + return outputs |
| 125 | + |
| 126 | + sum_op1 = operation(name="sum1", provides=["sum_ab"], needs=["a", "b"])(add) |
| 127 | + mul_op1 = operation(name="mul", provides=["sum_ab_times_b"], needs=["sum_ab", "b"])( |
| 128 | + mul |
| 129 | + ) |
| 130 | + pow_op1 = operation( |
| 131 | + name="pow", |
| 132 | + needs=["sum_ab", "exponent"], |
| 133 | + provides=["sum_ab_p1", "sum_ab_p2", "sum_ab_p3"], |
| 134 | + )(powers_in_trange) |
| 135 | + sum_op2 = operation( |
| 136 | + name="sum2", provides=["p1_plus_p2"], needs=["sum_ab_p1", "sum_ab_p2"] |
| 137 | + )(add) |
| 138 | + |
| 139 | + net = network.Network() |
| 140 | + net.add_op(sum_op1) |
| 141 | + net.add_op(mul_op1) |
| 142 | + net.add_op(pow_op1) |
| 143 | + net.add_op(sum_op2) |
| 144 | + net.compile() |
| 145 | + |
| 146 | + # |
| 147 | + # Running the network |
| 148 | + # |
| 149 | + |
| 150 | + # get all outputs |
| 151 | + exp = { |
| 152 | + "a": 1, |
| 153 | + "b": 2, |
| 154 | + "exponent": 3, |
| 155 | + "p1_plus_p2": 12.0, |
| 156 | + "sum_ab": 3, |
| 157 | + "sum_ab_p1": 3.0, |
| 158 | + "sum_ab_p2": 9.0, |
| 159 | + "sum_ab_p3": 27.0, |
| 160 | + "sum_ab_times_b": 6, |
| 161 | + } |
| 162 | + |
| 163 | + inputs = {"a": 1, "b": 2, "exponent": 3} |
| 164 | + plan = net.compile(outputs=None, inputs=inputs.keys()) |
| 165 | + sol = plan.execute(named_inputs=inputs) |
| 166 | + assert sol == exp |
| 167 | + |
| 168 | + # get specific outputs |
| 169 | + exp = {"sum_ab_times_b": 6} |
| 170 | + plan = net.compile(outputs=["sum_ab_times_b"], inputs=list(inputs)) |
| 171 | + sol = plan.execute(named_inputs=inputs) |
| 172 | + assert sol == exp |
| 173 | + |
| 174 | + # start with inputs already computed |
| 175 | + inputs = {"sum_ab": 1, "b": 2, "exponent": 3} |
| 176 | + exp = {"sum_ab_times_b": 2} |
| 177 | + plan = net.compile(outputs=["sum_ab_times_b"], inputs=inputs) |
| 178 | + sol = plan.execute(named_inputs={"sum_ab": 1, "b": 2}) |
| 179 | + assert sol == exp |
| 180 | + |
| 181 | + |
120 | 182 | def test_network_simple_merge(): |
121 | 183 |
|
122 | 184 | sum_op1 = operation(name="sum_op1", needs=["a", "b"], provides="sum1")(add) |
@@ -178,11 +240,8 @@ def test_network_merge_in_doctests(): |
178 | 240 | operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), |
179 | 241 | operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), |
180 | 242 | operation( |
181 | | - name="abspow1", |
182 | | - needs=["a_minus_ab"], |
183 | | - provides=["abs_a_minus_ab_cubed"], |
184 | | - params={"p": 3}, |
185 | | - )(abspow), |
| 243 | + name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"] |
| 244 | + )(partial(abspow, p=3)), |
186 | 245 | ) |
187 | 246 |
|
188 | 247 | another_graph = compose(name="another_graph")( |
@@ -853,11 +912,8 @@ def test_multithreading_plan_execution(): |
853 | 912 | operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), |
854 | 913 | operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), |
855 | 914 | operation( |
856 | | - name="abspow1", |
857 | | - needs=["a_minus_ab"], |
858 | | - provides=["abs_a_minus_ab_cubed"], |
859 | | - params={"p": 3}, |
860 | | - )(abspow), |
| 915 | + name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"] |
| 916 | + )(partial(abspow, p=3)), |
861 | 917 | ) |
862 | 918 |
|
863 | 919 | pool = Pool(10) |
@@ -969,11 +1025,8 @@ def test_compose_another_network(bools): |
969 | 1025 | operation(name="mul1", needs=["a", "b"], provides=["ab"])(mul), |
970 | 1026 | operation(name="sub1", needs=["a", "ab"], provides=["a_minus_ab"])(sub), |
971 | 1027 | operation( |
972 | | - name="abspow1", |
973 | | - needs=["a_minus_ab"], |
974 | | - provides=["abs_a_minus_ab_cubed"], |
975 | | - params={"p": 3}, |
976 | | - )(abspow), |
| 1028 | + name="abspow1", needs=["a_minus_ab"], provides=["abs_a_minus_ab_cubed"] |
| 1029 | + )(partial(abspow, p=3)), |
977 | 1030 | ) |
978 | 1031 | if parallel1: |
979 | 1032 | graphop.set_execution_method("parallel") |
|
0 commit comments