Skip to content

Commit ba2eb55

Browse files
authored
Clean up id_model code (#5989)
1 parent d689097 commit ba2eb55

File tree

2 files changed

+23
-40
lines changed

2 files changed

+23
-40
lines changed

csrc/id_model/id_model.cpp

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <utility>
1313

1414
#include "device_lower/analysis/circular_buffer.h"
15-
#include "device_lower/analysis/trivial_broadcast.h"
1615
#include "device_lower/lower2device.h"
1716
#include "device_lower/utils.h"
1817
#include "disjoint_set.h"
@@ -25,7 +24,6 @@
2524
#include "iter_visitor.h"
2625
#include "logical_domain_map.h"
2726
#include "transform_iter.h"
28-
#include "val_graph_visitor.h"
2927

3028
namespace nvfuser {
3129

@@ -58,14 +56,10 @@ IdModel::IdModel(
5856
: allow_self_mapping_(allow_self_mapping),
5957
loop_promotion_map_builder_callback_(
6058
loop_promotion_map_builder_callback) {
61-
std::copy_if(
62-
exprs.begin(),
63-
exprs.end(),
64-
std::back_inserter(tv_exprs_),
65-
[](Expr* expr) {
66-
NVF_ERROR(expr != nullptr);
67-
return ir_utils::isTvOp(expr);
68-
});
59+
std::ranges::copy_if(exprs, std::back_inserter(tv_exprs_), [](Expr* expr) {
60+
NVF_ERROR(expr != nullptr);
61+
return ir_utils::isTvOp(expr);
62+
});
6963

7064
auto all_tvs = ir_utils::allTvsOfExprs(tv_exprs_);
7165
all_tvs.pushBack(additional_tvs.begin(), additional_tvs.end());
@@ -96,11 +90,8 @@ IdModel::IdModel(
9690
loop_promotion_map_builder_callback_(
9791
loop_promotion_map_builder_callback) {
9892
auto all_exprs = fusion->exprs();
99-
std::copy_if(
100-
all_exprs.begin(),
101-
all_exprs.end(),
102-
std::back_inserter(tv_exprs_),
103-
[](Expr* expr) {
93+
std::ranges::copy_if(
94+
all_exprs, std::back_inserter(tv_exprs_), [](Expr* expr) {
10495
NVF_ERROR(expr != nullptr);
10596
return ir_utils::isTvOp(expr);
10697
});
@@ -160,8 +151,7 @@ void IdModel::buildIterDomainDefinitionsAndUses() {
160151
// domain is marked as an rfactor product and is in the rfactor
161152
// domain, it's a view like rfactor iteration domain
162153
const auto& logical_domain = tv->domain()->logical();
163-
if (std::find(logical_domain.begin(), logical_domain.end(), id) !=
164-
logical_domain.end()) {
154+
if (std::ranges::find(logical_domain, id) != logical_domain.end()) {
165155
view_rfactor_ids_.emplace(id);
166156
}
167157
}
@@ -184,13 +174,10 @@ void IdModel::buildIterDomainDefinitionsAndUses() {
184174
// not include the definition in the model. Note that it is
185175
// possible that some are included but not all since a single ID
186176
// may be used by multiple exprs.
187-
if (std::any_of(
188-
def->inputs().begin(), def->inputs().end(), [&](Val* inp) {
189-
return std::find(
190-
all_ids.begin(),
191-
all_ids.end(),
192-
inp->as<IterDomain>()) == all_ids.end();
193-
})) {
177+
if (std::ranges::any_of(def->inputs(), [&](Val* inp) {
178+
return std::ranges::find(all_ids, inp->as<IterDomain>()) ==
179+
all_ids.end();
180+
})) {
194181
continue;
195182
}
196183

@@ -220,10 +207,10 @@ std::string IdModel::toString() const {
220207
ss << " Disjoint Ids:\n"
221208
<< idGroupsString(idGraph(mode), 2)
222209
<< "\n Disjoint Expression groups:\n"
223-
<< exprGroupsString(idGraph(mode), 2) << std::endl;
224-
ss << " } IdGraph\n" << std::endl;
210+
<< exprGroupsString(idGraph(mode), 2) << '\n';
211+
ss << " } IdGraph\n" << '\n';
225212
}
226-
ss << " } IterDomainGraphs\n" << std::endl;
213+
ss << " } IterDomainGraphs\n" << '\n';
227214
return ss.str();
228215
}
229216

@@ -240,10 +227,9 @@ ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) const {
240227
all_ids.push_back(id);
241228
}
242229

243-
std::sort(
244-
all_ids.begin(), all_ids.end(), [](IterDomain* id1, IterDomain* id2) {
245-
return id1->name() < id2->name();
246-
});
230+
std::ranges::sort(all_ids, [](IterDomain* id1, IterDomain* id2) {
231+
return id1->name() < id2->name();
232+
});
247233

248234
for (auto id : all_ids) {
249235
auto uses_it = id_uses_.find(id);
@@ -785,7 +771,7 @@ void buildAsyncWarpInliningInfo(
785771
std::vector<AsyncWarp> async_warps = createAsyncWarps(exprs);
786772

787773
// short-circuit: no async operations detected.
788-
if (async_warps.size() == 0) {
774+
if (async_warps.empty()) {
789775
return;
790776
}
791777
NVF_ERROR(
@@ -1105,13 +1091,10 @@ Expr* IdModel::addReplayAs(std::vector<IterDomain*> new_inputs, Expr* expr) {
11051091

11061092
// Replace the provided inputs with IterType::Iteration domains as
11071093
// reduction domains cannot be merged with non-reduction domains.
1108-
if (std::any_of(
1109-
new_inputs.begin(),
1110-
new_inputs.end(),
1111-
[](IterDomain* id) { return id->isReduction(); }) &&
1112-
std::any_of(new_inputs.begin(), new_inputs.end(), [](IterDomain* id) {
1113-
return !id->isReduction();
1114-
})) {
1094+
if (std::ranges::any_of(
1095+
new_inputs, [](IterDomain* id) { return id->isReduction(); }) &&
1096+
std::ranges::any_of(
1097+
new_inputs, [](IterDomain* id) { return !id->isReduction(); })) {
11151098
// Inputs have mismatched type, replace new_inputs
11161099
auto tmp_inputs = new_inputs;
11171100
for (const auto i : arange(new_inputs.size())) {

csrc/id_model/id_model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ StatefulInliningInfo buildStatefulInliningInfo(
104104
// IdMappingMode::LOOP
105105
// Subgraph of the permissive graph. Maps only CA and their
106106
// dependent domains.
107-
class NVF_API IdModel : public PolymorphicBase {
107+
class NVF_API IdModel {
108108
public:
109109
// Sometimes fusion inputs or outputs are disconnected from expressions, in
110110
// those cases we still may want to send in some additional tensor views from

0 commit comments

Comments
 (0)