Skip to content

Commit 0ad7b47

Browse files
committed
Plan with particle filters
1 parent 008cd79 commit 0ad7b47

File tree

4 files changed

+72
-8
lines changed

4 files changed

+72
-8
lines changed

examples/gym-cartpole/smart_agent.zls

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ let proba model obs_gym = action where
2727
rec obs = simple_pendulum (obs_gym, (Right fby action))
2828
and action = controller (obs)
2929
and () = Infer_pf.factor (-10. *. (abs_float (obs.pole_angle)))
30-
and display = draw_obs_back obs
3130

3231
let node smart_main () = () where
33-
rec reset action = Infer_pf.plan 10 10 model obs every true
32+
rec reset action = Infer_pf.plan_pf 30 10 10 model obs every true
3433
and obs, _, stop = cart_pole_gym true (Right fby action)
35-
and display = draw_obs_front obs

examples/gym-cartpole/smart_pid_agent.zls

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ let proba smart_model (obs_gym) = action where
5151
rec obs = simple_pendulum (obs_gym, (Right fby action))
5252
and action = smart_controller (obs)
5353
and () = Infer_pf.factor (-10. *. (abs_float (obs.pole_angle)))
54-
and display = draw_obs_back obs
5554

5655

5756
(** PID controller for the cart-pole example **)
@@ -72,11 +71,10 @@ let proba pid_model (obs, ctrl_action) = p, (i, d) where
7271

7372

7473
let node smart_pid_main () = () where
75-
rec reset action_smart = Infer_pf.plan 10 10 smart_model obs_smart
74+
rec reset action_smart = Infer_pf.plan_pf 30 10 10 smart_model obs_smart
7675
every true
7776
and obs_smart, _, _ = cart_pole_gym true (Right fby action_smart)
7877
and pid_dist = Infer_pf.infer 1000 pid_model (obs_smart, action_smart)
79-
and () = draw_obs_front obs_smart
8078
and (p, (i, d)) = Distribution.draw pid_dist
8179
and obs, _, stop = cart_pole_gym true (Right fby action)
8280
and reset action = pid_controller (obs.pole_angle, (p, i, d))

probzelus/inference/infer_pf.ml

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
open Ztypes
2828
open Owl
29+
open Printf
2930

3031
type pstate = {
3132
idx : int; (** particle index *)
@@ -203,6 +204,7 @@ let expectation scores =
203204
let s = Array.fold_left ( +. ) 0. scores in
204205
s /. float (Array.length scores)
205206

207+
(** [plan_step n k model_step model_copy] return a function [step] that duplicates the current particle [n] times and advances it forward [k] times.*)
206208
let plan_step n k model_step model_copy =
207209
let table = Hashtbl.create 7 in
208210
let rec expected_utility (state, score) (ttl, input) =
@@ -244,8 +246,49 @@ let plan_step n k model_step model_copy =
244246
in
245247
step
246248

247-
(* [plan n k f x] runs n instances of [f] on the input stream *)
248-
(* [x] but at each step, do a prediction of depth k *)
249+
(** [plan_step_pf n k model_step model_copy] returns a function [step] that duplicates the current particle [n] times, advances it forward, copies it [n] times, applies a particle filter of size [h], and repeats this process [k] times. *)
250+
let plan_step_pf n h k model_step model_copy =
251+
let table = Hashtbl.create 7 in
252+
let rec expected_utility state (ttl, input) =
253+
let states = Array.init h (fun _ -> Probzelus_utils.copy state) in
254+
let scores = Array.make h 0.0 in
255+
Array.iteri
256+
(fun i state -> ignore @@ model_step state ({ idx = i; scores }, input))
257+
states;
258+
let norm = Normalize.log_sum_exp scores in
259+
let probabilities = Array.map (fun score -> exp (score -. norm)) scores in
260+
let dist = Normalize.to_distribution (Array.init h id) probabilities in
261+
let index = Distribution.draw dist in
262+
let state, score = (states.(index), scores.(index)) in
263+
if ttl < 1 then score else norm +. expected_utility state (ttl - 1, input)
264+
in
265+
let state_value_copy (src_st, src_val) (dst_st, dst_val) =
266+
model_copy src_st dst_st;
267+
dst_val := !src_val
268+
in
269+
let step { infer_states = states; infer_scores = scores } input =
270+
let values =
271+
Array.mapi
272+
(fun i state ->
273+
let value = model_step state ({ idx = i; scores }, input) in
274+
scores.(i) <- expected_utility state (k, input);
275+
value)
276+
states
277+
in
278+
let states_values =
279+
Array.mapi (fun i state -> (state, ref values.(i))) states
280+
in
281+
let norm = Normalize.log_sum_exp scores in
282+
let probabilities = Array.map (fun score -> exp (score -. norm)) scores in
283+
Normalize.resample state_value_copy n probabilities states_values;
284+
Array.fill scores 0 n 0.0;
285+
Hashtbl.clear table;
286+
states_values
287+
in
288+
step
289+
290+
(** [plan n k f x] runs n instances of [f] on the input stream
291+
[x] but at each step, do a prediction of depth [k] *)
249292
let plan n k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) =
250293
let alloc () = ref (model.alloc ()) in
251294
let reset state = model.reset !state in
@@ -264,6 +307,26 @@ let plan n k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) =
264307
in
265308
Cnode { alloc; reset; copy; step }
266309

310+
(** [plan n k f x] runs n instances of [f] on the input stream
311+
[x] but at each step, do a prediction of depth [k] and use a particle filter of size [h] *)
312+
let plan_pf n h k (Cnode model : (pstate * 't1, 't2) Ztypes.cnode) =
313+
let alloc () = ref (model.alloc ()) in
314+
let reset state = model.reset !state in
315+
let copy src dst = model.copy !src !dst in
316+
let step_body = plan_step_pf n h k model.step model.copy in
317+
let step plan_state input =
318+
let states = Array.init n (fun _ -> Probzelus_utils.copy !plan_state) in
319+
let scores = Array.make n 0.0 in
320+
let states_values =
321+
step_body { infer_states = states; infer_scores = scores } input
322+
in
323+
let dist = Normalize.normalize states_values in
324+
let state', value = Distribution.draw dist in
325+
plan_state := state';
326+
!value
327+
in
328+
Cnode { alloc; reset; copy; step }
329+
267330
type 'state infd_state = {
268331
infd_states : 'state array;
269332
infd_scores : float array;

probzelus/inference/infer_pf.zli

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ val plan :
5959
('t1 ~D~> 't2) -S->
6060
't1 -D-> 't2
6161

62+
val plan_pf :
63+
int -S-> int -S-> int -S->
64+
('t1 ~D~> 't2) -S->
65+
't1 -D-> 't2
66+
6267
val infer_depth :
6368
int -S-> int -S->
6469
('t1 ~D~> 't2) -S->

0 commit comments

Comments
 (0)