Skip to content

Commit 5236af8

Browse files
authored
[MLIR][XeGPU] Extend propagation and sg_to_lane distribution pass support broadcast with low rank and scalar source input (#170409)
This PR extends XeGPU layout propagation and distribution for vector.broadcast operation. It relaxes the restriction of layout propagation to allow low-rank and scalar source input, and adds a pattern in sg-to-wi distribution to support the lowering.
1 parent 94ebcfd commit 5236af8

File tree

8 files changed

+550
-18
lines changed

8 files changed

+550
-18
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
223223
InterfaceMethod<"Derive a new layout by dropping InstData",
224224
"xegpu::DistributeLayoutAttr",
225225
"dropInstData">,
226+
InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
227+
"xegpu::DistributeLayoutAttr",
228+
"setUnitDimData",
229+
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
230+
InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
231+
"xegpu::DistributeLayoutAttr",
232+
"setUnitDimLayout",
233+
/*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
226234
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
227235
indices based on the effective layout level.}],
228236
"FailureOr<SmallVector<Value>>",
@@ -283,9 +291,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
283291
}
284292
return true;
285293
}]>,
286-
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
294+
InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
287295
/*retTy=*/"bool",
288296
/*methodName=*/"isSliceOf",
297+
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
298+
299+
InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
300+
/*retTy=*/"bool",
301+
/*methodName=*/"isEqualTo",
289302
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
290303
];
291304
}
@@ -487,6 +500,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
487500
return {};
488501
}
489502

503+
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
504+
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
505+
506+
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
507+
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
508+
490509
/// Delinearizes a linear ID into its multidimensional indices
491510
/// based on the effective level of the layout.
492511
FailureOr<SmallVector<Value>>
@@ -501,6 +520,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
501520

502521
/// Check if this is slice of some other layout.
503522
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
523+
524+
/// Check if this is identical to some other layout.
525+
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
504526

505527
}];
506528

@@ -649,6 +671,12 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
649671
return SliceAttr::get(getContext(), parent, attr.getDims());
650672
}
651673

674+
//set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
675+
DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
676+
677+
//set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
678+
DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
679+
652680
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
653681
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
654682
/// it will coalese two slice operations and return a simplified SliceAttr
@@ -670,7 +698,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
670698

671699
/// Check if this is slice of some other layout.
672700
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
673-
701+
702+
/// Check if this is identical to some other layout.
703+
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
674704
}];
675705

676706
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
405405
OptionalAttr<DenseI64ArrayAttr>: $transpose,
406406
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
407407
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
408-
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
408+
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
409409
OptionalAttr<DistributeLayoutAttr>:$layout);
410410

411411
let results = (outs XeGPU_ValueType: $value);

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,86 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
390390
return genCoordinates(builder, loc, ids, layout, subShape, shape);
391391
}
392392

393+
bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
394+
if (dyn_cast<xegpu::SliceAttr>(other))
395+
return false;
396+
397+
return *this == dyn_cast<xegpu::LayoutAttr>(other);
398+
}
399+
400+
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
401+
DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
402+
auto sgDataOpt = getSgData();
403+
auto instDataOpt = getInstData();
404+
auto laneDataOpt = getLaneData();
405+
406+
SmallVector<int32_t> sgData;
407+
SmallVector<int32_t> instData;
408+
SmallVector<int32_t> laneData;
409+
410+
if (sgDataOpt) {
411+
sgData = llvm::to_vector(sgDataOpt.asArrayRef());
412+
}
413+
if (instDataOpt) {
414+
instData = llvm::to_vector(instDataOpt.asArrayRef());
415+
}
416+
if (laneDataOpt) {
417+
laneData = llvm::to_vector(laneDataOpt.asArrayRef());
418+
}
419+
420+
for (auto dim : unitDims) {
421+
if (dim < static_cast<int64_t>(sgData.size()))
422+
sgData[dim] = 1;
423+
if (dim < static_cast<int64_t>(instData.size()))
424+
instData[dim] = 1;
425+
if (dim < static_cast<int64_t>(laneData.size()))
426+
laneData[dim] = 1;
427+
}
428+
429+
return LayoutAttr::get(
430+
getContext(), getSgLayout(),
431+
sgData.empty() ? DenseI32ArrayAttr()
432+
: DenseI32ArrayAttr::get(getContext(), sgData),
433+
instData.empty() ? DenseI32ArrayAttr()
434+
: DenseI32ArrayAttr::get(getContext(), instData),
435+
getLaneLayout(),
436+
laneData.empty() ? DenseI32ArrayAttr()
437+
: DenseI32ArrayAttr::get(getContext(), laneData),
438+
getOrder());
439+
}
440+
441+
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
442+
DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
443+
auto sgLayoutOpt = getSgLayout();
444+
auto laneLayoutOpt = getLaneLayout();
445+
446+
SmallVector<int32_t> sgLayout;
447+
SmallVector<int32_t> laneLayout;
448+
449+
if (sgLayoutOpt) {
450+
sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
451+
}
452+
if (laneLayoutOpt) {
453+
laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
454+
}
455+
456+
for (auto dim : unitDims) {
457+
if (dim < static_cast<int64_t>(sgLayout.size()))
458+
sgLayout[dim] = 1;
459+
if (dim < static_cast<int64_t>(laneLayout.size()))
460+
laneLayout[dim] = 1;
461+
}
462+
463+
return LayoutAttr::get(
464+
getContext(),
465+
sgLayout.empty() ? DenseI32ArrayAttr()
466+
: DenseI32ArrayAttr::get(getContext(), sgLayout),
467+
getSgData(), getInstData(),
468+
laneLayout.empty() ? DenseI32ArrayAttr()
469+
: DenseI32ArrayAttr::get(getContext(), laneLayout),
470+
getLaneData(), getOrder());
471+
}
472+
393473
//===----------------------------------------------------------------------===//
394474
// XeGPU_SliceAttr
395475
//===----------------------------------------------------------------------===//
@@ -510,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
510590
[&](int64_t dim) { return thisDims.contains(dim); });
511591
}
512592

593+
bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
594+
if (dyn_cast<xegpu::LayoutAttr>(other))
595+
return false;
596+
597+
auto flattenedThis = flatten();
598+
auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
599+
600+
return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
601+
(flattenedThis.getDims() == flattenedOther.getDims()));
602+
}
603+
604+
// Helper function to adjust unit dimensions from sliced space to parent space
605+
static SetVector<int64_t>
606+
adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
607+
ArrayRef<int64_t> sliceDims) {
608+
// Reconstruct parent's non-sliced dimensions
609+
610+
int64_t parentRank = sliceDims.size() + unitDims.size();
611+
llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
612+
sliceDims.end());
613+
SmallVector<int64_t> nonSlicedDims;
614+
for (int64_t i = 0; i < parentRank; ++i) {
615+
if (!slicedDimsSet.contains(i))
616+
nonSlicedDims.push_back(i);
617+
}
618+
619+
// Map unit dims from sliced space to parent space
620+
SetVector<int64_t> adjustUnitDims;
621+
for (auto dim : unitDims) {
622+
if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
623+
adjustUnitDims.insert(nonSlicedDims[dim]);
624+
}
625+
}
626+
627+
return adjustUnitDims;
628+
}
629+
630+
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
631+
DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
632+
SliceAttr attr = flatten();
633+
ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
634+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
635+
636+
SetVector<int64_t> adjustUnitDims =
637+
adjustUnitDimsWithSliceDims(unitDims, sliceDims);
638+
639+
return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
640+
attr.getDims());
641+
}
642+
643+
// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
644+
DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
645+
SliceAttr attr = flatten();
646+
ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
647+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
648+
649+
SetVector<int64_t> adjustUnitDims =
650+
adjustUnitDimsWithSliceDims(unitDims, sliceDims);
651+
652+
return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
653+
attr.getDims());
654+
}
655+
513656
//===----------------------------------------------------------------------===//
514657
// XeGPU_RangeAttr
515658
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -580,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
580580
// Only consider vector to vector broadcasts for now.
581581
VectorType resultTy = broadcast.getResultVectorType();
582582
VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
583-
if (!sourceTy) {
584-
broadcast.emitWarning("Expecting source type to be a vector type.");
583+
// skip layout propagation for non-vector source operand.
584+
if (!sourceTy)
585585
return;
586-
}
587586

588-
// Only consider nD -> nD broadcast.
587+
// Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
589588
if (sourceTy.getRank() != resultTy.getRank()) {
590-
broadcast.emitWarning("Expecting source and result to have same rank.");
589+
auto sourceDims = sourceTy.getShape();
590+
auto resultDims = resultTy.getShape();
591+
SmallVector<int64_t> bcastDims;
592+
auto dimDiff = resultTy.getRank() - sourceTy.getRank();
593+
// adding the missing leading dims
594+
for (int i = 0; i < dimDiff; i++)
595+
bcastDims.push_back(i);
596+
597+
// for the rest dims in the resultTy, if sourceTy dim is 1, then it's
598+
// broadcasted dim
599+
for (size_t i = 0; i < sourceDims.size(); i++)
600+
if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
601+
bcastDims.push_back(i + dimDiff);
602+
603+
// create a slice layout for the source
604+
xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
605+
broadcast->getContext(),
606+
cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
607+
DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
608+
609+
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
591610
return;
592611
}
612+
593613
SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
594-
if (broadcastUnitDims.size() != 1) {
595-
broadcast.emitWarning("Expecting source type to be nD vector only with "
596-
"one broadcasted dimension.");
597-
return;
598-
}
599-
// Propagate the result layout to the source operand.
614+
resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
615+
.setUnitDimData(broadcastUnitDims);
600616
propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
601617
}
602618

@@ -917,7 +933,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
917933
} else {
918934

919935
// The layout is strictly determined by the payload type.
920-
auto payloadTy = dyn_cast<VectorType>(load.getValueType());
936+
VectorType payloadTy = load.getValueType();
921937
if (!payloadTy) {
922938
load.emitWarning("Not propagating, non-vector payload supplied.");
923939
return;
@@ -987,7 +1003,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
9871003
// Currently, for 2D StoreScatterOp we expect that the height dimension of
9881004
// the tensor descriptor is equal to the subgroup size. This is ensured by
9891005
// the op verifier.
990-
auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
1006+
VectorType payloadTy = storeScatter.getValueType();
9911007
if (!payloadTy) {
9921008
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
9931009
return;

0 commit comments

Comments
 (0)