Skip to content

Commit 7d5b1b4

Browse files
committed
[MLIR][Vector] Allow any shaped typed to be distributed for vector.warp_execute_on_lane_0's return values
1 parent fcfd643 commit 7d5b1b4

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6558,14 +6558,14 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
65586558
// If the types matches there is no distribution.
65596559
if (expanded == distributed)
65606560
return success();
6561-
auto expandedVecType = llvm::dyn_cast<VectorType>(expanded);
6562-
auto distributedVecType = llvm::dyn_cast<VectorType>(distributed);
6561+
auto expandedVecType = llvm::dyn_cast<ShapedType>(expanded);
6562+
auto distributedVecType = llvm::dyn_cast<ShapedType>(distributed);
65636563
if (!expandedVecType || !distributedVecType)
6564-
return op->emitOpError("expected vector type for distributed operands.");
6564+
return op->emitOpError("expected shaped type for distributed operands.");
65656565
if (expandedVecType.getRank() != distributedVecType.getRank() ||
65666566
expandedVecType.getElementType() != distributedVecType.getElementType())
65676567
return op->emitOpError(
6568-
"expected distributed vectors to have same rank and element type.");
6568+
"expected distributed types to have same rank and element type.");
65696569

65706570
SmallVector<int64_t> scales(expandedVecType.getRank(), 1);
65716571
for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
@@ -6575,8 +6575,8 @@ static LogicalResult verifyDistributedType(Type expanded, Type distributed,
65756575
continue;
65766576
if (eDim % dDim != 0)
65776577
return op->emitOpError()
6578-
<< "expected expanded vector dimension #" << i << " (" << eDim
6579-
<< ") to be a multipler of the distributed vector dimension ("
6578+
<< "expected expanded type dimension #" << i << " (" << eDim
6579+
<< ") to be a multipler of the distributed type dimension ("
65806580
<< dDim << ")";
65816581
scales[i] = eDim / dDim;
65826582
}

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,7 +1665,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
16651665
// -----
16661666

16671667
func.func @warp_2_distributed_dims(%laneid: index) {
1668-
// expected-error@+1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}}
1668+
// expected-error@+1 {{expected expanded type dimension #1 (8) to be a multipler of the distributed type dimension (3)}}
16691669
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) {
16701670
%0 = arith.constant dense<2>: vector<4x8xi32>
16711671
vector.yield %0 : vector<4x8xi32>
@@ -1676,7 +1676,7 @@ func.func @warp_2_distributed_dims(%laneid: index) {
16761676
// -----
16771677

16781678
func.func @warp_mismatch_rank(%laneid: index) {
1679-
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}}
1679+
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed types to have same rank and element type.}}
16801680
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) {
16811681
%0 = arith.constant dense<2>: vector<128xi32>
16821682
vector.yield %0 : vector<128xi32>
@@ -1687,7 +1687,7 @@ func.func @warp_mismatch_rank(%laneid: index) {
16871687
// -----
16881688

16891689
func.func @warp_mismatch_rank(%laneid: index) {
1690-
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected vector type for distributed operands.}}
1690+
// expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected shaped type for distributed operands.}}
16911691
%2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (i32) {
16921692
%0 = arith.constant dense<2>: vector<128xi32>
16931693
vector.yield %0 : vector<128xi32>

0 commit comments

Comments
 (0)