Skip to content

Commit e9b3ac6

Browse files
authored
Fix cuda memory leak (#655)
* fix mem leak
1 parent 5a4f57c commit e9b3ac6

File tree

2 files changed

+83
-70
lines changed

2 files changed

+83
-70
lines changed

paddlenlp/ops/faster_transformer/src/fusion_decoding_op.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,11 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
340340
cublasCreate(&cublas_handle_);
341341
cublasSetStream(cublas_handle_, stream);
342342

343+
std::vector<paddle::Tensor> ret;
344+
343345
switch (input.type()) {
344346
case paddle::DataType::FLOAT16: {
345-
return decoding_kernel<paddle::DataType::FLOAT16>(
347+
ret = decoding_kernel<paddle::DataType::FLOAT16>(
346348
input,
347349
mem_seq_len,
348350
word_embedding,
@@ -393,9 +395,10 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
393395
beam_search_diversity_rate,
394396
cublas_handle_,
395397
stream);
398+
break;
396399
}
397400
case paddle::DataType::FLOAT32: {
398-
return decoding_kernel<paddle::DataType::FLOAT32>(
401+
ret = decoding_kernel<paddle::DataType::FLOAT32>(
399402
input,
400403
mem_seq_len,
401404
word_embedding,
@@ -446,11 +449,16 @@ std::vector<paddle::Tensor> DecodingCUDAForward(
446449
beam_search_diversity_rate,
447450
cublas_handle_,
448451
stream);
452+
break;
449453
}
450454
default: {
451455
PD_THROW(
452456
"NOT supported data type. "
453457
"Only float16 and float32 are supported. ");
458+
break;
454459
}
455460
}
461+
462+
cublasDestroy(cublas_handle_);
463+
return ret;
456464
}

paddlenlp/ops/faster_transformer/src/fusion_gpt_op.cu

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -190,75 +190,80 @@ std::vector<paddle::Tensor> GPT2CUDAForward(
190190
cublasCreate(&cublas_handle_);
191191
cublasSetStream(cublas_handle_, stream);
192192

193+
std::vector<paddle::Tensor> ret;
194+
193195
if (use_fp16) {
194-
return gpt2_kernel<paddle::DataType::FLOAT16>(input,
195-
word_embedding,
196-
self_ln_weight,
197-
self_ln_bias,
198-
self_q_weight,
199-
self_q_bias,
200-
self_k_weight,
201-
self_k_bias,
202-
self_v_weight,
203-
self_v_bias,
204-
self_out_weight,
205-
self_out_bias,
206-
ffn_ln_weight,
207-
ffn_ln_bias,
208-
ffn_inter_weight,
209-
ffn_inter_bias,
210-
ffn_out_weight,
211-
ffn_out_bias,
212-
decoder_ln_weight,
213-
decoder_ln_bias,
214-
positional_embedding_weight,
215-
emb_weight,
216-
output_ids,
217-
topk,
218-
topp,
219-
max_len,
220-
n_head,
221-
size_per_head,
222-
num_layer,
223-
bos_id,
224-
eos_id,
225-
temperature,
226-
cublas_handle_,
227-
stream);
196+
ret = gpt2_kernel<paddle::DataType::FLOAT16>(input,
197+
word_embedding,
198+
self_ln_weight,
199+
self_ln_bias,
200+
self_q_weight,
201+
self_q_bias,
202+
self_k_weight,
203+
self_k_bias,
204+
self_v_weight,
205+
self_v_bias,
206+
self_out_weight,
207+
self_out_bias,
208+
ffn_ln_weight,
209+
ffn_ln_bias,
210+
ffn_inter_weight,
211+
ffn_inter_bias,
212+
ffn_out_weight,
213+
ffn_out_bias,
214+
decoder_ln_weight,
215+
decoder_ln_bias,
216+
positional_embedding_weight,
217+
emb_weight,
218+
output_ids,
219+
topk,
220+
topp,
221+
max_len,
222+
n_head,
223+
size_per_head,
224+
num_layer,
225+
bos_id,
226+
eos_id,
227+
temperature,
228+
cublas_handle_,
229+
stream);
228230
} else {
229-
return gpt2_kernel<paddle::DataType::FLOAT32>(input,
230-
word_embedding,
231-
self_ln_weight,
232-
self_ln_bias,
233-
self_q_weight,
234-
self_q_bias,
235-
self_k_weight,
236-
self_k_bias,
237-
self_v_weight,
238-
self_v_bias,
239-
self_out_weight,
240-
self_out_bias,
241-
ffn_ln_weight,
242-
ffn_ln_bias,
243-
ffn_inter_weight,
244-
ffn_inter_bias,
245-
ffn_out_weight,
246-
ffn_out_bias,
247-
decoder_ln_weight,
248-
decoder_ln_bias,
249-
positional_embedding_weight,
250-
emb_weight,
251-
output_ids,
252-
topk,
253-
topp,
254-
max_len,
255-
n_head,
256-
size_per_head,
257-
num_layer,
258-
bos_id,
259-
eos_id,
260-
temperature,
261-
cublas_handle_,
262-
stream);
231+
ret = gpt2_kernel<paddle::DataType::FLOAT32>(input,
232+
word_embedding,
233+
self_ln_weight,
234+
self_ln_bias,
235+
self_q_weight,
236+
self_q_bias,
237+
self_k_weight,
238+
self_k_bias,
239+
self_v_weight,
240+
self_v_bias,
241+
self_out_weight,
242+
self_out_bias,
243+
ffn_ln_weight,
244+
ffn_ln_bias,
245+
ffn_inter_weight,
246+
ffn_inter_bias,
247+
ffn_out_weight,
248+
ffn_out_bias,
249+
decoder_ln_weight,
250+
decoder_ln_bias,
251+
positional_embedding_weight,
252+
emb_weight,
253+
output_ids,
254+
topk,
255+
topp,
256+
max_len,
257+
n_head,
258+
size_per_head,
259+
num_layer,
260+
bos_id,
261+
eos_id,
262+
temperature,
263+
cublas_handle_,
264+
stream);
263265
}
266+
267+
cublasDestroy(cublas_handle_);
268+
return ret;
264269
}

0 commit comments

Comments
 (0)