Skip to content

Commit 93f07e7

Browse files
committed
remove restriction
1 parent 13a2137 commit 93f07e7

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,11 @@ struct MemrefExtractAlignedPointerAsIndexDistribution final
10471047
}
10481048
};
10491049

1050+
/// Distribute a vector::BitCastOp feeding into yield op of an enclosing
1051+
/// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
1052+
/// diemension of the source/result vectors. Equivalent vector::BitCastOp is
1053+
/// created outside of the warp op with distributed source vector type (computed
1054+
/// using assigned layout).
10501055
struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
10511056
using gpu::WarpDistributionPattern::WarpDistributionPattern;
10521057
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
@@ -1069,11 +1074,6 @@ struct VectorBitcastDistribution final : public gpu::WarpDistributionPattern {
10691074
"vector::BitCast op");
10701075
VectorType distributedResultType =
10711076
cast<VectorType>(warpOp.getResult(operandIdx).getType());
1072-
if (distributedSourceType.getRank() != 2 ||
1073-
distributedResultType.getRank() != 2)
1074-
return rewriter.notifyMatchFailure(
1075-
bitcastOp, "the source or result vector of the bitcast op "
1076-
"are not 2D vectors");
10771077
SmallVector<size_t> newRetIndices;
10781078
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
10791079
rewriter, warpOp, bitcastOp.getSource(),

0 commit comments

Comments
 (0)