@@ -26,6 +26,15 @@ def __init__(self, input_type, pos):
26
26
if not isinstance (self .input_type , dp2 .InputType ):
27
27
raise ValueError ("input type should be dataprovider2.InputType" )
28
28
self .pos = pos
29
+ # data_in_gpu is used to indicate whether to create argument on GPU
30
+ # or not in GPU mode. Now if using one thread (trainer_count=1),
31
+ # trainer uses NeuralNetwork which needs to create argument on GPU
32
+ # before calling forward function. So, set data_in_gpu to True.
33
+ # Otherwise, trainer uses MultiGradientMachine which will transfer
34
+ # data from CPU to GPU in the forward function, set data_in_gpu to
35
+ # False in this case.
36
+ self .data_in_gpu = swig_paddle .isUsingGpu (
37
+ ) and swig_paddle .getTrainerCount () == 1
29
38
30
39
def scan (self , dat ):
31
40
pass
@@ -53,7 +62,8 @@ def finish_scan(self, argument):
53
62
assert isinstance (argument , swig_paddle .Arguments )
54
63
if self .__mat__ .dtype != numpy .float32 :
55
64
self .__mat__ = self .__mat__ .astype (numpy .float32 )
56
- m = swig_paddle .Matrix .createDenseFromNumpy (self .__mat__ , True , False )
65
+ m = swig_paddle .Matrix .createDenseFromNumpy (self .__mat__ , True ,
66
+ self .data_in_gpu )
57
67
argument .setSlotValue (self .pos , m )
58
68
59
69
@@ -75,10 +85,13 @@ def extend_cols(self, dat):
75
85
76
86
def finish_scan (self , argument ):
77
87
assert isinstance (argument , swig_paddle .Arguments )
78
- m = swig_paddle .Matrix .createSparse (self .__height__ ,
79
- self .input_type .dim ,
80
- len (self .__cols__ ),
81
- len (self .__value__ ) == 0 )
88
+ m = swig_paddle .Matrix .createSparse (
89
+ self .__height__ ,
90
+ self .input_type .dim ,
91
+ len (self .__cols__ ),
92
+ len (self .__value__ ) == 0 ,
93
+ False , # trans
94
+ False ) # TODO supoort GPU
82
95
assert isinstance (m , swig_paddle .Matrix )
83
96
m .sparseCopyFrom (self .__rows__ , self .__cols__ , self .__value__ )
84
97
argument .setSlotValue (self .pos , m )
@@ -102,7 +115,7 @@ def scan(self, dat):
102
115
self .__ids__ .append (dat )
103
116
104
117
def finish_scan (self , argument ):
105
- ids = swig_paddle .IVector .create (self .__ids__ )
118
+ ids = swig_paddle .IVector .create (self .__ids__ , self . data_in_gpu )
106
119
assert isinstance (argument , swig_paddle .Arguments )
107
120
argument .setSlotIds (self .pos , ids )
108
121
0 commit comments