Skip to content

Commit 0431e74

Browse files
fix(compiler): Handle non-exhaustive record patterns (#2274)
Co-authored-by: Oscar Spencer <oscar.spen@gmail.com>
1 parent 4f4eee0 commit 0431e74

File tree

5 files changed

+239
-48
lines changed

5 files changed

+239
-48
lines changed

compiler/src/typed/parmatch.re

Lines changed: 137 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@ let all_coherent = column => {
177177
false
178178
}
179179
| (TPatTuple(l1), TPatTuple(l2)) => List.length(l1) == List.length(l2)
180+
| (
181+
TPatRecord([(_, l1, _), ..._], _),
182+
TPatRecord([(_, l2, _), ..._], _),
183+
) =>
184+
Array.length(l1.lbl_all) == Array.length(l2.lbl_all)
185+
| (TPatRecord([], _), TPatRecord([], _)) => true
180186
| (TPatArray(_), TPatArray(_))
181187
| (TPatAny, _)
182188
| (_, TPatAny) => true
@@ -305,6 +311,33 @@ let const_compare = (x, y) =>
305311
Stdlib.compare(x, y)
306312
};
307313

314+
let records_args = (l1, l2) => {
315+
/* Invariant: fields are already sorted by Typecore.type_label_a_list */
316+
let rec combine = (r1, r2, l1, l2) =>
317+
switch (l1, l2) {
318+
| ([], []) => (List.rev(r1), List.rev(r2))
319+
| ([], [(_, _, p2), ...rem2]) =>
320+
combine([omega, ...r1], [p2, ...r2], [], rem2)
321+
| ([(_, _, p1), ...rem1], []) =>
322+
combine([p1, ...r1], [omega, ...r2], rem1, [])
323+
| ([(_, lbl1, p1), ...rem1], [(_, lbl2, p2), ...rem2]) =>
324+
if (lbl1.lbl_pos < lbl2.lbl_pos) {
325+
combine([p1, ...r1], [omega, ...r2], rem1, l2);
326+
} else if (lbl1.lbl_pos > lbl2.lbl_pos) {
327+
combine([omega, ...r1], [p2, ...r2], l1, rem2);
328+
} else {
329+
// same label on both sides
330+
combine(
331+
[p1, ...r1],
332+
[p2, ...r2],
333+
rem1,
334+
rem2,
335+
);
336+
}
337+
};
338+
combine([], [], l1, l2);
339+
};
340+
308341
module Compat =
309342
(
310343
Constr: {
@@ -329,6 +362,9 @@ module Compat =
329362
/* More standard stuff */
330363
| (TPatConstant(c1), TPatConstant(c2)) => const_compare(c1, c2) == 0
331364
| (TPatTuple(ps), TPatTuple(qs)) => compats(ps, qs)
365+
| (TPatRecord(l1, _), TPatRecord(l2, _)) =>
366+
let (ps, qs) = records_args(l1, l2);
367+
compats(ps, qs);
332368
| (TPatArray(ps), TPatArray(qs)) =>
333369
List.length(ps) == List.length(qs) && compats(ps, qs)
334370
| (_, _) => false
@@ -390,23 +426,45 @@ let simple_match = (p1, p2) =>
390426
| (TPatConstruct(_, c1, _), TPatConstruct(_, c2, _)) =>
391427
Types.equal_tag(c1.cstr_tag, c2.cstr_tag)
392428
| (TPatConstant(c1), TPatConstant(c2)) => const_compare(c1, c2) == 0
429+
| (TPatRecord(_, _), TPatRecord(_, _)) => true
393430
| (TPatTuple(p1s), TPatTuple(p2s)) => List.length(p1s) == List.length(p2s)
394431
| (TPatArray(p1s), TPatArray(p2s)) => List.length(p1s) == List.length(p2s)
395432
| (_, TPatAny | TPatVar(_)) => true
396433
| (_, _) => false
397434
};
398435

436+
/* extract record fields as a whole */
437+
let record_arg = ph => {
438+
switch (ph.pat_desc) {
439+
| TPatAny => []
440+
| TPatRecord(args, _) => args
441+
| _ => fatal_error("Parmatch.record_arg")
442+
};
443+
};
444+
445+
let extract_fields = (fields, arg) => {
446+
let get_field = (pos, arg) => {
447+
switch (List.find(((_, lbl, _)) => pos == lbl.lbl_pos, arg)) {
448+
| (_, _, p) => p
449+
| exception Not_found => omega
450+
};
451+
};
452+
List.map(((_, lbl, _)) => get_field(lbl.lbl_pos, arg), fields);
453+
};
454+
399455
/* Build argument list when p2 >= p1, where p1 is a simple pattern */
400456
let rec simple_match_args = (p1, p2) =>
401457
switch (p2.pat_desc) {
402458
| TPatAlias(p2, _, _) => simple_match_args(p1, p2)
403459
| TPatConstruct(_, _, args) => args
460+
| TPatRecord(args, _) => extract_fields(record_arg(p1), args)
404461
| TPatTuple(args) => args
405462
| TPatArray(args) => args
406463
| TPatAny
407464
| TPatVar(_) =>
408465
switch (p1.pat_desc) {
409466
| TPatConstruct(_, _, args) => omega_list(args)
467+
| TPatRecord(args, _) => omega_list(args)
410468
| TPatTuple(args) => omega_list(args)
411469
| TPatArray(args) => omega_list(args)
412470
| _ => []
@@ -456,6 +514,11 @@ let rec normalize_pat = q =>
456514
whatever other pattern we might find, as well as the pattern we're threading
457515
along.
458516
517+
- when we find a [Record] then it is a bit more involved: it is also alone
518+
in its signature, however it might only be matching a subset of the
519+
record fields. We use these fields to refine our accumulator and keep going
520+
as another row might match on different fields.
521+
459522
- rows starting with a wildcard do not bring any information, so we ignore
460523
them and keep going
461524
@@ -472,19 +535,39 @@ let discr_pat = (q, pss) => {
472535
| TPatVar(_)
473536
| TPatAlias(_) => assert(false)
474537
| TPatAny => refine_pat(acc, rows)
475-
| TPatTuple(_)
476-
| TPatRecord(_) => normalize_pat(head)
538+
| TPatTuple(_) => normalize_pat(head)
539+
| TPatRecord(lbls, c) =>
540+
/* N.B. we could make this case "simpler" by refining the record case
541+
using [all_record_args].
542+
In which case we wouldn't need to fold over the first column for
543+
records.
544+
However it makes the witness we generate for the exhaustivity warning
545+
less pretty. */
546+
let fields =
547+
List.fold_right(
548+
((_, lbl, _) as field, r) =>
549+
if (List.exists(((_, l, _)) => l.lbl_pos == lbl.lbl_pos, r)) {
550+
r;
551+
} else {
552+
[field, ...r];
553+
},
554+
lbls,
555+
record_arg(acc),
556+
);
557+
let d = {...head, pat_desc: TPatRecord(fields, c)};
558+
refine_pat(d, rows);
477559
| TPatArray(_)
478560
| TPatConstant(_)
479561
| TPatConstruct(_) => acc
480562
};
481563

482564
let q = normalize_pat(q);
483-
/* short-circuiting: clearly if we have anything other than
484-
[Tpat_any] to start with, we're not going to be able refine at all. So
565+
/* short-circuiting: clearly if we have anything other than [Record] or
566+
[Any] to start with, we're not going to be able refine at all. So
485567
there's no point going over the matrix. */
486568
switch (q.pat_desc) {
487-
| TPatAny => refine_pat(q, pss)
569+
| TPatAny
570+
| TPatRecord(_, _) => refine_pat(q, pss)
488571
| _ => q
489572
};
490573
};
@@ -508,6 +591,11 @@ let do_set_args = (erase_mutable, q, r) =>
508591
| {pat_desc: TPatTuple(omegas)} =>
509592
let (args, rest) = read_args(omegas, r);
510593
[make_pat(TPatTuple(args), q.pat_type, q.pat_env), ...rest];
594+
| {pat_desc: TPatRecord(omegas, closed)} =>
595+
let (args, rest) = read_args(omegas, r);
596+
let args =
597+
List.map2(((lid, lbl, _), arg) => (lid, lbl, arg), omegas, args);
598+
[make_pat(TPatRecord(args, closed), q.pat_type, q.pat_env), ...rest];
511599
| {pat_desc: TPatConstruct(lid, c, omegas)} =>
512600
let (args, rest) = read_args(omegas, r);
513601
[make_pat(TPatConstruct(lid, c, args), q.pat_type, q.pat_env), ...rest];
@@ -663,6 +751,7 @@ let build_specialized_submatrices = (~extend_row, q, rows) => {
663751
let (constr_groups, omega_tails) = {
664752
let initial_constr_group =
665753
switch (q.pat_desc) {
754+
| TPatRecord(_)
666755
| TPatTuple(_) =>
667756
/* [q] comes from [discr_pat], and in this case subsumes any of the
668757
patterns we could find on the first column of [rows]. So it is better
@@ -1771,6 +1860,9 @@ let rec le_pat = (p, q) =>
17711860
| (TPatArray(ps), TPatArray(qs)) =>
17721861
List.length(ps) == List.length(qs) && le_pats(ps, qs)
17731862
| (TPatTuple(ps), TPatTuple(qs)) => le_pats(ps, qs)
1863+
| (TPatRecord(l1, _), TPatRecord(l2, _)) =>
1864+
let (ps, qs) = records_args(l1, l2);
1865+
le_pats(ps, qs);
17741866
/* In all other cases, enumeration is performed */
17751867
| (_, _) => !satisfiable([[p]], [q])
17761868
}
@@ -1815,6 +1907,9 @@ let rec lub = (p, q) =>
18151907
when Types.equal_tag(c1.cstr_tag, c2.cstr_tag) =>
18161908
let rs = lubs(ps1, ps2);
18171909
make_pat(TPatConstruct(lid, c1, rs), p.pat_type, p.pat_env);
1910+
| (TPatRecord(l1, closed), TPatRecord(l2, _)) =>
1911+
let rs = record_lubs(l1, l2);
1912+
make_pat(TPatRecord(rs, closed), p.pat_type, p.pat_env);
18181913
| (TPatArray(ps1), TPatArray(ps2))
18191914
when List.length(ps1) == List.length(ps2) =>
18201915
let rs = lubs(ps1, ps2);
@@ -1832,6 +1927,24 @@ and orlub = (p1, p2, q) =>
18321927
| Empty => lub(p2, q)
18331928
}
18341929

1930+
and record_lubs = (l1, l2) => {
1931+
let rec lub_rec = (l1, l2) => {
1932+
switch (l1, l2) {
1933+
| ([], _) => l2
1934+
| (_, []) => l1
1935+
| ([(lid1, lbl1, p1), ...rem1], [(lid2, lbl2, p2), ...rem2]) =>
1936+
if (lbl1.lbl_pos < lbl2.lbl_pos) {
1937+
[(lid1, lbl1, p1), ...lub_rec(rem1, l2)];
1938+
} else if (lbl2.lbl_pos < lbl1.lbl_pos) {
1939+
[(lid2, lbl2, p2), ...lub_rec(l1, rem2)];
1940+
} else {
1941+
[(lid1, lbl1, lub(p1, p2)), ...lub_rec(rem1, rem2)];
1942+
}
1943+
};
1944+
};
1945+
lub_rec(l1, l2);
1946+
}
1947+
18351948
and lubs = (ps, qs) =>
18361949
switch (ps, qs) {
18371950
| ([p, ...ps], [q, ...qs]) => [lub(p, q), ...lubs(ps, qs)]
@@ -1989,10 +2102,25 @@ module Conv = {
19892102
txt: Identifier.IdentName({...cstr_lid, txt: id}),
19902103
};
19912104
Hashtbl.add(constrs, id, cstr);
1992-
mkpat(
1993-
~loc=pat.pat_loc,
1994-
PPatConstruct(lid, PPatConstrTuple(List.map(loop, lst))),
1995-
); // record vs tuple should not matter at this point
2105+
switch (lst) {
2106+
| [{pat_desc: TPatRecord(fields, closed)}]
2107+
when cstr.cstr_inlined != None =>
2108+
mkpat(
2109+
~loc=pat.pat_loc,
2110+
PPatConstruct(
2111+
lid,
2112+
PPatConstrRecord(
2113+
List.map(((id, _, p)) => (id, loop(p)), fields),
2114+
closed,
2115+
),
2116+
),
2117+
)
2118+
| _ =>
2119+
mkpat(
2120+
~loc=pat.pat_loc,
2121+
PPatConstruct(lid, PPatConstrTuple(List.map(loop, lst))),
2122+
)
2123+
};
19962124
};
19972125

19982126
let ps = loop(typed);

compiler/src/typed/printpat.re

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ let rec pretty_val = (ppf, v) =>
7474
let filtered_lvs =
7575
List.filter(
7676
fun
77-
| (_, _, {pat_desc: TPatAny}) => false /* do not show lbl=_ */
77+
| (_, _, {pat_desc: TPatAny}) => false /* do not show lbl: _ */
7878
| _ => true,
7979
lvs,
8080
);
@@ -84,37 +84,40 @@ let rec pretty_val = (ppf, v) =>
8484
let elision_mark = ppf =>
8585
/* we assume that there are no label repetitions here */
8686
if (Array.length(lbl.lbl_all) > 1 + List.length(q)) {
87-
fprintf(ppf, ";@ _@ ");
87+
fprintf(ppf, ",@ _@ ");
8888
} else {
8989
();
9090
};
9191
fprintf(ppf, "@[{%a%t}@]", pretty_lvals, filtered_lvs, elision_mark);
9292
};
9393
| TPatConstant(c) => fprintf(ppf, "%s", pretty_const(c))
94-
| TPatConstruct(_, {cstr_name}, args) =>
95-
if (cstr_name == "[...]") {
96-
fprintf(
97-
ppf,
98-
"@[[%a]@]",
99-
pretty_vals(","),
100-
List.rev(
101-
List.fold_left(
102-
(acc, arg) =>
103-
switch (arg.pat_desc) {
104-
| TPatConstruct(_, {cstr_name: "[...]"}, args) =>
105-
List.concat([args, acc])
106-
| _ => [arg, ...acc]
107-
},
108-
[],
109-
args,
110-
),
94+
| TPatConstruct(_, {cstr_name: "[...]"}, args) =>
95+
fprintf(
96+
ppf,
97+
"@[[%a]@]",
98+
pretty_vals(","),
99+
List.rev(
100+
List.fold_left(
101+
(acc, arg) =>
102+
switch (arg.pat_desc) {
103+
| TPatConstruct(_, {cstr_name: "[...]"}, args) =>
104+
List.concat([args, acc])
105+
| _ => [arg, ...acc]
106+
},
107+
[],
108+
args,
111109
),
112-
);
113-
} else if (List.length(args) > 0) {
114-
fprintf(ppf, "@[%s(%a)@]", cstr_name, pretty_vals(","), args);
115-
} else {
116-
fprintf(ppf, "@[%s@]", cstr_name);
117-
}
110+
),
111+
)
112+
| TPatConstruct(_, {cstr_name}, []) => fprintf(ppf, "@[%s@]", cstr_name)
113+
| TPatConstruct(
114+
_,
115+
{cstr_name, cstr_inlined},
116+
[{pat_desc: TPatRecord(_, _)}] as args,
117+
) =>
118+
fprintf(ppf, "@[%s%a@]", cstr_name, pretty_vals(","), args)
119+
| TPatConstruct(_, {cstr_name}, args) =>
120+
fprintf(ppf, "@[%s(%a)@]", cstr_name, pretty_vals(","), args)
118121
| TPatAlias(v, x, _) =>
119122
fprintf(ppf, "@[(%a@ as %a)@]", pretty_val, v, Ident.print, x)
120123
| TPatOr(v, w) =>
@@ -153,11 +156,11 @@ and pretty_vals = (sep, ppf) =>
153156
and pretty_lvals = ppf =>
154157
fun
155158
| [] => ()
156-
| [(_, lbl, v)] => fprintf(ppf, "%s=%a", lbl.lbl_name, pretty_val, v)
159+
| [(_, lbl, v)] => fprintf(ppf, "%s: %a", lbl.lbl_name, pretty_val, v)
157160
| [(_, lbl, v), ...rest] =>
158161
fprintf(
159162
ppf,
160-
"%s=%a;@ %a",
163+
"%s: %a,@ %a",
161164
lbl.lbl_name,
162165
pretty_val,
163166
v,

compiler/test/input/mixedPatternMatching.gr

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,18 @@ let test = test => {
2626
PAssoc({ foo: [a, b], _ }) => a + b,
2727
PAssoc({ foo: [a, ..._], _ }) => a,
2828
PMulti([], a, { foo: [], _ }) => a,
29+
PMulti(_, a, { foo: [], _ }) => a,
2930
PMulti([_, ..._], a, { foo: [b, ..._], _ }) => a + b,
3031
PMulti([], _, { foo: [b, ..._], _ }) => b,
3132
PMulti([_, ..._], a, { foo: [b, c, ..._], _ }) => a + b + c,
32-
PPoly(
33-
[PList([a]), PAssoc({ foo: [b], _ }), PMulti([], c, { foo: [d], _ })]
34-
) =>
35-
a +
36-
b +
37-
c +
38-
d,
33+
PPoly([PList([a]), PAssoc({ foo: [b], _ }), PMulti([], c, { foo: [d], _ })]) =>
34+
a + b + c + d,
3935
PPoly([]) => 42,
4036
PPoly([PList([]) | PPoly([])]) => 50,
4137
PPoly([PList([a, b]) | PPoly([PList([a, b]) | PPoly([PList([a, b])])])]) =>
42-
a +
43-
b,
38+
a + b,
4439
PPoly([_, ..._]) => 43,
45-
PArray([> ]) => 44,
40+
PArray([>]) => 44,
4641
PArray([> PList([a, b])]) => a + b,
4742
PArray([> PPoly([PList([a, b])]), PList([c, d])]) => a + b + c + d,
4843
PArray(_) => 45,
@@ -74,7 +69,7 @@ let tests = [
7469
PList([1]),
7570
PAssoc({ foo: [2], bar: "bar" }),
7671
PMulti([], 3, { foo: [4], bar: "bar" }),
77-
]
72+
],
7873
),
7974
10,
8075
),

0 commit comments

Comments
 (0)