Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];

let arguments = (ins
TransformHandleTypeInterface : $target,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
TransformHandleTypeInterface:$target,
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
);

let results = (outs TransformHandleTypeInterface : $transformed);
let results = (outs TransformHandleTypeInterface:$transformed);
let builders = [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
Expand Down Expand Up @@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];
}

def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface
]> {

let summary = "Set xegpu.layout attribute of an op.";
let description = [{
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning here: XeGPU is refactoring the layout setting API which user only need to set the anchor op's layout, and XeGPU will take care the propagation automatically. The new API that sets anchor ops' layout doesn't require "layout_result_" and "layout_operand_", so may impact this operation also.

target operand/result value is defined by the `index` argument. The layout
is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedOptionalAttr<I64Attr, "0">:$index,
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedAttr<UnitAttr, "false">:$result
);

let results = (outs);
let builders = [
OpBuilder<(ins "Value":$target,
"int64_t":$index,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
CArg<"bool", "false">:$result
)>,
];

let assemblyFormat = [{
$target (`result` $result^)? (`index` `=` $index^)?
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
attr-dict `:` qualified(type(operands))
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure apply(
::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state);

::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
Builder b(getContext());
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
}
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
Builder b(getContext());
return getMixedValues(getStaticSgData(), getSgData(), b);
}
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
Builder b(getContext());
return getMixedValues(getStaticInstData(), getInstData(), b);
}
}];
}

#endif // XEGPU_TRANSFORM_OPS
121 changes: 102 additions & 19 deletions mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,36 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
/*order=*/nullptr);
}

/// Generate `xegpu::LayoutAttr` from op mixed layout values.
DiagnosedSilenceableFailure
getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
TransformOpInterface transformOp,
ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
ArrayRef<::mlir::OpFoldResult> mixedSgData,
ArrayRef<::mlir::OpFoldResult> mixedInstData,
xegpu::LayoutAttr &layoutAttr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing order attribute. It defines the layout of subgroup ids and lane ids. It is optional, if not set, the layout of these ids are row major. It can be added later when you need to handle vector.transpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I prefer to add features when they are needed. It's also easy to add relevant test cases then.

SmallVector<int32_t> sgLayout, sgData, instData;
auto status =
convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
if (!status.succeeded())
return status;

status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
if (!status.succeeded())
return status;

status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
if (!status.succeeded())
return status;
auto maybeInstData = instData.empty()
? std::nullopt
: std::optional<ArrayRef<int32_t>>(instData);

layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);

return DiagnosedSilenceableFailure::success();
}

/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
Expand Down Expand Up @@ -142,26 +172,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
Operation *target = *targetOps.begin();

SmallVector<int32_t> sgLayout;
DiagnosedSilenceableFailure status =
convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
if (!status.succeeded())
return status;

SmallVector<int32_t> sgData;
status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
xegpu::LayoutAttr layoutAttr = nullptr;
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
getMixedSgLayout(), getMixedSgData(),
getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;

SmallVector<int32_t> instData;
status =
convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
if (!status.succeeded())
return status;
auto maybeInstData = instData.empty()
? std::nullopt
: std::optional<ArrayRef<int32_t>>(instData);

// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
Expand All @@ -173,8 +190,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}

// Set layout attr in desc op's return type. Replaces old desc op.
auto layoutAttr =
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);

// Map result handles.
Expand All @@ -193,6 +208,74 @@ void transform::SetDescLayoutOp::getEffects(
modifiesPayload(effects);
}

void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData, bool result) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
build(builder, ostate, target.getType(),
/*target=*/target,
/*index=*/index,
/*sg_layout=*/dynamicSgLayout,
/*sg_data=*/dynamicSgData,
/*inst_data=*/dynamicInstData,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
/*result=*/result);
}

DiagnosedSilenceableFailure
transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto targetOps = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(targetOps)) {
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
<< llvm::range_size(targetOps) << ")";
}
Operation *target = *targetOps.begin();

bool resultTarget = getResult();

int64_t index = getIndex();
if (resultTarget && index >= target->getNumResults()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op results";
}
if (!resultTarget && index >= target->getNumOperands()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op operands";
}

xegpu::LayoutAttr layoutAttr = nullptr;
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
getMixedSgLayout(), getMixedSgData(),
getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;

// Set layout attribute for the op result or operand
if (resultTarget)
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
else
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
return DiagnosedSilenceableFailure::success();
}

void transform::SetOpLayoutAttrOp::getEffects(
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
onlyReadsHandle(getSgLayoutMutable(), effects);
onlyReadsHandle(getSgDataMutable(), effects);
onlyReadsHandle(getInstDataMutable(), effects);
modifiesPayload(effects);
}

namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
Expand Down
47 changes: 47 additions & 0 deletions mlir/python/mlir/dialects/transform/xegpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,50 @@ def __init__(
loc=loc,
ip=ip,
)


@_ods_cext.register_operation(_Dialect, replace=True)
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
"""Specialization for SetOpLayoutAttrOp class."""

def __init__(
self,
target: Union[Operation, Value],
sg_layout: MixedValues,
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
):
inst_data = [] if inst_data is None else inst_data
(
dynamic_sg_layout,
static_sg_layout,
_,
) = _dispatch_dynamic_index_list(sg_layout)
(
dynamic_sg_data,
static_sg_data,
_,
) = _dispatch_dynamic_index_list(sg_data)
(
dynamic_inst_data,
static_inst_data,
_,
) = _dispatch_dynamic_index_list(inst_data)
super().__init__(
_get_op_result_or_value(target),
dynamic_sg_layout,
dynamic_sg_data,
dynamic_inst_data,
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
index=index,
result=result,
loc=loc,
ip=ip,
)
58 changes: 58 additions & 0 deletions mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_bad_result_index
func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error@below {{Index exceeds the number of op results}}
transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error@below {{Index exceeds the number of op operands}}
transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_multiple
func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
%3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error@below {{Requires exactly one targetOp handle (got 2)}}
transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}
Loading