Skip to content

Commit ca6fc14

Browse files
committed
add some comments
1 parent 6bf6034 commit ca6fc14

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

paddle/py_paddle/dataprovider_converter.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@ def __init__(self, input_type, pos):
2626
if not isinstance(self.input_type, dp2.InputType):
2727
raise ValueError("input type should be dataprovider2.InputType")
2828
self.pos = pos
29-
self.use_gpu = True if swig_paddle.isUsingGpu() and (
29+
# data_in_gpu is used to indicate whether to create argument on GPU
30+
# or not in GPU mode. Now if using one thread (trainer_count=1),
31+
# trainer uses NeuralNetwork which needs to create argument on GPU
32+
# before calling forward function. So, set data_in_gpu to True.
33+
# Otherwise, trainer uses MultiGradientMachine which will transfer
34+
# data from CPU to GPU in the forward function, set data_in_gpu to
35+
# False in this case.
36+
self.data_in_gpu = True if swig_paddle.isUsingGpu() and (
3037
swig_paddle.getTrainerCount() == 1) else False
3138

3239
def scan(self, dat):

0 commit comments

Comments
 (0)