|
| 1 | +# Copyright 1999-2021 Alibaba Group Holding Ltd. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import itertools |
| 15 | +import pytest |
| 16 | + |
| 17 | + |
| 18 | +from ..core import OptimizationRule, ReplaceSubgraphError |
| 19 | +from .... import tensor as mt |
| 20 | +from .... import dataframe as md |
| 21 | + |
| 22 | + |
| 23 | +class _MockRule(OptimizationRule): |
| 24 | + def apply(self) -> bool: |
| 25 | + pass |
| 26 | + |
| 27 | + def replace_subgraph(self, graph, removed_nodes, new_results=None): |
| 28 | + self._replace_subgraph(graph, removed_nodes, new_results) |
| 29 | + |
| 30 | + |
| 31 | +def test_replace_tileable_subgraph(): |
| 32 | + """ |
| 33 | + Original Graph: |
| 34 | + s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5 |
| 35 | + | ^ |
| 36 | + | | |
| 37 | + V | |
| 38 | + v3 ------| |
| 39 | + ^ |
| 40 | + | |
| 41 | + s2 ---> c2 ---> v2 |
| 42 | +
|
| 43 | + Target Graph: |
| 44 | + s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5 |
| 45 | + ^ |
| 46 | + | |
| 47 | + s2 ---> c2 ---> v2 |
| 48 | +
|
| 49 | + The nodes [v3, v4, v6] will be removed. |
| 50 | + Subgraph only contains [v7, v8] |
| 51 | + """ |
| 52 | + s1 = mt.random.randint(0, 100, size=(5, 4)) |
| 53 | + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) |
| 54 | + s2 = mt.random.randint(0, 100, size=(5, 4)) |
| 55 | + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) |
| 56 | + v3 = v1.add(v2) |
| 57 | + v4 = v3.add(v1) |
| 58 | + s5 = mt.random.randint(0, 100, size=(5, 4)) |
| 59 | + v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4) |
| 60 | + v6 = v5.sub(v4) |
| 61 | + g1 = v6.build_graph() |
| 62 | + v7 = v1.sub(v2) |
| 63 | + v8 = v7.add(v5) |
| 64 | + g2 = v8.build_graph() |
| 65 | + |
| 66 | + # Here we use a trick way to construct the subgraph for test only |
| 67 | + key_to_node = dict() |
| 68 | + for node in g2.iter_nodes(): |
| 69 | + key_to_node[node.key] = node |
| 70 | + for key, node in key_to_node.items(): |
| 71 | + if key != v7.key and key != v8.key: |
| 72 | + g2.remove_node(node) |
| 73 | + r = _MockRule(g1, None, None) |
| 74 | + for node in g1.iter_nodes(): |
| 75 | + key_to_node[node.key] = node |
| 76 | + |
| 77 | + c1 = g1.successors(key_to_node[s1.key])[0] |
| 78 | + c2 = g1.successors(key_to_node[s2.key])[0] |
| 79 | + c5 = g1.successors(key_to_node[s5.key])[0] |
| 80 | + |
| 81 | + expected_results = [v8.outputs[0]] |
| 82 | + r.replace_subgraph( |
| 83 | + g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results |
| 84 | + ) |
| 85 | + assert g1.results == expected_results |
| 86 | + |
| 87 | + expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8} |
| 88 | + assert set(g1) == {key_to_node[n.key] for n in expected_nodes} |
| 89 | + |
| 90 | + expected_edges = { |
| 91 | + s1: [c1], |
| 92 | + c1: [v1], |
| 93 | + v1: [v7], |
| 94 | + s2: [c2], |
| 95 | + c2: [v2], |
| 96 | + v2: [v7], |
| 97 | + s5: [c5], |
| 98 | + c5: [v5], |
| 99 | + v5: [v8], |
| 100 | + v7: [v8], |
| 101 | + v8: [], |
| 102 | + } |
| 103 | + for pred, successors in expected_edges.items(): |
| 104 | + pred_node = key_to_node[pred.key] |
| 105 | + assert g1.count_successors(pred_node) == len(successors) |
| 106 | + for successor in successors: |
| 107 | + assert g1.has_successor(pred_node, key_to_node[successor.key]) |
| 108 | + |
| 109 | + |
| 110 | +def test_replace_null_subgraph(): |
| 111 | + """ |
| 112 | + Original Graph: |
| 113 | + s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2 |
| 114 | +
|
| 115 | + Target Graph: |
| 116 | + c1 ---> v1 ---> v3 <--- v2 <--- c2 |
| 117 | +
|
| 118 | + The nodes [s1, s2] will be removed. |
| 119 | + Subgraph is None |
| 120 | + """ |
| 121 | + s1 = mt.random.randint(0, 100, size=(10, 4)) |
| 122 | + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) |
| 123 | + s2 = mt.random.randint(0, 100, size=(10, 4)) |
| 124 | + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) |
| 125 | + v3 = v1.add(v2) |
| 126 | + g1 = v3.build_graph() |
| 127 | + key_to_node = {node.key: node for node in g1.iter_nodes()} |
| 128 | + c1 = g1.successors(key_to_node[s1.key])[0] |
| 129 | + c2 = g1.successors(key_to_node[s2.key])[0] |
| 130 | + r = _MockRule(g1, None, None) |
| 131 | + expected_results = [v3.outputs[0]] |
| 132 | + # delete c5 s5 will fail |
| 133 | + with pytest.raises(ReplaceSubgraphError) as e: |
| 134 | + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) |
| 135 | + assert g1.results == expected_results |
| 136 | + assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}} |
| 137 | + expected_edges = { |
| 138 | + s1: [c1], |
| 139 | + c1: [v1], |
| 140 | + v1: [v3], |
| 141 | + s2: [c2], |
| 142 | + c2: [v2], |
| 143 | + v2: [v3], |
| 144 | + v3: [], |
| 145 | + } |
| 146 | + for pred, successors in expected_edges.items(): |
| 147 | + pred_node = key_to_node[pred.key] |
| 148 | + assert g1.count_successors(pred_node) == len(successors) |
| 149 | + for successor in successors: |
| 150 | + assert g1.has_successor(pred_node, key_to_node[successor.key]) |
| 151 | + |
| 152 | + c1.inputs.clear() |
| 153 | + c2.inputs.clear() |
| 154 | + r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]}) |
| 155 | + assert g1.results == expected_results |
| 156 | + assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}} |
| 157 | + expected_edges = { |
| 158 | + c1: [v1], |
| 159 | + v1: [v3], |
| 160 | + c2: [v2], |
| 161 | + v2: [v3], |
| 162 | + v3: [], |
| 163 | + } |
| 164 | + for pred, successors in expected_edges.items(): |
| 165 | + pred_node = key_to_node[pred.key] |
| 166 | + assert g1.count_successors(pred_node) == len(successors) |
| 167 | + for successor in successors: |
| 168 | + assert g1.has_successor(pred_node, key_to_node[successor.key]) |
| 169 | + |
| 170 | + |
| 171 | +def test_replace_subgraph_without_removing_nodes(): |
| 172 | + """ |
| 173 | + Original Graph: |
| 174 | + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 |
| 175 | +
|
| 176 | + Target Graph: |
| 177 | + s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2 |
| 178 | + s3 ---> c3 ---> v3 |
| 179 | +
|
| 180 | + Nothing will be removed. |
| 181 | + Subgraph only contains [s3, c3, v3] |
| 182 | + """ |
| 183 | + s1 = mt.random.randint(0, 100, size=(10, 4)) |
| 184 | + v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5) |
| 185 | + s2 = mt.random.randint(0, 100, size=(10, 4)) |
| 186 | + v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5) |
| 187 | + v4 = v1.add(v2) |
| 188 | + g1 = v4.build_graph() |
| 189 | + |
| 190 | + s3 = mt.random.randint(0, 100, size=(10, 4)) |
| 191 | + v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5) |
| 192 | + g2 = v3.build_graph() |
| 193 | + key_to_node = { |
| 194 | + node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes()) |
| 195 | + } |
| 196 | + expected_results = [v3.outputs[0], v4.outputs[0]] |
| 197 | + c1 = g1.successors(key_to_node[s1.key])[0] |
| 198 | + c2 = g1.successors(key_to_node[s2.key])[0] |
| 199 | + c3 = g2.successors(key_to_node[s3.key])[0] |
| 200 | + r = _MockRule(g1, None, None) |
| 201 | + r.replace_subgraph(g2, None, expected_results) |
| 202 | + assert g1.results == expected_results |
| 203 | + assert set(g1) == { |
| 204 | + key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4} |
| 205 | + } |
| 206 | + expected_edges = { |
| 207 | + s1: [c1], |
| 208 | + c1: [v1], |
| 209 | + v1: [v4], |
| 210 | + s2: [c2], |
| 211 | + c2: [v2], |
| 212 | + v2: [v4], |
| 213 | + s3: [c3], |
| 214 | + c3: [v3], |
| 215 | + v3: [], |
| 216 | + v4: [], |
| 217 | + } |
| 218 | + for pred, successors in expected_edges.items(): |
| 219 | + pred_node = key_to_node[pred.key] |
| 220 | + assert g1.count_successors(pred_node) == len(successors) |
| 221 | + for successor in successors: |
| 222 | + assert g1.has_successor(pred_node, key_to_node[successor.key]) |
0 commit comments