Skip to content

Commit ca5d9d6

Browse files
tpoppmemfrob
authored andcommitted
[mlir] Add lowering for IsBroadcastable to Std dialect.
Differential Revision: https://reviews.llvm.org/D90407
1 parent be460d1 commit ca5d9d6

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,86 @@ LogicalResult ConstSizeOpConversion::matchAndRewrite(
207207
return success();
208208
}
209209

210+
namespace {
211+
struct IsBroadcastableOpConverter
212+
: public OpConversionPattern<IsBroadcastableOp> {
213+
using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
214+
215+
LogicalResult
216+
matchAndRewrite(IsBroadcastableOp op, ArrayRef<Value> operands,
217+
ConversionPatternRewriter &rewriter) const override;
218+
};
219+
} // namespace
220+
221+
LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
222+
IsBroadcastableOp op, ArrayRef<Value> operands,
223+
ConversionPatternRewriter &rewriter) const {
224+
// For now, this lowering is only defined on `tensor<?xindex>` operands, not
225+
// on shapes.
226+
IsBroadcastableOp::Adaptor transformed(operands);
227+
if (transformed.lhs().getType().isa<ShapeType>() ||
228+
transformed.rhs().getType().isa<ShapeType>())
229+
return failure();
230+
231+
auto loc = op.getLoc();
232+
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
233+
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
234+
235+
// Find smaller and greater rank and extent tensor.
236+
Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero);
237+
Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero);
238+
Value lhsRankULE =
239+
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
240+
Type indexTy = rewriter.getIndexType();
241+
Value lesserRank =
242+
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
243+
Value greaterRank =
244+
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
245+
auto erasedRankType =
246+
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
247+
Value rankErasedLhs =
248+
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.lhs());
249+
Value rankErasedRhs =
250+
rewriter.create<TensorCastOp>(loc, erasedRankType, transformed.rhs());
251+
Value lesserRankOperand =
252+
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
253+
Value greaterRankOperand =
254+
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
255+
Value rankDiff =
256+
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
257+
Type i1Ty = rewriter.getI1Type();
258+
Value init =
259+
rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
260+
261+
// Determine if all overlapping extents are broadcastable.
262+
auto reduceResult = rewriter.create<ForOp>(
263+
loc, rankDiff, greaterRank, one, ValueRange{init},
264+
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
265+
Value greaterRankOperandExtent =
266+
b.create<ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
267+
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
268+
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
269+
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
270+
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
271+
loc, lesserRankOperand, ValueRange{ivShifted});
272+
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
273+
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
274+
Value extentsAreEqual =
275+
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
276+
lesserRankOperandExtent);
277+
Value broadcastableExtents = b.create<AndOp>(
278+
loc, iterArgs[0],
279+
b.create<OrOp>(loc,
280+
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
281+
lesserRankOperandExtentIsOne),
282+
extentsAreEqual));
283+
b.create<scf::YieldOp>(loc, broadcastableExtents);
284+
});
285+
286+
rewriter.replaceOp(op, reduceResult.results().front());
287+
return success();
288+
}
289+
210290
namespace {
211291
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
212292
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -522,6 +602,7 @@ void mlir::populateShapeToStandardConversionPatterns(
522602
BroadcastOpConverter,
523603
ConstShapeOpConverter,
524604
ConstSizeOpConversion,
605+
IsBroadcastableOpConverter,
525606
GetExtentOpConverter,
526607
RankOpConverter,
527608
ReduceOpConverter,

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,41 @@ func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xinde
382382
: tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
383383
return
384384
}
385+
386+
// -----
387+
388+
func @try_is_broadcastable(%a : tensor<3xindex>, %b : tensor<?xindex>) -> i1 {
389+
%0 = shape.is_broadcastable %a, %b : tensor<3xindex>, tensor<?xindex>
390+
return %0 : i1
391+
}
392+
393+
// CHECK-LABEL: func @try_is_broadcastable(
394+
// CHECK-SAME: %[[LHS:.*]]: tensor<3xindex>,
395+
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> i1 {
396+
// CHECK: %[[C0:.*]] = constant 0 : index
397+
// CHECK: %[[C1:.*]] = constant 1 : index
398+
// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<3xindex>
399+
// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
400+
// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
401+
// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
402+
// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
403+
// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<3xindex> to tensor<?xindex>
404+
// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
405+
// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
406+
// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
407+
// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
408+
// CHECK: %[[TRUE:.*]] = constant true
409+
// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[I:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
410+
// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[I]]] : tensor<?xindex>
411+
// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
412+
// CHECK: %[[SMALLER_EXTENT_INDEX:.*]] = subi %[[I]], %[[RANK_DIFF]] : index
413+
// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[SMALLER_EXTENT_INDEX]]] : tensor<?xindex>
414+
// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
415+
// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
416+
// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
417+
// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
418+
// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
419+
// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1
420+
// CHECK: }
421+
// CHECK: return %[[ALL_RESULT]] : i1
422+
// CHECK: }

0 commit comments

Comments
 (0)