@@ -28,11 +28,12 @@ void testMatrixProjectionForward(int context_start,
28
28
std::max (0 , (int )(context_start + context_length - 1 ));
29
29
if (pad == 0 ) is_padding = false ;
30
30
31
- FunctionCompare test (" ContextProjectionForward" ,
32
- FuncConfig ()
33
- .set (" context_length" , context_length)
34
- .set (" context_start" , context_start)
35
- .set (" begin_pad" , std::max (0 , -context_start)));
31
+ FunctionCompare test (
32
+ " ContextProjectionForward" ,
33
+ FuncConfig ()
34
+ .set (" context_length" , context_length)
35
+ .set (" context_start" , context_start)
36
+ .set (" begin_pad" , (size_t )std::max (0 , -context_start)));
36
37
37
38
// prepare input arguments
38
39
test.addSequence (SequenceIdArg (TensorShape{batch_size}));
@@ -51,21 +52,22 @@ void testMatrixProjectionForward(int context_start,
51
52
}
52
53
53
54
void testMatrixProjectionBackward (int context_start,
54
- int context_length,
55
+ size_t context_length,
55
56
bool is_padding,
56
57
size_t batch_size,
57
58
size_t input_dim) {
58
59
size_t pad = std::max (0 , -context_start) +
59
60
std::max (0 , (int )(context_start + context_length - 1 ));
60
61
if (pad == 0 ) is_padding = false ;
61
62
62
- FunctionCompare test (" ContextProjectionBackward" ,
63
- FuncConfig ()
64
- .set (" context_length" , context_length)
65
- .set (" context_start" , context_start)
66
- .set (" begin_pad" , std::max (0 , -context_start))
67
- .set (" is_padding" , is_padding)
68
- .set (" total_pad" , pad));
63
+ FunctionCompare test (
64
+ " ContextProjectionBackward" ,
65
+ FuncConfig ()
66
+ .set (" context_length" , context_length)
67
+ .set (" context_start" , context_start)
68
+ .set (" begin_pad" , (size_t )std::max (0 , -context_start))
69
+ .set (" is_padding" , is_padding)
70
+ .set (" total_pad" , pad));
69
71
70
72
// prepare input arguments
71
73
test.addSequence (SequenceIdArg (TensorShape{batch_size}));
0 commit comments