Skip to content

Commit f22ea72

Browse files
authored
Support nested splat patterns by matching native lowerer algorithm (#91)
Adds support for nested splat expressions like `tuple((xs...)...)` by restructuring the splat expansion to match the native lowerer's recursive algorithm. The native lowerer unwraps only one layer of `...` per pass and relies on recursive expansion to handle nested cases. This approach naturally builds the nested `_apply_iterate` structure through multiple expansion passes, avoiding the need for explicit depth tracking and normalization. Changes: - Refactor `_wrap_unsplatted_args` to unwrap only one layer of `...` - Refactor `expand_splat` to construct unevaluated `_apply_iterate` call then recursively expand it - Add test cases for nested splats including triple-nested and mixed-depth
1 parent 593b9ac commit f22ea72

File tree

3 files changed

+148
-26
lines changed

3 files changed

+148
-26
lines changed

src/desugaring.jl

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -947,13 +947,58 @@ function expand_comprehension_to_loops(ctx, ex)
947947
]
948948
end
949949

950+
# Mimics native lowerer's tuple-wrap function (julia-syntax.scm:2723-2736)
951+
# Unwraps only ONE layer of `...` and wraps sequences of non-splat args in tuples.
952+
# Example: `[a, b, xs..., c]` -> `[tuple(a, b), xs, tuple(c)]`
953+
function _wrap_unsplatted_args(ctx, call_ex, args)
954+
result = SyntaxList(ctx)
955+
non_splat_run = SyntaxList(ctx)
956+
for arg in args
957+
if kind(arg) == K"..."
958+
# Flush any accumulated non-splat args
959+
if !isempty(non_splat_run)
960+
push!(result, @ast ctx call_ex [K"call" "tuple"::K"core" non_splat_run...])
961+
non_splat_run = SyntaxList(ctx)
962+
end
963+
# Unwrap only ONE layer of `...` (corresponds to (cadr x) in native lowerer)
964+
push!(result, arg[1])
965+
else
966+
# Accumulate non-splat args
967+
push!(non_splat_run, arg)
968+
end
969+
end
970+
# Flush any remaining non-splat args
971+
if !isempty(non_splat_run)
972+
push!(result, @ast ctx call_ex [K"call" "tuple"::K"core" non_splat_run...])
973+
end
974+
result
975+
end
976+
950977
function expand_splat(ctx, ex, topfunc, args)
951-
return @ast ctx ex [K"call"
978+
# Matches native lowerer's algorithm
979+
# https://github.com/JuliaLang/julia/blob/f362f47338de099cdeeb1b2d81b3ec1948443274/src/julia-syntax.scm#L2761-2762:
980+
# 1. Unwrap one layer of `...` from each argument (via _wrap_unsplatted_args)
981+
# 2. Create `_apply_iterate(iterate, f, wrapped_args...)` WITHOUT expanding args yet
982+
# 3. Recursively expand the entire call - if any wrapped_arg still contains `...`,
983+
# the recursive expansion will handle it, naturally building nested structure
984+
#
985+
# Example: tuple((xs...)...) recursion:
986+
# Pass 1: unwrap outer `...` -> _apply_iterate(iterate, tuple, (xs...))
987+
# Pass 2: expand sees (xs...) in call context, unwraps again
988+
# -> _apply_iterate(iterate, _apply_iterate, tuple(iterate, tuple), xs)
989+
990+
wrapped_args = _wrap_unsplatted_args(ctx, ex, args)
991+
992+
# Construct the unevaluated _apply_iterate call
993+
result = @ast ctx ex [K"call"
952994
"_apply_iterate"::K"core"
953995
"iterate"::K"top"
954996
topfunc
955-
expand_forms_2(ctx, _wrap_unsplatted_args(ctx, ex, args))...
997+
wrapped_args...
956998
]
999+
1000+
# Recursively expand the entire call (matching native's expand-forms)
1001+
return expand_forms_2(ctx, result)
9571002
end
9581003

9591004
function expand_array(ctx, ex, topfunc)
@@ -1812,29 +1857,6 @@ function expand_ccall(ctx, ex)
18121857
]
18131858
end
18141859

1815-
# Wrap unsplatted arguments in `tuple`:
1816-
# `[a, b, xs..., c]` -> `[(a, b), xs, (c,)]`
1817-
function _wrap_unsplatted_args(ctx, call_ex, args)
1818-
wrapped = SyntaxList(ctx)
1819-
i = 1
1820-
while i <= length(args)
1821-
if kind(args[i]) == K"..."
1822-
splatarg = args[i]
1823-
@chk numchildren(splatarg) == 1
1824-
push!(wrapped, splatarg[1])
1825-
else
1826-
i1 = i
1827-
# Find range of non-splatted args
1828-
while i < length(args) && kind(args[i+1]) != K"..."
1829-
i += 1
1830-
end
1831-
push!(wrapped, @ast ctx call_ex [K"call" "tuple"::K"core" args[i1:i]...])
1832-
end
1833-
i += 1
1834-
end
1835-
wrapped
1836-
end
1837-
18381860
function remove_kw_args!(ctx, args::SyntaxList)
18391861
kws = nothing
18401862
j = 0

test/function_calls_ir.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,58 @@ function A.ccall()
597597
# └─────┘ ── Invalid function name
598598
end
599599

600+
########################################
601+
# Nested splat: simple case
602+
tuple((xs...)...)
603+
#---------------------
604+
1 TestMod.tuple
605+
2 (call core.tuple top.iterate %₁)
606+
3 TestMod.xs
607+
4 (call core._apply_iterate top.iterate core._apply_iterate %%₃)
608+
5 (return %₄)
609+
610+
########################################
611+
# Nested splat: with mixed arguments
612+
tuple(a, (xs...)..., b)
613+
#---------------------
614+
1 TestMod.tuple
615+
2 TestMod.a
616+
3 (call core.tuple %₂)
617+
4 (call core.tuple top.iterate %%₃)
618+
5 TestMod.xs
619+
6 TestMod.b
620+
7 (call core.tuple %₆)
621+
8 (call core.tuple %₇)
622+
9 (call core._apply_iterate top.iterate core._apply_iterate %%%₈)
623+
10 (return %₉)
624+
625+
########################################
626+
# Nested splat: multiple nested splats
627+
tuple((xs...)..., (ys...)...)
628+
#---------------------
629+
1 TestMod.tuple
630+
2 (call core.tuple top.iterate %₁)
631+
3 TestMod.xs
632+
4 TestMod.ys
633+
5 (call core._apply_iterate top.iterate core._apply_iterate %%%₄)
634+
6 (return %₅)
635+
636+
########################################
637+
# Nested splat: triple nesting
638+
tuple(((xs...)...)...)
639+
#---------------------
640+
1 TestMod.tuple
641+
2 (call core.tuple top.iterate %₁)
642+
3 (call core.tuple top.iterate core._apply_iterate %₂)
643+
4 TestMod.xs
644+
5 (call core._apply_iterate top.iterate core._apply_iterate %%₄)
645+
6 (return %₅)
646+
647+
########################################
648+
# Error: Standalone splat expression
649+
(xs...)
650+
#---------------------
651+
LoweringError:
652+
(xs...)
653+
#└───┘ ── `...` expression outside call
654+

test/functions.jl

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,52 @@ end
2020
(2,3,4),
2121
(1,2,3,4,5))
2222

23+
# Nested splatting
24+
@test JuliaLowering.include_string(test_mod, """
25+
let
26+
xs = [[1, 2], [3, 4]]
27+
tuple((xs...)...)
28+
end
29+
""") == (1, 2, 3, 4)
30+
31+
@test JuliaLowering.include_string(test_mod, """
32+
let
33+
xs = [[1, 2]]
34+
ys = [[3, 4]]
35+
tuple((xs...)..., (ys...)...)
36+
end
37+
""") == (1, 2, 3, 4)
38+
39+
# Multiple (>2) nested splat
40+
@test JuliaLowering.include_string(test_mod, """
41+
let
42+
xs = [[[1, 2]]]
43+
tuple(((xs...)...)...)
44+
end
45+
""") == (1, 2)
46+
@test JuliaLowering.include_string(test_mod, """
47+
let
48+
xs = [[[1, 2]]]
49+
ys = [[[3, 4]]]
50+
tuple(((xs...)...)..., ((ys...)...)...)
51+
end
52+
""") == (1, 2, 3, 4)
53+
@test JuliaLowering.include_string(test_mod, """
54+
let
55+
xs = [[[1, 2]]]
56+
ys = [[[3, 4]]]
57+
tuple(((xs...)...)..., ((ys...)...))
58+
end
59+
""") == (1, 2, [3, 4])
60+
61+
# Trailing comma case should still work (different semantics)
62+
@test JuliaLowering.include_string(test_mod, """
63+
let
64+
xs = [[1, 2], [3, 4]]
65+
tuple((xs...,)...)
66+
end
67+
""") == ([1, 2], [3, 4])
68+
2369
# Keyword calls
2470
Base.eval(test_mod, :(
2571
begin
@@ -36,7 +82,6 @@ begin
3682
end
3783
))
3884

39-
4085
@test JuliaLowering.include_string(test_mod, """
4186
let
4287
kws = (c=3,d=4)

0 commit comments

Comments
 (0)