Skip to content

Commit 8c79c74

Browse files
committed
Remove number type from TMaze and StochasticMaze
1 parent 8e2c8b3 commit 8c79c74

File tree

6 files changed

+51
-89
lines changed

6 files changed

+51
-89
lines changed

scripts/stochastic_maze.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ function parse_command_line()
4242
help = "Time to wait between steps in seconds (default: 0.0)"
4343
arg_type = Float64
4444
default = 0.0
45-
"--number-type", "-n"
46-
help = "Number type to use (default: Float64)"
47-
arg_type = Symbol
48-
default = :Float64
4945
"--seed", "-s"
5046
help = "Random seed for the experiment"
5147
arg_type = Int
@@ -79,17 +75,6 @@ function parse_command_line()
7975
args["n-iterations"] > 0 || throw(ArgumentError("n-iterations must be positive"))
8076
args["wait-time"] >= 0 || throw(ArgumentError("wait-time must be non-negative"))
8177

82-
# Convert number type string to actual type
83-
number_type = if args["number-type"] == :Float32
84-
Float32
85-
elseif args["number-type"] == :Float64
86-
Float64
87-
elseif args["number-type"] == :Float16
88-
Float16
89-
else
90-
throw(ArgumentError("Unsupported number type: $(args["number-type"])"))
91-
end
92-
9378
# Handle save_results argument logic
9479
save_results = true
9580
if args["no-save-results"]
@@ -103,7 +88,6 @@ function parse_command_line()
10388
n_episodes=args["n-episodes"],
10489
n_iterations=args["n-iterations"],
10590
wait_time=args["wait-time"],
106-
number_type=number_type,
10791
seed=args["seed"],
10892
record_episode=args["record-episode"],
10993
experiment_name=args["experiment-name"],
@@ -121,7 +105,6 @@ function run_stochastic_maze_experiment(;
121105
wait_time::Float64=0.0,
122106
record_episode::Bool=false,
123107
seed::Int=123,
124-
number_type::Type{<:AbstractFloat}=Float64,
125108
experiment_name::String="stochastic_maze_$(Dates.format(now(), "yyyymmdd_HHMMSS"))",
126109
log_dir::String=datadir("logs", "stochastic_maze"),
127110
save_results::Bool=true,
@@ -151,7 +134,7 @@ function run_stochastic_maze_experiment(;
151134
n_actions = 4 # NESW
152135

153136
# Create goal distribution focused on the goal state
154-
p_goal = zeros(number_type, n_states)
137+
p_goal = zeros(Float64, n_states)
155138
p_goal[goal_state] = 1.0
156139
goal_distribution = Categorical(p_goal)
157140

@@ -161,7 +144,6 @@ function run_stochastic_maze_experiment(;
161144
n_episodes=n_episodes,
162145
n_iterations=n_iterations,
163146
wait_time=wait_time,
164-
number_type=number_type,
165147
seed=seed,
166148
record_episode=record_episode,
167149
experiment_name=experiment_name
@@ -312,7 +294,6 @@ function main()
312294
wait_time=args.wait_time,
313295
record_episode=args.record_episode,
314296
seed=args.seed,
315-
number_type=args.number_type,
316297
experiment_name=args.experiment_name,
317298
save_results=args.save_results,
318299
debug_mode=args.debug_mode,

scripts/tmaze_experiments.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ function parse_command_line()
5050
help = "Time to wait between steps in seconds (default: 0.0)"
5151
arg_type = Float64
5252
default = 0.0
53-
"--number-type", "-n"
54-
help = "Number type to use (default: Float64)"
55-
arg_type = Symbol
56-
default = :Float64
5753
"--seed", "-s"
5854
help = "Random seed for the experiment"
5955
arg_type = Int
@@ -87,17 +83,6 @@ function parse_command_line()
8783
args["n-iterations"] > 0 || throw(ArgumentError("n-iterations must be positive"))
8884
args["wait-time"] >= 0 || throw(ArgumentError("wait-time must be non-negative"))
8985

90-
# Convert number type string to actual type
91-
number_type = if args["number-type"] == :Float32
92-
Float32
93-
elseif args["number-type"] == :Float64
94-
Float64
95-
elseif args["number-type"] == :Float16
96-
Float16
97-
else
98-
throw(ArgumentError("Unsupported number type: $(args["number-type"])"))
99-
end
100-
10186
# Handle save_results argument logic
10287
save_results = true
10388
if args["no-save-results"]
@@ -111,7 +96,6 @@ function parse_command_line()
11196
n_episodes=args["n-episodes"],
11297
n_iterations=args["n-iterations"],
11398
wait_time=args["wait-time"],
114-
number_type=number_type,
11599
seed=args["seed"],
116100
record_episode=args["record-episode"],
117101
experiment_name=args["experiment-name"],
@@ -128,7 +112,6 @@ function run_tmaze_experiment(;
128112
wait_time::Float64=0.0,
129113
record_episode::Bool=false,
130114
seed::Int=123,
131-
number_type::Type{<:AbstractFloat}=Float64,
132115
experiment_name::String="tmaze_$(Dates.format(now(), "yyyymmdd_HHMMSS"))",
133116
log_dir::String=datadir("logs", "tmaze"),
134117
save_results::Bool=true,
@@ -152,7 +135,7 @@ function run_tmaze_experiment(;
152135
end
153136

154137
# Create goal distribution - prefer the left arm location (state 3)
155-
left_goal = zeros(number_type, 5)
138+
left_goal = zeros(Float64, 5)
156139
left_goal[3] = 1.0
157140
left_goal_distribution = Categorical(left_goal)
158141

@@ -162,7 +145,6 @@ function run_tmaze_experiment(;
162145
n_episodes=n_episodes,
163146
n_iterations=n_iterations,
164147
wait_time=wait_time,
165-
number_type=number_type,
166148
seed=seed,
167149
record_episode=record_episode,
168150
experiment_name=experiment_name
@@ -312,7 +294,6 @@ function main()
312294
wait_time=args.wait_time,
313295
record_episode=args.record_episode,
314296
seed=args.seed,
315-
number_type=args.number_type,
316297
experiment_name=args.experiment_name,
317298
save_results=args.save_results,
318299
debug_mode=args.debug_mode,

src/agents/stochastic_maze_agent.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,16 @@ Configuration for StochasticMaze agent experiments.
2424
- `n_episodes::Int`: Number of episodes to run
2525
- `n_iterations::Int`: Number of inference iterations per step
2626
- `wait_time::Float64`: Time to wait between steps (for visualization)
27-
- `number_type::Type{T}`: Numeric type for computations
2827
- `seed::Int`: Random seed
2928
- `record_episode::Bool`: Whether to record episode frames as individual PNG files
3029
- `experiment_name::String`: Name of the experiment (for saving results)
3130
- `parallel::Bool`: Whether to run episodes in parallel
3231
"""
33-
Base.@kwdef struct StochasticMazeConfig{T<:AbstractFloat}
32+
Base.@kwdef struct StochasticMazeConfig
3433
time_horizon::Int
3534
n_episodes::Int
3635
n_iterations::Int
3736
wait_time::Float64
38-
number_type::Type{T}
3937
seed::Int
4038
record_episode::Bool = false
4139
experiment_name::String
@@ -48,10 +46,10 @@ end
4846
Container for agent's beliefs about the StochasticMaze environment.
4947
5048
# Fields
51-
- `state::Categorical{T}`: Belief about current state
49+
- `state::Categorical{Float64}`: Belief about current state
5250
"""
53-
Base.@kwdef mutable struct StochasticMazeBeliefs{T<:AbstractFloat}
54-
state::Categorical{T}
51+
Base.@kwdef mutable struct StochasticMazeBeliefs
52+
state::Categorical{Float64}
5553
end
5654

5755
"""
@@ -67,14 +65,14 @@ function validate_config(config::StochasticMazeConfig)
6765
end
6866

6967
"""
70-
initialize_beliefs_stochastic_maze(n_states::Int, T::Type{<:AbstractFloat})
68+
initialize_beliefs_stochastic_maze(n_states::Int)
7169
7270
Initialize agent beliefs for the StochasticMaze environment.
7371
"""
74-
function initialize_beliefs_stochastic_maze(n_states::Int, T::Type{<:AbstractFloat})
72+
function initialize_beliefs_stochastic_maze(n_states::Int)
7573
# Initialize with uniform beliefs over states
7674
return StochasticMazeBeliefs(
77-
state=Categorical(fill(T(1.0 / n_states), n_states))
75+
state=Categorical(fill(1.0 / n_states, n_states))
7876
)
7977
end
8078

@@ -103,18 +101,18 @@ function execute_step(env, observation, beliefs, model, tensors, config, goal, c
103101

104102
# Convert previous action to one-hot encoding
105103
n_actions = 4
106-
previous_action_vec = zeros(config.number_type, n_actions)
104+
previous_action_vec = zeros(Float64, n_actions)
107105
if !isnothing(previous_action)
108-
previous_action_vec[previous_action.index] = one(config.number_type)
106+
previous_action_vec[previous_action.index] = one(Float64)
109107
end
110108

111109
# Get initialization from previous results or initialize fresh
112110
n_states = size(tensors.transition_tensor, 1)
113111
initialization = initialization_fn(n_states)
114112

115113
# Create observation vector
116-
observation_vec = zeros(config.number_type, n_states)
117-
observation_vec[observation] = one(config.number_type)
114+
observation_vec = zeros(Float64, n_states)
115+
observation_vec[observation] = one(Float64)
118116

119117
# Run inference
120118
result = infer(
@@ -191,7 +189,7 @@ function run_stochastic_maze_single_episode(model, tensors, config, goal, callba
191189

192190
# Initialize beliefs
193191
n_states = size(tensors.transition_tensor, 1)
194-
beliefs = initialize_beliefs_stochastic_maze(n_states, config.number_type)
192+
beliefs = initialize_beliefs_stochastic_maze(n_states)
195193

196194
# Initialize tracking variables
197195
total_reward = 0.0

src/agents/tmaze_agent.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,16 @@ Configuration for TMaze agent experiments.
2222
- `n_episodes::Int`: Number of episodes to run
2323
- `n_iterations::Int`: Number of inference iterations per step
2424
- `wait_time::Float64`: Time to wait between steps (for visualization)
25-
- `number_type::Type{T}`: Numeric type for computations
2625
- `seed::Int`: Random seed
2726
- `record_episode::Bool`: Whether to record episode frames as individual PNG files
2827
- `experiment_name::String`: Name of the experiment (for saving results)
2928
- `parallel::Bool`: Whether to run episodes in parallel
3029
"""
31-
Base.@kwdef struct TMazeConfig{T<:AbstractFloat}
30+
Base.@kwdef struct TMazeConfig
3231
time_horizon::Int
3332
n_episodes::Int
3433
n_iterations::Int
3534
wait_time::Float64
36-
number_type::Type{T}
3735
seed::Int
3836
record_episode::Bool = false
3937
experiment_name::String
@@ -46,12 +44,12 @@ end
4644
Container for agent's beliefs about the TMaze environment.
4745
4846
# Fields
49-
- `location::Categorical{T}`: Belief about current location (5 possible states)
50-
- `reward_location::Categorical{T}`: Belief about reward location (left or right)
47+
- `location::Categorical{Float64}`: Belief about current location (5 possible states)
48+
- `reward_location::Categorical{Float64}`: Belief about reward location (left or right)
5149
"""
52-
Base.@kwdef mutable struct TMazeBeliefs{T<:AbstractFloat}
53-
location::Categorical{T}
54-
reward_location::Categorical{T}
50+
Base.@kwdef mutable struct TMazeBeliefs
51+
location::Categorical{Float64}
52+
reward_location::Categorical{Float64}
5553
end
5654

5755
"""
@@ -67,15 +65,15 @@ function validate_config(config::TMazeConfig)
6765
end
6866

6967
"""
70-
initialize_beliefs_tmaze(T::Type{<:AbstractFloat})
68+
initialize_beliefs_tmaze()
7169
7270
Initialize agent beliefs for the TMaze environment.
7371
"""
74-
function initialize_beliefs_tmaze(T::Type{<:AbstractFloat})
72+
function initialize_beliefs_tmaze()
7573
# Initialize with uniform beliefs over states
7674
return TMazeBeliefs(
77-
location=Categorical(fill(T(1 / 5), 5)),
78-
reward_location=Categorical([T(0.5), T(0.5)])
75+
location=Categorical(fill(1.0 / 5, 5)),
76+
reward_location=Categorical([0.5, 0.5])
7977
)
8078
end
8179

@@ -127,15 +125,15 @@ Takes current observations and returns the next planned action.
127125
function execute_step(env, position_obs, reward_cue, beliefs, model, tensors, config, goal, callbacks, time_remaining, previous_result, previous_action;
128126
constraints_fn, initialization_fn, inference_kwargs...)
129127
# Convert previous action to one-hot encoding
130-
previous_action_vec = zeros(config.number_type, 4)
128+
previous_action_vec = zeros(Float64, 4)
131129
if previous_action.direction isa North
132-
previous_action_vec[1] = one(config.number_type)
130+
previous_action_vec[1] = one(Float64)
133131
elseif previous_action.direction isa East
134-
previous_action_vec[2] = one(config.number_type)
132+
previous_action_vec[2] = one(Float64)
135133
elseif previous_action.direction isa South
136-
previous_action_vec[3] = one(config.number_type)
134+
previous_action_vec[3] = one(Float64)
137135
elseif previous_action.direction isa West
138-
previous_action_vec[4] = one(config.number_type)
136+
previous_action_vec[4] = one(Float64)
139137
end
140138

141139
# Get initialization from previous results or initialize fresh
@@ -210,7 +208,7 @@ function run_tmaze_single_episode(model, tensors, config, goal, callbacks, seed;
210208
env = create_tmaze(reward_position, (2, 2)) # Start at middle junction (2,2)
211209

212210
# Initialize beliefs
213-
beliefs = initialize_beliefs_tmaze(config.number_type)
211+
beliefs = initialize_beliefs_tmaze()
214212

215213
# Initialize tracking variables
216214
total_reward = 0.0
@@ -247,8 +245,8 @@ function run_tmaze_single_episode(model, tensors, config, goal, callbacks, seed;
247245
)
248246

249247
# Initial position observation and reward cue
250-
position_obs = convert.(config.number_type, get_position_observation(env))
251-
reward_cue = convert.(config.number_type, get_reward_cue(env))
248+
position_obs = convert.(Float64, get_position_observation(env))
249+
reward_cue = convert.(Float64, get_reward_cue(env))
252250

253251
# Record initial state
254252
push!(episode_data["positions"], [env.agent_position...])
@@ -283,8 +281,8 @@ function run_tmaze_single_episode(model, tensors, config, goal, callbacks, seed;
283281
position_obs, reward_cue, reward = step!(env, next_action)
284282

285283
# Convert to the required numeric type
286-
position_obs = convert.(config.number_type, position_obs)
287-
reward_cue = convert.(config.number_type, reward_cue)
284+
position_obs = convert.(Float64, position_obs)
285+
reward_cue = convert.(Float64, reward_cue)
288286

289287
# Update total reward
290288
episode_reward = reward isa Number ? reward : 0
@@ -408,7 +406,6 @@ function run_tmaze_agent(
408406
n_episodes=config.n_episodes,
409407
n_iterations=config.n_iterations,
410408
wait_time=config.wait_time,
411-
number_type=config.number_type,
412409
seed=config.seed,
413410
record_episode=config.record_episode,
414411
experiment_name=config.experiment_name,

src/environments/stochastic_maze.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ function visualize_stochastic_maze(env::StochasticMaze)
550550
x, y = state_to_xy(state, grid_size_x)
551551
color = reward > 0 ? MAZE_THEME.reward_positive : MAZE_THEME.reward_negative
552552
opacity = min(abs(reward), 1.0) # Use absolute value of reward for opacity, capped at 1.0
553-
scatter!(p, [x - 0.5], [grid_size_y - y + 0.5], color=color, alpha=opacity, markersize=ceil(Int, scale))
553+
scatter!(p, [x - 0.5], [grid_size_y - y + 0.5], color=color, alpha=opacity, markersize=ceil(Int, scale), markerstrokewidth=ceil(Int, scale / 15))
554554
end
555555

556556
# Plot observation noise
@@ -583,7 +583,7 @@ function visualize_stochastic_maze(env::StochasticMaze)
583583

584584
# Plot agent
585585
x, y = state_to_xy(env.agent_state, grid_size_x)
586-
scatter!(p, [x - 0.5], [grid_size_y - y + 0.5], color=MAZE_THEME.agent, markersize=ceil(Int, (2 / 3) * scale))
586+
scatter!(p, [x - 0.5], [grid_size_y - y + 0.5], color=MAZE_THEME.agent, markersize=ceil(Int, (2 / 3) * scale), markerstrokewidth=ceil(Int, scale / 15))
587587

588588
return p
589589
end

0 commit comments

Comments
 (0)