Skip to content

Commit 8b20ca3

Browse files
authored
[GCU] Add custom_engine ppocr_cls testcase (#1616)
1 parent f15d7d1 commit 8b20ca3

File tree

6 files changed

+490
-0
lines changed

6 files changed

+490
-0
lines changed

backends/gcu/custom_engine/ir_translator/operators/reshape.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ static GcuOpPtr TranslateReshape(
3838
} // namespace custom_engine
3939

4040
REGISTER_OP_TRANSLATOR(pd_op_reshape, custom_engine::TranslateReshape)
41+
REGISTER_OP_TRANSLATOR(pd_op_reshape_, custom_engine::TranslateReshape)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <vector>
16+
17+
#include "custom_engine/ir_translator/translator_registry.h"
18+
19+
namespace custom_engine {
20+
21+
static GcuOpPtr TranslateShape(
22+
GcuBuilderPtr gcu_builder,
23+
const pir::Operation *op,
24+
const std::vector<std::vector<GcuOpPtr>> &gcu_op_inputs) {
25+
auto x = *(gcu_op_inputs[0][0]);
26+
auto out = builder::Shape(x);
27+
out = builder::Convert(
28+
out, {out.GetType().GetShape(), builder::PrimitiveType::S32()});
29+
return std::make_shared<GcuOp>(out);
30+
}
31+
32+
} // namespace custom_engine
33+
34+
REGISTER_OP_TRANSLATOR(pd_op_shape, custom_engine::TranslateShape)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <vector>
16+
17+
#include "custom_engine/ir_translator/translator_registry.h"
18+
19+
namespace custom_engine {
20+
21+
static GcuOpPtr TranslateSlice(
22+
GcuBuilderPtr gcu_builder,
23+
const pir::Operation *op,
24+
const std::vector<std::vector<GcuOpPtr>> &gcu_op_inputs) {
25+
// Get attributes
26+
const auto &attributes = op->attributes();
27+
auto axes_list =
28+
attributes.at("axes").dyn_cast<pir::ArrayAttribute>().AsVector();
29+
std::vector<int64_t> axes;
30+
if (axes_list.size() > 0) {
31+
PADDLE_ENFORCE_EQ(axes_list[0].isa<pir::Int64Attribute>(),
32+
true,
33+
common::errors::Unimplemented(
34+
"the 0th axes MUST be pir::Int64Attribute"));
35+
for (size_t i = 0; i < axes_list.size(); ++i) {
36+
axes.push_back(axes_list[i].dyn_cast<pir::Int64Attribute>().data());
37+
}
38+
}
39+
40+
auto infer_flags_list =
41+
attributes.at("infer_flags").dyn_cast<pir::ArrayAttribute>().AsVector();
42+
std::vector<int64_t> infer_flags;
43+
if (infer_flags_list.size() > 0) {
44+
PADDLE_ENFORCE_EQ(infer_flags_list[0].isa<pir::Int64Attribute>(),
45+
true,
46+
common::errors::Unimplemented(
47+
"the 0th infer_flags MUST be pir::Int64Attribute"));
48+
for (size_t i = 0; i < infer_flags_list.size(); ++i) {
49+
infer_flags.push_back(
50+
infer_flags_list[i].dyn_cast<pir::Int64Attribute>().data());
51+
}
52+
}
53+
54+
auto decrease_axis_list =
55+
attributes.at("decrease_axis").dyn_cast<pir::ArrayAttribute>().AsVector();
56+
std::vector<int64_t> decrease_axis;
57+
if (decrease_axis_list.size() > 0) {
58+
PADDLE_ENFORCE_EQ(decrease_axis_list[0].isa<pir::Int64Attribute>(),
59+
true,
60+
common::errors::Unimplemented(
61+
"the 0th decrease_axis MUST be pir::Int64Attribute"));
62+
for (size_t i = 0; i < decrease_axis_list.size(); ++i) {
63+
decrease_axis.push_back(
64+
decrease_axis_list[i].dyn_cast<pir::Int64Attribute>().data());
65+
}
66+
}
67+
68+
auto input = *(gcu_op_inputs[0][0]);
69+
70+
auto starts_tensor = *(gcu_op_inputs[1][0]);
71+
PADDLE_ENFORCE_EQ(starts_tensor.IsConstant(),
72+
true,
73+
common::errors::PreconditionNotMet(
74+
"Input[1] starts_tensor is not a Constant."));
75+
auto starts = starts_tensor.GetConstData<int64_t>();
76+
77+
auto ends_tensor = *(gcu_op_inputs[2][0]);
78+
PADDLE_ENFORCE_EQ(ends_tensor.IsConstant(),
79+
true,
80+
common::errors::PreconditionNotMet(
81+
"Input[1] ends_tensor is not a Constant."));
82+
auto ends = ends_tensor.GetConstData<int64_t>();
83+
84+
auto rank = input.GetType().GetRank();
85+
const std::vector<int64_t> &input_shapes = input.GetType().GetShape();
86+
std::vector<int64_t> start_indices(rank, 0);
87+
std::vector<int64_t> limit_indices = input_shapes;
88+
for (size_t i = 0; i < axes.size(); ++i) {
89+
int dim = axes[i];
90+
if (dim < 0) {
91+
dim += rank;
92+
}
93+
start_indices[dim] =
94+
starts[i] < 0 ? starts[i] + input_shapes[dim] : starts[i];
95+
start_indices[dim] = std::max(start_indices[dim], 0L);
96+
start_indices[dim] = std::min(start_indices[dim], input_shapes[dim]);
97+
98+
limit_indices[dim] = ends[i] < 0 ? ends[i] + input_shapes[dim] : ends[i];
99+
limit_indices[dim] = std::min(limit_indices[dim], input_shapes[dim]);
100+
limit_indices[dim] = std::max(limit_indices[dim], 0L);
101+
}
102+
std::vector<int64_t> strides(rank, 1);
103+
104+
auto slice = builder::Slice(input, start_indices, limit_indices, strides);
105+
106+
if (decrease_axis.size() == 0) {
107+
return std::make_shared<GcuOp>(slice);
108+
} else {
109+
auto slice_shape = slice.GetType().GetShape();
110+
std::vector<int64_t> new_shape;
111+
size_t iter = 0;
112+
for (int64_t i = 0; i < static_cast<int64_t>(slice_shape.size()); ++i) {
113+
if (iter < decrease_axis.size() && i == decrease_axis[iter]) {
114+
++iter;
115+
} else {
116+
new_shape.emplace_back(slice_shape[i]);
117+
}
118+
}
119+
if (new_shape.empty()) {
120+
new_shape.emplace_back(1);
121+
}
122+
return std::make_shared<GcuOp>(builder::Reshape(slice, new_shape));
123+
}
124+
}
125+
126+
} // namespace custom_engine
127+
128+
REGISTER_OP_TRANSLATOR(pd_op_slice, custom_engine::TranslateSlice)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <vector>
16+
17+
#include "custom_engine/ir_translator/translator_registry.h"
18+
19+
namespace custom_engine {
20+
21+
static GcuOpPtr TranslateSoftmax(
22+
GcuBuilderPtr gcu_builder,
23+
const pir::Operation* op,
24+
const std::vector<std::vector<GcuOpPtr>>& gcu_op_inputs) {
25+
auto input = *(gcu_op_inputs[0][0]);
26+
27+
// Get attributes
28+
const auto& attributes = op->attributes();
29+
int64_t axis = static_cast<int64_t>(
30+
attributes.at("axis").dyn_cast<pir::Int32Attribute>().data());
31+
32+
if (!(input.GetType().GetPrimitiveType() == builder::PrimitiveType::F32()) &&
33+
!(input.GetType().GetPrimitiveType() == builder::PrimitiveType::F64())) {
34+
PADDLE_THROW(phi::errors::Unimplemented(
35+
"GCU softmax only support FP32/FP64 datatype so far as now!"));
36+
}
37+
38+
// to avoid 0
39+
double max_value_d = 1.0;
40+
double min_value_d = 1e-16;
41+
float max_value_f = 1.0;
42+
float min_value_f = 1e-7;
43+
void* max_ptr = nullptr;
44+
void* min_ptr = nullptr;
45+
auto scalar_type = builder::Type(input.GetType().GetPrimitiveType());
46+
if (input.GetType().GetPrimitiveType() == builder::PrimitiveType::F32()) {
47+
max_ptr = static_cast<void*>(&max_value_f);
48+
min_ptr = static_cast<void*>(&min_value_f);
49+
} else if (input.GetType().GetPrimitiveType() ==
50+
builder::PrimitiveType::F64()) {
51+
max_ptr = static_cast<void*>(&max_value_d);
52+
min_ptr = static_cast<void*>(&min_value_d);
53+
} else {
54+
PADDLE_THROW(phi::errors::InvalidArgument("Unsupported datatype"));
55+
}
56+
57+
auto max_op = builder::Const(gcu_builder, max_ptr, scalar_type);
58+
auto min_op = builder::Const(gcu_builder, min_ptr, scalar_type);
59+
auto softmax = builder::Softmax(input, axis, true, false, 0.0);
60+
auto res = builder::Clamp(min_op, softmax, max_op);
61+
return std::make_shared<GcuOp>(res);
62+
}
63+
64+
} // namespace custom_engine
65+
66+
REGISTER_OP_TRANSLATOR(pd_op_softmax, custom_engine::TranslateSoftmax)
67+
REGISTER_OP_TRANSLATOR(pd_op_softmax_, custom_engine::TranslateSoftmax)

backends/gcu/passes/gcu_op_marker_pass.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ DEFINE_GENERAL_PATTERN(Pool2d, paddle::dialect::Pool2dOp)
8080
DEFINE_GENERAL_PATTERN(Relu, paddle::dialect::ReluOp)
8181
DEFINE_GENERAL_PATTERN(Relu_, paddle::dialect::Relu_Op)
8282
DEFINE_GENERAL_PATTERN(Reshape, paddle::dialect::ReshapeOp)
83+
DEFINE_GENERAL_PATTERN(Shape, paddle::dialect::ShapeOp)
8384
DEFINE_GENERAL_PATTERN(Sigmoid, paddle::dialect::SigmoidOp)
85+
DEFINE_GENERAL_PATTERN(Slice, paddle::dialect::SliceOp)
86+
DEFINE_GENERAL_PATTERN(Softmax, paddle::dialect::SoftmaxOp)
8487
DEFINE_GENERAL_PATTERN(Sqrt, paddle::dialect::SqrtOp)
8588
DEFINE_GENERAL_PATTERN(Where, paddle::dialect::WhereOp)
8689

@@ -127,7 +130,10 @@ class GcuOpMarkerPass : public pir::PatternRewritePass {
127130
ADD_PATTERN(Relu)
128131
ADD_PATTERN(Relu_)
129132
ADD_PATTERN(Reshape)
133+
// ADD_PATTERN(Shape)
130134
ADD_PATTERN(Sigmoid)
135+
// ADD_PATTERN(Slice)
136+
ADD_PATTERN(Softmax)
131137
ADD_PATTERN(Sqrt)
132138
ADD_PATTERN(Where)
133139

0 commit comments

Comments
 (0)