File tree Expand file tree Collapse file tree 1 file changed +10
-7
lines changed
python/paddle/incubate/nn/functional Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change 14
14
15
15
from __future__ import annotations
16
16
17
+ import functools
17
18
from typing import TYPE_CHECKING
18
19
19
20
import paddle
20
- from paddle import _C_ops
21
+ from paddle import Tensor , _C_ops
21
22
from paddle .framework import in_dynamic_or_pir_mode
22
23
23
24
if TYPE_CHECKING :
24
25
from collections .abc import Sequence
25
26
26
- from paddle import Tensor
27
+
28
+ # special re-use of empty to reduce launch cost.
29
+ @functools .cache
30
+ def _empty_tensor () -> Tensor :
31
+ """Get tensor with no entries and no data"""
32
+ return Tensor ()
27
33
28
34
29
35
def fused_stack_transpose_quant (
@@ -145,7 +151,7 @@ def fp8_gemm_blockwise(
145
151
assert bias is None , "Bias is not supported"
146
152
147
153
if bias is None :
148
- bias = paddle . empty ([ 0 ], dtype = paddle . float32 )
154
+ bias = _empty_tensor ( )
149
155
else :
150
156
assert bias .dtype in (
151
157
paddle .float16 ,
@@ -172,9 +178,6 @@ def fp8_gemm_blockwise(
172
178
else 4_194_304
173
179
)
174
180
workspace = paddle .empty ([workspace_size ], dtype = paddle .uint8 )
175
-
176
- empty_pre_gelu_out = paddle .empty ([0 ], dtype = paddle .float32 )
177
-
178
181
transa , transb = True , False
179
182
grad = False
180
183
math_sm_count = 112
@@ -187,7 +190,7 @@ def fp8_gemm_blockwise(
187
190
a_decode_scale ,
188
191
out ,
189
192
bias ,
190
- empty_pre_gelu_out ,
193
+ _empty_tensor () ,
191
194
workspace ,
192
195
transa ,
193
196
transb ,
You can’t perform that action at this time.
0 commit comments