-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathCountingOracle.lean
More file actions
348 lines (304 loc) · 16.3 KB
/
CountingOracle.lean
File metadata and controls
348 lines (304 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
/-
Copyright (c) 2024 Devon Tuma. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Devon Tuma, Quang Dao
-/
import VCVio.OracleComp.QueryTracking.Structures
import VCVio.OracleComp.SimSemantics.WriterT
import VCVio.OracleComp.SimSemantics.Constructions
/-!
# Counting Queries Made by a Computation
This file defines a simulation oracle `countingOracle` for counting the number of queries made
while running the computation. The count is represented by a function from oracle indices to
counts, allowing each oracle to be tracked individually.
Tracking individually is not necessary, but gives tighter security bounds in some cases.
It also allows for generating things like seed values for a computation more tightly.
-/
open OracleSpec OracleComp
universe u v w
variable {ι : Type u} {spec : OracleSpec ι} {α β γ : Type u}
namespace QueryImpl
variable {m : Type u → Type v} [Monad m]
/-- Wrap an oracle implementation to count queries in a `WriterT (QueryCount ι)` layer.
Counting happens before the implementation runs, so failed queries are still counted. -/
def withCounting [DecidableEq ι] (so : QueryImpl spec m) :
QueryImpl spec (WriterT (QueryCount ι) m) :=
fun t => do tell (QueryCount.single t); so t
@[simp, grind =]
lemma withCounting_apply [DecidableEq ι] (so : QueryImpl spec m) (t : spec.Domain) :
so.withCounting t = (do tell (QueryCount.single t); so t) := rfl
lemma fst_map_run_withCounting [DecidableEq ι] [LawfulMonad m]
(so : QueryImpl spec m) (mx : OracleComp spec α) :
Prod.fst <$> (simulateQ (so.withCounting) mx).run = simulateQ so mx := by
induction mx using OracleComp.inductionOn with
| pure x => simp
| query_bind t oa h => simp [h]
end QueryImpl
/-- Oracle for counting the number of queries made by a computation. The count is stored as a
function from oracle indices to counts, to give finer grained information about the count. -/
def countingOracle [DecidableEq ι] :
QueryImpl spec (WriterT (QueryCount ι) (OracleComp spec)) :=
(QueryImpl.ofLift spec (OracleComp spec)).withCounting
namespace countingOracle
variable [DecidableEq ι]
@[simp]
lemma fst_map_run_simulateQ (oa : OracleComp spec α) :
Prod.fst <$> (simulateQ countingOracle oa).run = oa := by
rw [countingOracle, QueryImpl.fst_map_run_withCounting, simulateQ_ofLift_eq_self]
@[simp]
lemma run_simulateQ_bind_fst (oa : OracleComp spec α) (ob : α → OracleComp spec β) :
((simulateQ countingOracle oa).run >>= fun x => ob x.1) = oa >>= ob := by
rw [← bind_map_left Prod.fst, fst_map_run_simulateQ]
@[simp]
lemma probFailure_run_simulateQ {ι₀ : Type} {spec₀ : OracleSpec.{0,0} ι₀} [DecidableEq ι₀]
[spec₀.Fintype] [spec₀.Inhabited] {α : Type} (oa : OracleComp spec₀ α) :
Pr[⊥ | (simulateQ (countingOracle (spec := spec₀)) oa).run] = Pr[⊥ | oa] := by
simp only [HasEvalPMF.probFailure_eq_zero]
@[simp]
lemma NeverFail_run_simulateQ_iff {ι₀ : Type} {spec₀ : OracleSpec.{0,0} ι₀} [DecidableEq ι₀]
[spec₀.Fintype] [spec₀.Inhabited] {α : Type}
(oa : OracleComp spec₀ α) :
NeverFail (simulateQ (countingOracle (spec := spec₀)) oa).run ↔ NeverFail oa := by
rw [← probFailure_eq_zero_iff, ← probFailure_eq_zero_iff,
HasEvalPMF.probFailure_eq_zero, HasEvalPMF.probFailure_eq_zero]
@[simp]
lemma probEvent_fst_run_simulateQ {ι₀ : Type} {spec₀ : OracleSpec.{0,0} ι₀} [DecidableEq ι₀]
[spec₀.Fintype] [spec₀.Inhabited] {α : Type}
(oa : OracleComp spec₀ α) (p : α → Prop) :
Pr[fun z => p z.1 | (simulateQ (countingOracle (spec := spec₀)) oa).run] = Pr[p | oa] := by
rw [show (fun z : α × QueryCount ι₀ => p z.1) = p ∘ Prod.fst from rfl,
← probEvent_map, fst_map_run_simulateQ]
@[simp]
lemma probOutput_fst_map_run_simulateQ {ι₀ : Type} {spec₀ : OracleSpec.{0,0} ι₀} [DecidableEq ι₀]
[spec₀.Fintype] [spec₀.Inhabited] {α : Type}
(oa : OracleComp spec₀ α) (x : α) :
Pr[= x | Prod.fst <$> (simulateQ (countingOracle (spec := spec₀)) oa).run] =
Pr[= x | oa] := by
rw [fst_map_run_simulateQ]
-- -- lemma run_simulateT_eq_run_simulateT_zero (oa : OracleComp spec α) (qc : ι → ℕ) :
-- -- (simulateT countingOracle oa).run qc =
-- -- map id (qc + ·) <$> (simulateT countingOracle oa).run 0 := by
-- -- revert qc
-- -- induction oa using OracleComp.inductionOn with
-- -- | pure x => simp
-- -- | query_bind i t oa h =>
-- -- intro qc
-- -- simp [h _ (update qc i (qc i + 1)), h _ (update 0 i 1)]
-- -- refine funext λ y ↦ congr_arg (· <$> _) (funext λ x ↦ ?_)
-- -- simp only [eq_iff_fst_eq_snd_eq, map_fst, id_eq, map_snd, true_and]
-- -- cases x
-- -- ext j
-- -- by_cases hj : j = i
-- -- · induction hj
-- -- simp [add_assoc]
-- -- · simp [hj]
-- -- | failure => simp [StateT.monad_failure_def]
-- -- section support
-- -- /-- We can always reduce the initial state of simulation with a counting oracle to start with a
-- -- count of zero, and add the initial count back on at the end. -/
-- -- lemma support_simulate (oa : OracleComp spec α) (qc : ι → ℕ) :
-- -- (simulate countingOracle qc oa).support =
-- -- Prod.map id (qc + ·) '' (simulate countingOracle 0 oa).support := by
-- -- revert qc
-- -- induction oa using OracleComp.inductionOn with
-- -- | pure a => simp only [simulate_pure, support_pure, Set.image_singleton, Prod.map_apply,
-- -- id_eq,
-- -- add_zero, implies_true]
-- -- | query_bind i t oa hoa =>
-- -- refine λ qc ↦ ?_
-- -- -- simp
-- -- sorry -- refine λ qc ↦ ?_
-- -- -- simp only [simulate_bind, simulate_query,countingOracle.apply_eq,
-- -- support_bind,support_map,
-- -- -- support_query, Set.image_univ, Set.mem_range, Set.iUnion_exists,
-- -- Set.iUnion_iUnion_eq',
-- -- -- Prod.map_apply, id_eq, Pi.zero_apply, zero_add, Set.image_iUnion]
-- -- -- refine Set.iUnion_congr (λ u ↦ ?_)
-- -- -- simp only [hoa u (update qc i (qc i + 1)), hoa u (update 0 i 1),
-- -- -- ← Set.image_comp, Function.comp, Prod.map_apply, id_eq, ← add_assoc]
-- -- -- refine Set.image_congr' (λ z ↦ Prod.eq_iff_fst_eq_snd_eq.2 ⟨rfl, funext (λ j ↦ ?_)⟩)
-- -- -- by_cases hij : j = i <;> simp [hij, add_assoc]
-- -- | failure => simp only [simulate_failure, support_failure, Set.image_empty, implies_true]
-- -- /-- Reduce membership in the support of simulation with counting to membership in simulation
-- -- starting with the count at `0`.
-- -- TODO: lemmas like this suggest maybe support shouldn't auto reduce on the computation type?
-- -- TODO: implicit parameters, and add extra helper lemmas -/
-- -- lemma mem_support_simulate_iff (oa : OracleComp spec α) (qc : ι → ℕ) (z : α × (ι → ℕ)) :
-- -- z ∈ (simulate countingOracle qc oa).support ↔
-- -- ∃ qc', (z.1, qc') ∈ (simulate countingOracle 0 oa).support ∧ qc + qc' = z.2 := by
-- -- rw [support_simulate]
-- -- simp only [Prod.map_apply, id_eq, Set.mem_image, Prod.eq_iff_fst_eq_snd_eq, Prod.exists]
-- -- exact ⟨λ h ↦ let ⟨x, qc', h, hx, hqc'⟩ := h; ⟨qc', hx ▸ ⟨h, hqc'⟩⟩,
-- -- λ h ↦ let ⟨qc', h, hqc'⟩ := h; ⟨z.1, qc', h, rfl, hqc'⟩⟩
-- -- lemma mem_support_simulate_iff_of_le (oa : OracleComp spec α) (qc : ι → ℕ) (z : α × (ι → ℕ))
-- -- (hz : qc ≤ z.2) : z ∈ (simulate countingOracle qc oa).support ↔
-- -- (z.1, z.2 - qc) ∈ (simulate countingOracle 0 oa).support := by
-- -- rw [mem_support_simulate_iff oa 0]
-- -- simp only [mem_support_simulate_iff oa qc z, zero_add, exists_eq_right]
-- -- refine ⟨λ ⟨qc', h, hqc'⟩ ↦ ?_, λ h ↦ ?_⟩
-- -- · convert h
-- -- refine funext (λ x ↦ ?_)
-- -- rw [Pi.sub_apply, Nat.sub_eq_iff_eq_add' (hz x)]
-- -- exact symm (congr_fun hqc' x)
-- -- · refine ⟨z.2 - qc, h, ?_⟩
-- -- refine funext (λ x ↦ ?_)
-- -- refine Nat.add_sub_cancel' (hz x)
-- -- lemma le_of_mem_support_simulate {oa : OracleComp spec α} {qc : ι → ℕ} {z : α × (ι → ℕ)}
-- -- (h : z ∈ (simulate countingOracle qc oa).support) : qc ≤ z.2 := by
-- -- rw [mem_support_simulate_iff] at h
-- -- obtain ⟨qc'', _, h⟩ := h
-- -- exact le_of_le_of_eq le_self_add h
-- -- section snd_map
-- -- lemma mem_support_snd_map_simulate_iff {α ι : Type u} [DecidableEq ι] {spec : OracleSpec ι}
-- -- (oa : OracleComp spec α) (qc qc' : ι → ℕ) :
-- -- qc' ∈ (@snd α _ <$> simulate countingOracle qc oa).support ↔
-- -- ∃ qc'', ∃ x, (x, qc'') ∈ (simulate countingOracle 0 oa).support ∧ qc + qc'' = qc' := by
-- -- simp only [support_map, Set.mem_image, Prod.exists, exists_eq_right]
-- -- refine ⟨λ h ↦ ?_, λ h ↦ ?_⟩
-- -- · obtain ⟨x, hx⟩ := h
-- -- rw [mem_support_simulate_iff] at hx
-- -- obtain ⟨qc'', h, hqc''⟩ := hx
-- -- refine ⟨qc'', x, h, hqc''⟩
-- -- · obtain ⟨qc'', x, h, hqc''⟩ := h
-- -- refine ⟨x, ?_⟩
-- -- rw [mem_support_simulate_iff]
-- -- refine ⟨qc'', h, hqc''⟩
-- -- lemma mem_support_snd_map_simulate_iff_of_le (oa : OracleComp spec α) {qc qc' : ι → ℕ}
-- -- (hqc : qc ≤ qc') : qc' ∈ (@snd α _ <$> simulate countingOracle qc oa).support ↔
-- -- qc' - qc ∈ (@snd α _ <$> simulate countingOracle 0 oa).support := by
-- -- simp only [mem_support_snd_map_simulate_iff, zero_add]
-- -- refine exists_congr (λ qc'' ↦ exists_congr (λ x ↦ ?_))
-- -- refine and_congr_right' ⟨λ h ↦ funext (λ x ↦ ?_), λ h ↦ funext (λ x ↦ ?_)⟩
-- -- · simp only [← h, Pi.sub_apply, Pi.add_apply, add_tsub_cancel_left]
-- -- · simp [h, Nat.add_sub_cancel' (hqc x)]
-- -- lemma le_of_mem_support_snd_map_simulate {oa : OracleComp spec α} {qc qc' : ι → ℕ}
-- -- (h : qc' ∈ (@snd α _ <$> simulate countingOracle qc oa).support) : qc ≤ qc' := by
-- -- simp only [support_map, Set.mem_image, Prod.exists, exists_eq_right] at h
-- -- obtain ⟨y, hy⟩ := h
-- -- exact le_of_mem_support_simulate hy
-- -- lemma sub_mem_support_snd_map_simulate {oa : OracleComp spec α} {qc qc' : ι → ℕ}
-- -- (h : qc' ∈ (@snd α _ <$> simulate countingOracle qc oa).support) :
-- -- qc' - qc ∈ (@snd α _ <$> simulate countingOracle 0 oa).support := by
-- -- rwa [mem_support_snd_map_simulate_iff_of_le] at h
-- -- convert le_of_mem_support_snd_map_simulate h
-- -- end snd_map
-- -- lemma add_mem_support_simulate {oa : OracleComp spec α} {qc : ι → ℕ} {z : α × (ι → ℕ)}
-- -- (hz : z ∈ (simulate countingOracle qc oa).support) (qc' : ι → ℕ) :
-- -- (z.1, z.2 + qc') ∈ (simulate countingOracle (qc + qc') oa).support := by
-- -- obtain ⟨qc1, hqc', h⟩ := (mem_support_simulate_iff _ _ _).1 hz
-- -- exact (mem_support_simulate_iff _ _ _).2 ⟨qc1, hqc', h ▸ by ring⟩
-- -- @[simp]
-- -- lemma add_right_mem_support_simulate_iff (oa : OracleComp spec α) (qc qc' : ι → ℕ) (x : α) :
-- -- (x, qc + qc') ∈ (simulate countingOracle qc oa).support ↔
-- -- (x, qc') ∈ (simulate countingOracle 0 oa).support := by
-- -- rw [support_simulate, Set.mem_image]
-- -- simp only [Prod.exists, Prod.map_apply, id_eq, Prod.mk.injEq, add_right_inj,
-- -- exists_eq_right_right, exists_eq_right]
-- -- @[simp]
-- -- lemma add_left_mem_support_simulate_iff (oa : OracleComp spec α) (qc qc' : ι → ℕ) (x : α) :
-- -- (x, qc' + qc) ∈ (simulate countingOracle qc oa).support ↔
-- -- (x, qc') ∈ (simulate countingOracle 0 oa).support := by
-- -- rw [add_comm qc' qc, add_right_mem_support_simulate_iff]
-- -- lemma mem_support_simulate_pure_iff (x : α) (qc : ι → ℕ) (z : α × (ι → ℕ)) :
-- -- z ∈ (simulate countingOracle qc (pure x : OracleComp spec α)).support ↔ z = (x, qc) := by
-- -- simp only [simulate_pure, support_pure, Set.mem_singleton_iff]
-- -- lemma apply_ne_zero_of_mem_support_simulate_queryBind {i : ι} {t : spec.Domain i}
-- -- {oa : spec.Range i → OracleComp spec α} {qc : ι → ℕ} {z : α × (ι → ℕ)}
-- -- (hz : z ∈ (simulate countingOracle qc ((query i t : OracleComp spec _) >>= oa)).support) :
-- -- z.2 i ≠ 0 := by
-- -- rw [mem_support_simulate_iff, simulate_query_bind] at hz
-- -- rw [support_bind] at hz
-- -- simp at hz
-- -- obtain ⟨qc', ⟨⟨u, hu⟩, hqc⟩⟩ := hz
-- -- sorry
-- -- -- have := le_of_mem_support_simulate hu i
-- -- -- simp at this
-- -- -- refine Nat.pos_iff_ne_zero.1 ?_
-- -- -- rw [← hqc, Pi.add_apply]
-- -- -- refine Nat.add_pos_right ?_ ?_
-- -- -- refine Nat.lt_of_succ_le this
-- -- lemma exists_mem_support_of_mem_support_simulate_queryBind {i : ι} {t : spec.Domain i}
-- -- {oa : spec.Range i → OracleComp spec α} {qc : ι → ℕ} {z : α × (ι → ℕ)}
-- -- (hz : z ∈ (simulate countingOracle qc ((query i t : OracleComp spec _) >>= oa)).support) :
-- -- ∃ u, (z.1, Function.update z.2 i (z.2 i - 1)) ∈
-- -- (simulate countingOracle qc (oa u)).support := by
-- -- rw [mem_support_simulate_iff, simulate_query_bind, support_bind] at hz
-- -- simp at hz
-- -- obtain ⟨qc', ⟨⟨u, hu⟩, hqc⟩⟩ := hz
-- -- refine ⟨u, ?_⟩
-- -- simp [← hqc]
-- -- have hqc' : qc' i ≠ 0 := by {
-- -- sorry
-- -- -- have := le_of_mem_support_simulate hu i
-- -- -- refine Nat.pos_iff_ne_zero.1 ?_
-- -- -- refine lt_of_lt_of_le ?_ this
-- -- -- simp only [update_same, zero_lt_one]
-- -- }
-- -- rw [mem_support_simulate_iff_of_le]
-- -- · simp
-- -- sorry
-- -- -- rw [mem_support_simulate_iff_of_le] at hu
-- -- -- · simp at hu
-- -- -- convert hu using 2
-- -- -- refine funext (λ j ↦ ?_)
-- -- -- by_cases hj : j = i
-- -- -- · simp [hj]
-- -- -- refine Nat.sub_eq_of_eq_add ?_
-- -- -- rw [add_comm _ (qc i)]
-- -- -- rw [Nat.add_sub_assoc]
-- -- -- rw [Nat.one_le_iff_ne_zero]
-- -- -- exact hqc'
-- -- -- · simp [hj]
-- -- -- · intro j
-- -- -- by_cases hj : j = i
-- -- -- · induction hj
-- -- -- simp only [update_same]
-- -- -- have := le_of_mem_support_simulate hu j
-- -- -- refine le_trans ?_ this
-- -- -- simp
-- -- -- · simp [hj]
-- -- · intro j
-- -- by_cases hj : j = i
-- -- · induction hj
-- -- simp only [update_self]
-- -- rw [Nat.le_sub_one_iff_lt]
-- -- · refine Nat.lt_add_of_pos_right ?_
-- -- rw [pos_iff_ne_zero]
-- -- assumption
-- -- · refine Nat.add_pos_right _ ?_
-- -- rwa [pos_iff_ne_zero]
-- -- · simp [hj]
-- -- lemma mem_support_simulate_queryBind_iff (i : ι) (t : spec.Domain i)
-- -- (oa : spec.Range i → OracleComp spec α) (qc : ι → ℕ) (z : α × (ι → ℕ)) :
-- -- z ∈ (simulate countingOracle qc ((query i t : OracleComp spec _) >>= oa)).support ↔
-- -- z.2 i ≠ 0 ∧ ∃ u, (z.1, Function.update z.2 i (z.2 i - 1)) ∈
-- -- (simulate countingOracle qc (oa u)).support := by
-- -- refine ⟨λ h ↦ ⟨?_, ?_⟩, λ h ↦ ?_⟩
-- -- · refine apply_ne_zero_of_mem_support_simulate_queryBind h
-- -- · refine exists_mem_support_of_mem_support_simulate_queryBind h
-- -- · obtain ⟨hz0, ⟨u, hu⟩⟩ := h
-- -- simp only [simulate_bind, simulate_query, countingOracle.apply_eq, support_bind,
-- -- support_query, Set.image_univ, Set.mem_range, Set.iUnion_exists,
-- -- Set.iUnion_iUnion_eq', Set.mem_iUnion]
-- -- sorry
-- -- -- refine ⟨u, ?_⟩
-- -- -- have := add_mem_support_simulate hu (update 0 i 1)
-- -- -- convert this
-- -- -- · refine funext (λ j ↦ symm ?_)
-- -- -- by_cases hij : j = i
-- -- -- · simp [Function.update_apply, hij]
-- -- -- · simp [hij]
-- -- -- ·
-- -- -- refine funext (λ j ↦ ?_)
-- -- -- by_cases hij : j = i
-- -- -- · induction hij
-- -- -- simpa using (Nat.succ_pred_eq_of_ne_zero hz0).symm
-- -- -- · simp [hij]
-- -- lemma exists_mem_support_of_mem_support {oa : OracleComp spec α} {x : α} (hx : x ∈ oa.support)
-- -- (qc : ι → ℕ) : ∃ qc', (x, qc') ∈ (simulate countingOracle qc oa).support := by
-- -- rw [← SimOracle.IsTracking.run'_simulateT_eq_self countingOracle oa] at hx
-- -- sorry; sorry
-- -- -- simp at hx
-- -- -- exact hx
-- -- end support
end countingOracle