-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRankOpt.lean
More file actions
198 lines (170 loc) · 8.72 KB
/
RankOpt.lean
File metadata and controls
198 lines (170 loc) · 8.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
/-
Copyright (c) 2026 Gaëtan Serré. All rights reserved.
Released under GNU GPL 3.0 license as described in the file LICENSE.
Authors: Gaëtan Serré
-/
import LeanGO.Algorithm
import LeanGO.Examples.Uniform
import LeanGO.Examples.Utils
import Mathlib.Analysis.Normed.Lp.MeasurableSpace
open MeasureTheory ProbabilityTheory Set
/-!
# RankOpt: A Ranking Approach to Global Optimization
Implementation of the _RankOpt_ algorithm
[(_A Ranking Approach to Global Optimization_, Malherbe et al. 2017)](https://arxiv.org/pdf/1603.04381)
defined on a measurable subset of a Euclidean space, with finite and non-zero measure.
The algorithm samples from the uniform distribution on the set of potential maximizers of
the function at each iteration.
-/
section RankRule
/-- A rank rule is a measurable function that compares pairs of points.
It returns 1 if the first point is ranked higher, -1 if lower, and 0 if equal. -/
-- ANCHOR: RankRule
def RankRule (α : Type) [MeasurableSpace α] :=
{f : α → α → ({-1, 0, 1} : Set ℝ) // Measurable <| Function.uncurry f}
-- ANCHOR_END: RankRule
end RankRule
namespace RankOpt
variable {α : Type} [MeasurableSpace α] {d : ℕ} {α : Set (ℝᵈ d)}
(mes_α : MeasurableSet α) (mα₁ : ℙ α ≠ ⊤)
noncomputable instance : MeasureSpace α := Measure.Subtype.measureSpace
instance : MeasurableSpace α := by infer_instance
instance i₁ : IsFiniteMeasure (ℙ : Measure α) := by
rw [isFiniteMeasure_iff ℙ, Measure.Subtype.volume_univ]
· exact mα₁.lt_top
· exact mes_α.nullMeasurableSet
instance : MeasurableSpace (RankRule α) := Subtype.instMeasurableSpace
/-- Computes the ranking from observed function values.
Returns 1 if `y₁ > y₂`, 0 if `y₁ = y₂`, and -1 if `y₁ < y₂`. -/
noncomputable def ranking_data (y₁ y₂ : ℝ) :=
if y₂ < y₁ then 1 else if y₂ = y₁ then 0 else -1
/-- Indicator function checking if two rankings agree.
Returns 1 if both values are equal, 0 otherwise. -/
noncomputable abbrev rindicator (r₁ r₂ : ℝ) :=
if r₁ = r₂ then (1 : ℝ) else 0
variable {n : ℕ} (data : prod_iter_image α ℝ n)
abbrev s := {(i, j) : Finset.Iic n × Finset.Iic n | i ≤ j}
/-- Computes the ranking loss for a rank rule.
Measures the agreement between a candidate rule `r` and the rankings induced by the observed
function values on all pairs of data points, normalized by the number of pairs. -/
noncomputable def ranking_loss (r : RankRule α) :=
2 * (n * (n + 1) : ℝ)⁻¹ * ∑ ij ∈ s,
rindicator (r.1 (data.1 ij.1) (data.1 ij.2)) (ranking_data (data.2 ij.1) (data.2 ij.2))
--variable (𝓡 : Set (RankRule α))
/-- The point in the observed data with the maximum function value. -/
noncomputable abbrev argmax_f := data.1 <| Tuple.argmax id data.2
/-- The set of potential maximizers for the RankOpt algorithm.
Contains all points `x` for which there exists a ranking rule `r` in the hypothesis class `𝓡`
that: (1) has zero ranking loss (perfectly consistent with the observed data),
and (2) ranks `x` at least as high as the current best observed point. -/
def potential_max (𝓡 : Set (RankRule α)) :=
{x | ∃ (r : 𝓡), ranking_loss data r = 0 ∧ 0 ≤ (r.1.1 x (argmax_f data)).1}
lemma measurableSet_potential_max_prod {𝓡 : Set (RankRule α)} (h𝓡 : 𝓡.Countable) :
MeasurableSet {p : prod_iter_image α ℝ n × α | p.2 ∈ potential_max p.1 𝓡} := by
simp only [potential_max, mem_setOf_eq, measurableSet_setOf]
have : Countable (𝓡) := h𝓡.to_subtype
refine Measurable.exists fun r ↦ (.and ?_ ?_)
· simp only [ranking_loss]
refine Measurable.eq ?_ measurable_const
refine Measurable.const_mul (Finset.measurable_sum _ fun i hi ↦ ?_) _
simp only [rindicator]
refine Measurable.ite (measurableSet_eq_fun ?_ ?_) measurable_const measurable_const
· have := r.1.2
fun_prop
· simp only [ranking_data]
have : Measurable (fun (z : ℤ) ↦ (z : ℝ)) := by fun_prop
refine this.comp ?_
refine Measurable.ite ?_ measurable_const <| .ite ?_ measurable_const measurable_const
· measurability
· measurability
· refine Measurable.le' measurable_const ?_
have : Measurable (fun x : ({-1, 0, 1} : Set ℝ) ↦ (x : ℝ)) := by fun_prop
refine this.comp (r.1.2.comp (measurable_snd.prodMk ?_))
suffices Measurable (fun p : prod_iter_image α ℝ n ↦ p.1 (Tuple.argmax id p.2)) by
exact this.comp measurable_fst
have h_eval : Measurable (fun p : iter α n × Finset.Iic n ↦ p.1 p.2) := by
intro s hs
simp only [preimage]
have : {x | x.1 x.2 ∈ s} =
⋃ i : Finset.Iic n, {p : iter α n × Finset.Iic n | p.2 = i ∧ p.1 i ∈ s} := by
ext ⟨p, i⟩
simp
rw [this]
refine MeasurableSet.iUnion fun i ↦ (.inter ?_ ?_)
· exact measurableSet_eq_fun (by fun_prop) (by fun_prop)
· change MeasurableSet {p : iter α n × Finset.Iic n | (p.1 i) ∈ s}
refine Measurable.setOf ?_
exact hs.mem.comp (by fun_prop)
refine h_eval.comp (Measurable.prodMk ?_ ?_)
· fun_prop
· change Measurable (fun p : prod_iter_image α ℝ n ↦ Tuple.argmax id p.2)
suffices Measurable (fun u : iter ℝ n ↦ Tuple.argmax id u) from this.comp measurable_snd
refine measurable_to_countable' fun i ↦ ?_
simp only [preimage, mem_singleton_iff]
let Maximizers {n : ℕ} (u : iter ℝ n) : Set (Finset.Iic n) := {i | u i = Tuple.max u}
have : {u : iter ℝ n | Tuple.argmax id u = i} = ⋃ (S)
(hS : ∀ x, Maximizers x = S → Tuple.argmax id x = i), {u | Maximizers u = S} := by
ext u
simp only [mem_setOf_eq, mem_iUnion, exists_prop, exists_eq_right']
constructor
· intro hu x hx
rw [← hu]
unfold Tuple.argmax
exact Classical.choose.congr_simp hx (Tuple.exists_argmax id x)
· intro h
exact h u rfl
rw [this]
refine MeasurableSet.iUnion fun S ↦ (.iUnion fun hS ↦ ?_)
exact measurableSet_eq_fun (by fun_prop) measurable_const
include mes_α mα₁ in
lemma measurable_volume_potential_max_inter {𝓡 : Set (RankRule α)} (h𝓡 : 𝓡.Countable)
(s : Set α) (hs : MeasurableSet s) :
Measurable (fun data : prod_iter_image α ℝ n ↦ ℙ (potential_max data 𝓡 ∩ s)) := by
set E := {p : prod_iter_image α ℝ n × α | p.2 ∈ potential_max p.1 𝓡 ∩ s}
have hE_meas : MeasurableSet E :=
(measurableSet_potential_max_prod h𝓡).inter (measurableSet_preimage measurable_snd hs)
have := i₁ mes_α mα₁
exact measurable_measure_prodMk_left hE_meas
/-- Markov kernel that samples uniformly from the set of potential maximizers.
This kernel forms the core sampling strategy of RankOpt: at each iteration, given the observed
data, it samples the next query point uniformly from `potential_max`. -/
noncomputable def potential_max_kernel {𝓡 : Set (RankRule α)} (h𝓡 : 𝓡.Countable) :
Kernel (prod_iter_image α ℝ n) α := by
refine ⟨fun data ↦ uniform <| @potential_max d α n data 𝓡, ?_⟩
rw [Measure.measurable_measure]
intro s hs
simp only [Measure.smul_apply, MeasureTheory.Measure.restrict_apply hs, smul_eq_mul]
refine Measurable.mul ?_ ?_
· refine Measurable.inv ?_
convert measurable_volume_potential_max_inter mes_α mα₁ h𝓡 Set.univ (MeasurableSet.univ)
simp [Set.inter_univ]
· convert measurable_volume_potential_max_inter mes_α mα₁ h𝓡 s hs using 1
simp [Set.inter_comm]
end RankOpt
open RankOpt
-- ANCHOR: RankOptvars
variable {α : Type} [MeasurableSpace α] {d : ℕ} {α : Set (ℝᵈ d)}
(mes_α : MeasurableSet α) (mα₀ : ℙ α ≠ 0) (mα₁ : ℙ α ≠ ⊤)
{𝓡 : Set (RankRule α)} (h𝓡 : 𝓡.Countable)
-- ANCHOR_END: RankOptvars
/- We suppose that the set of potential maximizers has non-zero measure at each iteration,
ensuring that the algorithm can sample from it. -/
variable (h : ∀ n (data : prod_iter_image α ℝ n), ℙ (potential_max data 𝓡) ≠ 0)
/-- The RankOpt algorithm for global optimization.
This algorithm uses a ranking approach to optimize an unknown function. It maintains a hypothesis
class `𝓡` of ranking rules. At each iteration, it samples from the set of points that could be
optimal according to ranking rules consistent with the observed data
[(Malherbe et al., 2017)](https://arxiv.org/pdf/1603.04381). -/
-- ANCHOR: RankOpt
noncomputable def RankOpt : Algorithm α ℝ where
ν := uniform Set.univ
prob_measure := by
have := i₁ mes_α mα₁
refine uniform_is_prob_measure ?_
rwa [Measure.Subtype.volume_univ mes_α.nullMeasurableSet]
kernel_iter _ := potential_max_kernel mes_α mα₁ h𝓡
markov_kernel n := by
refine ⟨fun data ↦ ?_⟩
have := i₁ mes_α mα₁
refine uniform_is_prob_measure <| h n data
-- ANCHOR_END: RankOpt