Skip to content

Commit e959537

Browse files
authored
Change distance function (#2)
* Improve distance function * normalized distance * test * double-pointer * fix depth
1 parent 0b8319f commit e959537

File tree

4 files changed

+90
-50
lines changed

4 files changed

+90
-50
lines changed

planner.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"sync"
99
)
1010

11+
const maxDepth = 100
12+
1113
// Action represents an action that can be performed.
1214
type Action interface {
1315

@@ -24,7 +26,6 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
2426
start = start.Clone()
2527
start.node = node{
2628
heuristic: start.Distance(goal),
27-
stateCost: 0,
2829
}
2930

3031
heap := acquireHeap()
@@ -34,6 +35,14 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
3435
for heap.Len() > 0 {
3536
current, _ := heap.Pop()
3637

38+
/*fmt.Printf("- (%d) %s, cost=%v, heuristic=%v, total=%v\n",
39+
current.depth, current.action,
40+
current.stateCost, current.heuristic, current.totalCost)*/
41+
42+
if current.depth >= maxDepth {
43+
return reconstructPlan(current), nil
44+
}
45+
3746
// If we reached the goal, reconstruct the path.
3847
done, err := current.Match(goal)
3948
switch {
@@ -59,8 +68,6 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
5968
return nil, err
6069
}
6170

62-
//fmt.Printf("Action: %s, State: %s, New: %s\n", action.String(), current.String(), newState.String())
63-
6471
// Check if newState is already planned to be visited or if the newCost is lower
6572
newCost := current.stateCost + action.Cost()
6673
node, found := heap.Find(newState.Hash())
@@ -72,6 +79,7 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
7279
newState.heuristic = heuristic
7380
newState.stateCost = newCost
7481
newState.totalCost = newCost + heuristic
82+
newState.depth = current.depth + 1
7583
heap.Push(newState)
7684

7785
// In any of those cases, we need to release the new state
@@ -92,7 +100,7 @@ func Plan(start, goal *State, actions []Action) ([]Action, error) {
92100

93101
// reconstructPlan reconstructs the plan from the goal node to the start node.
94102
func reconstructPlan(goalNode *State) []Action {
95-
plan := make([]Action, 0, int(goalNode.index))
103+
plan := make([]Action, 0, int(goalNode.depth))
96104
for n := goalNode; n != nil; n = n.parent {
97105
if n.action != nil { // The start node has no action
98106
plan = append(plan, n.action)

planner_test.go

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ BenchmarkPlan/deep-24 380756 3103 ns/op 230 B/op 1
2929
BenchmarkPlan/deep-24 337836 3519 ns/op 230 B/op 1 allocs/op
3030
BenchmarkPlan/deep-24 420907 2831 ns/op 230 B/op 1 allocs/op
3131
BenchmarkPlan/deep-24 444250 2716 ns/op 230 B/op 1 allocs/op
32+
BenchmarkPlan/deep-24 499970 2345 ns/op 211 B/op 1 allocs/op
3233
3334
BenchmarkPlan/maze-24 37 31458708 ns/op 2702894 B/op 80711 allocs/op
3435
BenchmarkPlan/maze-24 63 18643352 ns/op 1569536 B/op 51464 allocs/op
@@ -93,40 +94,44 @@ func TestNumericPlan(t *testing.T) {
9394

9495
plan, err := Plan(start, goal, actions)
9596
assert.NoError(t, err)
96-
assert.Equal(t, []string{"Forage", "Forage", "Forage", "Sleep", "Forage", "Forage", "Sleep", "Forage", "Forage", "Forage", "Sleep", "Eat", "Forage"},
97+
assert.Equal(t, []string{"Forage", "Forage", "Forage", "Sleep", "Forage", "Forage", "Sleep", "Forage", "Forage", "Forage", "Sleep", "Forage"},
9798
planOf(plan))
99+
100+
//assert.Fail(t, "xxx")
98101
}
99102

100103
func TestMaze(t *testing.T) {
101-
start := StateOf("A")
102-
goal := StateOf("Z")
103-
actions := []Action{
104+
plan, err := Plan(StateOf("A"), StateOf("Z"), []Action{
104105
move("A->B"), move("B->C"), move("C->D"), move("D->E"), move("E->F"), move("F->G"),
105106
move("G->H"), move("H->I"), move("I->J"), move("C->X1"), move("E->X2"), move("G->X3"),
106107
move("X1->D"), move("X2->F"), move("X3->H"), move("B->Y1"), move("D->Y2"), move("F->Y3"),
107108
move("Y1->C"), move("Y2->E"), move("Y3->G"), move("J->K"), move("K->L"), move("L->M"),
108109
move("M->N"), move("N->O"), move("O->P"), move("P->Q"), move("Q->R"), move("R->S"),
109110
move("S->T"), move("T->U"), move("U->V"), move("V->W"), move("W->X"), move("X->Y"),
110111
move("Y->Z"), move("U->Z1"), move("W->Z2"), move("Z1->V"), move("Z2->X"), move("A->Z3"),
111-
}
112-
113-
plan, err := Plan(start, goal, actions)
112+
})
114113
assert.NoError(t, err)
115114
assert.Equal(t, []string{"A->B", "B->C", "C->D", "D->E", "E->F", "F->G", "G->H", "H->I", "I->J",
116115
"J->K", "K->L", "L->M", "M->N", "N->O", "O->P", "P->Q", "Q->R", "R->S", "S->T", "T->U", "U->V",
117116
"V->W", "W->X", "X->Y", "Y->Z"},
118117
planOf(plan))
118+
//assert.Fail(t, "xxx")
119119
}
120120

121-
func TestSimplePlan(t *testing.T) {
122-
start := StateOf("A", "B")
123-
goal := StateOf("C", "D")
124-
actions := []Action{move("A->C"), move("A->D"), move("B->C"), move("B->D")}
121+
func TestWeightedPlan(t *testing.T) {
122+
plan, err := Plan(StateOf("A", "B"), StateOf("C", "D"),
123+
[]Action{move("A->C"), move("A->D", 0.5), move("B->C"), move("B->D", 0.75)},
124+
)
125+
assert.NoError(t, err)
126+
assert.Equal(t, []string{"A->D", "B->C"}, planOf(plan))
127+
}
125128

126-
plan, err := Plan(start, goal, actions)
129+
func TestSimplePlan(t *testing.T) {
130+
plan, err := Plan(StateOf("A", "B"), StateOf("C", "D"),
131+
[]Action{move("A->C"), move("A->D"), move("B->C"), move("B->D")},
132+
)
127133
assert.NoError(t, err)
128-
assert.Equal(t, []string{"A->C", "B->D"},
129-
planOf(plan))
134+
assert.Equal(t, []string{"A->C", "B->D"}, planOf(plan))
130135
}
131136

132137
func TestNoPlanFound(t *testing.T) {
@@ -139,9 +144,13 @@ func TestNoPlanFound(t *testing.T) {
139144

140145
// ------------------------------------ Test Action ------------------------------------
141146

142-
func move(m string) Action {
147+
func move(m string, w ...float32) Action {
148+
if len(w) == 0 {
149+
w = append(w, 1.0)
150+
}
151+
143152
arr := strings.Split(m, "->")
144-
return actionOf(m, 1, StateOf(arr[0]), StateOf("!"+arr[0], arr[1]))
153+
return actionOf(m, w[0], StateOf(arr[0]), StateOf("!"+arr[0], arr[1]))
145154
}
146155

147156
func planOf(plan []Action) []string {

state.go

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"sync"
1111
)
1212

13-
const linearCutoff = 8 // 1 cache line
13+
const linearCutoff = 16 // 2 cache line
1414

1515
var pool = sync.Pool{
1616
New: func() any {
@@ -45,6 +45,7 @@ type node struct {
4545
stateCost float32 // Cost from the start state to this state
4646
totalCost float32 // Sum of cost and heuristic
4747
index int // Index of the state in the heap
48+
depth int // Depth of the state in the tree
4849
visited bool // Whether the state was visited
4950
}
5051

@@ -225,38 +226,47 @@ func (s *State) Apply(effects *State) error {
225226
return nil
226227
}
227228

228-
// Distance estimates the distance to the goal state as the number of differing keys.
229+
// Distance estimates the distance to the goal state.
229230
func (state *State) Distance(goal *State) (diff float32) {
230-
i, j := 0, 0
231-
for i < len(goal.vx) && j < len(state.vx) {
232-
f0 := goal.vx[i].Fact()
233-
f1 := state.vx[j].Fact()
231+
i := 0
232+
for _, g := range goal.vx {
233+
x := g.Expr().Value()
234+
v := float32(0)
235+
236+
// Find the value in the state
237+
for ; i < len(state.vx); i++ {
238+
if state.vx[i].Fact() == g.Fact() {
239+
v = state.vx[i].Expr().Value()
240+
break // Found
241+
}
242+
if state.vx[i].Fact() < g.Fact() {
243+
break // Not found
244+
}
245+
}
234246

235-
switch {
236-
case f1 == f0:
237-
x := goal.vx[i].Expr().Value()
238-
y := state.vx[j].Expr().Value()
247+
// Calculate the difference, normalized
248+
switch g.Expr().Operator() {
249+
case opEqual:
239250
switch {
240-
case x > y:
241-
diff += x - y
242-
case x < y:
243-
diff += y - x
251+
case v < x:
252+
diff += (x - v)
253+
case v > x:
254+
diff += (v - x)
255+
default: // v == x
244256
}
245257

246-
j++
247-
i++
248-
case f1 > f0:
249-
diff += 100
250-
j++
251-
case f1 < f0:
252-
diff += 100
253-
i++
258+
case opLess:
259+
if v > x {
260+
diff += (v - x)
261+
}
262+
263+
case opGreater:
264+
if v < x {
265+
diff += (x - v)
266+
}
254267
}
255268
}
256269

257-
// Add the remaining elements
258-
diff += float32(len(goal.vx)-i) * 100
259-
diff += float32(len(state.vx)-j) * 100
260270
return diff
261271
}
262272

state_test.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,32 @@ func TestDistance(t *testing.T) {
151151
{[]string{"A"}, []string{"A"}, 0},
152152
{[]string{"A=100"}, []string{"A=10"}, 90},
153153
{[]string{"A=100"}, []string{"A=90"}, 10},
154-
{[]string{"A"}, []string{"B"}, 200},
154+
{[]string{"A=25"}, []string{"A=50"}, 25},
155+
{[]string{"A=0"}, []string{"A=50"}, 50},
156+
{[]string{"A=75"}, []string{"A=50"}, 25},
157+
{[]string{"A"}, []string{"B"}, 100},
155158
{[]string{"A"}, []string{"A", "B"}, 100},
156-
{[]string{"A", "B"}, []string{"A"}, 100},
157-
{[]string{"A", "B"}, []string{"C", "D"}, 400},
159+
{[]string{"A", "B"}, []string{"A"}, 0},
160+
{[]string{"A", "B"}, []string{"C", "D"}, 200},
158161
{[]string{"A", "B"}, []string{"A", "B"}, 0},
159162
{[]string{"A", "B"}, []string{"A", "B", "C"}, 100},
160-
{[]string{"A", "B", "C"}, []string{"D", "B"}, 300},
163+
{[]string{"A", "B", "C"}, []string{"D", "B"}, 100},
164+
{[]string{"A=20"}, []string{"B=10"}, 10},
165+
{[]string{"A=20"}, []string{"B=70"}, 70},
166+
{[]string{"A=20", "C=40"}, []string{"B=5"}, 5},
167+
{[]string{"A=5", "C=40"}, []string{"A=10", "E=40"}, 45},
168+
{[]string{"A=10"}, []string{}, 0},
169+
{[]string{}, []string{"A=10"}, 10},
170+
{[]string{"A=10"}, []string{"A<50"}, 0},
171+
{[]string{"A=75"}, []string{"A<50"}, 25},
172+
{[]string{"A=10"}, []string{"A>50"}, 40},
173+
{[]string{"A=70"}, []string{"A>50"}, 0},
161174
}
162175

163176
for _, test := range tests {
164177
state1 := StateOf(test.state1...)
165178
state2 := StateOf(test.state2...)
166-
assert.Equal(t, test.expect, state1.Distance(state2),
179+
assert.InDelta(t, test.expect, state1.Distance(state2), 0.01,
167180
"state1=%s, state2=%s", state1, state2)
168181
}
169182
}

0 commit comments

Comments
 (0)