Skip to content

Commit 3b0a95d

Browse files
authored
add data copy (#475)
[FasterTransformer] Add data copy to implement ChooseKernel and DataTransform in Paddle
1 parent 62b460c commit 3b0a95d

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

examples/machine_translation/transformer/faster_transformer/encoder_decoding_predict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def do_predict(args):
8787
place = paddle.set_device(place)
8888

8989
# Define data loader
90+
# NOTE: Data yielded by DataLoader may be on CUDAPinnedPlace,
91+
# but custom op doesn't support CUDAPinnedPlace. Hence,
92+
# disable using CUDAPinnedPlace in DataLoader.
93+
paddle.fluid.reader.use_pinned_memory(False)
9094
test_loader, to_tokens = reader.create_infer_loader(args)
9195

9296
# Define model

paddlenlp/ops/src/fusion_decoding_op.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,23 @@ std::vector<paddle::Tensor> DecodingForward(
8080
} else {
8181
PD_THROW("Not supported decoding strategy. ");
8282
}
83-
auto output_ids = paddle::Tensor(input.place(), output_dims);
84-
auto parent_ids = paddle::Tensor(input.place(), parent_ids_dims);
85-
auto sequence_length = paddle::Tensor(input.place(), sequence_length_dims);
8683

8784
if (input.place() == paddle::PlaceType::kGPU) {
85+
auto output_ids = paddle::Tensor(paddle::PlaceType::kGPU, output_dims);
86+
auto parent_ids = paddle::Tensor(paddle::PlaceType::kGPU, parent_ids_dims);
87+
auto sequence_length =
88+
paddle::Tensor(paddle::PlaceType::kGPU, sequence_length_dims);
89+
90+
paddle::Tensor seq_len = paddle::Tensor(paddle::PlaceType::kGPU);
91+
92+
if (mem_seq_len.place() != paddle::PlaceType::kGPU) {
93+
seq_len = mem_seq_len.copy_to<int>(paddle::PlaceType::kGPU);
94+
} else {
95+
seq_len = mem_seq_len;
96+
}
97+
8898
return DecodingCUDAForward(input,
89-
mem_seq_len,
99+
seq_len,
90100
word_embedding,
91101
self_ln_weight,
92102
self_ln_bias,

paddlenlp/ops/src/fusion_gpt_op.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,16 @@ std::vector<paddle::Tensor> GPT2Forward(
4444
std::vector<int64_t> output_dims({total_len, batch_size});
4545
auto output_ids = paddle::Tensor(input.place(), output_dims);
4646

47-
if (input.place() == paddle::PlaceType::kGPU) {
48-
return GPT2CUDAForward(input,
47+
if (word_embedding.place() == paddle::PlaceType::kGPU) {
48+
paddle::Tensor input_ids = paddle::Tensor(paddle::PlaceType::kCPU);
49+
50+
if (input.place() != paddle::PlaceType::kCPU) {
51+
input_ids = input.copy_to<int>(paddle::PlaceType::kCPU);
52+
} else {
53+
input_ids = input;
54+
}
55+
56+
return GPT2CUDAForward(input_ids,
4957
word_embedding,
5058
self_ln_weight,
5159
self_ln_bias,

paddlenlp/ops/src/fusion_gpt_op.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,16 @@ std::vector<paddle::Tensor> gpt2_kernel(
6161
DecodingInitParam<DataType_> decoding_params;
6262
decoding_params.cublas_handle = cublas_handle_;
6363

64-
decoding_params.output_ids = output_ids.mutable_data<int>(input.place());
64+
decoding_params.output_ids = output_ids.mutable_data<int>(word_emb.place());
6565

6666
typedef DecoderTransformerTraits<traits_::OpType> DecodingTraits_;
6767
decoding_params.stream = stream;
6868
fastertransformer::Allocator<AllocatorType::PD> allocator_(stream);
6969

7070
DecodingGpt2<DecodingTraits_::OpType>* gpt2_decoding;
7171

72-
// input data is on gpu.
73-
int* h_input_data = new int[batch_size_ * start_len];
74-
cudaMemcpy(h_input_data,
75-
input.data<int>(),
76-
sizeof(int) * batch_size_ * start_len,
77-
cudaMemcpyDeviceToHost);
72+
// input data should be on CPU.
73+
int* h_input_data = input.data<int>();
7874
gpt2_decoding = new DecodingGpt2<DecodingTraits_::OpType>(allocator_,
7975
batch_size_,
8076
max_len,
@@ -189,7 +185,7 @@ std::vector<paddle::Tensor> GPT2CUDAForward(
189185
const int& eos_id,
190186
const float& temperature,
191187
const bool& use_fp16 = false) {
192-
auto stream = input.stream();
188+
auto stream = word_embedding.stream();
193189
cublasHandle_t cublas_handle_;
194190
cublasCreate(&cublas_handle_);
195191
cublasSetStream(cublas_handle_, stream);

0 commit comments

Comments
 (0)