Skip to content

Commit bd60e69

Browse files
authored
[XPU] Add bool type for concat op (#10527)
1 parent 79b0a58 commit bd60e69

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

lite/kernels/xpu/concat_compute.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ using concati64 =
104104
paddle::lite::kernels::xpu::ConcatCompute<int64_t, PRECISION(kFloat)>;
105105
using concati8 =
106106
paddle::lite::kernels::xpu::ConcatCompute<int8_t, PRECISION(kInt8)>;
107+
using concatbool =
108+
paddle::lite::kernels::xpu::ConcatCompute<bool, PRECISION(kFloat)>;
107109

108110
REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatfp32, def)
109111
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
@@ -147,3 +149,9 @@ REGISTER_LITE_KERNEL(concat, kXPU, kInt8, kNCHW, concati8, concat_INT8)
147149
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
148150
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))})
149151
.Finalize();
152+
REGISTER_LITE_KERNEL(concat, kXPU, kFloat, kNCHW, concatbool, concat_BOOL)
153+
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))})
154+
.BindInput("AxisTensor",
155+
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt32))})
156+
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kBool))})
157+
.Finalize();

0 commit comments

Comments
 (0)