@@ -123,11 +123,11 @@ tt::CoarseSchedule::splitClusterBefore(Operation *op, scf::ForOp forOp) {
123
123
}
124
124
125
125
// Check if op a will show up before op b in the final unrolled code.
126
- bool tt::CoarseSchedule::isOpBefore (Operation *a, Operation *b) {
126
+ bool tt::CoarseSchedule::isOpBefore (Operation *a, Operation *b) const {
127
127
assert (opToStageAndCluster.count (a) && opToStageAndCluster.count (b) &&
128
128
" Operations must be in the schedule" );
129
- auto [aStage, aCluster] = opToStageAndCluster[a] ;
130
- auto [bStage, bCluster] = opToStageAndCluster[b] ;
129
+ auto [aStage, aCluster] = opToStageAndCluster. at (a) ;
130
+ auto [bStage, bCluster] = opToStageAndCluster. at (b) ;
131
131
if (aStage != bStage) {
132
132
return aStage < bStage;
133
133
}
@@ -137,21 +137,22 @@ bool tt::CoarseSchedule::isOpBefore(Operation *a, Operation *b) {
137
137
return a->isBeforeInBlock (b);
138
138
}
139
139
140
- bool tt::CoarseSchedule::isOpInEarlierCluster (Operation *a, Operation *b) {
140
+ bool tt::CoarseSchedule::isOpInEarlierCluster (Operation *a,
141
+ Operation *b) const {
141
142
assert (opToStageAndCluster.count (a) && opToStageAndCluster.count (b) &&
142
143
" Operations must be in the schedule" );
143
- return clusters.isBefore (opToStageAndCluster[a] .second ,
144
- opToStageAndCluster[b] .second );
144
+ return clusters.isBefore (opToStageAndCluster. at (a) .second ,
145
+ opToStageAndCluster. at (b) .second );
145
146
}
146
147
147
- bool tt::CoarseSchedule::isOpInSameCluster (Operation *a, Operation *b) {
148
+ bool tt::CoarseSchedule::isOpInSameCluster (Operation *a, Operation *b) const {
148
149
assert (opToStageAndCluster.count (a) && opToStageAndCluster.count (b) &&
149
150
" Operations must be in the schedule" );
150
- return opToStageAndCluster[a] .second == opToStageAndCluster[b] .second ;
151
+ return opToStageAndCluster. at (a) .second == opToStageAndCluster. at (b) .second ;
151
152
}
152
153
153
154
SmallVector<std::tuple<Operation *, int , tt::CoarseSchedule::Cluster>>
154
- tt::CoarseSchedule::getOpsInOrder (scf::ForOp forOp) {
155
+ tt::CoarseSchedule::getOpsInOrder (scf::ForOp forOp) const {
155
156
SmallVector<SmallVector<std::tuple<Operation *, int , Cluster>>, 8 >
156
157
orderClusters (clusters.size ());
157
158
for (auto &op : forOp.getBody ()->without_terminator ()) {
@@ -160,12 +161,11 @@ tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) {
160
161
continue ;
161
162
}
162
163
auto [stage, cluster] = it->second ;
163
- if (cluster == Cluster{}) {
164
- continue ;
165
- }
164
+ assert (cluster != Cluster{} && " Op with invalid cluster!" );
166
165
assert (stage < numStages && " Op with invalid stage!" );
167
166
int clusterId = *cluster;
168
- assert (clusterId == std::distance (clusters.begin (), cluster) &&
167
+ assert (clusterId == std::distance (clusters.begin (),
168
+ ClusterList::const_iterator (cluster)) &&
169
169
" Cluster ID mismatch!" );
170
170
orderClusters[clusterId].push_back (make_tuple (&op, stage, cluster));
171
171
}
@@ -180,7 +180,7 @@ tt::CoarseSchedule::getOpsInOrder(scf::ForOp forOp) {
180
180
}
181
181
182
182
std::vector<std::pair<Operation *, unsigned >>
183
- tt::CoarseSchedule::createFinalSchedule (scf::ForOp forOp) {
183
+ tt::CoarseSchedule::createFinalSchedule (scf::ForOp forOp) const {
184
184
SmallVector<std::tuple<Operation *, int , tt::CoarseSchedule::Cluster>>
185
185
opsInOrder = getOpsInOrder (forOp);
186
186
std::vector<std::pair<Operation *, unsigned >> schedule;
@@ -248,7 +248,7 @@ static std::optional<int> tryGetMaxStage(scf::ForOp &forOp) {
248
248
}
249
249
250
250
// Set <stage, cluster> based on CoarseSchedule.
251
- void tt::CoarseSchedule::serialize (scf::ForOp &forOp) {
251
+ void tt::CoarseSchedule::serialize (scf::ForOp &forOp) const {
252
252
for (auto [op, stage, cluster] : getOpsInOrder (forOp)) {
253
253
setStageCluster (op, stage, *cluster);
254
254
}
0 commit comments