@@ -17,6 +17,7 @@ limitations under the License.
1717using System ;
1818using System . Collections . Generic ;
1919using System . Linq ;
20+ using Tensorflow . Eager ;
2021using Tensorflow . Framework ;
2122using static Tensorflow . Binding ;
2223
@@ -48,6 +49,7 @@ public class _EagerTensorArray : TensorArray
4849 public override Tensor flow => _flow ;
4950 bool _clear_after_read ;
5051 List < Tensor > _tensor_array ;
52+ List < int > _previous_read_indices ;
5153
5254 public _EagerTensorArray ( TF_DataType dtype , Tensor size , bool dynamic_size = false ,
5355 bool clear_after_read = true , string tensor_array_name = null , Tensor handle = null , Tensor flow = null ,
@@ -61,16 +63,20 @@ public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = fal
6163 _dtype = dtype . as_base_dtype ( ) ;
6264 _dynamic_size = dynamic_size ;
6365 _clear_after_read = clear_after_read ;
64- _tensor_array = new List < Tensor > ( ) ;
66+ _tensor_array = Enumerable . Repeat < Tensor > ( null , size . numpy ( ) ) . ToList ( ) ;
67+ _previous_read_indices = new ( ) ;
6568 }
6669
6770 public override TensorArray unstack ( Tensor value , string name = null )
6871 {
69- return tf_with ( ops . name_scope ( name , "TensorArrayUnstack" , new { _handle , value } ) , delegate
72+ var tensors = array_ops . unstack ( value , name : name ) ;
73+ if ( tensors . Length > _tensor_array . Count && ! _dynamic_size )
7074 {
71- var num_elements = array_ops . shape ( value ) [ 0 ] ;
72- return scatter ( indices : math_ops . range ( 0 , num_elements ) , value : value , name : name ) ;
73- } ) ;
75+ throw new ValueError ( $ "Cannot unstack { tensors . Length } tensors into a TensorArray of static size { _tensor_array . Count } ") ;
76+ }
77+ _tensor_array = tensors . ToList ( ) ;
78+ // TODO(Rinne): revise the implementation. Here we should return `parent()`.
79+ return this ;
7480 }
7581
7682 public TensorArray scatter ( Tensor indices , Tensor value , string name = null )
@@ -116,37 +122,95 @@ public void _maybe_colocate_with(Tensor value)
116122 _colocate_with . Add ( value ) ;
117123 }
118124
125+ private Tensor _maybe_zero ( int ix )
126+ {
127+ var val = _tensor_array [ ix ] ;
128+ if ( val is null )
129+ {
130+ val = _tensor_array [ ix ] = array_ops . zeros ( _element_shape , _dtype ) ;
131+ }
132+ return val ;
133+ }
134+
119135 public override Tensor read < T > ( T index , string name = null )
120136 {
121- int index_int = - 1 ;
137+ int index_int ;
122138 if ( index is int int_index )
123139 index_int = int_index ;
124140 else if ( index is Tensor tensor_index )
125141 index_int = tensor_index . numpy ( ) ;
126142 else
127143 throw new ValueError ( "" ) ;
128144
145+ if ( index_int >= _tensor_array . Count )
146+ {
147+ throw new OutOfRangeError ( $ "Tried to read from index { index_int } but array size is: { _tensor_array . Count } ") ;
148+ }
149+
150+ var res = _tensor_array [ index_int ] ;
151+ if ( res is null )
152+ {
153+ if ( _previous_read_indices . Contains ( index_int ) )
154+ {
155+ throw new InvalidArgumentError ( $ "Could not read index { index_int } twice because it was cleared after " +
156+ $ "a previous read (perhaps try setting clear_after_read = false?)") ;
157+ }
158+ else
159+ {
160+ res = _maybe_zero ( index_int ) ;
161+ }
162+ }
163+
129164 if ( _clear_after_read )
130165 {
131166 _tensor_array [ index_int ] = null ;
167+ _previous_read_indices . Add ( index_int ) ;
132168 }
133-
134- return _tensor_array [ index_int ] ;
169+ return res ;
135170 }
136171
137172 public override TensorArray write ( Tensor index , Tensor value , string name = null )
138173 {
139- if ( _infer_shape )
140- _element_shape = _element_shape . merge_with ( value . shape ) ;
141- _tensor_array . add ( value ) ;
142- return this ;
174+ int index_int ;
175+ if ( index is EagerTensor eager )
176+ {
177+ return write < Tensor > ( eager . numpy ( ) , value , name ) ;
178+ }
179+ throw new InvalidArgumentError ( "The index is supposed to be an EagerTensor" ) ;
143180 }
144181
145182 public override TensorArray write < T > ( int index , T value , string name = null )
146183 {
147- var value_tensor = ops . convert_to_tensor ( value , preferred_dtype : _dtype , name : "value" ) ;
148- var index_tensor = ops . convert_to_tensor ( index , name : "index" ) ;
149- return write ( index_tensor , value_tensor , name : name ) ;
184+ int size = _tensor_array . Count ;
185+ if ( index >= size )
186+ {
187+ if ( ! _dynamic_size )
188+ {
189+ throw new OutOfRangeError ( $ "Tried to write to index { index } but array is not resizeable and size " +
190+ $ "is: { size } ") ;
191+ }
192+ _tensor_array . AddRange ( Enumerable . Repeat < Tensor > ( null , index - size + 1 ) ) ;
193+ }
194+
195+ Tensor tensor = ops . convert_to_tensor ( value , preferred_dtype : _dtype , name : "value" ) ;
196+
197+ if ( _dtype != tensor . dtype )
198+ {
199+ throw new InvalidArgumentError ( $ "TensorArray dtype is { _dtype . as_python_name ( ) } but Op is " +
200+ $ "trying to write dtype { tensor . dtype . as_python_name ( ) } ") ;
201+ }
202+
203+ if ( ! _element_shape . is_compatible_with ( tensor . shape ) )
204+ {
205+ throw new ValueError ( $ "Incompatible shape for value ({ tensor . shape } ), expected ({ _element_shape } )") ;
206+ }
207+
208+ if ( _infer_shape )
209+ {
210+ _element_shape = _element_shape . merge_with ( tensor . shape ) ;
211+ }
212+ _tensor_array [ index ] = tensor ;
213+ return this ;
150214 }
151215
152216 private Tensor size ( string name = null )
@@ -156,11 +220,26 @@ private Tensor size(string name = null)
156220
157221 public override Tensor stack ( string name = null )
158222 {
159- ops . colocate_with ( _handle ) ;
160- return tf_with ( ops . name_scope ( name , "TensorArrayStack" , new { _handle } ) , delegate
223+ if ( _tensor_array . Count > 0 )
161224 {
162- return gather ( math_ops . range ( 0 , size ( ) ) , name : name ) ;
163- } ) ;
225+ for ( int i = 0 ; i < _tensor_array . Count ; i ++ )
226+ {
227+ _maybe_zero ( i ) ;
228+ }
229+ }
230+ if ( _tensor_array . Count == 0 && _element_shape . IsFullyDefined )
231+ {
232+ return ops . convert_to_tensor ( new Shape ( new long [ ] { 0 } . Concat ( _element_shape . dims ) . ToArray ( ) ) , name : name , dtype : _dtype ) ;
233+ }
234+ else
235+ {
236+ return ops . convert_to_tensor ( _tensor_array , name : name , dtype : _dtype ) ;
237+ }
238+ //ops.colocate_with(_handle);
239+ //return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
240+ //{
241+ // return gather(math_ops.range(0, size()), name: name);
242+ //});
164243 }
165244
166245 public override Tensor gather ( Tensor indices , string name = null )
0 commit comments