@@ -104,7 +104,8 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
104104
105105 Value v = dyn_cast<Value>(ofr);
106106 if (!v)
107- v = b.create <arith::ConstantOp>(loc, cast<IntegerAttr>(cast<Attribute>(ofr)));
107+ v = b.create <arith::ConstantOp>(loc,
108+ cast<IntegerAttr>(cast<Attribute>(ofr)));
108109
109110 Type ty = v.getType ();
110111 if (targetTy == ty)
@@ -126,7 +127,8 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
126127 // This path is for case like:
127128 // input_ptr + (row_indices[:, None] + row_offsets[:,None] % mod_offset) *
128129 // stride_m + col_offsets[None, :] * stride_n
129- // The modulo will be in shape of [ROW_SIZE, 1] while row_indices is in shape of [ROW_SIZE,].
130+ // The modulo will be in shape of [ROW_SIZE, 1] while row_indices is in
131+ // shape of [ROW_SIZE,].
130132 LLVM_DEBUG ({
131133 llvm::dbgs () << " Reshaping " ;
132134 shapedTy.dump ();
@@ -135,14 +137,15 @@ OpFoldResult expandOFRIndex(OpFoldResult ofr, OpFoldResult targetForTy,
135137 });
136138 SmallVector<Value> shapeValues;
137139 for (auto dim : targetShapedTy.getShape ()) {
138- shapeValues.push_back (b. create <arith::ConstantOp>(
139- loc, b.getIndexAttr (dim)));
140+ shapeValues.push_back (
141+ b. create <arith::ConstantOp>( loc, b.getIndexAttr (dim)));
140142 }
141143 RankedTensorType targetShapeTensorTy = RankedTensorType::get (
142144 targetShapedTy.getShape ().size (), b.getIndexType ());
143145 auto shapeTensor = b.create <tensor::FromElementsOp>(
144146 loc, targetShapeTensorTy, shapeValues);
145- return b.create <triton::ReshapeOp>(loc, targetTy, v, shapeTensor).getResult ();
147+ return b.create <triton::ReshapeOp>(loc, targetTy, v, shapeTensor)
148+ .getResult ();
146149 }
147150 if (isa<IndexType>(targetEltTy) || isa<IndexType>(eltTy)) {
148151 assert ((isa<IntegerType>(targetEltTy) || isa<IntegerType>(eltTy)) &&
@@ -228,7 +231,7 @@ OpFoldResult subOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
228231}
229232
230233OpFoldResult mulOFRs (const OpFoldResult lhs, const OpFoldResult rhs,
231- const Location loc, OpBuilder &b) {
234+ const Location loc, OpBuilder &b) {
232235 auto lhsIntAttr = getIntAttr (lhs);
233236 auto rhsIntAttr = getIntAttr (rhs);
234237
@@ -336,44 +339,65 @@ OpFoldResult maxOFRs(const OpFoldResult lhs, const OpFoldResult rhs,
336339 return maxOp.getResult ();
337340}
338341
342+ OpFoldResult selectOFRs (const OpFoldResult condOFR, const OpFoldResult trueOFR,
343+ const OpFoldResult falseOFR, const Location loc,
344+ OpBuilder &b) {
345+ auto trueValue = ofrToIndexValue (trueOFR, loc, b);
346+ auto falseValue = ofrToIndexValue (falseOFR, loc, b);
347+ auto condValue = ofrToIndexValue (condOFR, loc, b);
348+
349+ // Ideally we should not be passing around everything as index type since mask
350+ // analysis can come across i1 values, but that improvement is being left for
351+ // future work. For now we just unwrap an index back into it's i1 value if
352+ // necessary.
353+ if (!condValue.getType ().isInteger (1 )) {
354+ assert (condValue.getDefiningOp <arith::IndexCastOp>());
355+ condValue = condValue.getDefiningOp <arith::IndexCastOp>().getOperand ();
356+ assert (condValue.getType ().isInteger (1 ));
357+ }
358+
359+ auto selectOp =
360+ b.create <arith::SelectOp>(loc, condValue, trueValue, falseValue);
361+ return selectOp.getResult ();
362+ }
363+
339364OpFoldResult compareOFRs (const OpFoldResult lhs, const OpFoldResult rhs,
340- const arith::CmpIPredicate pred, const OpFoldResult trueOFR,
341- const OpFoldResult falseOFR, const Location loc, OpBuilder &b) {
365+ const arith::CmpIPredicate pred,
366+ const OpFoldResult trueOFR,
367+ const OpFoldResult falseOFR, const Location loc,
368+ OpBuilder &b) {
342369 auto lhsIntAttr = getIntAttr (lhs);
343370 auto rhsIntAttr = getIntAttr (rhs);
344371
345372 // both lhs and rhs are constants, return the result directly
346373 if (lhsIntAttr && rhsIntAttr) {
347374 switch (pred) {
348- case arith::CmpIPredicate::eq:
349- return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR;
350- case arith::CmpIPredicate::ne:
351- return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR;
352- case arith::CmpIPredicate::slt:
353- case arith::CmpIPredicate::ult:
354- return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR;
355- case arith::CmpIPredicate::sle:
356- case arith::CmpIPredicate::ule:
357- return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR;
358- case arith::CmpIPredicate::sgt:
359- case arith::CmpIPredicate::ugt:
360- return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR;
361- case arith::CmpIPredicate::sge:
362- case arith::CmpIPredicate::uge:
363- return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR;
364- default :
365- llvm_unreachable (" Unsupported predicate" );
375+ case arith::CmpIPredicate::eq:
376+ return *lhsIntAttr == *rhsIntAttr ? trueOFR : falseOFR;
377+ case arith::CmpIPredicate::ne:
378+ return *lhsIntAttr != *rhsIntAttr ? trueOFR : falseOFR;
379+ case arith::CmpIPredicate::slt:
380+ case arith::CmpIPredicate::ult:
381+ return *lhsIntAttr < *rhsIntAttr ? trueOFR : falseOFR;
382+ case arith::CmpIPredicate::sle:
383+ case arith::CmpIPredicate::ule:
384+ return *lhsIntAttr <= *rhsIntAttr ? trueOFR : falseOFR;
385+ case arith::CmpIPredicate::sgt:
386+ case arith::CmpIPredicate::ugt:
387+ return *lhsIntAttr > *rhsIntAttr ? trueOFR : falseOFR;
388+ case arith::CmpIPredicate::sge:
389+ case arith::CmpIPredicate::uge:
390+ return *lhsIntAttr >= *rhsIntAttr ? trueOFR : falseOFR;
391+ default :
392+ llvm_unreachable (" Unsupported predicate" );
366393 }
367394 }
368395
369396 auto lhsValue = ofrToIndexValue (lhs, loc, b);
370397 auto rhsValue = ofrToIndexValue (rhs, loc, b);
371- auto trueValue = ofrToIndexValue (trueOFR, loc, b);
372- auto falseValue = ofrToIndexValue (falseOFR, loc, b);
373398
374399 auto cmpOp = b.create <arith::CmpIOp>(loc, pred, lhsValue, rhsValue);
375- auto selectOp = b.create <arith::SelectOp>(loc, cmpOp, trueValue, falseValue);
376- return selectOp.getResult ();
400+ return selectOFRs (cmpOp.getResult (), trueOFR, falseOFR, loc, b);
377401}
378402
379403} // namespace mlir
0 commit comments