Skip to content

Commit c3eb447

Browse files
authored
[CINN] Fixed dynamic arange symbolic values extraction. (#74412)
* [CINN] Fix cinn_op.generate_op attribute storing useless dim_expr * [CINN] Removed unnecessary VLOGs * [CINN] Simplify dynamic arange logic and fix bugs.
1 parent ded8a1e commit c3eb447

File tree

3 files changed

+24
-97
lines changed

3 files changed

+24
-97
lines changed

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -667,79 +667,6 @@ std::vector<ir::LoweredFunc> OpLowererImpl::DoOpLower(
667667
return funcs;
668668
}
669669

670-
/**
671-
* This function converts pir::Value::defining_op for ir::Tensor::operation
672-
* Normally, ir::Tensor::operation will only be used to record the name
673-
* of the compiler-generated var name, which is useless. However, operation
674-
* has Attributes field, so can be used to record the op info.
675-
*/
676-
ir::PlaceholderOp* TensorOperationRecording(const ::pir::Value& value) {
677-
// TODO(heqianyue): I think this is kinda ugly, since we should manually
678-
// specify the rules to convert all the op (and their attribute), yet current
679-
// implementation works and can be quickly written.
680-
const ::pir::Operation* define_op = value.defining_op();
681-
ir::PlaceholderOp* res = nullptr;
682-
if (!define_op) return res;
683-
res = cinn::common::make_shared<ir::PlaceholderOp>();
684-
res->name = define_op->name();
685-
// we filter some of the ops, and only record the **needed** attributes
686-
if (define_op->name() == "pd_op.full") {
687-
auto dtype = define_op->attribute("dtype")
688-
.dyn_cast<paddle::dialect::DataTypeAttribute>()
689-
.data();
690-
phi::Scalar data = define_op->attribute("value")
691-
.dyn_cast<paddle::dialect::ScalarAttribute>()
692-
.data();
693-
ir::Expr value;
694-
#define DEFINE_CASE(TypeFlag, Type) \
695-
case phi::DataType::TypeFlag: \
696-
value = ir::Expr(data.to<Type>()); \
697-
break;
698-
switch (dtype) {
699-
DEFINE_CASE(FLOAT32, float)
700-
DEFINE_CASE(FLOAT64, double)
701-
DEFINE_CASE(INT32, int)
702-
DEFINE_CASE(BFLOAT16, float)
703-
value->set_type(cinn::common::BFloat16());
704-
break;
705-
DEFINE_CASE(FLOAT16, float)
706-
value->set_type(cinn::common::Float16());
707-
break;
708-
default:
709-
value = ir::Expr(data.to<int64_t>());
710-
}
711-
#undef DEFINE_CASE
712-
res->attrs.emplace("value", value);
713-
} else if (define_op->name() == "cinn_op.generate_shape") {
714-
// pir::Attribute --> symbol::DimExpr --> ir::Expr
715-
716-
auto ir_dim_expr = [&]() {
717-
auto dim_expr_attr = define_op->attribute("output_dim_exprs");
718-
auto dim_exprs = dialect::ConvertAttributeToDimExprs(dim_expr_attr);
719-
720-
PADDLE_ENFORCE_EQ(
721-
dim_exprs.has_value(),
722-
true,
723-
::common::errors::PreconditionNotMet(
724-
"Required success to execute convert attribute to dim exprs."));
725-
726-
auto expr_vec = dim_exprs.value();
727-
PADDLE_ENFORCE_EQ(
728-
expr_vec.empty(),
729-
false,
730-
::common::errors::PreconditionNotMet(
731-
"Generate shape op can not yield empty symbolic shape."));
732-
// only the first dim_expr matters for ArangeOp
733-
return common::DimExprConverter().ConvertToIrExpr(expr_vec[0]);
734-
}();
735-
res->attrs.emplace("value", ir_dim_expr);
736-
} else {
737-
VLOG(6) << "Tensor defining op recording: not currently supported op.";
738-
return nullptr;
739-
}
740-
return res;
741-
}
742-
743670
ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
744671
const ::pir::Value& value) {
745672
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
@@ -778,9 +705,6 @@ ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group,
778705
tensor->set_value(*tensor_value);
779706
}
780707
}
781-
if (auto op_ptr = TensorOperationRecording(value)) {
782-
tensor->operation = ir::FunctionRef(op_ptr);
783-
}
784708
return tensor;
785709
}
786710

paddle/cinn/hlir/framework/pir/op_lowering_util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ std::unordered_set<::pir::Operation*> GetMasters(
180180
}
181181

182182
bool IsConstOp(const ::pir::Operation* op) {
183-
static std::unordered_set<std::string> const_op_type = {
184-
"const_scalar", "fill_constant", "arange"};
183+
static std::unordered_set<std::string> const_op_type = {"const_scalar",
184+
"fill_constant"};
185185
return const_op_type.count(CompatibleInfo::OpName(*op));
186186
}
187187

paddle/cinn/hlir/op/elementwise.cc

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,16 +1257,12 @@ std::shared_ptr<framework::OpStrategy> StrategyForArangeSymbolic(
12571257
const Target &target) {
12581258
bool all_static = true;
12591259
for (int i = 0; i < 3; i++) {
1260-
auto op_node = inputs[i]->operation->as<ir::PlaceholderOp>();
1261-
PADDLE_ENFORCE_NE(
1262-
op_node,
1263-
nullptr,
1264-
::common::errors::PreconditionNotMet(
1265-
"The defining op of the input tensor is not set! Please check."));
1266-
if (op_node->name == "cinn_op.generate_shape") {
1267-
all_static = false;
1268-
break;
1269-
}
1260+
if (!inputs[i]->value().has_value()) continue;
1261+
auto input_val = inputs[i]->value().value();
1262+
if (input_val.empty()) continue;
1263+
if (input_val[0].is_constant()) continue;
1264+
all_static = false;
1265+
break;
12701266
}
12711267
auto attr_store = attrs.attr_store;
12721268
auto dtype =
@@ -1341,15 +1337,22 @@ std::shared_ptr<framework::OpStrategy> StrategyForArangeSymbolic(
13411337
"bfloat16 or float16."));
13421338
}
13431339
#undef EXPR_FROM_ATTR
1344-
} else { // has dynamic shape, some of the operands come from
1345-
// cinn_op.generate_shape
1346-
// in op_lowering_impl.cc, tensor op recorder unified the attribute name
1347-
start = Expr(
1348-
inputs[0]->operation->as<ir::PlaceholderOp>()->attrs.at("value").ptr());
1349-
step = Expr(
1350-
inputs[2]->operation->as<ir::PlaceholderOp>()->attrs.at("value").ptr());
1351-
Expr end = Expr(
1352-
inputs[1]->operation->as<ir::PlaceholderOp>()->attrs.at("value").ptr());
1340+
} else {
1341+
for (int i = 0; i < 3; i++) {
1342+
PADDLE_ENFORCE_EQ(
1343+
inputs[i]->value().has_value(),
1344+
true,
1345+
::common::errors::InvalidArgument(
1346+
"The input tensor of dynamic arange should have valid values."));
1347+
PADDLE_ENFORCE_NE(
1348+
inputs[i]->value().value().empty(),
1349+
true,
1350+
::common::errors::InvalidArgument(
1351+
"The tensor value of dynamic arange should not be empty."));
1352+
}
1353+
start = inputs[0]->value().value()[0];
1354+
step = inputs[2]->value().value()[0];
1355+
Expr end = inputs[1]->value().value()[0];
13531356
auto IrAbs = [=](Expr ir) -> Expr {
13541357
return ir::Call::Make(step.type(), "abs", {ir}, {}, ir::CallType::Extern);
13551358
};

0 commit comments

Comments
 (0)