1
1
#include <paddle/capi.h>
2
2
#include <time.h>
3
- #include < iostream>
4
- #include < vector>
5
3
6
4
#include "../common/common.h"
7
5
8
6
#define CONFIG_BIN "./trainer_config.bin"
9
7
10
8
int main () {
11
9
// Initalize Paddle
12
- std::string comand [] = {" --use_gpu=False" };
13
- CHECK (paddle_init (1 , (char **)comand ));
10
+ char * argv [] = {"--use_gpu=False" };
11
+ CHECK (paddle_init (1 , (char * * )argv ));
14
12
15
13
// Reading config binary file. It is generated by `convert_protobin.sh`
16
14
long size ;
@@ -30,20 +28,19 @@ int main() {
30
28
CHECK (paddle_arguments_resize (in_args , 1 ));
31
29
32
30
// Create input matrix.
33
- paddle_matrix mat = paddle_matrix_create (/* sample_num */ 10 ,
31
+ paddle_matrix mat = paddle_matrix_create (/* sample_num */ 1 ,
34
32
/* size */ 784 ,
35
33
/* useGPU */ false);
36
34
srand (time (0 ));
37
35
38
- std::vector<paddle_real> input;
39
- input.resize (784 * 10 );
36
+ paddle_real * array ;
40
37
41
- for (int i = 0 ; i < input.size (); ++i) {
42
- input[i] = rand () / ((float )RAND_MAX);
43
- }
38
+ // Get First row.
39
+ CHECK (paddle_matrix_get_row (mat , 0 , & array ));
44
40
45
- // Set value for the input matrix
46
- CHECK (paddle_matrix_set_value (mat, input.data ()));
41
+ for (int i = 0 ; i < 784 ; ++ i ) {
42
+ array [i ] = rand () / ((float )RAND_MAX );
43
+ }
47
44
48
45
CHECK (paddle_arguments_set_value (in_args , 0 , mat ));
49
46
@@ -56,17 +53,15 @@ int main() {
56
53
57
54
CHECK (paddle_arguments_get_value (out_args , 0 , prob ));
58
55
59
- std::vector<paddle_real> result;
60
56
uint64_t height ;
61
57
uint64_t width ;
62
58
63
59
CHECK (paddle_matrix_get_shape (prob , & height , & width ));
64
- result.resize (height * width);
65
- CHECK (paddle_matrix_get_value (prob, result.data ()));
60
+ CHECK (paddle_matrix_get_row (prob , 0 , & array ));
66
61
67
62
printf ("Prob: \n" );
68
63
for (int i = 0 ; i < height * width ; ++ i ) {
69
- printf (" %.4f " , result [i]);
64
+ printf ("%.4f " , array [i ]);
70
65
if ((i + 1 ) % width == 0 ) {
71
66
printf ("\n" );
72
67
}
0 commit comments