|
79 | 79 | import com.oracle.graal.python.nodes.expression.UnaryArithmetic;
|
80 | 80 | import com.oracle.graal.python.nodes.frame.DeleteGlobalNode;
|
81 | 81 | import com.oracle.graal.python.nodes.frame.DestructuringAssignmentNode;
|
| 82 | +import com.oracle.graal.python.nodes.frame.FrameSlotIDs; |
82 | 83 | import com.oracle.graal.python.nodes.frame.ReadGlobalOrBuiltinNode;
|
83 | 84 | import com.oracle.graal.python.nodes.frame.WriteGlobalNode;
|
| 85 | +import com.oracle.graal.python.nodes.frame.WriteLocalVariableNode; |
84 | 86 | import com.oracle.graal.python.nodes.frame.WriteNode;
|
85 | 87 | import com.oracle.graal.python.nodes.function.FunctionDefinitionNode;
|
86 | 88 | import com.oracle.graal.python.nodes.function.GeneratorExpressionNode;
|
@@ -135,18 +137,32 @@ <T> T getChild(Node result, int num, Class<? extends T> klass) {
|
135 | 137 | if (++i <= num) {
|
136 | 138 | continue;
|
137 | 139 | }
|
138 |
| - if (n instanceof ExpressionNode.ExpressionStatementNode) { |
139 |
| - n = n.getChildren().iterator().next(); |
140 |
| - } else if (n instanceof ExpressionNode.ExpressionWithSideEffects) { |
141 |
| - n = n.getChildren().iterator().next(); |
142 |
| - } |
| 140 | + n = unpackModuleBodyWrappers(n); |
143 | 141 | assertTrue("Expected an instance of " + klass + ", got " + n.getClass(), klass.isInstance(n));
|
144 | 142 | return klass.cast(n);
|
145 | 143 | }
|
146 | 144 | assertFalse("Expected an instance of " + klass + ", got null", true);
|
147 | 145 | return null;
|
148 | 146 | }
|
149 | 147 |
|
| 148 | + private Node unpackModuleBodyWrappers(Node n) { |
| 149 | + Node actual = n; |
| 150 | + if (n instanceof ExpressionNode.ExpressionStatementNode) { |
| 151 | + actual = n.getChildren().iterator().next(); |
| 152 | + } else if (n instanceof ExpressionNode.ExpressionWithSideEffects) { |
| 153 | + actual = n.getChildren().iterator().next(); |
| 154 | + } else if (n instanceof WriteLocalVariableNode) { |
| 155 | + if (((WriteLocalVariableNode) n).getIdentifier().equals(FrameSlotIDs.RETURN_SLOT_ID)) { |
| 156 | + actual = ((WriteLocalVariableNode) n).getRhs(); |
| 157 | + } |
| 158 | + } |
| 159 | + if (actual == n) { |
| 160 | + return n; |
| 161 | + } else { |
| 162 | + return unpackModuleBodyWrappers(actual); |
| 163 | + } |
| 164 | + } |
| 165 | + |
150 | 166 | <T> T getFirstChild(Node result, Class<? extends T> klass) {
|
151 | 167 | return getChild(result, 0, klass);
|
152 | 168 | }
|
|
0 commit comments