1212#include < utility>
1313
1414#include " device_lower/analysis/circular_buffer.h"
15- #include " device_lower/analysis/trivial_broadcast.h"
1615#include " device_lower/lower2device.h"
1716#include " device_lower/utils.h"
1817#include " disjoint_set.h"
2524#include " iter_visitor.h"
2625#include " logical_domain_map.h"
2726#include " transform_iter.h"
28- #include " val_graph_visitor.h"
2927
3028namespace nvfuser {
3129
@@ -58,14 +56,10 @@ IdModel::IdModel(
5856 : allow_self_mapping_(allow_self_mapping),
5957 loop_promotion_map_builder_callback_ (
6058 loop_promotion_map_builder_callback) {
61- std::copy_if (
62- exprs.begin (),
63- exprs.end (),
64- std::back_inserter (tv_exprs_),
65- [](Expr* expr) {
66- NVF_ERROR (expr != nullptr );
67- return ir_utils::isTvOp (expr);
68- });
59+ std::ranges::copy_if (exprs, std::back_inserter (tv_exprs_), [](Expr* expr) {
60+ NVF_ERROR (expr != nullptr );
61+ return ir_utils::isTvOp (expr);
62+ });
6963
7064 auto all_tvs = ir_utils::allTvsOfExprs (tv_exprs_);
7165 all_tvs.pushBack (additional_tvs.begin (), additional_tvs.end ());
@@ -96,11 +90,8 @@ IdModel::IdModel(
9690 loop_promotion_map_builder_callback_(
9791 loop_promotion_map_builder_callback) {
9892 auto all_exprs = fusion->exprs ();
99- std::copy_if (
100- all_exprs.begin (),
101- all_exprs.end (),
102- std::back_inserter (tv_exprs_),
103- [](Expr* expr) {
93+ std::ranges::copy_if (
94+ all_exprs, std::back_inserter (tv_exprs_), [](Expr* expr) {
10495 NVF_ERROR (expr != nullptr );
10596 return ir_utils::isTvOp (expr);
10697 });
@@ -160,8 +151,7 @@ void IdModel::buildIterDomainDefinitionsAndUses() {
160151 // domain is marked as an rfactor product and is in the rfactor
161152 // domain, it's a view like rfactor iteration domain
162153 const auto & logical_domain = tv->domain ()->logical ();
163- if (std::find (logical_domain.begin (), logical_domain.end (), id) !=
164- logical_domain.end ()) {
154+ if (std::ranges::find (logical_domain, id) != logical_domain.end ()) {
165155 view_rfactor_ids_.emplace (id);
166156 }
167157 }
@@ -184,13 +174,10 @@ void IdModel::buildIterDomainDefinitionsAndUses() {
184174 // not include the definition in the model. Note that it is
185175 // possible that some are included but not all since a single ID
186176 // may be used by multiple exprs.
187- if (std::any_of (
188- def->inputs ().begin (), def->inputs ().end (), [&](Val* inp) {
189- return std::find (
190- all_ids.begin (),
191- all_ids.end (),
192- inp->as <IterDomain>()) == all_ids.end ();
193- })) {
177+ if (std::ranges::any_of (def->inputs (), [&](Val* inp) {
178+ return std::ranges::find (all_ids, inp->as <IterDomain>()) ==
179+ all_ids.end ();
180+ })) {
194181 continue ;
195182 }
196183
@@ -220,10 +207,10 @@ std::string IdModel::toString() const {
220207 ss << " Disjoint Ids:\n "
221208 << idGroupsString (idGraph (mode), 2 )
222209 << " \n Disjoint Expression groups:\n "
223- << exprGroupsString (idGraph (mode), 2 ) << std::endl ;
224- ss << " } IdGraph\n " << std::endl ;
210+ << exprGroupsString (idGraph (mode), 2 ) << ' \n ' ;
211+ ss << " } IdGraph\n " << ' \n ' ;
225212 }
226- ss << " } IterDomainGraphs\n " << std::endl ;
213+ ss << " } IterDomainGraphs\n " << ' \n ' ;
227214 return ss.str ();
228215}
229216
@@ -240,10 +227,9 @@ ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) const {
240227 all_ids.push_back (id);
241228 }
242229
243- std::sort (
244- all_ids.begin (), all_ids.end (), [](IterDomain* id1, IterDomain* id2) {
245- return id1->name () < id2->name ();
246- });
230+ std::ranges::sort (all_ids, [](IterDomain* id1, IterDomain* id2) {
231+ return id1->name () < id2->name ();
232+ });
247233
248234 for (auto id : all_ids) {
249235 auto uses_it = id_uses_.find (id);
@@ -785,7 +771,7 @@ void buildAsyncWarpInliningInfo(
785771 std::vector<AsyncWarp> async_warps = createAsyncWarps (exprs);
786772
787773 // short-circuit: no async operations detected.
788- if (async_warps.size () == 0 ) {
774+ if (async_warps.empty () ) {
789775 return ;
790776 }
791777 NVF_ERROR (
@@ -1105,13 +1091,10 @@ Expr* IdModel::addReplayAs(std::vector<IterDomain*> new_inputs, Expr* expr) {
11051091
11061092 // Replace the provided inputs with IterType::Iteration domains as
11071093 // reduction domains cannot be merged with non-reduction domains.
1108- if (std::any_of (
1109- new_inputs.begin (),
1110- new_inputs.end (),
1111- [](IterDomain* id) { return id->isReduction (); }) &&
1112- std::any_of (new_inputs.begin (), new_inputs.end (), [](IterDomain* id) {
1113- return !id->isReduction ();
1114- })) {
1094+ if (std::ranges::any_of (
1095+ new_inputs, [](IterDomain* id) { return id->isReduction (); }) &&
1096+ std::ranges::any_of (
1097+ new_inputs, [](IterDomain* id) { return !id->isReduction (); })) {
11151098 // Inputs have mismatched type, replace new_inputs
11161099 auto tmp_inputs = new_inputs;
11171100 for (const auto i : arange (new_inputs.size ())) {
0 commit comments