-
Notifications
You must be signed in to change notification settings - Fork 161
Open
Description
The number of states reachable from the given states differs when using PolicyIteration and ValueIteration.
When PolicyIteration and ValueIteration were run on a graph defined MDP of 100 states, ValueIteration output 100 but PolicyIteration output 99.
Code snippet to create the graph and run policy and value iterations (warning: code is untested, I'm working on an assignment, apologies):
public class GraphDefinedMDP {
private static final int NUM_STATES = 100;
private GraphDefinedDomain graphDefinedDomainGen;
private SADomain domain;
private State initialState;
private int[] goalStates = new int[NUM_STATES / 75];
private HashableStateFactory hashingFactory;
private Environment env;
public GraphDefinedMDP(int initialStateNum) {
graphDefinedDomainGen = new GraphDefinedDomain(NUM_STATES);
// Deterministic goal states
for (int i = 0; i < goalStates.length; i++) {
goalStates[i] = i * goalStates.length + (goalStates.length / 2);
}
// Print goal states
System.out.format("Goal States: %s\n",
Util.intArrayToString(goalStates));
// Set terminal states
TerminalFunction tf = new GraphTF(goalStates);
RewardFunction rf = new GraphRF() {
@Override
public double reward(int s, int a, int sprime) {
for (int goalState : goalStates)
if (goalState == sprime)
return 2;
return -1;
}
};
graphDefinedDomainGen.setTf(tf);
graphDefinedDomainGen.setRf(rf);
// All nodes are equally reachable from every other node
int action = 0;
double probability = 1.0 / (NUM_STATES - 1);
for (int srcNode = 0; srcNode < NUM_STATES; srcNode++) {
for (int dstNode = 0; dstNode < NUM_STATES; dstNode++) {
if (srcNode != dstNode) {
graphDefinedDomainGen.setTransition(srcNode, action,
dstNode, probability);
action = (action + 1) % goalStates.length;
}
}
}
if (graphDefinedDomainGen.isValidMDPGraph()) {
// Invalid MDP graph
System.exit(1);
}
domain = graphDefinedDomainGen.generateDomain();
initialState = new GraphStateNode(initialStateNum);
System.out.println("initialState: " + initialState);
hashingFactory = new SimpleHashableStateFactory();
env = new SimulatedEnvironment(domain, initialState);
}
public void valueIteration() {
Planner planner = new ValueIteration(domain, 0.99, hashingFactory,
0.001, 200);
Policy p = planner.planFromState(initialState);
Episode episode = PolicyUtils.rollout(p, initialState,
domain.getModel(), 500);
printEpisodeStats(episode);
}
public void policyIteration() {
Planner planner = new PolicyIteration(domain, 0.99, hashingFactory,
0.001, 200, 10);
Policy p = planner.planFromState(initialState);
Episode episode = PolicyUtils.rollout(p, initialState,
domain.getModel(), 500);
printEpisodeStats(episode);
}
public static void main(String args[]) {
int initialStateNum = 99;
System.out.println("---Value iteration---");
GraphDefinedMDP obj1 = new GraphDefinedMDP(initialStateNum);
long startTime = System.currentTimeMillis();
obj1.valueIteration();
long endTime = System.currentTimeMillis();
System.out.format("Time taken for value iteration: %d ms\n\n\n",
(endTime - startTime));
System.out.println("---Policy iteration---");
GraphDefinedMDP obj2 = new GraphDefinedMDP(initialStateNum);
startTime = System.currentTimeMillis();
obj2.policyIteration();
endTime = System.currentTimeMillis();
System.out.format("Time taken for policy iteration: %d ms\n\n\n",
(endTime - startTime));
}Diff of the performReachabilityFrom(State state) function: https://www.diffchecker.com/4ZsjNQx7
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels