Skip to content

Commit 45dd5c4

Browse files
committed
Solve some numerical instabilities and properly implement EFE msg
Also add some T-maze experiment as validation
1 parent eacbb62 commit 45dd5c4

File tree

22 files changed

+2951
-513
lines changed

22 files changed

+2951
-513
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,6 @@ __pycache__/
424424
*/*/*/__pycache__/
425425
*.py[cod]
426426
LocalPreferences.toml
427+
428+
429+
.cursor

scripts/debug_minigrid.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,13 @@ function main()
172172
session_id = env_response["session_id"]
173173

174174
# Create results directory with grid size, seed, iterations and sparse-tensor info
175+
timestamp = Dates.format(now(), "yyyy-mm-dd_HH-MM-SS")
175176
results_dir = mkpath(datadir("debug",
176-
"grid$(config.grid_size)_seed$(config.seed)_iter$(config.n_iterations)_sparsetensor$(args["sparse-tensor"])"
177+
timestamp * "_" *
178+
"gridsize_$(config.grid_size)_" *
179+
"seed_$(config.seed)_" *
180+
"iterations_$(config.n_iterations)_" *
181+
"sparse_tensor_$(args["sparse-tensor"])"
177182
))
178183
@info "Initialized environment"
179184
# Initialize beliefs and tensors
@@ -198,28 +203,27 @@ function main()
198203
end
199204
@info "Starting inference..."
200205
# Execute a single step with debug options
201-
action, new_env_state, inference_result = execute_step(
206+
next_action, new_env_state, inference_result = execute_step(
202207
env_state,
203208
action,
204209
beliefs,
205-
klcontrol_minigrid_agent,
210+
efe_minigrid_agent,
206211
tensors,
207212
config,
208213
goal,
209214
nothing, # no callbacks
210215
config.time_horizon,
211216
nothing, # no previous result
212217
session_id;
213-
constraints_fn=klcontrol_minigrid_agent_constraints,
214-
initialization_fn=klcontrol_minigrid_agent_initialization,
218+
constraints_fn=efe_minigrid_agent_constraints,
219+
initialization_fn=efe_minigrid_agent_initialization,
215220
free_energy=true, # Enable free energy tracking
216221
showprogress=true, # Show inference progress,
217222
options=(force_marginal_computation=true,
218223
limit_stack_depth=500), # Force marginal computation
219224
# Add any other inference kwargs as needed
220225
)
221226
@info "Inference completed"
222-
next_action = mode(first(inference_result.posteriors[:u]))
223227
env_action = EFEasVFE.convert_action(next_action)
224228

225229
# Plot and save inference results
@@ -233,12 +237,12 @@ function main()
233237
# Create and save animation if requested
234238
if args["save-animation"]
235239
@info "Creating belief evolution animation..."
236-
# animate_belief_evolution(
237-
# inference_result,
238-
# config.grid_size,
239-
# fps=2,
240-
# save_path=joinpath(results_dir, "belief_evolution.gif")
241-
# )
240+
animate_belief_evolution(
241+
inference_result,
242+
config.grid_size,
243+
fps=2,
244+
save_path=joinpath(results_dir, "belief_evolution.gif")
245+
)
242246
animate_trajectory_belief(
243247
inference_result,
244248
config.grid_size,

scripts/tmaze_experiments.jl

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
using DrWatson
2+
@quickactivate "EFEasVFE"
3+
4+
using RxInfer
5+
using ReactiveMP
6+
using ProgressMeter
7+
using Statistics
8+
using Distributions
9+
using StableRNGs
10+
using Dates
11+
using JSON
12+
using FileIO
13+
import RxInfer: Categorical
14+
using EFEasVFE
15+
16+
@rule DiscreteTransition(:in, Marginalisation) (q_out::PointMass{<:AbstractVector}, q_a::PointMass, meta::Any) = begin
17+
eloga = mean(q_a)
18+
out_idx = findfirst(isone, probvec(q_out))
19+
result = eloga[out_idx, :]
20+
return Categorical(normalize!(guard(result), 1); check_args=false)
21+
end
22+
23+
function run_tmaze_experiment(;
24+
T::Int=10,
25+
n_episodes::Int=10,
26+
n_iterations::Int=10,
27+
visualize::Bool=false,
28+
wait_time::Float64=0.0,
29+
record_episode::Bool=false,
30+
seed::Int=123,
31+
number_type::Type{<:AbstractFloat}=Float64,
32+
experiment_name::String="tmaze_$(Dates.format(now(), "yyyymmdd_HHMMSS"))",
33+
parallel::Bool=false,
34+
log_dir::String=datadir("logs", "tmaze"),
35+
save_results::Bool=true,
36+
debug_mode::Bool=false
37+
)
38+
# Create directory for logs
39+
if save_results
40+
mkpath(log_dir)
41+
log_file = joinpath(log_dir, "$(experiment_name).log")
42+
results_file = joinpath(log_dir, "$(experiment_name)_results.json")
43+
episodic_results_file = joinpath(log_dir, "$(experiment_name)_episodes.json")
44+
end
45+
46+
# Create goal distribution - prefer the left arm location (state 3)
47+
left_goal = zeros(number_type, 5)
48+
left_goal[3] = 1.0
49+
left_goal_distribution = Categorical(left_goal)
50+
51+
# Create config
52+
config = TMazeConfig(
53+
time_horizon=T,
54+
n_episodes=n_episodes,
55+
n_iterations=n_iterations,
56+
wait_time=wait_time,
57+
number_type=number_type,
58+
visualize=visualize,
59+
seed=seed,
60+
record_episode=record_episode,
61+
experiment_name=experiment_name,
62+
parallel=parallel
63+
)
64+
65+
# Create tensors
66+
tensors = (
67+
reward_observation=create_reward_observation_tensor(),
68+
location_transition=create_location_transition_tensor(),
69+
reward_to_location=create_reward_to_location_mapping()
70+
)
71+
72+
# Create benchmark callbacks
73+
callbacks = RxInferBenchmarkCallbacks()
74+
75+
# Initialize result metrics
76+
experiment_metrics = Dict{String,Any}(
77+
"experiment_name" => experiment_name,
78+
"date" => string(now()),
79+
"time_horizon" => T,
80+
"n_episodes" => n_episodes,
81+
"n_iterations" => n_iterations,
82+
"seed" => seed,
83+
"models" => Dict{String,Any}()
84+
)
85+
86+
# Function to log information to both console and log file
87+
function log_info(message)
88+
println(message)
89+
if save_results
90+
open(log_file, "a") do io
91+
println(io, "$(now()) - $message")
92+
end
93+
end
94+
end
95+
96+
# Run experiments with KL control model
97+
log_info("Running experiments with KL Control agent...")
98+
99+
# Initialize storage for episodic data
100+
kl_episodic_data = []
101+
kl_rewards = zeros(config.n_episodes)
102+
103+
efe_episodic_data = []
104+
efe_rewards = zeros(config.n_episodes)
105+
106+
# Run episodes
107+
if parallel
108+
thread_count = Threads.nthreads()
109+
log_info("Running with parallelization using $thread_count threads")
110+
episode_seeds = rand(StableRNG(seed), UInt32, config.n_episodes)
111+
112+
# Use Threads.@threads for parallelization
113+
progress = Progress(config.n_episodes; desc="Running KL Control episodes: ")
114+
115+
episodic_data_lock = ReentrantLock()
116+
117+
Threads.@threads for i in 1:config.n_episodes
118+
episode_seed = episode_seeds[i]
119+
local_config = config
120+
121+
if local_config.visualize && parallel && i != config.n_episodes
122+
# Turn off visualization for all but last episode in parallel mode
123+
local_config = TMazeConfig(
124+
time_horizon=config.time_horizon,
125+
n_episodes=config.n_episodes,
126+
n_iterations=config.n_iterations,
127+
wait_time=config.wait_time,
128+
number_type=config.number_type,
129+
visualize=false,
130+
seed=config.seed,
131+
record_episode=i == config.n_episodes && config.record_episode,
132+
experiment_name=config.experiment_name,
133+
parallel=parallel
134+
)
135+
end
136+
137+
reward, episode_data = run_tmaze_single_episode(
138+
klcontrol_tmaze_agent,
139+
tensors,
140+
local_config,
141+
left_goal_distribution,
142+
callbacks,
143+
episode_seed;
144+
constraints_fn=klcontrol_tmaze_agent_constraints,
145+
initialization_fn=klcontrol_tmaze_agent_initialization,
146+
record=i == config.n_episodes && config.record_episode,
147+
debug_mode=debug_mode
148+
)
149+
150+
kl_rewards[i] = reward
151+
152+
# Thread-safe update of episodic data
153+
lock(episodic_data_lock) do
154+
push!(kl_episodic_data, episode_data)
155+
end
156+
157+
# Update progress
158+
ProgressMeter.next!(progress)
159+
end
160+
else
161+
# Sequential execution
162+
log_info("Running sequentially")
163+
episode_seeds = rand(StableRNG(seed), UInt32, config.n_episodes)
164+
165+
@showprogress desc = "Running KL Control episodes: " for i in 1:config.n_episodes
166+
episode_seed = episode_seeds[i]
167+
168+
reward, episode_data = run_tmaze_single_episode(
169+
klcontrol_tmaze_agent,
170+
tensors,
171+
config,
172+
left_goal_distribution,
173+
callbacks,
174+
episode_seed;
175+
constraints_fn=klcontrol_tmaze_agent_constraints,
176+
initialization_fn=klcontrol_tmaze_agent_initialization,
177+
record=i == config.n_episodes && config.record_episode,
178+
debug_mode=debug_mode
179+
)
180+
181+
kl_rewards[i] = reward
182+
push!(kl_episodic_data, episode_data)
183+
end
184+
log_info("KL Control episodes completed")
185+
@showprogress desc = "Running EFE episodes: " for i in 1:config.n_episodes
186+
episode_seed = episode_seeds[i]
187+
188+
reward, episode_data = run_tmaze_single_episode(
189+
efe_tmaze_agent,
190+
tensors,
191+
config,
192+
left_goal_distribution,
193+
callbacks,
194+
episode_seed;
195+
constraints_fn=efe_tmaze_agent_constraints,
196+
initialization_fn=efe_tmaze_agent_initialization,
197+
record=i == config.n_episodes && config.record_episode,
198+
debug_mode=debug_mode,
199+
options=(force_marginal_computation=true,
200+
limit_stack_depth=500), # Force marginal computation
201+
)
202+
203+
efe_rewards[i] = reward
204+
push!(efe_episodic_data, episode_data)
205+
end
206+
end
207+
208+
# Calculate statistics
209+
kl_mean = mean(kl_rewards)
210+
kl_std = std(kl_rewards)
211+
efe_mean = mean(efe_rewards)
212+
efe_std = std(efe_rewards)
213+
214+
# Record results
215+
log_info("KL Control results: mean reward = $kl_mean, std = $kl_std")
216+
experiment_metrics["models"]["klcontrol"] = Dict(
217+
"mean_reward" => kl_mean,
218+
"std_reward" => kl_std,
219+
"rewards" => kl_rewards,
220+
)
221+
log_info("EFE results: mean reward = $efe_mean, std = $efe_std")
222+
experiment_metrics["models"]["efe"] = Dict(
223+
"mean_reward" => efe_mean,
224+
"std_reward" => efe_std,
225+
"rewards" => efe_rewards,
226+
)
227+
228+
# Save episodic data
229+
if save_results
230+
open(episodic_results_file, "w") do io
231+
JSON.print(io, Dict("klcontrol" => kl_episodic_data), 2)
232+
JSON.print(io, Dict("efe" => efe_episodic_data), 2)
233+
end
234+
235+
open(results_file, "w") do io
236+
JSON.print(io, experiment_metrics, 2)
237+
end
238+
239+
log_info("Results saved to $results_file")
240+
log_info("Episode data saved to $episodic_results_file")
241+
end
242+
243+
# Run a single visual episode with KL control
244+
if !visualize && !record_episode
245+
log_info("Running visualized episode...")
246+
visual_config = TMazeConfig(
247+
time_horizon=T,
248+
n_episodes=1,
249+
n_iterations=n_iterations,
250+
wait_time=0.1, # Slower for visualization
251+
number_type=number_type,
252+
visualize=true,
253+
seed=seed + 1, # Different seed
254+
record_episode=true,
255+
experiment_name=experiment_name,
256+
parallel=false
257+
)
258+
259+
vis_reward, vis_data = run_tmaze_single_episode(
260+
klcontrol_tmaze_agent,
261+
tensors,
262+
visual_config,
263+
left_goal_distribution,
264+
nothing,
265+
seed + 100;
266+
constraints_fn=klcontrol_tmaze_agent_constraints,
267+
initialization_fn=klcontrol_tmaze_agent_initialization,
268+
record=true,
269+
debug_mode=debug_mode
270+
)
271+
272+
log_info("Visualization complete. Reward: $vis_reward")
273+
274+
if save_results
275+
# Save visualization data
276+
visualization_file = joinpath(log_dir, "$(experiment_name)_visualization.json")
277+
open(visualization_file, "w") do io
278+
JSON.print(io, vis_data, 2)
279+
end
280+
log_info("Visualization data saved to $visualization_file")
281+
end
282+
end
283+
284+
return experiment_metrics
285+
end
286+
287+
# Run the experiment with default parameters
288+
if abspath(PROGRAM_FILE) == @__FILE__
289+
run_tmaze_experiment(
290+
T=5,
291+
n_episodes=50,
292+
n_iterations=10,
293+
visualize=false,
294+
wait_time=0.0,
295+
record_episode=true,
296+
seed=123,
297+
number_type=Float64,
298+
experiment_name="tmaze_$(Dates.format(now(), "yyyymmdd_HHMMSS"))",
299+
parallel=false,
300+
debug_mode=true
301+
)
302+
end

0 commit comments

Comments
 (0)