Skip to content

Commit d77f170

Browse files
authored
Merge pull request #102 from herbie-fp/dynamic-min-max
Dynamic path reduction
2 parents 3cb2cc5 + e04b342 commit d77f170

File tree

2 files changed

+121
-92
lines changed

2 files changed

+121
-92
lines changed

eval/adjust.rkt

Lines changed: 102 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,9 @@
2929
(define vhint (make-vector (vector-length ivec) #f))
3030
(define converged? #t)
3131

32-
; helper function
33-
(define (vhint-set! idx val)
34-
(when (>= idx varc)
35-
(vector-set! vhint (- idx varc) val)))
36-
3732
; roots always should be executed
3833
(for ([root-reg (in-vector rootvec)])
39-
(vhint-set! root-reg #t))
34+
(vector-shift-set! vhint varc root-reg #t))
4035
(for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)]
4136
[hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)]
4237
[o-hint (in-vector old-hint (- (vector-length old-hint) 1) -1 -1)]
@@ -48,89 +43,102 @@
4843
[(? integer? ref) ; instr is already "hinted" by old hint,
4944
(define idx (list-ref instr ref)) ; however, one child needs to be recomputed
5045
(when (>= idx varc)
51-
(vhint-set! idx #t))
46+
(vector-shift-set! vhint varc idx #t))
5247
o-hint]
5348
[#t
54-
(case (object-name (car instr))
55-
[(ival-assert)
56-
(match-define (list _ bool-idx) instr)
57-
(define bool-reg (vector-ref vregs bool-idx))
58-
(match* ((ival-lo bool-reg) (ival-hi bool-reg) (ival-err? bool-reg))
59-
; assert and its children should not be reexecuted if it is true already
60-
[(#t #t #f) (ival-bool #t)]
61-
; assert and its children should not be reexecuted if it is false already
62-
[(#f #f #f) (ival-bool #f)]
63-
[(_ _ _) ; assert and its children should be reexecuted
64-
(vhint-set! bool-idx #t)
65-
(set! converged? #f)
66-
#t])]
67-
[(ival-if)
68-
(match-define (list _ cond tru fls) instr)
69-
(define cond-reg (vector-ref vregs cond))
70-
(match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg))
71-
[(#t #t #f) ; only true path should be executed
72-
(vhint-set! tru #t)
73-
2]
74-
[(#f #f #f) ; only false path should be executed
75-
(vhint-set! fls #t)
76-
3]
77-
[(_ _ _) ; execute both paths and cond as well
78-
(vhint-set! cond #t)
79-
(vhint-set! tru #t)
80-
(vhint-set! fls #t)
81-
(set! converged? #f)
82-
#t])]
83-
[(ival-fmax)
84-
(match-define (list _ arg1 arg2) instr)
85-
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
86-
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
87-
[(#t #t #f) ; only arg1 should be executed
88-
(vhint-set! arg1 #t)
89-
1]
90-
[(#f #f #f) ; only arg2 should be executed
91-
(vhint-set! arg2 #t)
92-
2]
93-
[(_ _ _) ; both paths should be executed
94-
(vhint-set! arg1 #t)
95-
(vhint-set! arg2 #t)
96-
(set! converged? #f)
97-
#t])]
98-
[(ival-fmin)
99-
(match-define (list _ arg1 arg2) instr)
100-
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
101-
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
102-
[(#t #t #f) ; only arg2 should be executed
103-
(vhint-set! arg2 #t)
104-
2]
105-
[(#f #f #f) ; only arg1 should be executed
106-
(vhint-set! arg1 #t)
107-
1]
108-
[(_ _ _) ; both paths should be executed
109-
(vhint-set! arg1 #t)
110-
(vhint-set! arg2 #t)
111-
(set! converged? #f)
112-
#t])]
113-
[(ival-< ival-<= ival-> ival->= ival-== ival-!= ival-and ival-or ival-not)
114-
(define cmp (vector-ref vregs (+ varc n)))
115-
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
116-
; result is known
117-
[(#t #t #f) (ival-bool #t)]
118-
; result is known
119-
[(#f #f #f) (ival-bool #f)]
120-
[(_ _ _) ; all the paths should be executed
121-
(define srcs (rest instr))
122-
(for-each (λ (x) (vhint-set! x #t)) srcs)
123-
(set! converged? #f)
124-
#t])]
125-
[else ; at this point we are given that the current instruction should be executed
126-
(define srcs
127-
(drop-self-pointer (rest instr)
128-
(+ n varc))) ; then, children instructions should be executed as well
129-
(for-each (λ (x) (vhint-set! x #t)) srcs)
130-
#t])]))
49+
(define-values (hint* converged?*) (path-reduction vhint vregs varc instr n))
50+
(set! converged? (and converged?* converged?))
51+
hint*]))
13152
(vector-set! vhint n hint*))
13253
(values vhint converged?))
13354

55+
; helper function
56+
(define (vector-shift-set! vec varc idx val)
57+
(when (>= idx varc)
58+
(vector-set! vec (- idx varc) val)))
59+
60+
(define (path-reduction vpath vregs varc instr n #:reexec-val [reexec-val #t])
61+
(define converged? #t)
62+
(define hint
63+
(case (object-name (car instr))
64+
[(ival-assert)
65+
(match-define (list _ bool-idx) instr)
66+
(define bool-reg (vector-ref vregs bool-idx))
67+
(match* ((ival-lo bool-reg) (ival-hi bool-reg) (ival-err? bool-reg))
68+
; assert and its children should not be reexecuted if it is true already
69+
[(#t #t #f) (ival-bool #t)]
70+
; assert and its children should not be reexecuted if it is false already
71+
[(#f #f #f) (ival-bool #f)]
72+
[(_ _ _) ; assert and its children should be reexecuted
73+
(vector-shift-set! vpath varc bool-idx reexec-val)
74+
(set! converged? #f)
75+
#t])]
76+
[(ival-if)
77+
(match-define (list _ cond tru fls) instr)
78+
(define cond-reg (vector-ref vregs cond))
79+
(match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg))
80+
[(#t #t #f) ; only true path should be executed
81+
(vector-shift-set! vpath varc tru reexec-val)
82+
2]
83+
[(#f #f #f) ; only false path should be executed
84+
(vector-shift-set! vpath varc fls reexec-val)
85+
3]
86+
[(_ _ _) ; execute both paths and cond as well
87+
(vector-shift-set! vpath varc cond reexec-val)
88+
(vector-shift-set! vpath varc tru reexec-val)
89+
(vector-shift-set! vpath varc fls reexec-val)
90+
(set! converged? #f)
91+
#t])]
92+
[(ival-fmax)
93+
(match-define (list _ arg1 arg2) instr)
94+
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
95+
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
96+
[(#t #t #f) ; only arg1 should be executed
97+
(vector-shift-set! vpath varc arg1 reexec-val)
98+
1]
99+
[(#f #f #f) ; only arg2 should be executed
100+
(vector-shift-set! vpath varc arg2 reexec-val)
101+
2]
102+
[(_ _ _) ; both paths should be executed
103+
(vector-shift-set! vpath varc arg1 reexec-val)
104+
(vector-shift-set! vpath varc arg2 reexec-val)
105+
(set! converged? #f)
106+
#t])]
107+
[(ival-fmin)
108+
(match-define (list _ arg1 arg2) instr)
109+
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
110+
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
111+
[(#t #t #f) ; only arg2 should be executed
112+
(vector-shift-set! vpath varc arg2 reexec-val)
113+
2]
114+
[(#f #f #f) ; only arg1 should be executed
115+
(vector-shift-set! vpath varc arg1 reexec-val)
116+
1]
117+
[(_ _ _) ; both paths should be executed
118+
(vector-shift-set! vpath varc arg1 reexec-val)
119+
(vector-shift-set! vpath varc arg2 reexec-val)
120+
(set! converged? #f)
121+
#t])]
122+
[(ival-< ival-<= ival-> ival->= ival-== ival-!= ival-and ival-or ival-not)
123+
(define cmp (vector-ref vregs (+ varc n)))
124+
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
125+
; result is known
126+
[(#t #t #f) (ival-bool #t)]
127+
; result is known
128+
[(#f #f #f) (ival-bool #f)]
129+
[(_ _ _) ; all the paths should be executed
130+
(define srcs (rest instr))
131+
(for-each (λ (x) (vector-shift-set! vpath varc x reexec-val)) srcs)
132+
(set! converged? #f)
133+
#t])]
134+
[else ; at this point we are given that the current instruction should be executed
135+
(define srcs
136+
(drop-self-pointer (rest instr)
137+
(+ n varc))) ; then, children instructions should be executed as well
138+
(for-each (λ (x) (vector-shift-set! vpath varc x reexec-val)) srcs)
139+
#t]))
140+
(values hint converged?))
141+
134142
(define (backward-pass machine vhint)
135143
; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3
136144
(define args (rival-machine-arguments machine))
@@ -147,6 +155,7 @@
147155
(define current-iter (rival-machine-iteration machine))
148156
(define bumps (rival-machine-bumps machine))
149157

158+
(define first-tuning-pass? (equal? 1 current-iter))
150159
(define varc (vector-length args))
151160
(define vprecs-new (make-vector (vector-length ivec) 0)) ; new vprecs vector
152161

@@ -159,21 +168,23 @@
159168
(vector-set! vprecs-new (- root-reg varc) (get-slack)))
160169

161170
; Step 1b. Checking if a operation should be computed again at all
162-
(vector-fill! vrepeats #t) ; #t - means it WON'T be REEXECUTED
171+
; This step traverses instructions top-down and check whether a reevaluation is needed
172+
; Reevaluation can be skipped if:
173+
; 1) the result is already exact
174+
; 2) the operation is a part of the path that has been reduced
175+
; Once the path is fully reduced - vrepeats will store this path - no more reduction is needed
176+
(vector-fill! vrepeats #t)
163177
(for ([root (in-vector rootvec)]
164178
#:when (>= root varc))
165-
(vector-set! vrepeats (- root varc) #f)) ; #f - means it WILL be REEXECUTED
179+
(vector-set! vrepeats (- root varc) #f))
166180
(for ([reg (in-vector vregs (- (vector-length vregs) 1) (- varc 1) -1)]
167181
[instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)]
168182
[i (in-range (- (vector-length ivec) 1) -1 -1)]
169183
[repeat? (in-vector vrepeats (- (vector-length vrepeats) 1) -1 -1)]
170184
#:unless repeat?)
171185
(cond
172186
[(and (ival-lo-fixed? reg) (ival-hi-fixed? reg)) (vector-set! vrepeats i #t)]
173-
[else
174-
(for ([arg (in-list (drop-self-pointer (cdr instr) (+ i varc)))]
175-
#:when (>= arg varc))
176-
(vector-set! vrepeats (- arg varc) #f))]))
187+
[else (path-reduction vrepeats vregs varc instr i #:reexec-val #f)]))
177188

178189
; Step 2. Precision tuning
179190
(precision-tuning ivec vregs vprecs-new varc vstart-precs vrepeats vhint)
@@ -185,7 +196,7 @@
185196
(define any-reevaluation? #f)
186197
(for ([instr (in-vector ivec)]
187198
[result-is-exact-already (in-vector vrepeats)]
188-
[prec-old (in-vector (if (equal? 1 current-iter) vstart-precs vprecs))]
199+
[prec-old (in-vector (if first-tuning-pass? vstart-precs vprecs))]
189200
[prec-new (in-vector vprecs-new)]
190201
[best-known-precision (in-vector vbest-precs)]
191202
[constant? (in-vector vinitial-repeats)]

eval/tests.rkt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
(match-define (list res* hint* converged?*) (rival-analyze machine hyperrect hint))
105105
(check-equal? hint hint*)
106106
(check-equal? res res*)
107-
(check equal? converged? converged?*)
107+
(check-equal? converged? converged?*)
108108

109109
(for ([_ (in-range number-of-random-pts-per-rect)])
110110
(define pt (sample-pts hyperrect))
@@ -159,3 +159,21 @@
159159
(>= (log y) (log x)))
160160
x
161161
y))))
162+
163+
(module+ test
164+
(require rackunit
165+
"machine.rkt"
166+
"../main.rkt")
167+
; Test that checks correctness of early exit!
168+
(define expr '(fmin (exp (pow 1e100 1e200)) (- (cos (+ 1 1e200)) (cos 1e200))))
169+
(define machine (rival-compile (list expr) '() (list flonum-discretization)))
170+
(define out (vector-ref (rival-apply machine (vector)) 0))
171+
(check-equal? out 0.19018843355136827)
172+
173+
; Test that checks correctness of path reducing!
174+
(define machine2
175+
(rival-compile (list '(+ (fmin (- (cos (+ 1 1e200)) (cos 1e200)) -3)
176+
(- (cos (+ 1 1e200)) (cos 1e200))))
177+
'()
178+
(list flonum-discretization)))
179+
(check-equal? (vector-ref (rival-apply machine2 (vector)) 0) -2.809811566448632))

0 commit comments

Comments
 (0)