@@ -51,7 +51,7 @@ public enum tensorType
5151 public string [ ] ObservationPlaceholderName ;
5252 /// Modify only in inspector : Name of the action node
5353 public string ActionPlaceholderName = "action" ;
54- #if ENABLE_TENSORFLOW
54+ #if ENABLE_TENSORFLOW
5555 TFGraph graph ;
5656 TFSession session ;
5757 bool hasRecurrent ;
@@ -62,7 +62,7 @@ public enum tensorType
6262 float [ , ] inputState ;
6363 List < float [ , , , ] > observationMatrixList ;
6464 float [ , ] inputOldMemories ;
65- #endif
65+ #endif
6666
6767 /// Reference to the brain that uses this CoreBrainInternal
6868 public Brain brain ;
@@ -190,13 +190,22 @@ public void DecideAction()
190190
191191 foreach ( TensorFlowAgentPlaceholder placeholder in graphPlaceholders )
192192 {
193- if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . FloatingPoint )
193+ try
194194 {
195- runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new float [ ] { Random . Range ( placeholder . minValue , placeholder . maxValue ) } ) ;
195+ if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . FloatingPoint )
196+ {
197+ runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new float [ ] { Random . Range ( placeholder . minValue , placeholder . maxValue ) } ) ;
198+ }
199+ else if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . Integer )
200+ {
201+ runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new int [ ] { Random . Range ( ( int ) placeholder . minValue , ( int ) placeholder . maxValue + 1 ) } ) ;
202+ }
196203 }
197- else if ( placeholder . valueType == TensorFlowAgentPlaceholder . tensorType . Integer )
204+ catch
198205 {
199- runner . AddInput ( graph [ graphScope + placeholder . name ] [ 0 ] , new int [ ] { Random . Range ( ( int ) placeholder . minValue , ( int ) placeholder . maxValue + 1 ) } ) ;
206+ throw new UnityAgentsException ( string . Format ( @"One of the Tensorflow placeholder cound nout be found.
207+ In brain {0}, there are no {1} placeholder named {2}." ,
208+ brain . gameObject . name , placeholder . valueType . ToString ( ) , graphScope + placeholder . name ) ) ;
200209 }
201210 }
202211
@@ -212,6 +221,26 @@ public void DecideAction()
212221 runner . AddInput ( graph [ graphScope + ObservationPlaceholderName [ obs_number ] ] [ 0 ] , observationMatrixList [ obs_number ] ) ;
213222 }
214223
224+ TFTensor [ ] runned ;
225+ try
226+ {
227+ runned = runner . Run ( ) ;
228+ }
229+ catch ( TFException e )
230+ {
231+ string errorMessage = e . Message ;
232+ try
233+ {
234+ errorMessage = string . Format ( @"The tensorflow graph needs an input for {0} of type {1}" ,
235+ e . Message . Split ( new string [ ] { "Node: " } , 0 ) [ 1 ] . Split ( '=' ) [ 0 ] ,
236+ e . Message . Split ( new string [ ] { "dtype=" } , 0 ) [ 1 ] . Split ( ',' ) [ 0 ] ) ;
237+ }
238+ finally
239+ {
240+ throw new UnityAgentsException ( errorMessage ) ;
241+ }
242+
243+ }
215244
216245 // Create the recurrent tensor
217246 if ( hasRecurrent )
@@ -220,7 +249,7 @@ public void DecideAction()
220249
221250 runner . AddInput ( graph [ graphScope + RecurrentInPlaceholderName ] [ 0 ] , inputOldMemories ) ;
222251 runner . Fetch ( graph [ graphScope + RecurrentOutPlaceholderName ] [ 0 ] ) ;
223- float [ , ] recurrent_tensor = runner . Run ( ) [ 1 ] . GetValue ( ) as float [ , ] ;
252+ float [ , ] recurrent_tensor = runned [ 1 ] . GetValue ( ) as float [ , ] ;
224253
225254 int i = 0 ;
226255 foreach ( int k in agentKeys )
@@ -241,7 +270,7 @@ public void DecideAction()
241270
242271 if ( brain . brainParameters . actionSpaceType == StateType . continuous )
243272 {
244- float [ , ] output = runner . Run ( ) [ 0 ] . GetValue ( ) as float [ , ] ;
273+ float [ , ] output = runned [ 0 ] . GetValue ( ) as float [ , ] ;
245274 int i = 0 ;
246275 foreach ( int k in agentKeys )
247276 {
@@ -256,7 +285,7 @@ public void DecideAction()
256285 }
257286 else if ( brain . brainParameters . actionSpaceType == StateType . discrete )
258287 {
259- long [ , ] output = runner . Run ( ) [ 0 ] . GetValue ( ) as long [ , ] ;
288+ long [ , ] output = runned [ 0 ] . GetValue ( ) as long [ , ] ;
260289 int i = 0 ;
261290 foreach ( int k in agentKeys )
262291 {
0 commit comments