@@ -207,6 +207,86 @@ LogicalResult ConstSizeOpConversion::matchAndRewrite(
207
207
return success ();
208
208
}
209
209
210
+ namespace {
211
+ struct IsBroadcastableOpConverter
212
+ : public OpConversionPattern<IsBroadcastableOp> {
213
+ using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
214
+
215
+ LogicalResult
216
+ matchAndRewrite (IsBroadcastableOp op, ArrayRef<Value> operands,
217
+ ConversionPatternRewriter &rewriter) const override ;
218
+ };
219
+ } // namespace
220
+
221
+ LogicalResult IsBroadcastableOpConverter::matchAndRewrite (
222
+ IsBroadcastableOp op, ArrayRef<Value> operands,
223
+ ConversionPatternRewriter &rewriter) const {
224
+ // For now, this lowering is only defined on `tensor<?xindex>` operands, not
225
+ // on shapes.
226
+ IsBroadcastableOp::Adaptor transformed (operands);
227
+ if (transformed.lhs ().getType ().isa <ShapeType>() ||
228
+ transformed.rhs ().getType ().isa <ShapeType>())
229
+ return failure ();
230
+
231
+ auto loc = op.getLoc ();
232
+ Value zero = rewriter.create <ConstantIndexOp>(loc, 0 );
233
+ Value one = rewriter.create <ConstantIndexOp>(loc, 1 );
234
+
235
+ // Find smaller and greater rank and extent tensor.
236
+ Value lhsRank = rewriter.create <DimOp>(loc, transformed.lhs (), zero);
237
+ Value rhsRank = rewriter.create <DimOp>(loc, transformed.rhs (), zero);
238
+ Value lhsRankULE =
239
+ rewriter.create <CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
240
+ Type indexTy = rewriter.getIndexType ();
241
+ Value lesserRank =
242
+ rewriter.create <SelectOp>(loc, lhsRankULE, lhsRank, rhsRank);
243
+ Value greaterRank =
244
+ rewriter.create <SelectOp>(loc, lhsRankULE, rhsRank, lhsRank);
245
+ auto erasedRankType =
246
+ RankedTensorType::get ({ShapedType::kDynamicSize }, indexTy);
247
+ Value rankErasedLhs =
248
+ rewriter.create <TensorCastOp>(loc, erasedRankType, transformed.lhs ());
249
+ Value rankErasedRhs =
250
+ rewriter.create <TensorCastOp>(loc, erasedRankType, transformed.rhs ());
251
+ Value lesserRankOperand =
252
+ rewriter.create <SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs);
253
+ Value greaterRankOperand =
254
+ rewriter.create <SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs);
255
+ Value rankDiff =
256
+ rewriter.create <SubIOp>(loc, indexTy, greaterRank, lesserRank);
257
+ Type i1Ty = rewriter.getI1Type ();
258
+ Value init =
259
+ rewriter.create <ConstantOp>(loc, i1Ty, rewriter.getBoolAttr (true ));
260
+
261
+ // Determine if all overlapping extents are broadcastable.
262
+ auto reduceResult = rewriter.create <ForOp>(
263
+ loc, rankDiff, greaterRank, one, ValueRange{init},
264
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
265
+ Value greaterRankOperandExtent =
266
+ b.create <ExtractElementOp>(loc, greaterRankOperand, ValueRange{iv});
267
+ Value greaterRankOperandExtentIsOne = b.create <CmpIOp>(
268
+ loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
269
+ Value ivShifted = b.create <SubIOp>(loc, indexTy, iv, rankDiff);
270
+ Value lesserRankOperandExtent = b.create <ExtractElementOp>(
271
+ loc, lesserRankOperand, ValueRange{ivShifted});
272
+ Value lesserRankOperandExtentIsOne = b.create <CmpIOp>(
273
+ loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
274
+ Value extentsAreEqual =
275
+ b.create <CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
276
+ lesserRankOperandExtent);
277
+ Value broadcastableExtents = b.create <AndOp>(
278
+ loc, iterArgs[0 ],
279
+ b.create <OrOp>(loc,
280
+ b.create <OrOp>(loc, greaterRankOperandExtentIsOne,
281
+ lesserRankOperandExtentIsOne),
282
+ extentsAreEqual));
283
+ b.create <scf::YieldOp>(loc, broadcastableExtents);
284
+ });
285
+
286
+ rewriter.replaceOp (op, reduceResult.results ().front ());
287
+ return success ();
288
+ }
289
+
210
290
namespace {
211
291
class GetExtentOpConverter : public OpConversionPattern <GetExtentOp> {
212
292
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@@ -522,6 +602,7 @@ void mlir::populateShapeToStandardConversionPatterns(
522
602
BroadcastOpConverter,
523
603
ConstShapeOpConverter,
524
604
ConstSizeOpConversion,
605
+ IsBroadcastableOpConverter,
525
606
GetExtentOpConverter,
526
607
RankOpConverter,
527
608
ReduceOpConverter,
0 commit comments