Skip to content

Commit fe007ed

Browse files
committed
fix: Allow AISearcher to pick a step without relying on stepsPlayed
1 parent 4e90c51 commit fe007ed

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

VSharp.Explorer/AISearcher.fs

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -122,31 +122,27 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
122122

123123
Application.applicationGraphDelta.Clear()
124124

125-
if stepsToPlay = stepsPlayed then
125+
let toPredict =
126+
match aiMode with
127+
| TrainingSendEachStep
128+
| TrainingSendModel ->
129+
if stepsPlayed > 0u<step> then
130+
gameStateDelta
131+
else
132+
gameState.Value
133+
| Runner -> gameState.Value
134+
135+
let stateId = oracle.Predict toPredict
136+
afterFirstAIPeek <- true
137+
let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId)
138+
lastCollectedStatistics <- statistics
139+
match state with
140+
| Some state -> Some state
141+
| None ->
142+
incorrectPredictedStateId <- true
143+
oracle.Feedback(Feedback.IncorrectPredictedStateId stateId)
126144
None
127-
else
128-
let toPredict =
129-
match aiMode with
130-
| TrainingSendEachStep
131-
| TrainingSendModel ->
132-
if stepsPlayed > 0u<step> then
133-
gameStateDelta
134-
else
135-
gameState.Value
136-
| Runner -> gameState.Value
137-
138-
let stateId = oracle.Predict toPredict
139-
afterFirstAIPeek <- true
140-
let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId)
141-
lastCollectedStatistics <- statistics
142-
stepsPlayed <- stepsPlayed + 1u<step>
143-
144-
match state with
145-
| Some state -> Some state
146-
| None ->
147-
incorrectPredictedStateId <- true
148-
oracle.Feedback(Feedback.IncorrectPredictedStateId stateId)
149-
None
145+
150146
static member updateGameState (delta: GameState) (gameState: Option<GameState>) =
151147
match gameState with
152148
| None -> Some delta
@@ -199,11 +195,11 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
199195
Some
200196
<| GameState(vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
201197

202-
static member convertOutputToJson (output: IDisposableReadOnlyCollection<OrtValue>) =
198+
static member convertOutputToJson(output: IDisposableReadOnlyCollection<OrtValue>) =
203199
seq { 0 .. output.Count - 1 }
204200
|> Seq.map (fun i -> output[i].GetTensorDataAsSpan<float32>().ToArray())
205201

206-
202+
207203

208204
new
209205
(

0 commit comments

Comments
 (0)