-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU][TransformOps] Add set_op_layout_attr op #166854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,36 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, | |
| /*order=*/nullptr); | ||
| } | ||
|
|
||
| /// Generate `xegpu::LayoutAttr` from op mixed layout values. | ||
| DiagnosedSilenceableFailure | ||
| getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state, | ||
| TransformOpInterface transformOp, | ||
tkarna marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ArrayRef<::mlir::OpFoldResult> mixedSgLayout, | ||
| ArrayRef<::mlir::OpFoldResult> mixedSgData, | ||
| ArrayRef<::mlir::OpFoldResult> mixedInstData, | ||
| xegpu::LayoutAttr &layoutAttr) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing order attribute. It defines the layout of subgroup ids and lane ids. It is optional, if not set, the layout of these ids are row major. It can be added later when you need to handle vector.transpose.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I prefer to add features when they are needed. It's also easy to add relevant test cases then. |
||
| SmallVector<int32_t> sgLayout, sgData, instData; | ||
| auto status = | ||
| convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData); | ||
| if (!status.succeeded()) | ||
| return status; | ||
| auto maybeInstData = instData.empty() | ||
| ? std::nullopt | ||
| : std::optional<ArrayRef<int32_t>>(instData); | ||
|
|
||
| layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData); | ||
|
|
||
| return DiagnosedSilenceableFailure::success(); | ||
| } | ||
|
|
||
| /// Replace xegpu.create_nd_desc op with a new one with the given layout. | ||
| static xegpu::CreateNdDescOp | ||
| setDescLayout(transform::TransformRewriter &rewriter, | ||
|
|
@@ -142,26 +172,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, | |
| } | ||
| Operation *target = *targetOps.begin(); | ||
|
|
||
| SmallVector<int32_t> sgLayout; | ||
| DiagnosedSilenceableFailure status = | ||
| convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| SmallVector<int32_t> sgData; | ||
| status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); | ||
| xegpu::LayoutAttr layoutAttr = nullptr; | ||
| auto status = getLayoutAttrFromOperands(getContext(), state, (*this), | ||
| getMixedSgLayout(), getMixedSgData(), | ||
| getMixedInstData(), layoutAttr); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| SmallVector<int32_t> instData; | ||
| status = | ||
| convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); | ||
| if (!status.succeeded()) | ||
| return status; | ||
| auto maybeInstData = instData.empty() | ||
| ? std::nullopt | ||
| : std::optional<ArrayRef<int32_t>>(instData); | ||
|
|
||
| // For now only create_nd_desc op is supported. | ||
| auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); | ||
| if (!descOp) { | ||
|
|
@@ -173,8 +190,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, | |
| } | ||
|
|
||
| // Set layout attr in desc op's return type. Replaces old desc op. | ||
| auto layoutAttr = | ||
| createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); | ||
| auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); | ||
|
|
||
| // Map result handles. | ||
|
|
@@ -193,6 +208,74 @@ void transform::SetDescLayoutOp::getEffects( | |
| modifiesPayload(effects); | ||
| } | ||
|
|
||
| void transform::SetOpLayoutAttrOp::build( | ||
| OpBuilder &builder, OperationState &ostate, Value target, int64_t index, | ||
| ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, | ||
| ArrayRef<OpFoldResult> mixedInstData, bool result) { | ||
| SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; | ||
| SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; | ||
| dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); | ||
| dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); | ||
| dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); | ||
| build(builder, ostate, target.getType(), | ||
| /*target=*/target, | ||
| /*index=*/index, | ||
| /*sg_layout=*/dynamicSgLayout, | ||
| /*sg_data=*/dynamicSgData, | ||
| /*inst_data=*/dynamicInstData, | ||
| /*static_sg_layout=*/staticSgLayout, | ||
| /*static_sg_data=*/staticSgData, | ||
| /*static_inst_data=*/staticInstData, | ||
| /*result=*/result); | ||
| } | ||
|
|
||
| DiagnosedSilenceableFailure | ||
| transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, | ||
| transform::TransformResults &results, | ||
| transform::TransformState &state) { | ||
| auto targetOps = state.getPayloadOps(getTarget()); | ||
| if (!llvm::hasSingleElement(targetOps)) { | ||
| return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " | ||
| << llvm::range_size(targetOps) << ")"; | ||
| } | ||
| Operation *target = *targetOps.begin(); | ||
|
|
||
| bool resultTarget = getResult(); | ||
|
|
||
| int64_t index = getIndex(); | ||
| if (resultTarget && index >= target->getNumResults()) { | ||
| return emitSilenceableFailure(getLoc()) | ||
| << "Index exceeds the number of op results"; | ||
| } | ||
| if (!resultTarget && index >= target->getNumOperands()) { | ||
| return emitSilenceableFailure(getLoc()) | ||
| << "Index exceeds the number of op operands"; | ||
| } | ||
|
|
||
| xegpu::LayoutAttr layoutAttr = nullptr; | ||
| auto status = getLayoutAttrFromOperands(getContext(), state, (*this), | ||
| getMixedSgLayout(), getMixedSgData(), | ||
| getMixedInstData(), layoutAttr); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| // Set layout attribute for the op result or operand | ||
| if (resultTarget) | ||
| xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr); | ||
tkarna marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else | ||
| xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr); | ||
| return DiagnosedSilenceableFailure::success(); | ||
| } | ||
|
|
||
| void transform::SetOpLayoutAttrOp::getEffects( | ||
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | ||
| onlyReadsHandle(getTargetMutable(), effects); | ||
| onlyReadsHandle(getSgLayoutMutable(), effects); | ||
| onlyReadsHandle(getSgDataMutable(), effects); | ||
| onlyReadsHandle(getInstDataMutable(), effects); | ||
| modifiesPayload(effects); | ||
| } | ||
|
|
||
| namespace { | ||
| class XeGPUTransformDialectExtension | ||
| : public transform::TransformDialectExtension< | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a warning here: XeGPU is refactoring the layout setting API which user only need to set the anchor op's layout, and XeGPU will take care the propagation automatically. The new API that sets anchor ops' layout doesn't require "layout_result_" and "layout_operand_", so may impact this operation also.