Skip to content

Commit 6d22968

Browse files
committed
Assume distributable offset and mask producers
1 parent e174b69 commit 6d22968

File tree

2 files changed

+31
-53
lines changed

2 files changed

+31
-53
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -817,14 +817,14 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
817817
auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
818818
if (!storeScatterOp)
819819
return failure();
820-
else if (!storeScatterOp.getOffsets())
820+
if (!storeScatterOp.getOffsets())
821821
return rewriter.notifyMatchFailure(storeScatterOp,
822822
"Store op must have offsets argument");
823-
else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
824-
.getRank() != 1)
823+
VectorType offsetsTy =
824+
cast<VectorType>(storeScatterOp.getOffsets().getType());
825+
if (offsetsTy.getRank() != 1)
825826
return rewriter.notifyMatchFailure(storeScatterOp,
826827
"Expected 1D offsets vector");
827-
828828
VectorType storeVecTy =
829829
cast<VectorType>(storeScatterOp.getValue().getType());
830830
assert(storeVecTy.getRank() <= 2 &&
@@ -836,33 +836,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
836836
distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
837837

838838
SmallVector<size_t> newRetIndices;
839-
SmallVector<Value> operands =
840-
llvm::to_vector_of<Value>(storeScatterOp->getOperands());
839+
SmallVector<Value> operands = storeScatterOp->getOperands();
841840
SmallVector<Type> operandTypes =
842841
llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
843842
operandTypes[0] = distStoreVecTy;
843+
// Assume offset and mask pproducers will be distributed as well.
844+
operandTypes[2] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
845+
operandTypes[3] = VectorType::get(
846+
{1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
844847

845848
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
846849
rewriter, warpOp, operands, operandTypes, newRetIndices);
847850
SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
848851
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
849852

850-
Value offsetsVec = newStoreScatterOpOperands[2];
851-
Value maskVec = newStoreScatterOpOperands[3];
852-
853853
auto loc = newWarpOp.getLoc();
854-
Value laneId = warpOp.getLaneid();
855854
rewriter.setInsertionPointAfter(newWarpOp);
856-
Value laneOffset =
857-
vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
858-
laneOffset = vector::BroadcastOp::create(
859-
rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
860-
Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
861-
laneMask = vector::BroadcastOp::create(
862-
rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
863-
newStoreScatterOpOperands[2] = laneOffset;
864-
newStoreScatterOpOperands[3] = laneMask;
865-
866855
xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
867856
rewriter, loc, TypeRange{}, newStoreScatterOpOperands,
868857
storeScatterOp->getAttrs());
@@ -892,16 +881,20 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
892881
if (!loadGatherOp.getOffsets())
893882
return rewriter.notifyMatchFailure(loadGatherOp,
894883
"Load op must have offsets argument");
895-
else if (cast<VectorType>(loadGatherOp.getOffsets().getType()).getRank() !=
896-
1)
884+
VectorType offsetsTy =
885+
cast<VectorType>(loadGatherOp.getOffsets().getType());
886+
if (offsetsTy.getRank() != 1)
897887
return rewriter.notifyMatchFailure(loadGatherOp,
898888
"Expected 1D offsets vector");
899889

900890
SmallVector<size_t> newRetIndices;
901-
SmallVector<Value> operands =
902-
llvm::to_vector_of<Value>(loadGatherOp->getOperands());
891+
SmallVector<Value> operands = loadGatherOp->getOperands();
903892
SmallVector<Type> operandTypes =
904893
llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes());
894+
// Assume offset and mask pproducers will be distributed as well.
895+
operandTypes[1] = VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
896+
operandTypes[2] = VectorType::get(
897+
{1}, getElementTypeOrSelf(loadGatherOp.getMask().getType()));
905898

906899
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
907900
rewriter, warpOp, operands, operandTypes, newRetIndices);
@@ -914,21 +907,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
914907
cast<VectorType>(warpOp.getResult(operandIdx).getType());
915908
assert(loadVecTy.getRank() == 1 && "Expected a distributed vector");
916909

917-
Value offsetsVec = newLoadGatherOperands[1];
918-
Value maskVec = newLoadGatherOperands[2];
919910
auto loc = newWarpOp.getLoc();
920-
Value laneId = warpOp.getLaneid();
921911
rewriter.setInsertionPointAfter(newWarpOp);
922-
Value laneOffset =
923-
vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
924-
laneOffset = vector::BroadcastOp::create(
925-
rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
926-
Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
927-
laneMask = vector::BroadcastOp::create(
928-
rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
929-
newLoadGatherOperands[1] = laneOffset;
930-
newLoadGatherOperands[2] = laneMask;
931-
932912
xegpu::LoadGatherOp newOp = rewriter.create<xegpu::LoadGatherOp>(
933913
loc, loadVecTy, newLoadGatherOperands, loadGatherOp->getAttrs());
934914
Value distributedVal = newWarpOp.getResult(operandIdx);

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -323,19 +323,18 @@ gpu.module @test {
323323
// -----
324324
// CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
325325
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
326-
// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
327-
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
328-
// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
329-
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
326+
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
327+
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
330328
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
331-
// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
329+
// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64, l1_hint = #xegpu.cache_hint<cached>,
332330
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
333331
gpu.module @test {
334-
gpu.func @scatter_ops_chunksize(%src: memref<256xf16>, %offset1: vector<16xindex>) {
332+
gpu.func @scatter_ops_chunksize(%src: memref<256xf16>) {
335333
%1 = arith.constant dense<1>: vector<16xi1>
336-
%3 = xegpu.load %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
334+
%offset = arith.constant dense<12> : vector<16xindex>
335+
%3 = xegpu.load %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
337336
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
338-
xegpu.store %3, %src[%offset1], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
337+
xegpu.store %3, %src[%offset], %1 <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
339338
: vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
340339
gpu.return
341340
}
@@ -344,19 +343,18 @@ gpu.module @test {
344343
// -----
345344
// CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
346345
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1xi1>
347-
// CHECK-NEXT: %[[LANE_ID:.*]] = gpu.lane_id
348-
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = vector.extract %arg1[%[[LANE_ID]]] : index from vector<16xindex>
349-
// CHECK-NEXT: %[[LANE_OFFSET_VEC:.*]] = vector.broadcast %[[LANE_OFFSET]] : index to vector<1xindex>
350-
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
346+
// CHECK-NEXT: %[[LANE_OFFSET:.*]] = arith.constant dense<12> : vector<1xindex>
347+
// CHECK-NEXT: %[[LOADED:.*]] = xegpu.load %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
351348
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<1xf16>
352-
// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET_VEC]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
349+
// CHECK-NEXT: xegpu.store %[[LOADED]], %arg0[%[[LANE_OFFSET]]], %[[MASK]] <{l1_hint = #xegpu.cache_hint<cached>,
353350
// CHECK-SAME: l2_hint = #xegpu.cache_hint<uncached>}> : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
354351
gpu.module @test {
355-
gpu.func @scatter_ops(%src: memref<256xf16>, %offset1: vector<16xindex>) {
352+
gpu.func @scatter_ops(%src: memref<256xf16>) {
356353
%1 = arith.constant dense<1>: vector<16xi1>
357-
%3 = xegpu.load %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
354+
%offset = arith.constant dense<12> : vector<16xindex>
355+
%3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
358356
: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
359-
xegpu.store %3, %src[%offset1], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
357+
xegpu.store %3, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
360358
: vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
361359
gpu.return
362360
}

0 commit comments

Comments
 (0)