Skip to content

Commit 5310d8f

Browse files
authored
Merge pull request #10 from biaslab/experiments
update Experiments
2 parents 8d88094 + 1f852f3 commit 5310d8f

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
# Detect number of CPU cores and use (cores - 2) for Julia threads, minimum 1
44
NPROC := $(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 4)
5-
JULIA_THREADS := $(shell expr $(NPROC) - 2 || echo 1)
5+
JULIA_THREADS := $(shell expr $(NPROC) - 1 || echo 1)
66
# Ensure minimum of 1 thread
77
ifeq ($(shell expr $(JULIA_THREADS) \< 1), 1)
88
JULIA_THREADS := 1
99
endif
1010

1111
# Define experiment parameters
12-
MINIGRID_PARAMS := --save-results --parallel --save-video --n-iterations 40 --n-episodes 200 --time-horizon 25 --grid-size 4
12+
MINIGRID_PARAMS := --save-results --parallel --save-video --n-iterations 50 --n-episodes 200 --time-horizon 25 --grid-size 4
1313
DEBUG_MINIGRID_PARAMS := --grid-size 4 --time-horizon 25 --save-frame --iterations 40 --save-animation
1414
STOCHASTIC_MAZE_PARAMS := -r --save-results
1515
DEBUG_STOCHASTIC_MAZE_PARAMS := --save-frame --iterations 50

src/models/minigrid/efe.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import RxInfer: Categorical
5454
key_door_state_observation_marginalcomponents = JointMarginalMetaComponent[]
5555

5656
for x in 1:7, y in 1:7
57-
marginalstorage = JointMarginalStorage(Contingency(ones(number_type, size(observation_tensors[x, y]))))
57+
marginalstorage = JointMarginalStorage(Contingency(collect(observation_tensors[x, y])))
5858
location_observation_marginalcomponent = JointMarginalMetaComponent(marginalstorage, 1, 2)
5959
push!(location_observation_marginalcomponents, location_observation_marginalcomponent)
6060
orientation_observation_marginalcomponent = JointMarginalMetaComponent(marginalstorage, 1, 3)
@@ -66,14 +66,14 @@ import RxInfer: Categorical
6666
key_door_state_observation_marginalcomponent = JointMarginalMetaComponent(marginalstorage, 1, 6)
6767
push!(key_door_state_observation_marginalcomponents, key_door_state_observation_marginalcomponent)
6868
decomposed_tensor = observation_tensors[x, y]
69-
future_observations[x, y] ~ DiscreteTransition(current_location, decomposed_tensor, current_orientation, key_location, door_location, current_key_door_state) where {meta=marginalstorage}
70-
future_observations[x, y] ~ Categorical(fill(number_type(1 / 5), 5))
69+
# future_observations[x, y] ~ DiscreteTransition(current_location, decomposed_tensor, current_orientation, key_location, door_location, current_key_door_state) where {meta=marginalstorage}
70+
# future_observations[x, y] ~ Categorical(fill(number_type(1 / 5), 5))
7171
end
72-
location[t] ~ Ambiguity(observations[1, 1]) where {meta=JointMarginalMeta(location_observation_marginalcomponents)}
73-
orientation[t] ~ Ambiguity(observations[1, 1]) where {meta=JointMarginalMeta(orientation_observation_marginalcomponents)}
74-
key_location ~ Ambiguity(observations[1, 1]) where {meta=JointMarginalMeta(key_location_observation_marginalcomponents)}
75-
door_location ~ Ambiguity(observations[1, 1]) where {meta=JointMarginalMeta(door_location_observation_marginalcomponents)}
76-
key_door_state[t] ~ Ambiguity(observations[1, 1]) where {meta=JointMarginalMeta(key_door_state_observation_marginalcomponents)}
72+
location[t] ~ Ambiguity(1) where {meta=JointMarginalMeta(location_observation_marginalcomponents)}
73+
orientation[t] ~ Ambiguity(1) where {meta=JointMarginalMeta(orientation_observation_marginalcomponents)}
74+
key_location ~ Ambiguity(1) where {meta=JointMarginalMeta(key_location_observation_marginalcomponents)}
75+
door_location ~ Ambiguity(1) where {meta=JointMarginalMeta(door_location_observation_marginalcomponents)}
76+
key_door_state[t] ~ Ambiguity(1) where {meta=JointMarginalMeta(key_door_state_observation_marginalcomponents)}
7777
end
7878
location[end] ~ goal
7979
orientation[end] ~ Categorical(fill(number_type(1 / 4), 4))

src/models/minigrid/klcontrol.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,14 @@ import RxInfer: Categorical
2626
end
2727
current_orientation ~ DiscreteTransition(orientation_observation, diageye(number_type, 4))
2828

29+
current_location_ow ~ OneWay(current_location)
30+
current_orientation_ow ~ OneWay(current_orientation)
31+
current_key_door_state_ow ~ OneWay(current_key_door_state)
32+
2933
# Planning (Active Inference)
30-
previous_location = current_location
31-
previous_orientation = current_orientation
32-
previous_key_door_state = current_key_door_state
34+
previous_location = current_location_ow
35+
previous_orientation = current_orientation_ow
36+
previous_key_door_state = current_key_door_state_ow
3337
for t in 1:T
3438
u[t] ~ Categorical(fill(number_type(1 / 5), 5))
3539
location[t] ~ DiscreteTransition(previous_location, location_transition_tensor, previous_orientation, key_location, door_location, previous_key_door_state, u[t])
@@ -58,6 +62,10 @@ RxInfer.GraphPPL.default_constraints(::typeof(klcontrol_minigrid_agent)) = klcon
5862
μ(current_orientation) = p_current_orientation
5963
μ(current_key_door_state) = p_current_key_door_state
6064

65+
μ(current_location_ow) = p_current_location
66+
μ(current_orientation_ow) = p_current_orientation
67+
μ(current_key_door_state_ow) = p_current_key_door_state
68+
6169
μ(location) = p_future_locations
6270
μ(orientation) = p_future_orientations
6371
μ(key_door_state) = p_future_key_door_states

src/rules/full_tensor/uncached.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ end
185185
end
186186

187187
# Rules for observation model for future observations (m_out is categorical)
188-
@rule DiscreteTransition(:out, Marginalisation) (m_in::Categorical, m_T1::Categorical, m_T2::Categorical, m_T3::Categorical, m_T4::Categorical, q_a::PointMass{<:AbstractArray{T,7}}, meta::Any) where {T} = begin
188+
@rule DiscreteTransition(:out, Marginalisation) (m_in::Categorical, m_T1::Categorical, m_T2::Categorical, m_T3::Categorical, m_T4::Categorical, q_a::PointMass{<:AbstractArray{T,6}}, meta::Any) where {T} = begin
189189
eloga = mean(q_a)
190190
size_result = size(eloga)[1]
191191
return Categorical(fill(T(1 / size_result), size_result); check_args=false)

0 commit comments

Comments
 (0)