@@ -122,31 +122,27 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
122122
123123 Application.applicationGraphDelta.Clear()
124124
125- if stepsToPlay = stepsPlayed then
125+ let toPredict =
126+ match aiMode with
127+ | TrainingSendEachStep
128+ | TrainingSendModel ->
129+ if stepsPlayed > 0 u< step> then
130+ gameStateDelta
131+ else
132+ gameState.Value
133+ | Runner -> gameState.Value
134+
135+ let stateId = oracle.Predict toPredict
136+ afterFirstAIPeek <- true
137+ let state = availableStates |> Seq.tryFind ( fun s -> s.internalId = stateId)
138+ lastCollectedStatistics <- statistics
139+ match state with
140+ | Some state -> Some state
141+ | None ->
142+ incorrectPredictedStateId <- true
143+ oracle.Feedback( Feedback.IncorrectPredictedStateId stateId)
126144 None
127- else
128- let toPredict =
129- match aiMode with
130- | TrainingSendEachStep
131- | TrainingSendModel ->
132- if stepsPlayed > 0 u< step> then
133- gameStateDelta
134- else
135- gameState.Value
136- | Runner -> gameState.Value
137-
138- let stateId = oracle.Predict toPredict
139- afterFirstAIPeek <- true
140- let state = availableStates |> Seq.tryFind ( fun s -> s.internalId = stateId)
141- lastCollectedStatistics <- statistics
142- stepsPlayed <- stepsPlayed + 1 u< step>
143-
144- match state with
145- | Some state -> Some state
146- | None ->
147- incorrectPredictedStateId <- true
148- oracle.Feedback( Feedback.IncorrectPredictedStateId stateId)
149- None
145+
150146 static member updateGameState ( delta : GameState ) ( gameState : Option < GameState >) =
151147 match gameState with
152148 | None -> Some delta
@@ -199,11 +195,11 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
199195 Some
200196 <| GameState( vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
201197
202- static member convertOutputToJson ( output : IDisposableReadOnlyCollection < OrtValue >) =
198+ static member convertOutputToJson ( output : IDisposableReadOnlyCollection < OrtValue >) =
203199 seq { 0 .. output.Count - 1 }
204200 |> Seq.map ( fun i -> output[ i]. GetTensorDataAsSpan< float32>() .ToArray())
205201
206-
202+
207203
208204 new
209205 (
0 commit comments