1+ using DrWatson
2+ @quickactivate " EFEasVFE"
3+
4+ using RxInfer
5+ using ReactiveMP
6+ using ProgressMeter
7+ using Statistics
8+ using Distributions
9+ using StableRNGs
10+ using Dates
11+ using JSON
12+ using FileIO
13+ import RxInfer: Categorical
14+ using EFEasVFE
15+
16+ @rule DiscreteTransition (:in , Marginalisation) (q_out:: PointMass{<:AbstractVector} , q_a:: PointMass , meta:: Any ) = begin
17+ eloga = mean (q_a)
18+ out_idx = findfirst (isone, probvec (q_out))
19+ result = eloga[out_idx, :]
20+ return Categorical (normalize! (guard (result), 1 ); check_args= false )
21+ end
22+
23+ function run_tmaze_experiment (;
24+ T:: Int = 10 ,
25+ n_episodes:: Int = 10 ,
26+ n_iterations:: Int = 10 ,
27+ visualize:: Bool = false ,
28+ wait_time:: Float64 = 0.0 ,
29+ record_episode:: Bool = false ,
30+ seed:: Int = 123 ,
31+ number_type:: Type{<:AbstractFloat} = Float64,
32+ experiment_name:: String = " tmaze_$(Dates. format (now (), " yyyymmdd_HHMMSS" )) " ,
33+ parallel:: Bool = false ,
34+ log_dir:: String = datadir (" logs" , " tmaze" ),
35+ save_results:: Bool = true ,
36+ debug_mode:: Bool = false
37+ )
38+ # Create directory for logs
39+ if save_results
40+ mkpath (log_dir)
41+ log_file = joinpath (log_dir, " $(experiment_name) .log" )
42+ results_file = joinpath (log_dir, " $(experiment_name) _results.json" )
43+ episodic_results_file = joinpath (log_dir, " $(experiment_name) _episodes.json" )
44+ end
45+
46+ # Create goal distribution - prefer the left arm location (state 3)
47+ left_goal = zeros (number_type, 5 )
48+ left_goal[3 ] = 1.0
49+ left_goal_distribution = Categorical (left_goal)
50+
51+ # Create config
52+ config = TMazeConfig (
53+ time_horizon= T,
54+ n_episodes= n_episodes,
55+ n_iterations= n_iterations,
56+ wait_time= wait_time,
57+ number_type= number_type,
58+ visualize= visualize,
59+ seed= seed,
60+ record_episode= record_episode,
61+ experiment_name= experiment_name,
62+ parallel= parallel
63+ )
64+
65+ # Create tensors
66+ tensors = (
67+ reward_observation= create_reward_observation_tensor (),
68+ location_transition= create_location_transition_tensor (),
69+ reward_to_location= create_reward_to_location_mapping ()
70+ )
71+
72+ # Create benchmark callbacks
73+ callbacks = RxInferBenchmarkCallbacks ()
74+
75+ # Initialize result metrics
76+ experiment_metrics = Dict {String,Any} (
77+ " experiment_name" => experiment_name,
78+ " date" => string (now ()),
79+ " time_horizon" => T,
80+ " n_episodes" => n_episodes,
81+ " n_iterations" => n_iterations,
82+ " seed" => seed,
83+ " models" => Dict {String,Any} ()
84+ )
85+
86+ # Function to log information to both console and log file
87+ function log_info (message)
88+ println (message)
89+ if save_results
90+ open (log_file, " a" ) do io
91+ println (io, " $(now ()) - $message " )
92+ end
93+ end
94+ end
95+
96+ # Run experiments with KL control model
97+ log_info (" Running experiments with KL Control agent..." )
98+
99+ # Initialize storage for episodic data
100+ kl_episodic_data = []
101+ kl_rewards = zeros (config. n_episodes)
102+
103+ efe_episodic_data = []
104+ efe_rewards = zeros (config. n_episodes)
105+
106+ # Run episodes
107+ if parallel
108+ thread_count = Threads. nthreads ()
109+ log_info (" Running with parallelization using $thread_count threads" )
110+ episode_seeds = rand (StableRNG (seed), UInt32, config. n_episodes)
111+
112+ # Use Threads.@threads for parallelization
113+ progress = Progress (config. n_episodes; desc= " Running KL Control episodes: " )
114+
115+ episodic_data_lock = ReentrantLock ()
116+
117+ Threads. @threads for i in 1 : config. n_episodes
118+ episode_seed = episode_seeds[i]
119+ local_config = config
120+
121+ if local_config. visualize && parallel && i != config. n_episodes
122+ # Turn off visualization for all but last episode in parallel mode
123+ local_config = TMazeConfig (
124+ time_horizon= config. time_horizon,
125+ n_episodes= config. n_episodes,
126+ n_iterations= config. n_iterations,
127+ wait_time= config. wait_time,
128+ number_type= config. number_type,
129+ visualize= false ,
130+ seed= config. seed,
131+ record_episode= i == config. n_episodes && config. record_episode,
132+ experiment_name= config. experiment_name,
133+ parallel= parallel
134+ )
135+ end
136+
137+ reward, episode_data = run_tmaze_single_episode (
138+ klcontrol_tmaze_agent,
139+ tensors,
140+ local_config,
141+ left_goal_distribution,
142+ callbacks,
143+ episode_seed;
144+ constraints_fn= klcontrol_tmaze_agent_constraints,
145+ initialization_fn= klcontrol_tmaze_agent_initialization,
146+ record= i == config. n_episodes && config. record_episode,
147+ debug_mode= debug_mode
148+ )
149+
150+ kl_rewards[i] = reward
151+
152+ # Thread-safe update of episodic data
153+ lock (episodic_data_lock) do
154+ push! (kl_episodic_data, episode_data)
155+ end
156+
157+ # Update progress
158+ ProgressMeter. next! (progress)
159+ end
160+ else
161+ # Sequential execution
162+ log_info (" Running sequentially" )
163+ episode_seeds = rand (StableRNG (seed), UInt32, config. n_episodes)
164+
165+ @showprogress desc = " Running KL Control episodes: " for i in 1 : config. n_episodes
166+ episode_seed = episode_seeds[i]
167+
168+ reward, episode_data = run_tmaze_single_episode (
169+ klcontrol_tmaze_agent,
170+ tensors,
171+ config,
172+ left_goal_distribution,
173+ callbacks,
174+ episode_seed;
175+ constraints_fn= klcontrol_tmaze_agent_constraints,
176+ initialization_fn= klcontrol_tmaze_agent_initialization,
177+ record= i == config. n_episodes && config. record_episode,
178+ debug_mode= debug_mode
179+ )
180+
181+ kl_rewards[i] = reward
182+ push! (kl_episodic_data, episode_data)
183+ end
184+ log_info (" KL Control episodes completed" )
185+ @showprogress desc = " Running EFE episodes: " for i in 1 : config. n_episodes
186+ episode_seed = episode_seeds[i]
187+
188+ reward, episode_data = run_tmaze_single_episode (
189+ efe_tmaze_agent,
190+ tensors,
191+ config,
192+ left_goal_distribution,
193+ callbacks,
194+ episode_seed;
195+ constraints_fn= efe_tmaze_agent_constraints,
196+ initialization_fn= efe_tmaze_agent_initialization,
197+ record= i == config. n_episodes && config. record_episode,
198+ debug_mode= debug_mode,
199+ options= (force_marginal_computation= true ,
200+ limit_stack_depth= 500 ), # Force marginal computation
201+ )
202+
203+ efe_rewards[i] = reward
204+ push! (efe_episodic_data, episode_data)
205+ end
206+ end
207+
208+ # Calculate statistics
209+ kl_mean = mean (kl_rewards)
210+ kl_std = std (kl_rewards)
211+ efe_mean = mean (efe_rewards)
212+ efe_std = std (efe_rewards)
213+
214+ # Record results
215+ log_info (" KL Control results: mean reward = $kl_mean , std = $kl_std " )
216+ experiment_metrics[" models" ][" klcontrol" ] = Dict (
217+ " mean_reward" => kl_mean,
218+ " std_reward" => kl_std,
219+ " rewards" => kl_rewards,
220+ )
221+ log_info (" EFE results: mean reward = $efe_mean , std = $efe_std " )
222+ experiment_metrics[" models" ][" efe" ] = Dict (
223+ " mean_reward" => efe_mean,
224+ " std_reward" => efe_std,
225+ " rewards" => efe_rewards,
226+ )
227+
228+ # Save episodic data
229+ if save_results
230+ open (episodic_results_file, " w" ) do io
231+ JSON. print (io, Dict (" klcontrol" => kl_episodic_data), 2 )
232+ JSON. print (io, Dict (" efe" => efe_episodic_data), 2 )
233+ end
234+
235+ open (results_file, " w" ) do io
236+ JSON. print (io, experiment_metrics, 2 )
237+ end
238+
239+ log_info (" Results saved to $results_file " )
240+ log_info (" Episode data saved to $episodic_results_file " )
241+ end
242+
243+ # Run a single visual episode with KL control
244+ if ! visualize && ! record_episode
245+ log_info (" Running visualized episode..." )
246+ visual_config = TMazeConfig (
247+ time_horizon= T,
248+ n_episodes= 1 ,
249+ n_iterations= n_iterations,
250+ wait_time= 0.1 , # Slower for visualization
251+ number_type= number_type,
252+ visualize= true ,
253+ seed= seed + 1 , # Different seed
254+ record_episode= true ,
255+ experiment_name= experiment_name,
256+ parallel= false
257+ )
258+
259+ vis_reward, vis_data = run_tmaze_single_episode (
260+ klcontrol_tmaze_agent,
261+ tensors,
262+ visual_config,
263+ left_goal_distribution,
264+ nothing ,
265+ seed + 100 ;
266+ constraints_fn= klcontrol_tmaze_agent_constraints,
267+ initialization_fn= klcontrol_tmaze_agent_initialization,
268+ record= true ,
269+ debug_mode= debug_mode
270+ )
271+
272+ log_info (" Visualization complete. Reward: $vis_reward " )
273+
274+ if save_results
275+ # Save visualization data
276+ visualization_file = joinpath (log_dir, " $(experiment_name) _visualization.json" )
277+ open (visualization_file, " w" ) do io
278+ JSON. print (io, vis_data, 2 )
279+ end
280+ log_info (" Visualization data saved to $visualization_file " )
281+ end
282+ end
283+
284+ return experiment_metrics
285+ end
286+
287+ # Run the experiment with default parameters
288+ if abspath (PROGRAM_FILE ) == @__FILE__
289+ run_tmaze_experiment (
290+ T= 5 ,
291+ n_episodes= 50 ,
292+ n_iterations= 10 ,
293+ visualize= false ,
294+ wait_time= 0.0 ,
295+ record_episode= true ,
296+ seed= 123 ,
297+ number_type= Float64,
298+ experiment_name= " tmaze_$(Dates. format (now (), " yyyymmdd_HHMMSS" )) " ,
299+ parallel= false ,
300+ debug_mode= true
301+ )
302+ end
0 commit comments