Skip to content

Commit 88dd40d

Browse files
author
Kent Sommer
authored
Update README.md
1 parent d3939cb commit 88dd40d

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --
4646
- `l_q`: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
4747
- `batch_size`: Batch size. Default: 128
4848

49-
## How to visualize / test paths (requires training first)
49+
## How to test / visualize paths (requires training first)
5050
#### 8x8 gridworld
5151
```bash
5252
python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10
@@ -59,10 +59,15 @@ python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20
5959
```bash
6060
python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36
6161
```
62+
To visualize the optimal and predicted paths simply pass:
63+
```bash
64+
--plot
65+
```
6266

6367
**Flags**:
6468
- `weights`: Path to trained weights.
6569
- `imsize`: The size of input images. One of: [8, 16, 28]
70+
- `plot`: If supplied, the optimal and predicted paths will be plotted
6671
- `k`: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
6772
- `l_i`: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
6873
- `l_h`: Number of channels in first convolutional layer. Default: 150, described in paper.
@@ -85,6 +90,13 @@ Test set | 13846 | 77203 | 251755
8590

8691
The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the ```dataset/make_training_data.py``` script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)
8792

93+
## Performance: Success Rate
94+
This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).
95+
96+
Success Rate | 8x8 | 16x16 | 28x28
97+
-- | -- | -- | --
98+
PyTorch | 99.69% | 96.99% | 91.07%
99+
88100
## Performance: Test Accuracy
89101

90102
**NOTE**: This is the **accuracy on test set**. It is different from the table in the paper, which indicates the **success rate** from rollouts of the learned policy in the environment.

0 commit comments

Comments
 (0)