Skip to content

Commit fbdd4ff

Browse files
set model = true
1 parent a253a6a commit fbdd4ff

File tree

5 files changed

+29
-29
lines changed

5 files changed

+29
-29
lines changed

models/transformer/decoder.hpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,18 @@ class TransformerDecoder
134134
*/
135135
Sequential<>* AttentionBlock()
136136
{
137-
Sequential<>* decoderBlockBottom = new Sequential<>(false);
137+
Sequential<>* decoderBlockBottom = new Sequential<>();
138138
decoderBlockBottom->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);
139139

140140
// Broadcast the incoming input to decoder
141141
// i.e. query into (query, key, value).
142-
Concat<>* decoderInput = new Concat<>();
142+
Concat<>* decoderInput = new Concat<>(true);
143143
decoderInput->Add<IdentityLayer<>>();
144144
decoderInput->Add<IdentityLayer<>>();
145145
decoderInput->Add<IdentityLayer<>>();
146146

147147
// Masked Self attention layer.
148-
Sequential<>* maskedSelfAttention = new Sequential<>(false);
148+
Sequential<>* maskedSelfAttention = new Sequential<>();
149149
maskedSelfAttention->Add(decoderInput);
150150

151151
MultiheadAttention<>* mha1 = new MultiheadAttention<>(tgtSeqLen,
@@ -157,7 +157,7 @@ class TransformerDecoder
157157
maskedSelfAttention->Add(mha1);
158158

159159
// Residual connection.
160-
AddMerge<>* residualAdd1 = new AddMerge<>();
160+
AddMerge<>* residualAdd1 = new AddMerge<>(true);
161161
residualAdd1->Add(maskedSelfAttention);
162162
residualAdd1->Add<IdentityLayer<>>();
163163

@@ -167,19 +167,19 @@ class TransformerDecoder
167167
decoderBlockBottom->Add<LayerNorm<>>(dModel * tgtSeqLen);
168168

169169
// This layer broadcasts the output of encoder i.e. key into (key, value).
170-
Concat<>* broadcastEncoderOutput = new Concat<>();
170+
Concat<>* broadcastEncoderOutput = new Concat<>(true);
171171
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
172172
broadcastEncoderOutput->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
173173

174174
// This layer concatenates the output of the bottom decoder block (query)
175175
// and the output of the encoder (key, value).
176-
Concat<>* encoderDecoderAttentionInput = new Concat<>();
177-
encoderDecoderAttentionInput->Add(decoderBlockBottom);
178-
encoderDecoderAttentionInput->Add(broadcastEncoderOutput);
176+
Concat<>* encDecAttnInput = new Concat<>(true);
177+
encDecAttnInput->Add<Subview<>>(1, 0, dModel * tgtSeqLen - 1, 0, -1);
178+
encDecAttnInput->Add(broadcastEncoderOutput);
179179

180180
// Encoder-decoder attention.
181-
Sequential<>* encoderDecoderAttention = new Sequential<>(false);
182-
encoderDecoderAttention->Add(encoderDecoderAttentionInput);
181+
Sequential<>* encoderDecoderAttention = new Sequential<>();
182+
encoderDecoderAttention->Add(encDecAttnInput);
183183

184184
MultiheadAttention<>* mha2 = new MultiheadAttention<>(tgtSeqLen,
185185
srcSeqLen,
@@ -189,11 +189,11 @@ class TransformerDecoder
189189
encoderDecoderAttention->Add(mha2);
190190

191191
// Residual connection.
192-
AddMerge<>* residualAdd2 = new AddMerge<>();
192+
AddMerge<>* residualAdd2 = new AddMerge<>(true);
193193
residualAdd2->Add(encoderDecoderAttention);
194-
residualAdd2->Add<IdentityLayer<>>();
194+
residualAdd2->Add(decoderBlockBottom);
195195

196-
Sequential<>* decoderBlock = new Sequential<>(false);
196+
Sequential<>* decoderBlock = new Sequential<>();
197197
decoderBlock->Add(residualAdd2);
198198
decoderBlock->Add<LayerNorm<>>(dModel * tgtSeqLen);
199199
return decoderBlock;
@@ -204,18 +204,18 @@ class TransformerDecoder
204204
*/
205205
Sequential<>* PositionWiseFFNBlock()
206206
{
207-
Sequential<>* positionWiseFFN = new Sequential<>(false);
207+
Sequential<>* positionWiseFFN = new Sequential<>();
208208
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN);
209209
positionWiseFFN->Add<ActivationFunction>();
210210
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel);
211211
positionWiseFFN->Add<Dropout<>>(dropout);
212212

213213
/* Residual connection. */
214-
AddMerge<>* residualAdd = new AddMerge<>();
214+
AddMerge<>* residualAdd = new AddMerge<>(true);
215215
residualAdd->Add(positionWiseFFN);
216216
residualAdd->Add<IdentityLayer<>>();
217217

218-
Sequential<>* decoderBlock = new Sequential<>(false);
218+
Sequential<>* decoderBlock = new Sequential<>();
219219
decoderBlock->Add(residualAdd);
220220
return decoderBlock;
221221
}

models/transformer/decoder_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ TransformerDecoder<ActivationFunction, RegularizerType>::TransformerDecoder(
5555
keyPaddingMask(keyPaddingMask),
5656
ownMemory(ownMemory)
5757
{
58-
decoder = new Sequential<>(false);
58+
decoder = new Sequential<>();
5959

6060
for (size_t n = 0; n < numLayers; ++n)
6161
{
@@ -66,11 +66,11 @@ TransformerDecoder<ActivationFunction, RegularizerType>::TransformerDecoder(
6666
break;
6767
}
6868

69-
Sequential<>* decoderBlock = new Sequential<>(false);
69+
Sequential<>* decoderBlock = new Sequential<>();
7070
decoderBlock->Add(AttentionBlock());
7171
decoderBlock->Add(PositionWiseFFNBlock());
7272

73-
Concat<>* concatQueryKey = new Concat<>();
73+
Concat<>* concatQueryKey = new Concat<>(true);
7474
concatQueryKey->Add(decoderBlock);
7575
concatQueryKey->Add<Subview<>>(1, dModel * tgtSeqLen, -1, 0, -1);
7676

models/transformer/encoder.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ class TransformerEncoder
132132
*/
133133
void AttentionBlock()
134134
{
135-
Concat<>* input = new Concat<>();
135+
Concat<>* input = new Concat<>(true);
136136
input->Add<IdentityLayer<>>();
137137
input->Add<IdentityLayer<>>();
138138
input->Add<IdentityLayer<>>();
139139

140140
/* Self attention layer. */
141-
Sequential<>* selfAttn = new Sequential<>(false);
141+
Sequential<>* selfAttn = new Sequential<>();
142142
selfAttn->Add(input);
143143

144144
MultiheadAttention<>* mha = new MultiheadAttention<>(srcSeqLen,
@@ -150,7 +150,7 @@ class TransformerEncoder
150150
selfAttn->Add(mha);
151151

152152
/* This layer adds a residual connection. */
153-
AddMerge<>* residualAdd = new AddMerge<>();
153+
AddMerge<>* residualAdd = new AddMerge<>(true);
154154
residualAdd->Add(selfAttn);
155155
residualAdd->Add<IdentityLayer<>>();
156156

@@ -163,14 +163,14 @@ class TransformerEncoder
163163
*/
164164
void PositionWiseFFNBlock()
165165
{
166-
Sequential<>* positionWiseFFN = new Sequential<>(false);
166+
Sequential<>* positionWiseFFN = new Sequential<>();
167167
positionWiseFFN->Add<Linear3D<>>(dModel, dimFFN);
168168
positionWiseFFN->Add<ActivationFunction>();
169169
positionWiseFFN->Add<Linear3D<>>(dimFFN, dModel);
170170
positionWiseFFN->Add<Dropout<>>(dropout);
171171

172172
/* This layer adds a residual connection. */
173-
AddMerge<>* residualAdd = new AddMerge<>();
173+
AddMerge<>* residualAdd = new AddMerge<>(true);
174174
residualAdd->Add(positionWiseFFN);
175175
residualAdd->Add<IdentityLayer<>>();
176176

models/transformer/encoder_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ TransformerEncoder<ActivationFunction, RegularizerType>::TransformerEncoder(
4040
keyPaddingMask(keyPaddingMask),
4141
ownMemory(ownMemory)
4242
{
43-
encoder = new Sequential<>(false);
43+
encoder = new Sequential<>();
4444

4545
for (size_t n = 0; n < numLayers; ++n)
4646
{

models/transformer/transformer_impl.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ Transformer<ActivationFunction, RegularizerType>::Transformer(
4646
keyPaddingMask(keyPaddingMask),
4747
ownMemory(ownMemory)
4848
{
49-
transformer = new Sequential<>(false);
49+
transformer = new Sequential<>();
5050

51-
Sequential<>* encoder = new Sequential<>(false);
51+
Sequential<>* encoder = new Sequential<>();
5252

5353
// Pull out the sequences of source language which is stacked above in the
5454
// input matrix. Here 'lastCol = -1' denotes upto last batch of input matrix.
@@ -69,7 +69,7 @@ Transformer<ActivationFunction, RegularizerType>::Transformer(
6969

7070
encoder->Add(encoderStack);
7171

72-
Sequential<>* decoderPE = new Sequential<>(false);
72+
Sequential<>* decoderPE = new Sequential<>();
7373

7474
// Pull out the sequences of target language which is stacked below in the
7575
// input matrix. Here 'lastRow = -1' and 'lastCol = -1' denotes upto last
@@ -78,7 +78,7 @@ Transformer<ActivationFunction, RegularizerType>::Transformer(
7878
decoderPE->Add<Lookup<>>(tgtVocabSize, dModel);
7979
decoderPE->Add<PositionalEncoding<>>(dModel, tgtSeqLen);
8080

81-
Concat<>* encoderDecoderConcat = new Concat<>();
81+
Concat<>* encoderDecoderConcat = new Concat<>(true);
8282
encoderDecoderConcat->Add(encoder);
8383
encoderDecoderConcat->Add(decoderPE);
8484

0 commit comments

Comments
 (0)