Skip to content

Commit f2fdcc2

Browse files
authored
Fix fp8 gemm interface performance issue (#73512)
* fix gemm interface performance issue. * test=document_fix
1 parent 846e99a commit f2fdcc2

File tree

1 file changed

+10
-7
lines changed
  • python/paddle/incubate/nn/functional

1 file changed

+10
-7
lines changed

python/paddle/incubate/nn/functional/fp8.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
1718
from typing import TYPE_CHECKING
1819

1920
import paddle
20-
from paddle import _C_ops
21+
from paddle import Tensor, _C_ops
2122
from paddle.framework import in_dynamic_or_pir_mode
2223

2324
if TYPE_CHECKING:
2425
from collections.abc import Sequence
2526

26-
from paddle import Tensor
27+
28+
# special re-use of empty to reduce launch cost.
29+
@functools.cache
30+
def _empty_tensor() -> Tensor:
31+
"""Get tensor with no entries and no data"""
32+
return Tensor()
2733

2834

2935
def fused_stack_transpose_quant(
@@ -145,7 +151,7 @@ def fp8_gemm_blockwise(
145151
assert bias is None, "Bias is not supported"
146152

147153
if bias is None:
148-
bias = paddle.empty([0], dtype=paddle.float32)
154+
bias = _empty_tensor()
149155
else:
150156
assert bias.dtype in (
151157
paddle.float16,
@@ -172,9 +178,6 @@ def fp8_gemm_blockwise(
172178
else 4_194_304
173179
)
174180
workspace = paddle.empty([workspace_size], dtype=paddle.uint8)
175-
176-
empty_pre_gelu_out = paddle.empty([0], dtype=paddle.float32)
177-
178181
transa, transb = True, False
179182
grad = False
180183
math_sm_count = 112
@@ -187,7 +190,7 @@ def fp8_gemm_blockwise(
187190
a_decode_scale,
188191
out,
189192
bias,
190-
empty_pre_gelu_out,
193+
_empty_tensor(),
191194
workspace,
192195
transa,
193196
transb,

0 commit comments

Comments
 (0)