Skip to content

Commit 7ec9073

Browse files
authored
[0-size Tensor No.82] Add 0-size Tensor support for paddle.incubate.nn.functional.fused_bias_act [fluid_ops] (#74259)
* Fix * Fix * Fix * ci
1 parent 32e3f18 commit 7ec9073

File tree

4 files changed

+42
-1
lines changed

4 files changed

+42
-1
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2476,7 +2476,7 @@ void FusedBiasActInferMeta(const MetaTensor& x,
24762476
x_shapes.push_back(x_dims[i]);
24772477
}
24782478

2479-
if (config.is_runtime) {
2479+
if (config.is_runtime && x.numel() != 0) {
24802480
PADDLE_ENFORCE_GT(
24812481
x.numel() / dim,
24822482
0,

paddle/phi/kernels/fusion/gpu/fused_bias_act_kernel.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,20 @@ void FusedBiasActKernel(const Context &dev_ctx,
556556
float quant_max_bound,
557557
float quant_min_bound,
558558
DenseTensor *out) {
559+
if (out && out->numel() == 0) {
560+
if (quant_scale > 0) {
561+
dev_ctx.template Alloc<int8_t>(out);
562+
} else if (compute_dtype == "fp16") {
563+
dev_ctx.template Alloc<phi::dtype::float16>(out);
564+
} else if (compute_dtype == "bf16") {
565+
dev_ctx.template Alloc<phi::dtype::bfloat16>(out);
566+
} else if (compute_dtype == "fp32") {
567+
dev_ctx.template Alloc<float>(out);
568+
} else {
569+
dev_ctx.template Alloc<T>(out);
570+
}
571+
return;
572+
}
559573
int64_t cols = x.dims()[x.dims().size() - 1];
560574
int64_t rows = x.numel() / cols;
561575
if (x.dtype() == phi::DataType::INT32) {

paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ void FusedBiasActKernel(const Context &dev_ctx,
112112
DenseTensor *out) {
113113
auto xpu_ctx = static_cast<const phi::XPUContext *>(&dev_ctx);
114114
dev_ctx.template Alloc<T>(out);
115+
if (out->numel() == 0) return;
115116

116117
if (dequant_scales && dequant_scales.get().numel() > 0) {
117118
return DispatchComputeImpl<T>(xpu_ctx,

test/legacy_test/test_fused_bias_act_op.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,5 +784,31 @@ def test_check_output(self):
784784
)
785785

786786

787+
@unittest.skipIf(
788+
not core.is_compiled_with_cuda() and not core.is_compiled_with_rocm(),
789+
"core is not compiled with CUDA or ROCm",
790+
)
791+
class TestFusedBiasActOp_ZeroSize(TestWithoutBias):
792+
def setUp(self):
793+
paddle.seed(2017)
794+
np.random.seed(2017)
795+
796+
self.op_type = "fused_bias_act"
797+
self.rtol = 1e-5
798+
self.atol = 1e-3
799+
800+
self.batch_size = 2
801+
self.seq_len = 0
802+
self.cols = 512
803+
804+
self.dtype = 'float32'
805+
self.act_method = 'gelu'
806+
807+
self.use_glu = False
808+
809+
self.init_test_case()
810+
self.generate_inputs()
811+
812+
787813
if __name__ == '__main__':
788814
unittest.main()

0 commit comments

Comments
 (0)