@@ -22,18 +22,16 @@ Configuration for TMaze agent experiments.
2222- `n_episodes::Int`: Number of episodes to run
2323- `n_iterations::Int`: Number of inference iterations per step
2424- `wait_time::Float64`: Time to wait between steps (for visualization)
25- - `number_type::Type{T}`: Numeric type for computations
2625- `seed::Int`: Random seed
2726- `record_episode::Bool`: Whether to record episode frames as individual PNG files
2827- `experiment_name::String`: Name of the experiment (for saving results)
2928- `parallel::Bool`: Whether to run episodes in parallel
3029"""
31- Base. @kwdef struct TMazeConfig{T <: AbstractFloat }
30+ Base. @kwdef struct TMazeConfig
3231 time_horizon:: Int
3332 n_episodes:: Int
3433 n_iterations:: Int
3534 wait_time:: Float64
36- number_type:: Type{T}
3735 seed:: Int
3836 record_episode:: Bool = false
3937 experiment_name:: String
4644Container for agent's beliefs about the TMaze environment.
4745
4846# Fields
49- - `location::Categorical{T }`: Belief about current location (5 possible states)
50- - `reward_location::Categorical{T }`: Belief about reward location (left or right)
47+ - `location::Categorical{Float64 }`: Belief about current location (5 possible states)
48+ - `reward_location::Categorical{Float64 }`: Belief about reward location (left or right)
5149"""
52- Base. @kwdef mutable struct TMazeBeliefs{T <: AbstractFloat }
53- location:: Categorical{T }
54- reward_location:: Categorical{T }
50+ Base. @kwdef mutable struct TMazeBeliefs
51+ location:: Categorical{Float64 }
52+ reward_location:: Categorical{Float64 }
5553end
5654
5755"""
@@ -67,15 +65,15 @@ function validate_config(config::TMazeConfig)
6765end
6866
6967"""
70- initialize_beliefs_tmaze(T::Type{<:AbstractFloat} )
68+ initialize_beliefs_tmaze()
7169
7270Initialize agent beliefs for the TMaze environment.
7371"""
74- function initialize_beliefs_tmaze (T :: Type{<:AbstractFloat} )
72+ function initialize_beliefs_tmaze ()
7573 # Initialize with uniform beliefs over states
7674 return TMazeBeliefs (
77- location= Categorical (fill (T ( 1 / 5 ) , 5 )),
78- reward_location= Categorical ([T ( 0.5 ), T ( 0.5 ) ])
75+ location= Categorical (fill (1.0 / 5 , 5 )),
76+ reward_location= Categorical ([0.5 , 0.5 ])
7977 )
8078end
8179
@@ -127,15 +125,15 @@ Takes current observations and returns the next planned action.
127125function execute_step (env, position_obs, reward_cue, beliefs, model, tensors, config, goal, callbacks, time_remaining, previous_result, previous_action;
128126 constraints_fn, initialization_fn, inference_kwargs... )
129127 # Convert previous action to one-hot encoding
130- previous_action_vec = zeros (config . number_type , 4 )
128+ previous_action_vec = zeros (Float64 , 4 )
131129 if previous_action. direction isa North
132- previous_action_vec[1 ] = one (config . number_type )
130+ previous_action_vec[1 ] = one (Float64 )
133131 elseif previous_action. direction isa East
134- previous_action_vec[2 ] = one (config . number_type )
132+ previous_action_vec[2 ] = one (Float64 )
135133 elseif previous_action. direction isa South
136- previous_action_vec[3 ] = one (config . number_type )
134+ previous_action_vec[3 ] = one (Float64 )
137135 elseif previous_action. direction isa West
138- previous_action_vec[4 ] = one (config . number_type )
136+ previous_action_vec[4 ] = one (Float64 )
139137 end
140138
141139 # Get initialization from previous results or initialize fresh
@@ -210,7 +208,7 @@ function run_tmaze_single_episode(model, tensors, config, goal, callbacks, seed;
210208 env = create_tmaze (reward_position, (2 , 2 )) # Start at middle junction (2,2)
211209
212210 # Initialize beliefs
213- beliefs = initialize_beliefs_tmaze (config . number_type )
211+ beliefs = initialize_beliefs_tmaze ()
214212
215213 # Initialize tracking variables
216214 total_reward = 0.0
@@ -247,8 +245,8 @@ function run_tmaze_single_episode(model, tensors, config, goal, callbacks, seed;
247245 )
248246
249247 # Initial position observation and reward cue
250- position_obs = convert .(config . number_type , get_position_observation (env))
251- reward_cue = convert .(config . number_type , get_reward_cue (env))
248+ position_obs = convert .(Float64 , get_position_observation (env))
249+ reward_cue = convert .(Float64 , get_reward_cue (env))
252250
253251 # Record initial state
254252 push! (episode_data[" positions" ], [env. agent_position... ])
@@ -283,8 +281,8 @@ function run_tmaze_single_episode(model, tensors, config, goal, callbacks, seed;
283281 position_obs, reward_cue, reward = step! (env, next_action)
284282
285283 # Convert to the required numeric type
286- position_obs = convert .(config . number_type , position_obs)
287- reward_cue = convert .(config . number_type , reward_cue)
284+ position_obs = convert .(Float64 , position_obs)
285+ reward_cue = convert .(Float64 , reward_cue)
288286
289287 # Update total reward
290288 episode_reward = reward isa Number ? reward : 0
@@ -408,7 +406,6 @@ function run_tmaze_agent(
408406 n_episodes= config. n_episodes,
409407 n_iterations= config. n_iterations,
410408 wait_time= config. wait_time,
411- number_type= config. number_type,
412409 seed= config. seed,
413410 record_episode= config. record_episode,
414411 experiment_name= config. experiment_name,
0 commit comments