@@ -108,37 +108,35 @@ class ContextProjectionForwardFunc : public FunctionBase {
108
108
}
109
109
110
110
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
111
- CHECK (1 == inputs.size () || 2 == inputs.size ());
112
- CHECK_EQ (( size_t ) 1 , outputs.size ());
111
+ CHECK (1UL == inputs.size () || 2UL == inputs.size ());
112
+ CHECK_EQ (1UL , outputs.size ());
113
113
CHECK (inputs[0 ].isSequenceArg () && outputs[0 ].isSequenceArg ())
114
114
<< " SequenceArg required here" ;
115
115
const auto val_seqs = dynamic_cast <const SequenceArg&>(inputs[0 ]);
116
116
auto out_seq = dynamic_cast <const SequenceArg&>(outputs[0 ]);
117
117
118
118
CHECK (out_seq.data () && val_seqs.data () && val_seqs.getSequenceId ().data ());
119
- CHECK_EQ (out_seq.shape ().ndims (), (size_t )2 );
120
- CHECK_EQ (val_seqs.shape ().ndims (), (size_t )2 );
121
- CHECK_EQ (val_seqs.getSequenceId ().shape ().ndims (), (size_t )1 );
122
- if (2 == inputs.size ()) {
123
- CHECK_EQ (inputs[1 ].shape ().ndims (), (size_t )2 );
124
- }
119
+ CHECK_EQ (out_seq.shape ().ndims (), 2UL );
120
+ CHECK_EQ (val_seqs.shape ().ndims (), 2UL );
125
121
// / dim of output = dim of input * context_length
126
122
CHECK_EQ (out_seq.shape ()[1 ], val_seqs.shape ()[1 ] * context_length_);
127
123
// / input and output has the same batch_size
128
124
CHECK_EQ (val_seqs.shape ()[0 ], out_seq.shape ()[0 ]);
129
- // / dim of input == dim of weight
130
- if (2 == inputs.size ()) {
125
+ if (2UL == inputs.size ()) {
126
+ CHECK_EQ (inputs[1 ].shape ().ndims (), 2UL );
127
+ // / dim of input == dim of weight
131
128
CHECK_EQ (val_seqs.shape ()[1 ], inputs[1 ].shape ()[1 ]);
132
129
}
133
130
134
131
CHECK_EQ (out_seq.getArgType (), ADD_TO);
135
132
auto out_mat = out_seq.matrix <Device>();
136
133
const auto in_mat = val_seqs.matrix <Device>();
137
134
const auto w_mat =
138
- (2 == inputs.size ())
135
+ (2UL == inputs.size () && inputs[ 1 ]. data ())
139
136
? inputs[1 ].matrix <Device>()
140
137
: typename Tensor<real, Device>::Matrix (nullptr , 0 , 0 );
141
138
const auto seq_vec = val_seqs.getSequenceId ().vector <int , Device>();
139
+
142
140
ContextProjectionForward<Device>(out_mat,
143
141
in_mat,
144
142
w_mat,
@@ -235,36 +233,40 @@ class ContextProjectionBackwardFunc : public FunctionBase {
235
233
}
236
234
237
235
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
238
- CHECK_EQ (( size_t ) 1 , inputs.size ());
239
- CHECK_EQ (( size_t ) 2 , outputs.size ());
236
+ CHECK_EQ (1UL , inputs.size ());
237
+ CHECK ( 1UL == outputs. size () || 2UL == outputs.size ());
240
238
CHECK (inputs[0 ].isSequenceArg () && outputs[0 ].isSequenceArg ())
241
239
<< " SequenceArg required here" ;
242
240
const auto in_seq = dynamic_cast <const SequenceArg&>(inputs[0 ]);
243
241
auto out_seq = dynamic_cast <const SequenceArg&>(outputs[0 ]);
244
242
CHECK (in_seq.data () && in_seq.getSequenceId ().data ());
245
- CHECK_EQ (in_seq.shape ().ndims (), (size_t )2 );
246
- CHECK_EQ (in_seq.getSequenceId ().shape ().ndims (), (size_t )1 );
247
- CHECK_EQ (out_seq.shape ().ndims (), (size_t )2 );
248
- CHECK_EQ (out_seq.getSequenceId ().shape ().ndims (), (size_t )1 );
249
- CHECK_EQ (outputs[1 ].shape ().ndims (), (size_t )2 );
243
+ CHECK_EQ (in_seq.shape ().ndims (), 2UL );
244
+ CHECK_EQ (out_seq.shape ().ndims (), 2UL );
245
+ CHECK_EQ (out_seq.getSequenceId ().shape ().ndims (), 1UL );
250
246
251
- // / dim of input grad == dim of weight
252
- CHECK_EQ (out_seq.shape ()[1 ], outputs[1 ].shape ()[1 ]);
253
247
// / input and output grad has the same batch_size
254
248
CHECK_EQ (out_seq.shape ()[0 ], in_seq.shape ()[0 ]);
255
249
// / dim of output grad = dim of input grad * context_length
256
250
CHECK_EQ (in_seq.shape ()[1 ], out_seq.shape ()[1 ] * context_length_);
257
251
CHECK_EQ (out_seq.getArgType (), ADD_TO);
258
- CHECK_EQ (outputs[1 ].getArgType (), ADD_TO);
252
+
253
+ if (2UL == outputs.size ()) {
254
+ CHECK_EQ (outputs[1 ].shape ().ndims (), 2UL );
255
+ // / dim of input grad == dim of weight
256
+ CHECK_EQ (out_seq.shape ()[1 ], outputs[1 ].shape ()[1 ]);
257
+ CHECK_EQ (outputs[1 ].getArgType (), ADD_TO);
258
+ }
259
259
260
260
const auto seq_vec = in_seq.getSequenceId ().vector <int , Device>();
261
261
const auto out_grad_mat = in_seq.matrix <Device>();
262
262
auto in_grad_mat =
263
263
!out_seq.data () ? typename Tensor<real, Device>::Matrix (nullptr , 0 , 0 )
264
264
: out_seq.matrix <Device>();
265
- auto w_grad_mat = !outputs[1 ].data ()
266
- ? typename Tensor<real, Device>::Matrix (nullptr , 0 , 0 )
267
- : outputs[1 ].matrix <Device>();
265
+ auto w_grad_mat =
266
+ (2UL == outputs.size () && outputs[1 ].data ())
267
+ ? outputs[1 ].matrix <Device>()
268
+ : typename Tensor<real, Device>::Matrix (nullptr , 0 , 0 );
269
+
268
270
ContextProjectionBackward<Device>(out_grad_mat,
269
271
in_grad_mat,
270
272
w_grad_mat,
@@ -304,17 +306,17 @@ class ContextProjectionBackwardDataFunc : public FunctionBase {
304
306
}
305
307
306
308
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
307
- CHECK_EQ (1 , static_cast < int >( inputs.size () ));
308
- CHECK_EQ (1 , static_cast < int >( outputs.size () ));
309
+ CHECK_EQ (1UL , inputs.size ());
310
+ CHECK_EQ (1UL , outputs.size ());
309
311
CHECK (inputs[0 ].isSequenceArg () && outputs[0 ].isSequenceArg ())
310
312
<< " SequenceArg required here" ;
311
313
const auto in_seq = dynamic_cast <const SequenceArg&>(inputs[0 ]);
312
314
const auto out_seq = dynamic_cast <const SequenceArg&>(outputs[0 ]);
313
315
314
316
CHECK (in_seq.data () && out_seq.data () && in_seq.getSequenceId ().data ());
315
- CHECK_EQ (static_cast < int >( out_seq.shape ().ndims ()), 2 );
316
- CHECK_EQ (static_cast < int >( in_seq.shape ().ndims ()), 2 );
317
- CHECK_EQ (static_cast < int >( in_seq.getSequenceId ().shape ().ndims ()), 1 );
317
+ CHECK_EQ (out_seq.shape ().ndims (), 2UL );
318
+ CHECK_EQ (in_seq.shape ().ndims (), 2UL );
319
+ CHECK_EQ (in_seq.getSequenceId ().shape ().ndims (), 1UL );
318
320
// / output layer grad dim == input layer grad dim * context_length_
319
321
CHECK_EQ (in_seq.shape ().ndims (), out_seq.shape ().ndims () * context_length_);
320
322
// / input and output has the same batch_size
@@ -355,14 +357,14 @@ class ContextProjectionBackwardWeightFunc : public FunctionBase {
355
357
}
356
358
357
359
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
358
- CHECK_EQ (1 , static_cast < int >( inputs.size () ));
359
- CHECK_EQ (1 , static_cast < int >( outputs.size () ));
360
+ CHECK_EQ (1UL , inputs.size ());
361
+ CHECK_EQ (1UL , outputs.size ());
360
362
CHECK (inputs[0 ].isSequenceArg ()) << " SequenceArg required here" ;
361
363
const auto in_seq = dynamic_cast <const SequenceArg&>(inputs[0 ]);
362
364
CHECK (in_seq.data () && in_seq.getSequenceId ().data () && outputs[0 ].data ());
363
- CHECK_EQ (static_cast < int >( outputs[0 ].shape ().ndims ()), 2 );
364
- CHECK_EQ (static_cast < int >( in_seq.shape ().ndims ()), 2 );
365
- CHECK_EQ (static_cast < int >( in_seq.getSequenceId ().shape ().ndims ()), 1 );
365
+ CHECK_EQ (outputs[0 ].shape ().ndims (), 2UL );
366
+ CHECK_EQ (in_seq.shape ().ndims (), 2UL );
367
+ CHECK_EQ (in_seq.getSequenceId ().shape ().ndims (), 1UL );
366
368
CHECK_EQ (in_seq.shape ()[0 ], outputs[0 ].shape ()[0 ]);
367
369
// / output layer grad dim == weight dim * context_length_
368
370
CHECK_EQ (in_seq.shape ()[1 ], outputs[0 ].shape ()[1 ] * context_length_);
0 commit comments