Skip to content

Commit 6b15cb2

Browse files
merge develop
2 parents ff5b73d + 2d2476e commit 6b15cb2

File tree

269 files changed

+2793
-944
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

269 files changed

+2793
-944
lines changed

cmake/external/xpu.cmake

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ if(WITH_XPU_XRE5)
253253
DOWNLOAD_COMMAND
254254
bash ${CMAKE_SOURCE_DIR}/tools/xpu/pack_paddle_dependence.sh
255255
${XPU_XRE_URL} ${XPU_XRE_DIR_NAME} ${XPU_XHPC_URL} ${XPU_XHPC_DIR_NAME}
256-
${XPU_XCCL_URL} ${XPU_XCCL_DIR_NAME} 1 ${WITH_MKL}
257-
"${CMAKE_SOURCE_DIR}/build" && wget ${XPU_XFT_GET_DEPENCE_URL} && bash
258-
${XFT_COMMAND} ${XPU_XFT_URL} ${XPU_XFT_DIR_NAME} && bash
256+
${XPU_XCCL_URL} ${XPU_XCCL_DIR_NAME} 1 ${WITH_MKL} "${CMAKE_BINARY_DIR}"
257+
&& wget ${XPU_XFT_GET_DEPENCE_URL} && bash ${XFT_COMMAND} ${XPU_XFT_URL}
258+
${XPU_XFT_DIR_NAME} && bash
259259
${CMAKE_SOURCE_DIR}/tools/xpu/get_xpti_dependence.sh ${XPU_XPTI_URL}
260260
${XPU_XPTI_DIR_NAME} && bash
261261
${CMAKE_SOURCE_DIR}/tools/xpu/get_xpufft_dependence.sh ${XPU_FFT_URL}

paddle/cinn/common/integer_set.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ cas_intervals_t CollectVarIntervalsOfExprs(const std::vector<ir::Expr>& exprs,
164164
lower_bound = ir::Expr(1);
165165
}
166166
var_intervals.insert(
167-
{var->name, CasInterval(lower_bound, upper_bound)});
167+
{var->name,
168+
CasInterval(lower_bound, NormalizeUpperBound(upper_bound))});
168169
}
169170
return false;
170171
});
@@ -572,14 +573,21 @@ class BoundReplacer : public ir::IRMutator<> {
572573
ir::Expr SymbolicExprAnalyzer::LowerBound(const ir::Expr& expr) const {
573574
BoundReplacer bound_replacer(var_intervals_, true);
574575
ir::Expr bound = ir::ir_utils::IRCopy(expr);
576+
if (bound.is_index()) {
577+
bound = bound.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3);
578+
}
575579
bound_replacer(&bound);
576580
return optim::ArithSimplify(bound);
577581
}
578582

579583
ir::Expr SymbolicExprAnalyzer::UpperBound(const ir::Expr& expr) const {
580584
BoundReplacer bound_replacer(var_intervals_, false);
581585
ir::Expr bound = ir::ir_utils::IRCopy(expr);
586+
if (bound.is_index()) {
587+
bound = bound.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3);
588+
}
582589
bound_replacer(&bound);
590+
583591
return optim::ArithSimplify(bound);
584592
}
585593

@@ -709,7 +717,8 @@ SingleIntervalIntSet::SingleIntervalIntSet(const ir::Expr& min,
709717
? x->as_var()->upper_bound
710718
: SymbolicExprLimit::positive_inf;
711719
var_intervals_.insert(
712-
{x->as_var()->name, CasInterval(lower_bound, upper_bound)});
720+
{x->as_var()->name,
721+
CasInterval(lower_bound, NormalizeUpperBound(upper_bound))});
713722
}
714723
return false;
715724
};

paddle/cinn/common/ir_util.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,16 @@ bool is_zero(Expr v) {
270270
return false;
271271
}
272272

273+
Expr NormalizeUpperBound(Expr upper_bound, bool minus_one /* = true */) {
274+
if (upper_bound == SymbolicExprLimit::positive_inf) {
275+
return upper_bound;
276+
}
277+
if (minus_one) {
278+
return upper_bound - ir::Expr(1); // [lower, upper) to [lower, upper]
279+
}
280+
return upper_bound + ir::Expr(1); // (lower, upper] to [lower, upper)
281+
}
282+
273283
Expr CastIfNeeded(Expr body, Type type) {
274284
if (body.type() == type) return body;
275285
return ir::Cast::Make(type, body);

paddle/cinn/common/ir_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ std::vector<std::string> GatherItersToTensorProducer(
9191

9292
bool is_zero(Expr v);
9393

94+
Expr NormalizeUpperBound(Expr upper_bound, bool minus_one = true);
95+
9496
bool MathEqual(const Expr &a, const Expr &b);
9597

9698
//! helper function to get a ir::Select node.

paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ std::vector<std::pair<SymbolicPredicate, ir::Expr>>
136136
DynamicShapeGroupScheduler::GetCX86IRs() {
137137
std::vector<std::pair<SymbolicPredicate, ir::Expr>> irs(1);
138138
irs[0].first = ir::EQ::Make(ir::Expr(1), ir::Expr(1));
139-
irs[1].second = ir_sch_->GetModule().GetExprs()[0];
139+
irs[0].second = ir_sch_->GetModule().GetExprs()[0];
140140
return irs;
141141
}
142142

paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ IntSet Evaluate(Expr expr,
141141
const std::unordered_map<ir::Var, IntSet>& var_domain) {
142142
Expr copy_for_upper_bound = ir::ir_utils::IRCopy(expr);
143143
Expr copy_for_lower_bound = ir::ir_utils::IRCopy(expr);
144-
common::cas_intervals_t var_intervals;
144+
common::cas_intervals_t
145+
var_intervals; // variable name -> CasIntervals[lower_bound, upper_bound]
145146
std::vector<ir::Expr> var_vec = ir::ir_utils::CollectIRNodesWithoutTensor(
146147
expr, [](const ir::Expr* x) { return x->as_var(); });
147148
for (Expr var_expr : var_vec) {
@@ -150,7 +151,9 @@ IntSet Evaluate(Expr expr,
150151
const ir::Var& fixed_var = fixed.at(var);
151152
var_intervals.emplace(
152153
fixed_var->name,
153-
common::CasInterval(fixed_var->lower_bound, fixed_var->upper_bound));
154+
common::CasInterval(
155+
fixed_var->lower_bound,
156+
cinn::common::NormalizeUpperBound(fixed_var->upper_bound)));
154157
optim::ReplaceVarWithExpr(&copy_for_lower_bound, var, Expr(fixed_var));
155158
optim::ReplaceVarWithExpr(&copy_for_upper_bound, var, Expr(fixed_var));
156159
} else if (var_domain.count(var) != 0) {
@@ -172,7 +175,8 @@ IntSet Evaluate(Expr expr,
172175
::common::errors::InvalidArgument(
173176
"The 'upper_bound' of the variable must be defined."));
174177
optim::ReplaceVarWithExpr(&copy_for_lower_bound, var, var->lower_bound);
175-
optim::ReplaceVarWithExpr(&copy_for_upper_bound, var, var->upper_bound);
178+
optim::ReplaceVarWithExpr(
179+
&copy_for_upper_bound, var, NormalizeUpperBound(var->upper_bound));
176180
}
177181
}
178182
ir::Expr lower_bound = optim::ArithSimplify(copy_for_lower_bound);

paddle/cinn/ir/ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ struct _Var_ : public ExprNode<_Var_> {
421421
};
422422

423423
//! A named variable.
424+
// i ∈ [lower_bound, upper_bound)
424425
struct Var : public IrNodeRef {
425426
Var() = default;
426427
explicit Var(IrNode* n) : IrNodeRef(n) {}
@@ -846,6 +847,7 @@ struct For : public ExprNode<For>, public ForBase {
846847
//! The minimum value of the iteration.
847848
Expr min;
848849
//! The extent of the iteration.
850+
// loop_var ∈ [min, min + extent)
849851
Expr extent;
850852

851853
Expr body;

paddle/cinn/ir/ir_analyzer/ir_analyzer.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,8 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
621621
if (e.is_constant()) {
622622
std::string var_name =
623623
cinn::UniqName("constant" + static_cast<int>(e.get_constant()));
624-
result.emplace_back(e, e, var_name, /* is_reduce = */ false);
624+
result.emplace_back(
625+
e, NormalizeUpperBound(e, false), var_name, /* is_reduce = */ false);
625626
} else if (e.As<ir::_Var_>() != nullptr) {
626627
ir::Expr copy_e = ir::ir_utils::IRCopy(e);
627628
ir::_Var_* var_ref = copy_e.As<ir::_Var_>();
@@ -635,14 +636,17 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
635636
ir::Var var = x->as_var_ref();
636637
var_intervals.insert(
637638
{var->name,
638-
common::CasInterval{var->lower_bound, var->upper_bound}});
639+
common::CasInterval{var->lower_bound,
640+
NormalizeUpperBound(var->upper_bound)}});
639641
if (var->is_reduce_axis) is_reduce = true;
640642
}
641643
return false;
642644
});
643645
common::SymbolicExprAnalyzer analyzer(var_intervals);
644-
result.emplace_back(
645-
analyzer.LowerBound(e), analyzer.UpperBound(e), var_name, is_reduce);
646+
result.emplace_back(analyzer.LowerBound(e),
647+
NormalizeUpperBound(analyzer.UpperBound(e), false),
648+
var_name,
649+
is_reduce);
646650
}
647651
}
648652
return result;

0 commit comments

Comments
 (0)