Skip to content

Commit b0e236d

Browse files
mkannwischerhanno-becker
authored andcommitted
HOL-Light: Speed up NTT/INTT proofs
This commit ports @jargh's improvements to the NTT/INTT proofs in s2n-bignum which should result in vast proof performance improvements (1.6x- 4.3x according to the PR). - Resolves #1413 - Ports awslabs/s2n-bignum#325 Signed-off-by: Matthias J. Kannwischer <matthias@kannwischer.eu>
1 parent e6a9424 commit b0e236d

File tree

5 files changed

+147
-64
lines changed

5 files changed

+147
-64
lines changed

proofs/hol_light/arm/proofs/mlkem_intt.ml

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ let MLKEM_INTT_CORRECT = prove
655655
(*** Simulate all the way to the end, in effect unrolling loops ***)
656656

657657
MAP_UNTIL_TARGET_PC (fun n -> ARM_STEPS_TAC MLKEM_INTT_EXEC [n] THEN
658-
(SIMD_SIMPLIFY_TAC [barred; barmul]))
658+
(SIMD_SIMPLIFY_ABBREV_TAC[barred; barmul] []))
659659
1 THEN
660660
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
661661

@@ -666,19 +666,28 @@ let MLKEM_INTT_CORRECT = prove
666666
CONV_RULE(READ_MEMORY_SPLIT_CONV 3) o
667667
check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN
668668

669-
(*** Turn the conclusion into an explicit conjunction and split it up ***)
669+
(*** Expand and substitute in the conclusion we want to prove ***)
670670

671671
CONV_TAC(ONCE_DEPTH_CONV let_CONV) THEN REWRITE_TAC[INT_ABS_BOUNDS] THEN
672672
GEN_REWRITE_TAC (BINDER_CONV o RAND_CONV) [GSYM I_THM] THEN
673673
CONV_TAC(EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV) THEN
674674
ASM_REWRITE_TAC[I_THM; WORD_ADD_0] THEN DISCARD_STATE_TAC "s1153" THEN
675+
676+
(*** Perform congruence and bound propagation and finish ***)
677+
678+
W(fun (asl,w) ->
679+
let lfn = undefined
680+
and asms =
681+
map snd (filter (is_local_definition [barred; barmul] o concl o snd)
682+
asl) in
683+
let lfn' = LOCAL_CONGBOUND_RULE lfn (rev asms) in
684+
675685
REPEAT(W(fun (asl,w) ->
676686
if length(conjuncts w) > 3 then CONJ_TAC else NO_TAC)) THEN
677687

678-
(*** Get congruences and bounds for the result digits and finish ***)
679-
680-
(W(MP_TAC o CONGBOUND_RULE o rand o lhand o rator o lhand o snd) THEN
681-
MATCH_MP_TAC MONO_AND THEN CONJ_TAC THENL
688+
W(MP_TAC o ASM_CONGBOUND_RULE lfn' o
689+
rand o lhand o rator o lhand o snd) THEN
690+
(MATCH_MP_TAC MONO_AND THEN CONJ_TAC THENL
682691
[MATCH_MP_TAC(REWRITE_RULE[IMP_CONJ_ALT] INT_CONG_TRANS) THEN
683692
CONV_TAC(ONCE_DEPTH_CONV INVERSE_NTT_CONV) THEN
684693
REWRITE_TAC[GSYM INT_REM_EQ; o_THM] THEN CONV_TAC INT_REM_DOWN_CONV THEN
@@ -689,7 +698,7 @@ let MLKEM_INTT_CORRECT = prove
689698
MATCH_MP_TAC(INT_ARITH
690699
`l':int <= l /\ u <= u'
691700
==> l <= x /\ x <= u ==> l' <= x /\ x <= u'`) THEN
692-
CONV_TAC INT_REDUCE_CONV]));;
701+
CONV_TAC INT_REDUCE_CONV])));;
693702

694703
(*** Subroutine form, somewhat messy elaboration of the usual wrapper ***)
695704

proofs/hol_light/arm/proofs/mlkem_ntt.ml

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ let MLKEM_NTT_CORRECT = prove
590590
(*** Simulate all the way to the end, in effect unrolling loops ***)
591591

592592
MAP_UNTIL_TARGET_PC (fun n -> ARM_STEPS_TAC MLKEM_NTT_EXEC [n] THEN
593-
(SIMD_SIMPLIFY_TAC [barmul])) 1 THEN
593+
(SIMD_SIMPLIFY_ABBREV_TAC[barmul] [])) 1 THEN
594594
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
595595

596596
(*** Reverse the restructuring by splitting the 128-bit words up ***)
@@ -600,23 +600,28 @@ let MLKEM_NTT_CORRECT = prove
600600
CONV_RULE(READ_MEMORY_SPLIT_CONV 3) o
601601
check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN
602602

603-
(*** Turn the conclusion into an explicit conjunction and split it up ***)
603+
(*** Expand and substitute in the conclusion we want to prove ***)
604604

605605
DISCH_TAC THEN
606606
CONV_TAC(ONCE_DEPTH_CONV let_CONV) THEN REWRITE_TAC[INT_ABS_BOUNDS] THEN
607607
GEN_REWRITE_TAC (BINDER_CONV o RAND_CONV) [GSYM I_THM] THEN
608608
CONV_TAC(EXPAND_CASES_CONV THENC ONCE_DEPTH_CONV NUM_MULT_CONV) THEN
609609
ASM_REWRITE_TAC[I_THM; WORD_ADD_0] THEN DISCARD_STATE_TAC "s904" THEN
610-
REPEAT(W(fun (asl,w) ->
611-
if length(conjuncts w) > 3 then CONJ_TAC else NO_TAC)) THEN
612610

613-
(*** Get congruences and bounds for the result digits and finish ***)
611+
(*** Perform congruence and bound propagation and finish ***)
614612

615-
FIRST_X_ASSUM(MP_TAC o CONV_RULE EXPAND_CASES_CONV) THEN
616-
POP_ASSUM_LIST(K ALL_TAC) THEN
617-
DISCH_THEN(fun aboth ->
618-
W(MP_TAC o GEN_CONGBOUND_RULE (CONJUNCTS aboth) o
619-
rand o lhand o rator o lhand o snd)) THEN
613+
W(fun (asl,w) ->
614+
let lfn = PROCESS_BOUND_ASSUMPTIONS
615+
(CONJUNCTS(tryfind (CONV_RULE EXPAND_CASES_CONV o snd) asl))
616+
and asms =
617+
map snd (filter (is_local_definition [barmul] o concl o snd) asl) in
618+
let lfn' = LOCAL_CONGBOUND_RULE lfn (rev asms) in
619+
620+
REPEAT(W(fun (asl,w) ->
621+
if length(conjuncts w) > 3 then CONJ_TAC else NO_TAC)) THEN
622+
623+
W(MP_TAC o ASM_CONGBOUND_RULE lfn' o
624+
rand o lhand o rator o lhand o snd) THEN
620625
(MATCH_MP_TAC MONO_AND THEN CONJ_TAC THENL
621626
[MATCH_MP_TAC(REWRITE_RULE[IMP_CONJ_ALT] INT_CONG_TRANS) THEN
622627
CONV_TAC(ONCE_DEPTH_CONV FORWARD_NTT_CONV) THEN
@@ -628,7 +633,7 @@ let MLKEM_NTT_CORRECT = prove
628633
MATCH_MP_TAC(INT_ARITH
629634
`l':int <= l /\ u <= u'
630635
==> l <= x /\ x <= u ==> l' <= x /\ x <= u'`) THEN
631-
CONV_TAC INT_REDUCE_CONV]));;
636+
CONV_TAC INT_REDUCE_CONV])));;
632637

633638
(*** Subroutine form, somewhat messy elaboration of the usual wrapper ***)
634639

proofs/hol_light/common/mlkem_specs.ml

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -957,9 +957,7 @@ let CONCL_BOUNDS_RULE =
957957
let SIDE_ELIM_RULE th =
958958
MP th (EQT_ELIM(DIMINDEX_INT_REDUCE_CONV(lhand(concl th))));;
959959

960-
let GEN_CONGBOUND_RULE aboths =
961-
let lfn = PROCESS_BOUND_ASSUMPTIONS aboths in
962-
let rec rule tm =
960+
let rec ASM_CONGBOUND_RULE lfn tm =
963961
try apply lfn tm with Failure _ ->
964962
match tm with
965963
Comb(Const("word",_),n) when is_numeral n ->
@@ -972,60 +970,76 @@ let GEN_CONGBOUND_RULE aboths =
972970
let th2 = WORD_RED_CONV(lhand(lhand(snd(strip_forall(concl th1))))) in
973971
SUBS[SYM th0] (MATCH_MP th1 th2)
974972
| Comb(Comb(Const("barmul",_),kb),t) ->
975-
let ktm,btm = dest_pair kb and th0 = rule t in
973+
let ktm,btm = dest_pair kb and th0 = ASM_CONGBOUND_RULE lfn t in
976974
let th0' = WEAKEN_INTCONG_RULE (num 3329) th0 in
977975
let th1 = SPECL [ktm;btm] (MATCH_MP CONGBOUND_BARMUL th0') in
978976
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE th1)
979977
| Comb(Comb(Const("montmul_x86",_),ltm),rtm) ->
980-
let lth = WEAKEN_INTCONG_RULE (num 3329) (rule ltm)
981-
and rth = WEAKEN_INTCONG_RULE (num 3329) (rule rtm) in
978+
let lth = WEAKEN_INTCONG_RULE (num 3329) (ASM_CONGBOUND_RULE lfn ltm)
979+
and rth = WEAKEN_INTCONG_RULE (num 3329) (ASM_CONGBOUND_RULE lfn rtm) in
982980
let th1 = MATCH_MP CONGBOUND_MONTMUL_X86
983981
(UNIFY_INTCONG_RULE lth rth) in
984982
CONCL_BOUNDS_RULE(th1)
985983
| Comb(Const("barred",_),t) ->
986-
let th1 = WEAKEN_INTCONG_RULE (num 3329) (rule t) in
984+
let th1 = WEAKEN_INTCONG_RULE (num 3329) (ASM_CONGBOUND_RULE lfn t) in
987985
MATCH_MP CONGBOUND_BARRED th1
988986
| Comb(Const("barred_x86",_),t) ->
989-
let th1 = WEAKEN_INTCONG_RULE (num 3329) (rule t) in
987+
let th1 = WEAKEN_INTCONG_RULE (num 3329) (ASM_CONGBOUND_RULE lfn t) in
990988
MATCH_MP CONGBOUND_BARRED_X86 th1
991989
| Comb(Const("montred",_),t) ->
992-
let th1 = WEAKEN_INTCONG_RULE (num 3329) (rule t) in
990+
let th1 = WEAKEN_INTCONG_RULE (num 3329) (ASM_CONGBOUND_RULE lfn t) in
993991
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE(MATCH_MP CONGBOUND_MONTRED th1))
994992
| Comb(Comb(Const("ntt_montmul",_),ab),t) ->
995-
let atm,btm = dest_pair ab and th0 = rule t in
993+
let atm,btm = dest_pair ab and th0 = ASM_CONGBOUND_RULE lfn t in
996994
let th0' = WEAKEN_INTCONG_RULE (num 3329) th0 in
997995
let th1 = SPECL [atm;btm] (MATCH_MP CONGBOUND_NTT_MONTMUL th0') in
998996
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE th1)
999997
| Comb(Const("word_sx",_),t) ->
1000-
let th0 = rule t in
998+
let th0 = ASM_CONGBOUND_RULE lfn t in
1001999
let tyin = type_match
10021000
(type_of(rator(rand(lhand(funpow 4 rand (snd(dest_forall
10031001
(concl CONGBOUND_WORD_SX)))))))) (type_of(rator tm)) [] in
10041002
let th1 = MATCH_MP (INST_TYPE tyin CONGBOUND_WORD_SX) th0 in
10051003
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE th1)
10061004
| Comb(Const("word_neg",_),t) ->
1007-
let th0 = rule t in
1005+
let th0 = ASM_CONGBOUND_RULE lfn t in
10081006
let th1 = MATCH_MP CONGBOUND_WORD_NEG th0 in
10091007
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE th1)
10101008
| Comb(Comb(Const("word_add",_),ltm),rtm) ->
1011-
let lth = rule ltm and rth = rule rtm in
1009+
let lth = ASM_CONGBOUND_RULE lfn ltm
1010+
and rth = ASM_CONGBOUND_RULE lfn rtm in
10121011
let th1 = MATCH_MP CONGBOUND_WORD_ADD (UNIFY_INTCONG_RULE lth rth) in
10131012
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE th1)
10141013
| Comb(Comb(Const("word_sub",_),ltm),rtm) ->
1015-
let lth = rule ltm and rth = rule rtm in
1014+
let lth = ASM_CONGBOUND_RULE lfn ltm
1015+
and rth = ASM_CONGBOUND_RULE lfn rtm in
10161016
let th1 = MATCH_MP CONGBOUND_WORD_SUB (UNIFY_INTCONG_RULE lth rth) in
10171017
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE th1)
10181018
| Comb(Comb(Const("word_mul",_),ltm),rtm) ->
1019-
let lth = rule ltm and rth = rule rtm in
1019+
let lth = ASM_CONGBOUND_RULE lfn ltm
1020+
and rth = ASM_CONGBOUND_RULE lfn rtm in
10201021
let th1 = MATCH_MP CONGBOUND_WORD_MUL (UNIFY_INTCONG_RULE lth rth) in
10211022
CONCL_BOUNDS_RULE(SIDE_ELIM_RULE th1)
1022-
| _ -> CONCL_BOUNDS_RULE(ISPEC tm CONGBOUND_ATOM) in
1023-
rule;;
1023+
| _ -> CONCL_BOUNDS_RULE(ISPEC tm CONGBOUND_ATOM);;
1024+
1025+
let GEN_CONGBOUND_RULE aboths =
1026+
ASM_CONGBOUND_RULE (PROCESS_BOUND_ASSUMPTIONS aboths);;
10241027

10251028
let CONGBOUND_RULE = GEN_CONGBOUND_RULE [];;
10261029

1030+
let rec LOCAL_CONGBOUND_RULE lfn asms =
1031+
match asms with
1032+
[] -> lfn
1033+
| th::ths ->
1034+
let bod,var = dest_eq (concl th) in
1035+
let th1 = ASM_CONGBOUND_RULE lfn bod in
1036+
let th2 = SUBS[th] th1 in
1037+
let lfn' = (var |-> th2) lfn in
1038+
LOCAL_CONGBOUND_RULE lfn' ths;;
1039+
10271040
(* ------------------------------------------------------------------------- *)
1028-
(* Simplify SIMD cruft and fold abbreviations when encountered. *)
1041+
(* Simplify SIMD cruft and fold relevant definitions when encountered. *)
1042+
(* The ABBREV form also introduces abbreviations for relevant subterms. *)
10291043
(* ------------------------------------------------------------------------- *)
10301044

10311045
let SIMD_SIMPLIFY_CONV unfold_defs =
@@ -1042,3 +1056,31 @@ let SIMD_SIMPLIFY_TAC unfold_defs =
10421056
(ASSUME_TAC o
10431057
CONV_RULE(RAND_CONV (SIMD_SIMPLIFY_CONV unfold_defs)) o
10441058
check (simdable o concl)));;
1059+
1060+
let is_local_definition unfold_defs =
1061+
let pats = map (lhand o snd o strip_forall o concl) unfold_defs in
1062+
let pam t = exists (fun p -> can(term_match [] p) t) pats in
1063+
fun tm -> is_eq tm && is_var(rand tm) && pam(lhand tm);;
1064+
1065+
let AUTO_ABBREV_TAC tm =
1066+
let gv = genvar(type_of tm) in
1067+
ABBREV_TAC(mk_eq(gv,tm));;
1068+
1069+
let SIMD_SIMPLIFY_ABBREV_TAC =
1070+
let arm_simdable =
1071+
can (term_match [] `read X (s:armstate):int128 = whatever`)
1072+
and x86_simdable =
1073+
can (term_match [] `read X (s:x86state):int256 = whatever`) in
1074+
let simdable tm = arm_simdable tm || x86_simdable tm in
1075+
fun unfold_defs unfold_aux ->
1076+
let pats = map (lhand o snd o strip_forall o concl) unfold_defs in
1077+
let pam t = exists (fun p -> can(term_match [] p) t) pats in
1078+
let ttac th (asl,w) =
1079+
let th' = CONV_RULE(RAND_CONV
1080+
(SIMD_SIMPLIFY_CONV (unfold_defs @ unfold_aux))) th in
1081+
let asms =
1082+
map snd (filter (is_local_definition unfold_defs o concl o snd) asl) in
1083+
let th'' = GEN_REWRITE_RULE (RAND_CONV o TOP_DEPTH_CONV) asms th' in
1084+
let tms = sort free_in (find_terms pam (rand(concl th''))) in
1085+
(MP_TAC th'' THEN MAP_EVERY AUTO_ABBREV_TAC tms THEN DISCH_TAC) (asl,w) in
1086+
TRY(FIRST_X_ASSUM(ttac o check (simdable o concl)));;

proofs/hol_light/x86/proofs/mlkem_intt.ml

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,8 @@ let MLKEM_INTT_CORRECT = prove
11021102
CONV_TAC(LAND_CONV WORD_REDUCE_CONV) THEN STRIP_TAC THEN
11031103

11041104
MAP_EVERY (fun n -> X86_STEPS_TAC MLKEM_INTT_TMC_EXEC [n] THEN
1105-
SIMD_SIMPLIFY_TAC[ntt_montmul; ntt_montmul_add; ntt_montmul_sub; barred_x86])
1105+
SIMD_SIMPLIFY_ABBREV_TAC[ntt_montmul; barred_x86]
1106+
[ntt_montmul_add; ntt_montmul_sub])
11061107
(1--663) THEN
11071108

11081109
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
@@ -1120,6 +1121,13 @@ let MLKEM_INTT_CORRECT = prove
11201121

11211122
ASM_REWRITE_TAC[] THEN DISCARD_STATE_TAC "s663" THEN
11221123

1124+
W(fun (asl,w) ->
1125+
let asms =
1126+
map snd (filter (is_local_definition
1127+
[ntt_montmul; barred_x86] o concl o snd) asl) in
1128+
MP_TAC(end_itlist CONJ (rev asms)) THEN
1129+
MAP_EVERY (fun t -> UNDISCH_THEN (concl t) (K ALL_TAC)) asms) THEN
1130+
11231131
REWRITE_TAC[WORD_BLAST `word_subword (x:int32) (0, 32) = x`] THEN
11241132
REWRITE_TAC[WORD_BLAST `word_subword (x:int64) (0, 64) = x`] THEN
11251133
REWRITE_TAC[WORD_BLAST
@@ -1143,13 +1151,22 @@ let MLKEM_INTT_CORRECT = prove
11431151
word_subword x (0, 16)`] THEN
11441152
CONV_TAC(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN
11451153

1154+
STRIP_TAC THEN
1155+
11461156
CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN
11471157
REWRITE_TAC[GSYM CONJ_ASSOC] THEN
1148-
REPEAT(GEN_REWRITE_TAC I
1149-
[TAUT `p /\ q /\ r /\ s <=> (p /\ q /\ r) /\ s`] THEN CONJ_TAC) THEN
1150-
POP_ASSUM_LIST(K ALL_TAC) THEN
1151-
(W(MP_TAC o CONGBOUND_RULE o rand o lhand o rator o lhand o snd) THEN
1152-
MATCH_MP_TAC MONO_AND THEN CONJ_TAC THENL
1158+
1159+
W(fun (asl,w) ->
1160+
let lfn = undefined
1161+
and asms =
1162+
map snd (filter (is_local_definition [ntt_montmul; barred_x86] o concl o snd) asl) in
1163+
let lfn' = LOCAL_CONGBOUND_RULE lfn (rev asms) in
1164+
1165+
REPEAT(GEN_REWRITE_TAC I
1166+
[TAUT `p /\ q /\ r /\ s <=> (p /\ q /\ r) /\ s`] THEN CONJ_TAC) THEN
1167+
1168+
W(MP_TAC o ASM_CONGBOUND_RULE lfn' o rand o lhand o rator o lhand o snd) THEN
1169+
(MATCH_MP_TAC MONO_AND THEN CONJ_TAC THENL
11531170
[REWRITE_TAC[INVERSE_MOD_CONV `inverse_mod 3329 65536`] THEN
11541171
MATCH_MP_TAC(REWRITE_RULE[IMP_CONJ_ALT] INT_CONG_TRANS) THEN
11551172
CONV_TAC(ONCE_DEPTH_CONV AVX2_INVERSE_NTT_CONV) THEN
@@ -1161,7 +1178,7 @@ let MLKEM_INTT_CORRECT = prove
11611178
MATCH_MP_TAC(INT_ARITH
11621179
`l':int <= l /\ u <= u'
11631180
==> l <= x /\ x <= u ==> l' <= x /\ x <= u'`) THEN
1164-
CONV_TAC INT_REDUCE_CONV])
1181+
CONV_TAC INT_REDUCE_CONV]))
11651182
);;
11661183

11671184
let MLKEM_INTT_NOIBT_SUBROUTINE_CORRECT = prove

proofs/hol_light/x86/proofs/mlkem_ntt.ml

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ let MLKEM_NTT_CORRECT = prove
11141114
CONV_TAC(LAND_CONV WORD_REDUCE_CONV) THEN STRIP_TAC THEN
11151115

11161116
MAP_EVERY (fun n -> X86_STEPS_TAC MLKEM_NTT_TMC_EXEC [n] THEN
1117-
SIMD_SIMPLIFY_TAC[ntt_montmul; ntt_montmul_add; ntt_montmul_sub])
1117+
SIMD_SIMPLIFY_ABBREV_TAC[ntt_montmul] [ntt_montmul_add; ntt_montmul_sub])
11181118
(1--587) THEN
11191119
ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN
11201120

@@ -1131,6 +1131,12 @@ let MLKEM_NTT_CORRECT = prove
11311131

11321132
ASM_REWRITE_TAC[] THEN DISCARD_STATE_TAC "s587" THEN
11331133

1134+
W(fun (asl,w) ->
1135+
let asms =
1136+
map snd (filter (is_local_definition [ntt_montmul] o concl o snd) asl) in
1137+
MP_TAC(end_itlist CONJ (rev asms)) THEN
1138+
MAP_EVERY (fun t -> UNDISCH_THEN (concl t) (K ALL_TAC)) asms) THEN
1139+
11341140
REWRITE_TAC[WORD_BLAST `word_subword (x:int32) (0, 32) = x`] THEN
11351141
REWRITE_TAC[WORD_BLAST `word_subword (x:int64) (0, 64) = x`] THEN
11361142
REWRITE_TAC[WORD_BLAST
@@ -1152,31 +1158,35 @@ let MLKEM_NTT_CORRECT = prove
11521158

11531159
CONV_TAC(TOP_DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN
11541160

1161+
STRIP_TAC THEN
1162+
11551163
CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN
11561164
REWRITE_TAC[GSYM CONJ_ASSOC] THEN
1165+
1166+
W(fun (asl,w) ->
1167+
let lfn = PROCESS_BOUND_ASSUMPTIONS
1168+
(CONJUNCTS(tryfind (CONV_RULE EXPAND_CASES_CONV o snd) asl))
1169+
and asms =
1170+
map snd (filter (is_local_definition [ntt_montmul] o concl o snd) asl) in
1171+
let lfn' = LOCAL_CONGBOUND_RULE lfn (rev asms) in
1172+
11571173
REPEAT(GEN_REWRITE_TAC I
11581174
[TAUT `p /\ q /\ r /\ s <=> (p /\ q /\ r) /\ s`] THEN CONJ_TAC) THEN
11591175

1160-
FIRST_X_ASSUM(MP_TAC o CONV_RULE EXPAND_CASES_CONV) THEN
1161-
POP_ASSUM_LIST(K ALL_TAC) THEN
1162-
CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN
1163-
DISCH_THEN(fun aboth ->
1164-
W(MP_TAC o GEN_CONGBOUND_RULE (CONJUNCTS aboth) o
1165-
rand o lhand o rator o lhand o snd)) THEN
1166-
1167-
(MATCH_MP_TAC MONO_AND THEN CONJ_TAC THENL
1168-
[REWRITE_TAC[INVERSE_MOD_CONV `inverse_mod 3329 65536`] THEN
1169-
MATCH_MP_TAC(REWRITE_RULE[IMP_CONJ_ALT] INT_CONG_TRANS) THEN
1170-
CONV_TAC(ONCE_DEPTH_CONV AVX2_FORWARD_NTT_CONV) THEN
1171-
REWRITE_TAC[GSYM INT_REM_EQ; o_THM] THEN CONV_TAC INT_REM_DOWN_CONV THEN
1172-
REWRITE_TAC[INT_REM_EQ] THEN
1173-
REWRITE_TAC[REAL_INT_CONGRUENCE; INT_OF_NUM_EQ; ARITH_EQ] THEN
1174-
REWRITE_TAC[GSYM REAL_OF_INT_CLAUSES] THEN
1175-
CONV_TAC(RAND_CONV REAL_POLY_CONV) THEN REAL_INTEGER_TAC;
1176-
MATCH_MP_TAC(INT_ARITH
1177-
`l':int <= l /\ u <= u'
1178-
==> l <= x /\ x <= u ==> l' <= x /\ x <= u'`) THEN
1179-
CONV_TAC INT_REDUCE_CONV])
1176+
W(MP_TAC o ASM_CONGBOUND_RULE lfn' o rand o lhand o rator o lhand o snd) THEN
1177+
(MATCH_MP_TAC MONO_AND THEN CONJ_TAC THENL
1178+
[REWRITE_TAC[INVERSE_MOD_CONV `inverse_mod 3329 65536`] THEN
1179+
MATCH_MP_TAC(REWRITE_RULE[IMP_CONJ_ALT] INT_CONG_TRANS) THEN
1180+
CONV_TAC(ONCE_DEPTH_CONV AVX2_FORWARD_NTT_CONV) THEN
1181+
REWRITE_TAC[GSYM INT_REM_EQ; o_THM] THEN CONV_TAC INT_REM_DOWN_CONV THEN
1182+
REWRITE_TAC[INT_REM_EQ] THEN
1183+
REWRITE_TAC[REAL_INT_CONGRUENCE; INT_OF_NUM_EQ; ARITH_EQ] THEN
1184+
REWRITE_TAC[GSYM REAL_OF_INT_CLAUSES] THEN
1185+
CONV_TAC(RAND_CONV REAL_POLY_CONV) THEN REAL_INTEGER_TAC;
1186+
MATCH_MP_TAC(INT_ARITH
1187+
`l':int <= l /\ u <= u'
1188+
==> l <= x /\ x <= u ==> l' <= x /\ x <= u'`) THEN
1189+
CONV_TAC INT_REDUCE_CONV]))
11801190
);;
11811191

11821192
let MLKEM_NTT_NOIBT_SUBROUTINE_CORRECT = prove

0 commit comments

Comments
 (0)