@@ -134,18 +134,18 @@ class TransformerDecoder
134134 */
135135 Sequential<>* AttentionBlock ()
136136 {
137- Sequential<>* decoderBlockBottom = new Sequential<>(false );
137+ Sequential<>* decoderBlockBottom = new Sequential<>();
138138 decoderBlockBottom->Add <Subview<>>(1 , 0 , dModel * tgtSeqLen - 1 , 0 , -1 );
139139
140140 // Broadcast the incoming input to decoder
141141 // i.e. query into (query, key, value).
142- Concat<>* decoderInput = new Concat<>();
142+ Concat<>* decoderInput = new Concat<>(true );
143143 decoderInput->Add <IdentityLayer<>>();
144144 decoderInput->Add <IdentityLayer<>>();
145145 decoderInput->Add <IdentityLayer<>>();
146146
147147 // Masked Self attention layer.
148- Sequential<>* maskedSelfAttention = new Sequential<>(false );
148+ Sequential<>* maskedSelfAttention = new Sequential<>();
149149 maskedSelfAttention->Add (decoderInput);
150150
151151 MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen,
@@ -157,7 +157,7 @@ class TransformerDecoder
157157 maskedSelfAttention->Add (mha1);
158158
159159 // Residual connection.
160- AddMerge<>* residualAdd1 = new AddMerge<>();
160+ AddMerge<>* residualAdd1 = new AddMerge<>(true );
161161 residualAdd1->Add (maskedSelfAttention);
162162 residualAdd1->Add <IdentityLayer<>>();
163163
@@ -167,19 +167,19 @@ class TransformerDecoder
167167 decoderBlockBottom->Add <LayerNorm<>>(dModel * tgtSeqLen);
168168
169169 // This layer broadcasts the output of encoder i.e. key into (key, value).
170- Concat<>* broadcastEncoderOutput = new Concat<>();
170+ Concat<>* broadcastEncoderOutput = new Concat<>(true );
171171 broadcastEncoderOutput->Add <Subview<>>(1 , dModel * tgtSeqLen, -1 , 0 , -1 );
172172 broadcastEncoderOutput->Add <Subview<>>(1 , dModel * tgtSeqLen, -1 , 0 , -1 );
173173
174174 // This layer concatenates the output of the bottom decoder block (query)
175175 // and the output of the encoder (key, value).
176- Concat<>* encoderDecoderAttentionInput = new Concat<>();
177- encoderDecoderAttentionInput ->Add (decoderBlockBottom );
178- encoderDecoderAttentionInput ->Add (broadcastEncoderOutput);
176+ Concat<>* encDecAttnInput = new Concat<>(true );
177+ encDecAttnInput ->Add <Subview<>>( 1 , 0 , dModel * tgtSeqLen - 1 , 0 , - 1 );
178+ encDecAttnInput ->Add (broadcastEncoderOutput);
179179
180180 // Encoder-decoder attention.
181- Sequential<>* encoderDecoderAttention = new Sequential<>(false );
182- encoderDecoderAttention->Add (encoderDecoderAttentionInput );
181+ Sequential<>* encoderDecoderAttention = new Sequential<>();
182+ encoderDecoderAttention->Add (encDecAttnInput );
183183
184184 MultiheadAttention<>* mha2 = new MultiheadAttention<>(tgtSeqLen,
185185 srcSeqLen,
@@ -189,11 +189,11 @@ class TransformerDecoder
189189 encoderDecoderAttention->Add (mha2);
190190
191191 // Residual connection.
192- AddMerge<>* residualAdd2 = new AddMerge<>();
192+ AddMerge<>* residualAdd2 = new AddMerge<>(true );
193193 residualAdd2->Add (encoderDecoderAttention);
194- residualAdd2->Add <IdentityLayer<>>( );
194+ residualAdd2->Add (decoderBlockBottom );
195195
196- Sequential<>* decoderBlock = new Sequential<>(false );
196+ Sequential<>* decoderBlock = new Sequential<>();
197197 decoderBlock->Add (residualAdd2);
198198 decoderBlock->Add <LayerNorm<>>(dModel * tgtSeqLen);
199199 return decoderBlock;
@@ -204,18 +204,18 @@ class TransformerDecoder
204204 */
205205 Sequential<>* PositionWiseFFNBlock ()
206206 {
207- Sequential<>* positionWiseFFN = new Sequential<>(false );
207+ Sequential<>* positionWiseFFN = new Sequential<>();
208208 positionWiseFFN->Add <Linear3D<>>(dModel, dimFFN);
209209 positionWiseFFN->Add <ActivationFunction>();
210210 positionWiseFFN->Add <Linear3D<>>(dimFFN, dModel);
211211 positionWiseFFN->Add <Dropout<>>(dropout);
212212
213213 /* Residual connection. */
214- AddMerge<>* residualAdd = new AddMerge<>();
214+ AddMerge<>* residualAdd = new AddMerge<>(true );
215215 residualAdd->Add (positionWiseFFN);
216216 residualAdd->Add <IdentityLayer<>>();
217217
218- Sequential<>* decoderBlock = new Sequential<>(false );
218+ Sequential<>* decoderBlock = new Sequential<>();
219219 decoderBlock->Add (residualAdd);
220220 return decoderBlock;
221221 }
0 commit comments