Skip to content

Commit 4d2ba76

Browse files
committed
[GR-17457] Avoid executing arg nodes multiple times with op element assignment.
PullRequest: truffleruby/3372
2 parents 1c7474b + 3186b2f commit 4d2ba76

File tree

4 files changed

+115
-14
lines changed

4 files changed

+115
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Bug fixes:
4747
* Fix `/#{...}/o` to evaluate only once per context when splitting happens (@eregon).
4848
* Fix `Kernel#sprintf` formatting of floats to be like CRuby (@aardvark179).
4949
* Fix `Process.egid=` to accept `String`s (#2615, @ngtban)
50+
* Fix optional assignment to only evaluate index arguments once (#2658, @aardvark179).
5051

5152
Compatibility:
5253

spec/ruby/language/optional_assignments_spec.rb

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,44 @@ def []=(k, v)
300300
(@b[:k] ||= 12).should == 12
301301
end
302302

303+
it 'correctly handles a splatted argument for the index' do
304+
(@b[*[:k]] ||= 12).should == 12
305+
end
306+
307+
it "evaluates the index precisely once" do
308+
ary = [:x, :y]
309+
@a[:x] = 15
310+
@a[ary.pop] ||= 25
311+
ary.should == [:x]
312+
@a.should == { x: 15, y: 25 }
313+
end
314+
315+
it "evaluates the index arguments in the correct order" do
316+
ary = Class.new(Array) do
317+
def [](x, y)
318+
super(x + 3 * y)
319+
end
320+
321+
def []=(x, y, value)
322+
super(x + 3 * y, value)
323+
end
324+
end.new
325+
ary[0, 0] = 1
326+
ary[1, 0] = 1
327+
ary[2, 0] = nil
328+
ary[3, 0] = 1
329+
ary[4, 0] = 1
330+
ary[5, 0] = 1
331+
ary[6, 0] = nil
332+
333+
foo = [0, 2]
334+
335+
ary[foo.pop, foo.pop] ||= 2
336+
337+
ary[2, 0].should == 2
338+
ary[6, 0].should == nil
339+
end
340+
303341
it 'returns the assigned value, not the result of the []= method with +=' do
304342
@b[:k] = 17
305343
(@b[:k] += 12).should == 29

src/main/java/org/truffleruby/parser/BodyTranslator.java

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,58 +2474,98 @@ public RubyNode visitOpElementAsgnNode(OpElementAsgnParseNode node) {
24742474
tempName,
24752475
0,
24762476
node.getReceiverNode());
2477+
final ArrayList<ValueFromNode> argValues = argsToTemp(node);
24772478

24782479
final String op = node.getOperatorName();
24792480
final boolean logicalOperation = op.equals("&&") || op.equals("||");
24802481

24812482
if (logicalOperation) {
2482-
final ParseNode write = write(node, readReceiverFromTemp, value);
2483-
final ParseNode operation = operation(node, readReceiverFromTemp, op, write);
2483+
final ParseNode write = write(node, readReceiverFromTemp, argValues, value);
2484+
final ParseNode operation = operation(node, readReceiverFromTemp, argValues, op, write);
24842485

2485-
return block(node, writeReceiverToTemp, operation);
2486+
return block(node, writeReceiverToTemp, argValues, operation);
24862487
} else {
2487-
final ParseNode operation = operation(node, readReceiverFromTemp, op, value);
2488-
final ParseNode write = write(node, readReceiverFromTemp, operation);
2488+
final ParseNode operation = operation(node, readReceiverFromTemp, argValues, op, value);
2489+
final ParseNode write = write(node, readReceiverFromTemp, argValues, operation);
24892490

2490-
return block(node, writeReceiverToTemp, write);
2491+
return block(node, writeReceiverToTemp, argValues, write);
24912492
}
24922493
}
24932494

2494-
private RubyNode block(OpElementAsgnParseNode node, ParseNode writeReceiverToTemp, ParseNode main) {
2495+
private RubyNode block(OpElementAsgnParseNode node, ParseNode writeReceiverToTemp,
2496+
ArrayList<ValueFromNode> argValues, ParseNode main) {
24952497
final BlockParseNode block = new BlockParseNode(node.getPosition());
24962498
block.add(writeReceiverToTemp);
24972499
block.add(main);
24982500

2499-
final RubyNode ret = block.accept(this);
2501+
/* prepareAndThen is going to take an argument, and the action that comes after it, and return a node that does
2502+
* both of those things. We start off with ret being the block (our final action) and so the first node we
2503+
* should produce is one that evaluates the last argument, and then the block. The final value of ret should be
2504+
* a node that evaluates the first argument, and then any other arguments, and then the block. So, we must go
2505+
* through the argument list in reverse order. */
2506+
RubyNode ret = block.accept(this);
2507+
var listIterator = argValues.listIterator(argValues.size());
2508+
while (listIterator.hasPrevious()) {
2509+
ret = listIterator.previous().prepareAndThen(node.getPosition(), ret);
2510+
}
25002511
return addNewlineIfNeeded(node, ret);
25012512
}
25022513

2503-
private ParseNode write(OpElementAsgnParseNode node, ParseNode readReceiverFromTemp, ParseNode value) {
2504-
final ParseNode readArguments = node.getArgsNode();
2514+
private ParseNode write(OpElementAsgnParseNode node, ParseNode readReceiverFromTemp,
2515+
ArrayList<ValueFromNode> argValues, ParseNode value) {
25052516
final ParseNode writeArguments;
25062517
// Like ParserSupport#arg_add, but copy the first node
2507-
if (readArguments instanceof ArrayParseNode) {
2518+
if (node.getArgsNode() instanceof ArrayParseNode) {
25082519
final ArrayParseNode readArgsCopy = new ArrayParseNode(node.getPosition());
2509-
readArgsCopy.addAll((ArrayParseNode) readArguments).add(value);
2520+
for (var arg : argValues) {
2521+
readArgsCopy.add(arg.get(node.getPosition()));
2522+
}
2523+
readArgsCopy.add(value);
25102524
writeArguments = readArgsCopy;
25112525
} else {
2512-
writeArguments = new ArgsPushParseNode(node.getPosition(), readArguments, value);
2526+
writeArguments = new ArgsPushParseNode(node.getPosition(), argValues.get(0).get(node.getPosition()), value);
25132527
}
25142528

25152529
return new AttrAssignParseNode(node.getPosition(), readReceiverFromTemp, "[]=", writeArguments, false);
25162530
}
25172531

2532+
private ArrayList<ValueFromNode> argsToTemp(OpElementAsgnParseNode node) {
2533+
ArrayList<ValueFromNode> argValues = new ArrayList<>();
2534+
2535+
final ParseNode readArguments = node.getArgsNode();
2536+
if (readArguments instanceof ArrayParseNode) {
2537+
for (ParseNode child : ((ArrayParseNode) readArguments).children()) {
2538+
argValues.add(ValueFromNode.valueFromNode(this, child));
2539+
}
2540+
} else {
2541+
argValues.add(ValueFromNode.valueFromNode(this, readArguments));
2542+
}
2543+
2544+
return argValues;
2545+
}
2546+
25182547
private ParseNode operation(
25192548
OpElementAsgnParseNode node,
25202549
ParseNode readReceiverFromTemp,
2550+
ArrayList<ValueFromNode> argValues,
25212551
String op,
25222552
ParseNode right) {
2553+
ParseNode readArguments;
2554+
if (node.getArgsNode() instanceof ArrayParseNode) {
2555+
final ArrayParseNode readArgsArray = new ArrayParseNode(node.getPosition());
2556+
for (var arg : argValues) {
2557+
readArgsArray.add(arg.get(node.getPosition()));
2558+
}
2559+
readArguments = readArgsArray;
2560+
} else {
2561+
readArguments = argValues.get(0).get(node.getPosition());
2562+
}
25232563

25242564
final ParseNode read = new CallParseNode(
25252565
node.getPosition(),
25262566
readReceiverFromTemp,
25272567
"[]",
2528-
node.getArgsNode(),
2568+
readArguments,
25292569
null);
25302570
ParseNode operation;
25312571
switch (op) {

src/main/java/org/truffleruby/parser/ValueFromNode.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.truffleruby.parser.ast.LocalVarParseNode;
1616
import org.truffleruby.parser.ast.ParseNode;
1717
import org.truffleruby.parser.ast.SelfParseNode;
18+
import org.truffleruby.parser.ast.SplatParseNode;
1819

1920
import java.util.Arrays;
2021

@@ -80,9 +81,30 @@ public ParseNode get(SourceIndexLength sourceSection) {
8081

8182
}
8283

84+
class ValueFromSplatNode implements ValueFromNode {
85+
86+
private final ValueFromNode value;
87+
88+
public ValueFromSplatNode(BodyTranslator translator, SplatParseNode node) {
89+
value = valueFromNode(translator, node.getValue());
90+
}
91+
92+
@Override
93+
public RubyNode prepareAndThen(SourceIndexLength sourceSection, RubyNode subsequent) {
94+
return value.prepareAndThen(sourceSection, subsequent);
95+
}
96+
97+
@Override
98+
public ParseNode get(SourceIndexLength sourceSection) {
99+
return new SplatParseNode(sourceSection, value.get(sourceSection));
100+
}
101+
}
102+
83103
static ValueFromNode valueFromNode(BodyTranslator translator, ParseNode node) {
84104
if (node instanceof SelfParseNode) {
85105
return new ValueFromSelfNode();
106+
} else if (node instanceof SplatParseNode) {
107+
return new ValueFromSplatNode(translator, (SplatParseNode) node);
86108
} else {
87109
return new ValueFromEffectNode(translator, node);
88110
}

0 commit comments

Comments
 (0)