-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patha_star_heuristic.py
More file actions
32 lines (27 loc) · 913 Bytes
/
a_star_heuristic.py
File metadata and controls
32 lines (27 loc) · 913 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import solver_helpers as sh
import rule_helpers as rh
import neural_network as nn
import torch as tr
# A baseline heuristic for A* Search
def simple_heuristic(state: tuple) -> int:
(args, claim, hist) = sh.unpack(state)
parsed_claim = rh.parse_expression(claim)
a = parsed_claim[1]
b = parsed_claim[2]
advanced_conclusion = False
if b != "": advanced_conclusion = True
steps = 0
if a not in args: steps += 1
if advanced_conclusion:
if b not in args: steps += 1
return steps
# Encapsulate this in a class so that net is only
# loaded once
class NeuralNetwork:
def __init__(self):
self.net = tr.load("saved_net.pt")
# An advnaced heuristic using trained neural network
def nn_heuristic(self, state: tuple) -> int:
# Get heuristic value
steps = self.net(tr.stack(tuple(map(nn.one_hot_encoding, [state]))))
return steps