Skip to content

Commit ca61a9d

Browse files
authored
[mlir][xegpu] Support offset arguments in LoadNd, StoreNd and PrefetchNd subgroup distribution. (#160417)
Currently offsets are given as operands of `CreateNd` op. Sg distribution does not support offsets arguments at the consumer. This PR adds support for offsets given at the consumer (like LoadNd). With this change, it is required to specify the offsets at consumer op (LoadNd, StoreNd, PrefetchNd) of the tile or otherwise distribution will fail. This also removes the need for UpdateNdOffset op. PR removes the support for UpdateNdOffset .
1 parent 1b0553c commit ca61a9d

File tree

2 files changed

+199
-298
lines changed

2 files changed

+199
-298
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 81 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0
268268
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
269269
/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
270270
/// ...
271-
/// %td = xegpu.create_nd_tdesc %arg0[0, 0]
271+
/// %td = xegpu.create_nd_tdesc %arg0
272272
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
273273
/// vector.yield %td
274274
/// }
@@ -277,11 +277,11 @@ struct MoveFuncBodyToWarpExecuteOnLane0
277277
/// ```
278278
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (...) {
279279
/// ...
280-
/// %dead = xegpu.create_nd_tdesc %arg0[0, 0]
280+
/// %dead = xegpu.create_nd_tdesc %arg0
281281
/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32, #layout0>
282282
/// vector.yield %arg0, %dead
283283
/// }
284-
/// %td = xegpu.create_nd_tdesc %r#0[0, 0]: memref<4x8xf32>
284+
/// %td = xegpu.create_nd_tdesc %r#0: memref<4x8xf32>
285285
/// -> !xegpu.tensor_desc<4x8xf32>
286286
///
287287
/// ```
@@ -301,6 +301,10 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
301301
if (!layout)
302302
return rewriter.notifyMatchFailure(
303303
descOp, "the tensor descriptor lacks layout attribute");
304+
// CreateNdOp must not have offsets.
305+
if (descOp.getMixedOffsets().size())
306+
return rewriter.notifyMatchFailure(
307+
descOp, "xegpu::CreateNdDescOp must not have offsets");
304308

305309
SmallVector<size_t> newRetIndices;
306310
rewriter.setInsertionPoint(warpOp);
@@ -339,22 +343,23 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
339343
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
340344
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
341345
/// ...
342-
/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>,
346+
/// xegpu.store_nd %arg0, %arg1 [%x, %y]: vector<4x8xf32>,
343347
/// !xegpu.tensor_desc<4x8xf32, #layout0>
344348
/// }
345349
/// ```
346350
/// To
347351
/// ```
348352
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
349-
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
350-
/// gpu.yield %arg0, %arg1: vector<4x8xf32>, !xegpu.tensor_desc<4x8xf32,
351-
/// #layout0>
353+
/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
354+
/// ...
355+
/// gpu.yield %arg0, %arg1, %x, %y: vector<4x8xf32>,
356+
/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index
352357
/// }
353358
/// %0 = vector.shape_cast %r#0: vector<4x1xf32> to vector<4xf32>
354359
/// %1 = unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
355360
/// #layout0>
356361
/// -> !xegpu.tensor_desc<4x8xf32>
357-
/// xegpu.store_nd %0, %1: vector<4xf32>,
362+
/// xegpu.store_nd %0, %1 [%r#2, %r#3]: vector<4xf32>,
358363
/// !xegpu.tensor_desc<4x8xf32>
359364
///
360365
/// ```
@@ -368,10 +373,15 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
368373
if (!storeOp)
369374
return failure();
370375

371-
int64_t offsetSize = static_cast<int64_t>(storeOp.getOffsets().size());
372-
if ((offsetSize != 0) || storeOp.getConstOffsetsAttr())
373-
return failure();
374-
376+
SmallVector<OpFoldResult> offsets = storeOp.getMixedOffsets();
377+
// Expecting offsets to be present.
378+
if (offsets.empty())
379+
return rewriter.notifyMatchFailure(storeOp,
380+
"the store op must have offsets");
381+
SmallVector<Value> offsetsAsValues =
382+
vector::getAsValues(rewriter, storeOp.getLoc(), offsets);
383+
SmallVector<Type> offsetTypes = llvm::to_vector(
384+
llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
375385
xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
376386
xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
377387
if (!layout)
@@ -387,13 +397,13 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
387397
distributedTypeByWarpOpOrFailure.value();
388398

389399
SmallVector<size_t> newRetIndices;
400+
SmallVector<Value> newYieldedValues = {storeOp.getValue(),
401+
storeOp.getTensorDesc()};
402+
SmallVector<Type> newYieldedTypes = {distributedTypeByWarpOp, tensorDescTy};
403+
newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
404+
newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
390405
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
391-
rewriter, warpOp,
392-
/* new yielded values = */
393-
ValueRange{storeOp.getValue(), storeOp.getTensorDesc()},
394-
/* new yielded types = */
395-
TypeRange{distributedTypeByWarpOp, storeOp.getTensorDescType()},
396-
newRetIndices);
406+
rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
397407
// Create a new store op outside the warp op with the distributed vector
398408
// type. Tensor descriptor is not distributed.
399409
rewriter.setInsertionPointAfter(newWarpOp);
@@ -418,6 +428,9 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
418428
newStoreOperands.push_back(
419429
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
420430
distributedTensorDescTy, rewriter));
431+
// Collect offsets.
432+
for (size_t i = 2; i < newRetIndices.size(); ++i)
433+
newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
421434

422435
auto newStoreOp =
423436
xegpu::StoreNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
@@ -491,9 +504,15 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
491504
loadOp,
492505
"xegpu::LoadNdOp require chip information to determine transpose "
493506
"requirement");
494-
int64_t offsetSize = static_cast<int64_t>(loadOp.getOffsets().size());
495-
if ((offsetSize != 0) || loadOp.getConstOffsetsAttr())
496-
return failure();
507+
// Expecting offsets to be present.
508+
SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
509+
if (offsets.empty())
510+
return rewriter.notifyMatchFailure(loadOp,
511+
"the load op must have offsets");
512+
SmallVector<Value> offsetsAsValues =
513+
vector::getAsValues(rewriter, loadOp.getLoc(), offsets);
514+
SmallVector<Type> offsetTypes = llvm::to_vector(
515+
llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
497516

498517
xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
499518
xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
@@ -506,10 +525,12 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
506525
cast<VectorType>(warpOp.getResult(operandIdx).getType());
507526

508527
SmallVector<size_t> newRetIndices;
528+
SmallVector<Value> newYieldedValues = {loadOp.getTensorDesc()};
529+
SmallVector<Type> newYieldedTypes = {tensorDescTy};
530+
newYieldedValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
531+
newYieldedTypes.append(offsetTypes.begin(), offsetTypes.end());
509532
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
510-
rewriter, warpOp,
511-
/* new yielded values = */ loadOp.getTensorDesc(),
512-
/* new yielded types = */ tensorDescTy, newRetIndices);
533+
rewriter, warpOp, newYieldedValues, newYieldedTypes, newRetIndices);
513534

514535
// Create a new load op outside the warp op with the distributed vector
515536
// type.
@@ -523,11 +544,15 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
523544
loadOp.getTensorDescType().dropLayouts(); // Distributed tensor
524545
// descriptor type does not
525546
// contain layout info.
547+
SmallVector<Value> newLoadOperands{
548+
resolveDistributedTy(newWarpOp.getResult(newRetIndices[0]),
549+
distributedTensorDescTy, rewriter)};
550+
// Collect offsets.
551+
for (size_t i = 1; i < newRetIndices.size(); ++i)
552+
newLoadOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
526553
auto newLoadOp = xegpu::LoadNdOp::create(
527554
rewriter, newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
528-
resolveDistributedTy(newWarpOp->getResult(newRetIndices[0]),
529-
distributedTensorDescTy, rewriter),
530-
loadOp->getAttrs());
555+
newLoadOperands, loadOp->getAttrs());
531556
xegpu::removeLayoutAttrs(newLoadOp);
532557
// Set the packed attribute if the layout requires it.
533558
newLoadOp.setPacked(requirePacked(layout));
@@ -677,85 +702,6 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
677702
}
678703
};
679704

680-
/// Sink an update_nd_offset op feeding into yield op of an enclosing
681-
/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
682-
/// original op that will not be used by the yield op (and should be cleaned
683-
/// up later). The yield op will bypass the updateOp's arguments. The tensor
684-
/// descriptor type is not distributed. Appropriate cast ops are inserted if
685-
/// the distributed types does not match expected xegpu SIMT types.
686-
/// Example:
687-
/// ```
688-
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
689-
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
690-
/// (!xegpu.tensor_desc<4x8xf32, #layout0>) {
691-
/// ...
692-
/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
693-
/// !xegpu.tensor_desc<4x8xf32, #layout0>
694-
/// gpu.yield %update
695-
/// }
696-
/// ...
697-
/// ```
698-
/// To
699-
/// ```
700-
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (
701-
/// !xegpu.tensor_desc<4x8xf32, #layout0>,
702-
/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
703-
/// ...
704-
/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
705-
/// !xegpu.tensor_desc<4x8xf32, #layout0> gpu.yield %dead, %arg0
706-
/// gpu.yield %dead, %arg0, %c32, %c16
707-
/// }
708-
/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
709-
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
710-
/// %1 = xegpu.update_nd_offset %0, [%r#2, %r#3]:
711-
/// !xegpu.tensor_desc<4x8xf32>
712-
/// ...
713-
/// ```
714-
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
715-
using gpu::WarpDistributionPattern::WarpDistributionPattern;
716-
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
717-
PatternRewriter &rewriter) const override {
718-
OpOperand *operand =
719-
getWarpResult(warpOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
720-
if (!operand)
721-
return rewriter.notifyMatchFailure(
722-
warpOp, "warp result is not a xegpu::UpdateNdOffset op");
723-
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
724-
unsigned operandIdx = operand->getOperandNumber();
725-
726-
SmallVector<size_t> newRetIndices;
727-
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
728-
rewriter, warpOp, updateOp->getOperands(), updateOp.getOperandTypes(),
729-
newRetIndices);
730-
rewriter.setInsertionPointAfter(newWarpOp);
731-
// new update op does not have layout attribute.
732-
xegpu::TensorDescType distributedTensorDescTy =
733-
updateOp.getTensorDescType().dropLayouts();
734-
SmallVector<Value> newUpdateOperands =
735-
llvm::map_to_vector(newRetIndices, [&](size_t i) {
736-
// For the tensor descriptor operand, the layout attribute is
737-
// dropped after distribution. Types needs to be resolved in this
738-
// case.
739-
if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
740-
return resolveDistributedTy(newWarpOp.getResult(i),
741-
distributedTensorDescTy, rewriter);
742-
}
743-
return newWarpOp.getResult(i);
744-
});
745-
// Create a new update op outside the warp op.
746-
auto newUpdateOp = xegpu::UpdateNdOffsetOp::create(
747-
rewriter, newWarpOp.getLoc(), distributedTensorDescTy,
748-
newUpdateOperands, updateOp->getAttrs());
749-
xegpu::removeLayoutAttrs(newUpdateOp);
750-
Value distributedVal = newWarpOp.getResult(operandIdx);
751-
// Resolve the distributed type with the original type.
752-
Value typeResolved = resolveDistributedTy(
753-
newUpdateOp.getResult(), distributedVal.getType(), rewriter);
754-
rewriter.replaceAllUsesWith(distributedVal, typeResolved);
755-
return success();
756-
}
757-
};
758-
759705
/// Distribute a prefetch_nd op at the end of enclosing
760706
/// `gpu.warp_execute_on_lane_0`. In case arguments for the prefetch are passed
761707
/// through the warp op interface they would be propagated as returned values.
@@ -769,18 +715,19 @@ struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
769715
/// #layout0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
770716
/// gpu.warp_execute_on_lane_0(%laneid) -> () {
771717
/// ...
772-
/// xegpu.prefetch_nd %arg0 : !xegpu.tensor_desc<4x8xf32, #layout0>
718+
/// xegpu.prefetch_nd %arg0 [%x, %y] : !xegpu.tensor_desc<4x8xf32, #layout0>
773719
/// }
774720
/// ```
775721
/// To
776722
/// ```
777723
/// %r:1 = gpu.warp_execute_on_lane_0(%laneid) -> (
778-
/// !xegpu.tensor_desc<4x8xf32, #layout0>) {
779-
/// gpu.yield %arg0: !xegpu.tensor_desc<4x8xf32, #layout0>
724+
/// !xegpu.tensor_desc<4x8xf32, #layout0>, index, index) {
725+
/// gpu.yield %arg0, %x, %y: !xegpu.tensor_desc<4x8xf32, #layout0>, index,
726+
/// index
780727
/// }
781728
/// %1 = unrealized_conversion_cast %r#0: !xegpu.tensor_desc<4x8xf32,
782729
/// #layout0> -> !xegpu.tensor_desc<4x8xf32>
783-
/// xegpu.prefetch_nd %1 : !xegpu.tensor_desc<4x8xf32>
730+
/// xegpu.prefetch_nd %1 [%r#1, %r#2] : !xegpu.tensor_desc<4x8xf32>
784731
///
785732
/// ```
786733
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
@@ -793,17 +740,25 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
793740
if (!prefetchOp)
794741
return failure();
795742

796-
int64_t offsetSize = static_cast<int64_t>(prefetchOp.getOffsets().size());
797-
if ((offsetSize != 0) || prefetchOp.getConstOffsetsAttr())
798-
return failure();
743+
SmallVector<OpFoldResult> offsets = prefetchOp.getMixedOffsets();
744+
// PrefetchNdOp must have offsets.
745+
if (offsets.empty())
746+
return rewriter.notifyMatchFailure(prefetchOp,
747+
"the prefetch op must have offsets");
748+
SmallVector<Value> offsetsAsValues =
749+
vector::getAsValues(rewriter, prefetchOp.getLoc(), offsets);
750+
SmallVector<Type> offsetTypes = llvm::to_vector(
751+
llvm::map_range(offsetsAsValues, [](Value v) { return v.getType(); }));
799752

800753
xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
801754
if (!layout)
802755
return rewriter.notifyMatchFailure(
803756
prefetchOp, "the source tensor descriptor lacks layout attribute");
804757

805-
SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
806-
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
758+
SmallVector<Value> newYieldValues = {prefetchOp.getTensorDesc()};
759+
SmallVector<Type> newYieldTypes = {prefetchOp.getTensorDescType()};
760+
newYieldValues.append(offsetsAsValues.begin(), offsetsAsValues.end());
761+
newYieldTypes.append(offsetTypes.begin(), offsetTypes.end());
807762
SmallVector<size_t> newRetIndices;
808763
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
809764
rewriter, warpOp, newYieldValues, newYieldTypes, newRetIndices);
@@ -814,6 +769,9 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
814769
rewriter.setInsertionPointAfter(newWarpOp);
815770
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
816771
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
772+
// Collect offsets.
773+
for (size_t i = 1; i < newRetIndices.size(); ++i)
774+
newPrefetchOperands.push_back(newWarpOp.getResult(newRetIndices[i]));
817775
xegpu::PrefetchNdOp::create(rewriter, newWarpOp.getLoc(), TypeRange{},
818776
newPrefetchOperands, prefetchOp->getAttrs());
819777
xegpu::removeLayoutAttrs(prefetchOp);
@@ -1456,15 +1414,14 @@ struct XeGPUSubgroupDistributePass final
14561414

14571415
void xegpu::populateXeGPUSubgroupDistributePatterns(
14581416
RewritePatternSet &patterns) {
1459-
patterns
1460-
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1461-
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1462-
GpuBarrierDistribution, VectorMultiReductionDistribution,
1463-
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1464-
VectorBitcastDistribution,
1465-
MemrefExtractAlignedPointerAsIndexDistribution>(
1466-
patterns.getContext(),
1467-
/*pattern benefit=*/regularPatternBenefit);
1417+
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1418+
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1419+
GpuBarrierDistribution, VectorMultiReductionDistribution,
1420+
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
1421+
VectorBitcastDistribution,
1422+
MemrefExtractAlignedPointerAsIndexDistribution>(
1423+
patterns.getContext(),
1424+
/*pattern benefit=*/regularPatternBenefit);
14681425
patterns.add<VectorShapeCastDistribution>(
14691426
patterns.getContext(),
14701427
/*pattern benefit=*/highPatternBenefit);

0 commit comments

Comments
 (0)