@@ -177,6 +177,12 @@ let all_coherent = column => {
177177 false
178178 }
179179 | (TPatTuple (l1 ), TPatTuple (l2 )) => List . length(l1) == List . length(l2)
180+ | (
181+ TPatRecord ([ (_ , l1 , _ ), ... _ ] , _ ),
182+ TPatRecord ([ (_ , l2 , _ ), ... _ ] , _ ),
183+ ) =>
184+ Array . length(l1. lbl_all) == Array . length(l2. lbl_all)
185+ | (TPatRecord ([] , _ ), TPatRecord ([] , _ )) => true
180186 | (TPatArray (_ ), TPatArray (_ ))
181187 | (TPatAny , _ )
182188 | (_ , TPatAny ) => true
@@ -305,6 +311,33 @@ let const_compare = (x, y) =>
305311 Stdlib . compare(x, y)
306312 };
307313
314+ let records_args = (l1, l2) => {
315+ /* Invariant: fields are already sorted by Typecore.type_label_a_list */
316+ let rec combine = (r1, r2, l1, l2) =>
317+ switch (l1, l2) {
318+ | ([] , [] ) => (List . rev(r1), List . rev(r2))
319+ | ([] , [ (_ , _ , p2 ), ... rem2 ] ) =>
320+ combine([ omega, ... r1] , [ p2, ... r2] , [] , rem2)
321+ | ([ (_ , _ , p1 ), ... rem1 ] , [] ) =>
322+ combine([ p1, ... r1] , [ omega, ... r2] , rem1, [] )
323+ | ([ (_ , lbl1 , p1 ), ... rem1 ] , [ (_ , lbl2 , p2 ), ... rem2 ] ) =>
324+ if (lbl1. lbl_pos < lbl2. lbl_pos) {
325+ combine([ p1, ... r1] , [ omega, ... r2] , rem1, l2);
326+ } else if (lbl1. lbl_pos > lbl2. lbl_pos) {
327+ combine([ omega, ... r1] , [ p2, ... r2] , l1, rem2);
328+ } else {
329+ // same label on both sides
330+ combine(
331+ [ p1, ... r1] ,
332+ [ p2, ... r2] ,
333+ rem1,
334+ rem2,
335+ );
336+ }
337+ };
338+ combine([] , [] , l1, l2);
339+ };
340+
308341module Compat =
309342 (
310343 Constr : {
@@ -329,6 +362,9 @@ module Compat =
329362 /* More standard stuff */
330363 | (TPatConstant (c1 ), TPatConstant (c2 )) => const_compare(c1, c2) == 0
331364 | (TPatTuple (ps ), TPatTuple (qs )) => compats(ps, qs)
365+ | (TPatRecord (l1 , _ ), TPatRecord (l2 , _ )) =>
366+ let (ps , qs ) = records_args(l1, l2);
367+ compats(ps, qs);
332368 | (TPatArray (ps ), TPatArray (qs )) =>
333369 List . length(ps) == List . length(qs) && compats(ps, qs)
334370 | (_ , _ ) => false
@@ -390,23 +426,45 @@ let simple_match = (p1, p2) =>
390426 | (TPatConstruct (_ , c1 , _ ), TPatConstruct (_ , c2 , _ )) =>
391427 Types . equal_tag(c1. cstr_tag, c2. cstr_tag)
392428 | (TPatConstant (c1 ), TPatConstant (c2 )) => const_compare(c1, c2) == 0
429+ | (TPatRecord (_ , _ ), TPatRecord (_ , _ )) => true
393430 | (TPatTuple (p1s ), TPatTuple (p2s )) => List . length(p1s) == List . length(p2s)
394431 | (TPatArray (p1s ), TPatArray (p2s )) => List . length(p1s) == List . length(p2s)
395432 | (_ , TPatAny | TPatVar (_ )) => true
396433 | (_ , _ ) => false
397434 };
398435
436+ /* extract record fields as a whole */
437+ let record_arg = ph => {
438+ switch (ph. pat_desc) {
439+ | TPatAny => []
440+ | TPatRecord (args , _ ) => args
441+ | _ => fatal_error("Parmatch.record_arg" )
442+ };
443+ };
444+
445+ let extract_fields = (fields, arg) => {
446+ let get_field = (pos, arg) => {
447+ switch (List . find(((_, lbl, _)) => pos == lbl. lbl_pos, arg)) {
448+ | (_ , _ , p ) => p
449+ | exception Not_found => omega
450+ };
451+ };
452+ List . map(((_, lbl, _)) => get_field(lbl. lbl_pos, arg), fields);
453+ };
454+
399455/* Build argument list when p2 >= p1, where p1 is a simple pattern */
400456let rec simple_match_args = (p1, p2) =>
401457 switch (p2. pat_desc) {
402458 | TPatAlias (p2 , _ , _ ) => simple_match_args(p1, p2)
403459 | TPatConstruct (_ , _ , args ) => args
460+ | TPatRecord (args , _ ) => extract_fields(record_arg(p1), args)
404461 | TPatTuple (args ) => args
405462 | TPatArray (args ) => args
406463 | TPatAny
407464 | TPatVar (_ ) =>
408465 switch (p1. pat_desc) {
409466 | TPatConstruct (_ , _ , args ) => omega_list(args)
467+ | TPatRecord (args , _ ) => omega_list(args)
410468 | TPatTuple (args ) => omega_list(args)
411469 | TPatArray (args ) => omega_list(args)
412470 | _ => []
@@ -456,6 +514,11 @@ let rec normalize_pat = q =>
456514 whatever other pattern we might find, as well as the pattern we're threading
457515 along.
458516
517+ - when we find a [Record] then it is a bit more involved: it is also alone
518+ in its signature, however it might only be matching a subset of the
519+ record fields. We use these fields to refine our accumulator and keep going
520+ as another row might match on different fields.
521+
459522 - rows starting with a wildcard do not bring any information, so we ignore
460523 them and keep going
461524
@@ -472,19 +535,39 @@ let discr_pat = (q, pss) => {
472535 | TPatVar (_ )
473536 | TPatAlias (_ ) => assert (false )
474537 | TPatAny => refine_pat(acc, rows)
475- | TPatTuple (_ )
476- | TPatRecord (_ ) => normalize_pat(head)
538+ | TPatTuple (_ ) => normalize_pat(head)
539+ | TPatRecord (lbls , c ) =>
540+ /* N.B. we could make this case "simpler" by refining the record case
541+ using [all_record_args].
542+ In which case we wouldn't need to fold over the first column for
543+ records.
544+ However it makes the witness we generate for the exhaustivity warning
545+ less pretty. */
546+ let fields =
547+ List . fold_right(
548+ ((_, lbl, _) as field, r) =>
549+ if (List . exists(((_, l, _)) => l. lbl_pos == lbl. lbl_pos, r)) {
550+ r;
551+ } else {
552+ [ field, ... r] ;
553+ },
554+ lbls,
555+ record_arg(acc),
556+ );
557+ let d = {... head, pat_desc: TPatRecord (fields, c)};
558+ refine_pat(d, rows);
477559 | TPatArray (_ )
478560 | TPatConstant (_ )
479561 | TPatConstruct (_ ) => acc
480562 };
481563
482564 let q = normalize_pat(q);
483- /* short-circuiting: clearly if we have anything other than
484- [Tpat_any ] to start with, we're not going to be able refine at all. So
565+ /* short-circuiting: clearly if we have anything other than [Record] or
566+ [Any ] to start with, we're not going to be able refine at all. So
485567 there's no point going over the matrix. */
486568 switch (q. pat_desc) {
487- | TPatAny => refine_pat(q, pss)
569+ | TPatAny
570+ | TPatRecord (_ , _ ) => refine_pat(q, pss)
488571 | _ => q
489572 };
490573};
@@ -508,6 +591,11 @@ let do_set_args = (erase_mutable, q, r) =>
508591 | {pat_desc: TPatTuple (omegas )} =>
509592 let (args , rest ) = read_args(omegas, r);
510593 [ make_pat(TPatTuple (args), q. pat_type, q. pat_env), ... rest] ;
594+ | {pat_desc: TPatRecord (omegas , closed )} =>
595+ let (args , rest ) = read_args(omegas, r);
596+ let args =
597+ List . map2(((lid, lbl, _), arg) => (lid, lbl, arg), omegas, args);
598+ [ make_pat(TPatRecord (args, closed), q. pat_type, q. pat_env), ... rest] ;
511599 | {pat_desc: TPatConstruct (lid , c , omegas )} =>
512600 let (args , rest ) = read_args(omegas, r);
513601 [ make_pat(TPatConstruct (lid, c, args), q. pat_type, q. pat_env), ... rest] ;
@@ -663,6 +751,7 @@ let build_specialized_submatrices = (~extend_row, q, rows) => {
663751 let (constr_groups , omega_tails ) = {
664752 let initial_constr_group =
665753 switch (q. pat_desc) {
754+ | TPatRecord (_ )
666755 | TPatTuple (_ ) =>
667756 /* [q] comes from [discr_pat], and in this case subsumes any of the
668757 patterns we could find on the first column of [rows]. So it is better
@@ -1771,6 +1860,9 @@ let rec le_pat = (p, q) =>
17711860 | (TPatArray (ps ), TPatArray (qs )) =>
17721861 List . length(ps) == List . length(qs) && le_pats(ps, qs)
17731862 | (TPatTuple (ps ), TPatTuple (qs )) => le_pats(ps, qs)
1863+ | (TPatRecord (l1 , _ ), TPatRecord (l2 , _ )) =>
1864+ let (ps , qs ) = records_args(l1, l2);
1865+ le_pats(ps, qs);
17741866 /* In all other cases, enumeration is performed */
17751867 | (_ , _ ) => ! satisfiable([[ p]] , [ q] )
17761868 }
@@ -1815,6 +1907,9 @@ let rec lub = (p, q) =>
18151907 when Types . equal_tag(c1. cstr_tag, c2. cstr_tag) =>
18161908 let rs = lubs(ps1, ps2);
18171909 make_pat(TPatConstruct (lid, c1, rs), p. pat_type, p. pat_env);
1910+ | (TPatRecord (l1 , closed ), TPatRecord (l2 , _ )) =>
1911+ let rs = record_lubs(l1, l2);
1912+ make_pat(TPatRecord (rs, closed), p. pat_type, p. pat_env);
18181913 | (TPatArray (ps1 ), TPatArray (ps2 ))
18191914 when List . length(ps1) == List . length(ps2) =>
18201915 let rs = lubs(ps1, ps2);
@@ -1832,6 +1927,24 @@ and orlub = (p1, p2, q) =>
18321927 | Empty => lub(p2, q)
18331928 }
18341929
1930+ and record_lubs = (l1, l2) => {
1931+ let rec lub_rec = (l1, l2) => {
1932+ switch (l1, l2) {
1933+ | ([] , _ ) => l2
1934+ | (_ , [] ) => l1
1935+ | ([ (lid1 , lbl1 , p1 ), ... rem1 ] , [ (lid2 , lbl2 , p2 ), ... rem2 ] ) =>
1936+ if (lbl1. lbl_pos < lbl2. lbl_pos) {
1937+ [ (lid1, lbl1, p1), ... lub_rec(rem1, l2)] ;
1938+ } else if (lbl2. lbl_pos < lbl1. lbl_pos) {
1939+ [ (lid2, lbl2, p2), ... lub_rec(l1, rem2)] ;
1940+ } else {
1941+ [ (lid1, lbl1, lub(p1, p2)), ... lub_rec(rem1, rem2)] ;
1942+ }
1943+ };
1944+ };
1945+ lub_rec(l1, l2);
1946+ }
1947+
18351948and lubs = (ps, qs) =>
18361949 switch (ps, qs) {
18371950 | ([ p , ... ps ] , [ q , ... qs ] ) => [ lub(p, q), ... lubs(ps, qs)]
@@ -1989,10 +2102,25 @@ module Conv = {
19892102 txt: Identifier . IdentName ({... cstr_lid, txt: id}),
19902103 };
19912104 Hashtbl . add(constrs, id, cstr);
1992- mkpat(
1993- ~loc= pat. pat_loc,
1994- PPatConstruct (lid, PPatConstrTuple (List . map(loop, lst))),
1995- ); // record vs tuple should not matter at this point
2105+ switch (lst) {
2106+ | [ {pat_desc: TPatRecord (fields , closed )}]
2107+ when cstr. cstr_inlined != None =>
2108+ mkpat(
2109+ ~loc= pat. pat_loc,
2110+ PPatConstruct (
2111+ lid,
2112+ PPatConstrRecord (
2113+ List . map(((id, _, p)) => (id, loop(p)), fields),
2114+ closed,
2115+ ),
2116+ ),
2117+ )
2118+ | _ =>
2119+ mkpat(
2120+ ~loc= pat. pat_loc,
2121+ PPatConstruct (lid, PPatConstrTuple (List . map(loop, lst))),
2122+ )
2123+ };
19962124 };
19972125
19982126 let ps = loop(typed);
0 commit comments