@@ -651,13 +651,13 @@ object _get_input_tensor(int time)
651651 states = Nest . PackSequenceAs ( states , flat_final_states ) . ToTensors ( ) ;
652652 if ( return_all_outputs )
653653 {
654- successive_outputs . Add ( output ) ;
655- successive_states . Add ( states ) ;
654+ successive_outputs = successive_outputs . MergeWith ( output ) ;
655+ successive_outputs = successive_states . MergeWith ( states ) ;
656656 }
657657 else
658658 {
659- successive_outputs = new Tensors { output } ;
660- successive_states = new Tensors { states } ;
659+ successive_outputs = new Tensors ( output ) ;
660+ successive_states = new Tensors ( states ) ;
661661 }
662662
663663 }
@@ -722,16 +722,11 @@ object _get_input_tensor(int time)
722722 // Get the time(0) input and compute the output for that, the output will
723723 // be used to determine the dtype of output tensor array. Don't read from
724724 // input_ta due to TensorArray clear_after_read default to True.
725- var inps = new Tensors ( ) ;
726- foreach ( var inp in flatted_inptus )
727- {
728- inps . Add ( inp [ 0 ] ) ;
729- }
730- var input_time_zero = Nest . PackSequenceAs ( inputs , inps ) . ToTensors ( ) ;
725+ var input_time_zero = Nest . PackSequenceAs ( inputs , flatted_inptus . Select ( x => x [ 0 ] ) . ToArray ( ) ) . ToTensors ( ) ;
731726
732727 // output_time_zero is used to determine the cell output shape and its
733728 // dtype. the value is discarded.
734- ( output_time_zero , _ ) = step_function ( ( Tensor ) input_time_zero ,
729+ ( output_time_zero , _ ) = step_function ( input_time_zero ,
735730 constants is null ? initial_states : initial_states . MergeWith ( constants ) ) ;
736731
737732 int output_ta_size = return_all_outputs ? time_steps_t : 1 ;
@@ -816,6 +811,7 @@ object _get_input_tensor(int time)
816811
817812 Func < Tensor , Tensor > cond = ( time ) => ( time < time_steps_t ) ;
818813 int parallel_iterations = 32 ;
814+ new_states = states ;
819815 if ( masking_fn != null )
820816 {
821817 // Mask for the T output will be base on the output of T - 1. In the
@@ -846,7 +842,7 @@ RNN step function.
846842 // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
847843 var current_input = Nest . PackSequenceAs ( inputs , flat_current_input ) . ToTensors ( ) ;
848844 var mask_t = masking_fn ( time ) ;
849- var ( output , new_states_internal ) = step_function ( current_input , states . MergeWith ( constants ) ) ;
845+ var ( output , new_states_internal ) = step_function ( current_input , new_states . MergeWith ( constants ) ) ;
850846 // mask output
851847 var flat_output = Nest . Flatten ( output ) . ToList ( ) ;
852848
@@ -871,11 +867,12 @@ RNN step function.
871867 new_states_internal = Nest . PackSequenceAs ( new_states , flat_final_state ) . ToTensors ( ) ;
872868
873869 var ta_index_to_write = return_all_outputs ? time : tf . constant ( 0 ) ;
874- // TODO(Wanglongzhi2001),deal with zip output_ta_t
875- foreach ( var ( ta , Out ) in zip ( output_ta_t , flat_new_output ) )
870+ output_ta_t = zip ( output_ta_t , flat_new_output ) . Select ( item =>
876871 {
877- output_ta_t . Add ( ta . write ( ta_index_to_write , Out ) ) ;
878- }
872+ var ( ta , out_ ) = item ;
873+ return ta . write ( ta_index_to_write , out_ ) ;
874+ } ) . ToList ( ) ;
875+
879876
880877 new_states_internal = Nest . PackSequenceAs ( initial_states , flat_new_state ) . ToTensors ( ) ;
881878
@@ -921,15 +918,8 @@ Tensor _step(Tensor time)
921918 }
922919 var final_outputs = tf . while_loop ( cond : cond , body : _step , loop_vars : time , parallel_iterations : parallel_iterations ) ;
923920 }
924- //Tensors outputs = new Tensors();
925- foreach ( var o in output_ta )
926- {
927- outputs . Add ( o . stack ( ) ) ;
928- }
929- foreach ( var o in outputs )
930- {
931- last_output . Add ( o [ - 1 ] ) ;
932- }
921+ outputs = outputs . MergeWith ( output_ta . Select ( o => o . stack ( ) ) . ToTensors ( ) ) ;
922+ last_output = last_output . MergeWith ( outputs . Select ( o => o [ - 1 ] ) . ToTensors ( ) ) ;
933923 outputs = Nest . PackSequenceAs ( output_time_zero , outputs ) . ToTensors ( ) ;
934924 last_output = Nest . PackSequenceAs ( output_time_zero , last_output ) . ToTensors ( ) ;
935925
0 commit comments