Skip to content

Commit fe5b72a

Browse files
authored
[mlir][Vector] Pattern to linearize broadcast (llvm#163845)
The PR llvm#162167 removed a pattern to linearize vector.splat, without adding the equivalent pattern for vector.broadcast. This PR adds such a pattern, hopefully brining vector.broadcast up to full parity with vector.splat that has now been removed. --------- Signed-off-by: James Newling <[email protected]>
1 parent 64c8ebb commit fe5b72a

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,50 @@ struct LinearizeVectorToElements final
817817
}
818818
};
819819

820+
/// Convert broadcasts from scalars or 1-element vectors, such as
821+
///
822+
/// ```mlir
823+
/// vector.broadcast %value : f32 to vector<4x4xf32>
824+
/// ```
825+
///
826+
/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed.
827+
/// The above becomes,
828+
///
829+
/// ```mlir
830+
/// %out_1d = vector.broadcast %value : f32 to vector<16xf32>
831+
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
832+
/// ```
833+
struct LinearizeVectorBroadcast final
834+
: public OpConversionPattern<vector::BroadcastOp> {
835+
using Base::Base;
836+
837+
LinearizeVectorBroadcast(const TypeConverter &typeConverter,
838+
MLIRContext *context, PatternBenefit benefit = 1)
839+
: OpConversionPattern(typeConverter, context, benefit) {}
840+
841+
LogicalResult
842+
matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor,
843+
ConversionPatternRewriter &rewriter) const override {
844+
845+
int numElements = 1;
846+
Type sourceType = broadcastOp.getSourceType();
847+
if (auto vecType = dyn_cast<VectorType>(sourceType)) {
848+
numElements = vecType.getNumElements();
849+
}
850+
851+
if (numElements != 1) {
852+
return rewriter.notifyMatchFailure(
853+
broadcastOp, "only broadcasts of single elements can be linearized.");
854+
}
855+
856+
auto dstTy = getTypeConverter()->convertType(broadcastOp.getType());
857+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy,
858+
adaptor.getSource());
859+
860+
return success();
861+
}
862+
};
863+
820864
} // namespace
821865

822866
/// This method defines the set of operations that are linearizable, and hence
@@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
909953
patterns
910954
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
911955
LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
912-
LinearizeVectorFromElements, LinearizeVectorToElements>(
913-
typeConverter, patterns.getContext());
956+
LinearizeVectorBroadcast, LinearizeVectorFromElements,
957+
LinearizeVectorToElements>(typeConverter, patterns.getContext());
914958
}
915959

916960
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,47 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
428428

429429
// -----
430430

431+
// CHECK-LABEL: linearize_vector_broadcast_scalar_source
432+
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
433+
func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {
434+
435+
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32>
436+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
437+
// CHECK: return %[[CAST]] : vector<4x2xi32>
438+
%0 = vector.broadcast %arg0 : i32 to vector<4x2xi32>
439+
return %0 : vector<4x2xi32>
440+
}
441+
442+
// -----
443+
444+
// CHECK-LABEL: linearize_vector_broadcast_rank_two_source
445+
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32>
446+
func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {
447+
448+
// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32>
449+
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32>
450+
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
451+
// CHECK: return %[[CAST1]] : vector<4x2xi32>
452+
%0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32>
453+
return %0 : vector<4x2xi32>
454+
}
455+
456+
// -----
457+
458+
// CHECK-LABEL: linearize_scalable_vector_broadcast
459+
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
460+
func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> {
461+
462+
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<[8]xi32>
463+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<[8]xi32> to vector<4x[2]xi32>
464+
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
465+
%0 = vector.broadcast %arg0 : i32 to vector<4x[2]xi32>
466+
return %0 : vector<4x[2]xi32>
467+
468+
}
469+
470+
// -----
471+
431472
// CHECK-LABEL: linearize_create_mask
432473
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
433474
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {

0 commit comments

Comments
 (0)