Skip to content

Commit 4134404

Browse files
Fix invalid block handling (sorbet#9861)
1 parent 293d3e8 commit 4134404

File tree

8 files changed

+673
-37
lines changed

8 files changed

+673
-37
lines changed

parser/prism/Translator.cc

Lines changed: 100 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,14 @@ class Translator::DesugaredBlockArgument {
8282
// The literal block, if any, e.g. `{ ... }` or `do ... end`
8383
ast::ExpressionPtr literalBlockExpr;
8484

85-
// The expression passed as a block, if any, e.g. the `block` in `a.map(&block)`
85+
// The expression passed as a block, if any, e.g. the `block` in `a.map(&block)`,
86+
// or a `<fwd-block>` local var read, in the case of `foo(...)`.
8687
ast::ExpressionPtr blockPassExpr;
8788

8889
// The location of the entire block pass, including the `&`.
8990
// a.map(&:foo)
9091
// ^^^^^
92+
// or a zero-length loc for forwarding like `...`
9193
core::LocOffsets blockPassLoc;
9294

9395
private:
@@ -3580,36 +3582,78 @@ Translator::translateNumberedParametersNode(pm_numbered_parameters_node *numbere
35803582
return result;
35813583
}
35823584

3585+
// `parentLoc` is the location of the entire parent node, which we need for the `foo(...)` case
3586+
// (as the location used for the `<fwd-block>` argument).
35833587
Translator::DesugaredBlockArgument Translator::desugarBlock(pm_node_t *block, pm_arguments_node *otherArgs,
35843588
pm_location_t parentLoc) {
3585-
auto result = DesugaredBlockArgument::none();
3589+
auto hasFwdArgs = otherArgs != nullptr && PM_NODE_FLAG_P(otherArgs, PM_ARGUMENTS_NODE_FLAGS_CONTAINS_FORWARDING);
35863590

3587-
if (block) {
3588-
if (PM_NODE_TYPE_P(block, PM_BLOCK_NODE)) { // a literal block with `{ ... }` or `do ... end`
3589-
auto blockNode = down_cast<pm_block_node>(block);
3591+
// Check if there's a block argument in otherArgs (e.g., &block in `foo(&block) { }`)
3592+
pm_block_argument_node *blockArgInArgs = nullptr;
3593+
if (otherArgs != nullptr) {
3594+
auto args = absl::MakeSpan(otherArgs->arguments.nodes, otherArgs->arguments.size);
3595+
for (auto *arg : args) {
3596+
if (PM_NODE_TYPE_P(arg, PM_BLOCK_ARGUMENT_NODE)) {
3597+
blockArgInArgs = down_cast<pm_block_argument_node>(arg);
3598+
break;
3599+
}
3600+
}
3601+
}
35903602

3591-
result = DesugaredBlockArgument::literalBlock(desugarLiteralBlock(
3592-
blockNode->body, blockNode->parameters, blockNode->base.location, blockNode->opening_loc));
3593-
} else {
3594-
ENFORCE(PM_NODE_TYPE_P(block, PM_BLOCK_ARGUMENT_NODE)); // the `&b` in `a.map(&b)`
3603+
if (block == nullptr) {
3604+
// Desugar a call like `foo(...)` so it has a block argument like `foo(..., &<fwd-block>)`.
3605+
if (hasFwdArgs) {
3606+
// The local variable uses the full call location, but the Magic node uses zero-length at END
3607+
auto fullLoc = translateLoc(parentLoc);
3608+
auto magicLoc = fullLoc.copyEndWithZeroLength();
3609+
return DesugaredBlockArgument::blockPass(MK::Local(fullLoc, core::Names::fwdBlock()), magicLoc);
3610+
}
3611+
return DesugaredBlockArgument::none();
3612+
}
35953613

3596-
auto *bp = down_cast<pm_block_argument_node>(block);
3614+
if (PM_NODE_TYPE_P(block, PM_BLOCK_NODE)) { // a literal block with `{ ... }` or `do ... end`
3615+
auto blockNode = down_cast<pm_block_node>(block);
35973616

3598-
result = desugarBlockPassArgument(bp);
3617+
auto literalBlock = desugarLiteralBlock(blockNode->body, blockNode->parameters, blockNode->base.location,
3618+
blockNode->opening_loc);
3619+
3620+
// Handle combination of block pass argument AND a literal block.
3621+
// e.g., `foo(&block) { "literal" }` - both need to be kept.
3622+
if (blockArgInArgs != nullptr) {
3623+
auto blockPassResult = desugarBlockPassArgument(blockArgInArgs);
3624+
if (blockPassResult.hasBlockPass()) {
3625+
return DesugaredBlockArgument::both(move(literalBlock), move(blockPassResult.blockPassExpr),
3626+
blockPassResult.blockPassLoc);
3627+
} else if (blockPassResult.hasLiteralBlock()) {
3628+
// Handle an error case like `a.map(&:foo) { "literal" }`
3629+
// The Symbol proc would have been desuraged to a literal block.'
3630+
// We keep both, moving the Symbol proc to the literal block position.
3631+
auto symbolProc = move(blockPassResult.literalBlockExpr);
3632+
auto symbolProcLoc = symbolProc.loc();
3633+
return DesugaredBlockArgument::both(move(literalBlock), move(symbolProc), symbolProcLoc);
3634+
} else {
3635+
unreachable("Expected either a block pass or a literal block, but got neither");
3636+
}
35993637
}
3600-
}
36013638

3602-
auto hasFwdArgs = otherArgs != nullptr && PM_NODE_FLAG_P(otherArgs, PM_ARGUMENTS_NODE_FLAGS_CONTAINS_FORWARDING);
3639+
// Handle combination of forwarding args AND a literal block.
3640+
// e.g., `foo(...) { "literal" }` should pass both <fwd-block> AND the literal block.
3641+
if (hasFwdArgs) {
3642+
// The local variable uses the full call location, but the Magic node uses zero-length at END
3643+
auto fullLoc = translateLoc(parentLoc);
3644+
auto magicLoc = fullLoc.copyEndWithZeroLength();
3645+
return DesugaredBlockArgument::both(move(literalBlock), MK::Local(fullLoc, core::Names::fwdBlock()),
3646+
magicLoc);
3647+
}
36033648

3604-
if (hasFwdArgs) { // Desugar a call like `foo(...)` so it has a block argument like `foo(..., &b)`.
3605-
ENFORCE(!result.exists(), "The parser should have rejected a call with both a block pass "
3606-
"argument and forwarded args (e.g. `foo(&b, ...)`)");
3649+
return DesugaredBlockArgument::literalBlock(move(literalBlock));
36073650

3608-
result = DesugaredBlockArgument::blockPass(MK::Local(translateLoc(parentLoc), core::Names::fwdBlock()),
3609-
core::LocOffsets::none());
3610-
}
3651+
} else {
3652+
ENFORCE(PM_NODE_TYPE_P(block, PM_BLOCK_ARGUMENT_NODE)); // the `&b` in `a.map(&b)`
3653+
auto *bp = down_cast<pm_block_argument_node>(block);
36113654

3612-
return result;
3655+
return desugarBlockPassArgument(bp);
3656+
}
36133657
}
36143658

36153659
ast::ExpressionPtr Translator::desugarLiteralBlock(pm_node *blockBodyNode, pm_node *blockParameters,
@@ -3761,13 +3805,15 @@ ast::ExpressionPtr Translator::desugarMethodCall(ast::ExpressionPtr receiver, co
37613805
if (block.exists()) {
37623806
// There's a block, so we need to calculate the location of the "send" node, excluding it.
37633807
// Start with message location joined with receiver location
3808+
// (since receiverNode is null, we use the desugared receiver's location directly)
37643809
auto initialLoc = receiver.loc().join(messageLoc);
37653810
std::tie(sendLoc, blockLoc) = computeMethodCallLoc(initialLoc, receiverNode, prismArgs, closingLoc, block);
37663811
} else {
37673812
// There's no block, so the `sendLoc` and `sendWithBlockLoc` are the same, so we can just skip
37683813
// the finicky logic in `computeMethodCallLoc()`.
37693814
sendLoc = sendWithBlockLoc;
37703815
}
3816+
37713817
auto sendLoc0 = sendLoc.copyWithZeroLength();
37723818

37733819
if (methodName == core::Names::squareBrackets() || methodName == core::Names::squareBracketsEq()) {
@@ -3847,6 +3893,8 @@ ast::ExpressionPtr Translator::desugarMethodCall(ast::ExpressionPtr receiver, co
38473893
continue; // Skip anonymous splats (like `f(*)`), which are handled separately in `PM_CALL_NODE`
38483894
} else if (PM_NODE_TYPE_P(arg, PM_FORWARDING_ARGUMENTS_NODE)) {
38493895
continue; // Skip forwarded args (like `f(...)`), which are handled separately in `PM_CALL_NODE`
3896+
} else if (PM_NODE_TYPE_P(arg, PM_BLOCK_ARGUMENT_NODE)) {
3897+
continue; // Skip block args (like `f(&block)`), which are handled by `desugarBlock`
38503898
}
38513899

38523900
argExprs.emplace_back(desugar(arg));
@@ -3931,19 +3979,26 @@ ast::ExpressionPtr Translator::desugarMethodCall(ast::ExpressionPtr receiver, co
39313979
magicSendArgs.emplace_back(move(kwargsExpr));
39323980

39333981
if (block.hasBlockPass()) {
3934-
// Desugar a call with a splat, and any other expression as a block pass argument.
3935-
// E.g. `foo(*splat, &block)`
3982+
// Desugar a call with a splat and a block pass argument.
3983+
// E.g. `foo(*splat, &block)` or `foo(...) { "literal" }`
39363984

39373985
auto blockPassLoc = hasFwdArgs ? sendLoc.copyEndWithZeroLength() : block.blockPassLoc;
39383986

39393987
magicSendArgs.emplace_back(move(block.blockPassExpr));
39403988
numPosArgs++;
39413989

3990+
if (block.hasLiteralBlock()) {
3991+
// Both block pass AND literal block: `foo(...) { "literal" }`
3992+
magicSendArgs.emplace_back(move(block.literalBlockExpr));
3993+
flags.hasBlock = true;
3994+
}
3995+
39423996
return MK::Send(sendWithBlockLoc, MK::Magic(blockPassLoc), core::Names::callWithSplatAndBlockPass(),
39433997
messageLoc, numPosArgs, move(magicSendArgs), flags);
39443998
}
39453999

39464000
if (block.hasLiteralBlock()) {
4001+
// Just a literal block, no block pass
39474002
magicSendArgs.emplace_back(move(block.literalBlockExpr));
39484003
flags.hasBlock = true;
39494004
}
@@ -3955,46 +4010,56 @@ ast::ExpressionPtr Translator::desugarMethodCall(ast::ExpressionPtr receiver, co
39554010
numPosArgs, move(magicSendArgs), flags);
39564011
}
39574012

3958-
// Grab a copy of the argument count, before we concat in the kwargs key/value pairs. // huh?
3959-
int numPosArgs = prismArgs.size();
4013+
// Count args, excluding block arguments which are handled separately
4014+
int numPosArgs =
4015+
absl::c_count_if(prismArgs, [](auto *arg) { return !PM_NODE_TYPE_P(arg, PM_BLOCK_ARGUMENT_NODE); });
39604016

39614017
if (block.hasBlockPass()) {
3962-
// FIXME: move this comment
3963-
// Special handling for non-Symbol block pass args, like `a.map(&block)`
3964-
// Symbol procs like `a.map(:to_s)` are rewritten into literal block arguments,
3965-
// and handled separately below.
4018+
// Desugar a call (without splat) with a block pass argument.
4019+
// E.g. `a.each(&block)` or `foo(&block) { "literal" }`
39664020

3967-
// Desugar a call without a splat, and any other expression as a block pass argument.
3968-
// E.g. `a.each(&block)`
3969-
3970-
auto blockPassLoc = block.blockPassLoc;
4021+
auto blockPassLoc = hasFwdArgs ? sendLoc.copyEndWithZeroLength() : block.blockPassLoc;
39714022

39724023
ast::Send::ARGS_store magicSendArgs;
39734024
magicSendArgs.reserve(3 + prismArgs.size());
39744025
magicSendArgs.emplace_back(move(receiver));
39754026
magicSendArgs.emplace_back(MK::Symbol(sendLoc0, methodName));
39764027
magicSendArgs.emplace_back(move(block.blockPassExpr));
39774028

3978-
numPosArgs += 3;
4029+
// For block pass with literal block: pos_args = 3 (recv, sym, blockPass)
4030+
// For block pass without literal block: pos_args = 3 + numPosArgs (includes other args)
4031+
int magicNumPosArgs = block.hasLiteralBlock() ? 3 : numPosArgs + 3;
4032+
4033+
if (block.hasLiteralBlock()) {
4034+
// This supports the invalid case of having both a block pass AND a literal block
4035+
magicSendArgs.emplace_back(move(block.literalBlockExpr));
4036+
flags.hasBlock = true;
4037+
}
39794038

39804039
for (auto *arg : prismArgs) {
4040+
if (PM_NODE_TYPE_P(arg, PM_BLOCK_ARGUMENT_NODE)) {
4041+
continue; // Skip block args, handled above
4042+
}
39814043
magicSendArgs.emplace_back(desugar(arg));
39824044
}
39834045

39844046
if (kwargsHashNode) {
39854047
flattenKwargs(kwargsHashNode, magicSendArgs);
3986-
ast::desugar::DuplicateHashKeyCheck::checkSendArgs(ctx, numPosArgs, magicSendArgs);
4048+
ast::desugar::DuplicateHashKeyCheck::checkSendArgs(ctx, magicNumPosArgs, magicSendArgs);
39874049
}
39884050

39894051
return MK::Send(sendWithBlockLoc, MK::Magic(blockPassLoc), core::Names::callWithBlockPass(), messageLoc,
3990-
numPosArgs, move(magicSendArgs), flags);
4052+
magicNumPosArgs, move(magicSendArgs), flags);
39914053
}
39924054

39934055
ast::Send::ARGS_store sendArgs{};
39944056
// TODO: reserve size for kwargs Hash keys and values, if needed.
39954057
// TODO: reserve size for the block, if needed.
39964058
sendArgs.reserve(prismArgs.size());
39974059
for (auto *arg : prismArgs) {
4060+
if (PM_NODE_TYPE_P(arg, PM_BLOCK_ARGUMENT_NODE)) {
4061+
continue; // Skip block args, handled by `desugarBlock`
4062+
}
39984063
sendArgs.emplace_back(desugar(arg));
39994064
}
40004065

test/BUILD

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ pipeline_tests(
290290
# Tests having to do with tree differences in invalid Ruby code; will address later
291291
"testdata/parser/error_recovery/assign.rb",
292292
"testdata/parser/error_recovery/begin_1.rb",
293-
"testdata/parser/error_recovery/block_arg_and_block.rb",
294293
"testdata/parser/error_recovery/block_do_1.rb",
295294
"testdata/parser/error_recovery/case_1.rb",
296295
"testdata/parser/error_recovery/case_2.rb",
@@ -327,7 +326,6 @@ pipeline_tests(
327326
"testdata/parser/error_recovery/eof_7.rb",
328327
"testdata/parser/error_recovery/eof_9.rb",
329328
"testdata/parser/error_recovery/forward_args.rb",
330-
"testdata/parser/error_recovery/forward_args_with_block.rb",
331329
"testdata/parser/error_recovery/if_do_1.rb",
332330
"testdata/parser/error_recovery/if_do_2.rb",
333331
"testdata/parser/error_recovery/if_indent_1.rb",
@@ -516,6 +514,9 @@ pipeline_tests(
516514
"testdata/parser/error_recovery/eof_8.rb",
517515
"testdata/parser/error_recovery/other_missing_end.rb",
518516
"testdata/rbi/proc.rb",
517+
518+
# Prism preserves more of the structure than the legacy parser
519+
"testdata/parser/error_recovery/block_forwarding_invalid_def.rb",
519520
],
520521
),
521522
"PrismPosTests",
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# typed: false
2+
3+
# This file tests combinations of block forwarding and literal blocks.
4+
#
5+
# Summary of 12 combinations:
6+
# Valid:
7+
# 1. None - def f; bar; end
8+
# 2. & - def f(&b); bar(&b); end
9+
# 3. & (anonymous) - def f(&); bar(&); end
10+
# 4. ... - def f(...); bar(...); end
11+
# 5. { } - def f; bar { }; end
12+
# 6. ... + { } - def f(...); bar(...) { }; end
13+
# 7. & + { } - def f(&b); bar(&b) { }; end
14+
# 8. & (anonymous) + { } - def f(&); bar(&) { }; end
15+
# Error cases (definition - cannot be tested, see below):
16+
# 9. & + ... - def f(&b, ...); end
17+
# 10. & (anonymous) + ... - def f(&, ...); end
18+
# 11. & + ... + { } - def f(&b, ...); bar(...) { }; end
19+
# 12. & (anonymous) + ... + { } - def f(&, ...); bar(...) { }; end
20+
21+
# Cases 9-12 are excluded from this test.
22+
#
23+
# These cases have `&` and `...` together in the definition, which is a parser
24+
# error. The error recovery is fundamentally different between parsers:
25+
# - Whitequark: Aggressive recovery destroys the entire file from error point
26+
# - Prism: Graceful recovery keeps method definitions with error flags
27+
#
28+
# Since this tests parser error recovery (not desugaring), these cases cannot
29+
# be meaningfully compared between parsers.
30+
# Tests for these cases are in test/testdata/parser/error_recovery/block_forwarding_invalid_def.rb.
31+
32+
# ==============================================================================
33+
# VALID CASES - These should produce identical output
34+
# ==============================================================================
35+
36+
# Case 1: None
37+
def case1_none
38+
bar
39+
end
40+
41+
# Case 2: Just & - Explicit block parameter
42+
def case2_block_param(&block)
43+
bar(&block)
44+
end
45+
46+
# Case 3: Anonymous block parameter
47+
def case3_anonymous_block_param(&)
48+
bar(&)
49+
end
50+
51+
# Case 4: Just ... - Forwarding only
52+
def case4_forwarding(...)
53+
bar(...)
54+
end
55+
56+
# Case 5: Just { } - Block literal only
57+
def case5_block_literal
58+
bar { "block body" }
59+
end
60+
61+
# CASE 6: ... + { } at call site
62+
# Forwarding includes <fwd-block>, AND there's a literal block.
63+
# Both should be kept in the output.
64+
def case6_forwarding_and_literal(...)
65+
foo(...) { "literal block" }
66+
end
67+
68+
# CASE 7: & + { } at call site
69+
# Block pass argument AND a literal block.
70+
# Both should be kept in the output.
71+
def case7_block_pass_and_literal(&block)
72+
foo(&block) { "literal block" }
73+
end
74+
75+
# Case 8: Anonymous block param with literal block
76+
def case8_anonymous_block_pass_and_literal(&)
77+
foo(&) { "literal block" }
78+
end
79+
80+
# ==============================================================================
81+
# ERROR CASES - These cannot be tested, see above
82+
# ==============================================================================
83+
84+
# Case 9: & + ... in definition (named block param)
85+
# def case9_block_param_and_forwarding(&block, ...)
86+
# bar(...)
87+
# end
88+
89+
# Case 10: & + ... in definition (anonymous block param)
90+
# def case10_anonymous_block_param_and_forwarding(&, ...)
91+
# bar(...)
92+
# end
93+
94+
# Case 11: & + ... + { } (all three, named block param)
95+
# def case11_all_three(&block, ...)
96+
# foo(...) { "literal block" }
97+
# end
98+
99+
# Case 12: & + ... + { } (all three, anonymous block param)
100+
# def case12_all_three_anonymous(&, ...)
101+
# foo(...) { "literal block" }
102+
# end

0 commit comments

Comments
 (0)