diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 8728e666cd59d..70d424bae9285 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -21,13 +21,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/OpBase.td" include "mlir/IR/RegionKindInterface.td" -// This is roughly similar to OpFoldResult assuming the handle produces a single -// value in the payload IR. -def TransformAnyParamTypeOrAnyHandle : Type< - Or<[TransformHandleTypeInterface.predicate, - TransformParamTypeInterface.predicate]>, - "transform any param type or any handle type">; - //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td index 2d9a26e165b67..3e3fff4c63d2b 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td @@ -103,4 +103,9 @@ def TransformAnyHandle : Type< TransformValueHandleTypeInterface.predicate]>, "transform operation or value handle">; +def TransformAnyParamTypeOrAnyHandle : Type< + Or<[TransformHandleTypeInterface.predicate, + TransformParamTypeInterface.predicate]>, + "transform any param type or any handle type">; + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt index 9f57627c321fb..cb1e9d01821a2 100644 --- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..5924606402a02 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td) +mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls) +mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen) + +add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h new file mode 100644 index 0000000000000..3e16d1e4a7c94 --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h @@ -0,0 +1,28 @@ +//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H +#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace xegpu { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace xegpu +} // namespace mlir + +#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td new file mode 100644 index 0000000000000..b985d5450be0e --- /dev/null +++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td @@ -0,0 +1,81 @@ +//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef XEGPU_TRANSFORM_OPS +#define XEGPU_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def SetDescLayoutOp : Op, + TransformOpInterface +]> { + + let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result."; + let description = [{ + Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout` + attribute to the result tensor descriptor. The layout is defined by the + `sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle + to the transformed op. + }]; + + let arguments = (ins + TransformHandleTypeInterface : $target, + Variadic : $sg_layout, + Variadic : $sg_data, + Variadic : $inst_data, + DefaultValuedOptionalAttr:$static_sg_layout, + DefaultValuedOptionalAttr:$static_sg_data, + DefaultValuedOptionalAttr:$static_inst_data + ); + + let results = (outs TransformHandleTypeInterface : $transformed); + let builders = [ + OpBuilder<(ins "Value":$target, + "ArrayRef":$mixedSgLayout, + "ArrayRef":$mixedSgData, + "ArrayRef":$mixedInstData + )>, + ]; + + let assemblyFormat = [{ + $target + `sg_layout` `=` custom($sg_layout, $static_sg_layout) + `sg_data` `=` custom($sg_data, $static_sg_data) + (`inst_data` `=` custom($inst_data, $static_inst_data)^)? + attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::transform::TransformResults &transformResults, + ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() { + Builder b(getContext()); + return getMixedValues(getStaticSgLayout(), getSgLayout(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() { + Builder b(getContext()); + return getMixedValues(getStaticSgData(), getSgData(), b); + } + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() { + Builder b(getContext()); + return getMixedValues(getStaticInstData(), getInstData(), b); + } + }]; +} + +#endif // XEGPU_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt index 31167e6af908b..46b8251a57797 100644 --- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt new file mode 100644 index 0000000000000..48fe841afaa83 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRXeGPUTransformOps + XeGPUTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/ + + DEPENDS + MLIRXeGPUTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRXeGPUDialect + MLIRXeGPUTransforms + MLIRIR + MLIRTransformDialect + MLIRFuncDialect + MLIRSCFDialect +) diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp new file mode 100644 index 0000000000000..8943ba09d9c34 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp @@ -0,0 +1,225 @@ +//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" + +#include + +using namespace mlir; +using namespace mlir::transform; + +/// Assuming that `ofr` is an index attr or a param of index type +/// or a transform dialect handle mapped to exactly one op +/// with one index result, get that value and cast it to int type. +static DiagnosedSilenceableFailure convertMixedValuesToInt( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVectorImpl &result, ArrayRef ofrs) { + for (OpFoldResult ofr : ofrs) { + // Attribute case. + if (auto attr = dyn_cast(ofr)) { + if (auto intAttr = dyn_cast(attr)) { + result.push_back(intAttr.getInt()); + continue; + } + return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; + } + + // Transform param case. + Value transformValue = cast(ofr); + if (isa(transformValue.getType())) { + ArrayRef params = state.getParams(transformValue); + if (params.size() != 1) + return transformOp.emitDefiniteFailure() + << "requires exactly one parameter associated"; + result.push_back( + cast(params.front()).getValue().getSExtValue()); + continue; + } + + // Payload value case. + auto payloadOps = state.getPayloadOps(transformValue); + if (!llvm::hasSingleElement(payloadOps)) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "handle must be mapped to exactly one payload op"; + diag.attachNote(transformValue.getLoc()) + << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; + return diag; + } + + Operation *op = *payloadOps.begin(); + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + + IntegerAttr intAttr; + if (!matchPattern(op->getResult(0), m_Constant(&intAttr))) + return transformOp.emitSilenceableError() + << "requires param or handle to be the result of a constant like " + "op"; + + result.push_back(intAttr.getInt()); + } + return DiagnosedSilenceableFailure::success(); +} + +/// Create a layout attribute from the given parameters. +static xegpu::LayoutAttr +createLayoutAttr(MLIRContext *ctx, ArrayRef sgLayout, + ArrayRef sgData, + std::optional> instData) { + return xegpu::LayoutAttr::get( + ctx, DenseI32ArrayAttr::get(ctx, sgLayout), + DenseI32ArrayAttr::get(ctx, sgData), + instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr, + /*lane_layout=*/nullptr, + /*lane_data=*/nullptr, + /*order=*/nullptr); +} + +/// Replace xegpu.create_nd_desc op with a new one with the given layout. +static xegpu::CreateNdDescOp +setDescLayout(transform::TransformRewriter &rewriter, + xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) { + assert(descOp.getMixedOffsets().size() == 0 && + "create desc op with offsets is not supported"); + auto oldTensorDesc = descOp.getType(); + auto descType = xegpu::TensorDescType::get( + oldTensorDesc.getShape(), oldTensorDesc.getElementType(), + /*array_length=*/oldTensorDesc.getArrayLength(), + /*boundary_check=*/oldTensorDesc.getBoundaryCheck(), + /*memory_space=*/oldTensorDesc.getMemorySpace(), + /*layout=*/layout); + + rewriter.setInsertionPointAfter(descOp); + auto newDescOp = rewriter.replaceOpWithNewOp( + descOp, descType, descOp.getSource(), descOp.getMixedSizes(), + descOp.getMixedStrides()); + return newDescOp; +} + +void transform::SetDescLayoutOp::build(OpBuilder &builder, + OperationState &result, Value target, + ArrayRef mixedSgLayout, + ArrayRef mixedSgData, + ArrayRef mixedInstData) { + SmallVector staticSgLayout, staticSgData, staticInstData; + SmallVector dynamicSgLayout, dynamicSgData, dynamicInstData; + dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); + dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); + dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); + build(builder, result, target.getType(), + /*target=*/target, + /*sg_layout=*/dynamicSgLayout, + /*sg_data=*/dynamicSgData, + /*inst_data=*/dynamicInstData, + /*static_sg_layout=*/staticSgLayout, + /*static_sg_data=*/staticSgData, + /*static_inst_data=*/staticInstData); +} + +DiagnosedSilenceableFailure +transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto targetOps = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(targetOps)) { + return emitDefiniteFailure() << "requires exactly one targetOp handle (got " + << llvm::range_size(targetOps) << ")"; + } + Operation *target = *targetOps.begin(); + + SmallVector sgLayout; + DiagnosedSilenceableFailure status = + convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); + if (!status.succeeded()) + return status; + + SmallVector sgData; + status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); + if (!status.succeeded()) + return status; + + SmallVector instData; + status = + convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); + if (!status.succeeded()) + return status; + auto maybeInstData = instData.empty() + ? std::nullopt + : std::optional>(instData); + + // For now only create_nd_desc op is supported. + auto descOp = dyn_cast(target); + if (!descOp) { + auto diag = emitSilenceableFailure(getLoc()) + << "Expected a xegpu.create_nd_desc op, but got: " + << target->getName(); + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + // Set layout attr in desc op's return type. Replaces old desc op. + auto layoutAttr = + createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); + auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); + + // Map result handles. + results.set(cast(getTransformed()), {newdescOp.getOperation()}); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::SetDescLayoutOp::getEffects( + ::llvm::SmallVectorImpl &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getSgLayoutMutable(), effects); + onlyReadsHandle(getSgDataMutable(), effects); + onlyReadsHandle(getInstDataMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +namespace { +class XeGPUTransformDialectExtension + : public transform::TransformDialectExtension< + XeGPUTransformDialectExtension> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension) + + using Base::Base; + + void init(); +}; + +void XeGPUTransformDialectExtension::init() { + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + >(); +} +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc" + +void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index 3839172fd0b42..c857c38df717c 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -56,6 +56,7 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" @@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { transform::registerSMTExtension(registry); transform::registerTuneExtension(registry); vector::registerTransformDialectExtension(registry); + xegpu::registerTransformDialectExtension(registry); arm_neon::registerTransformDialectExtension(registry); arm_sve::registerTransformDialectExtension(registry); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 20ed3ab41a0b4..51c75764faf3c 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -322,6 +322,15 @@ declare_mlir_dialect_extension_python_bindings( "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" ) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/XeGPUTransformOps.td + SOURCES + dialects/transform/xegpu.py + DIALECT_NAME transform + EXTENSION_NAME xegpu_transform) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td new file mode 100644 index 0000000000000..5a5e7b912c4a5 --- /dev/null +++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td @@ -0,0 +1,19 @@ +//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the XeGPU transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS +#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS + +include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td" + +#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py new file mode 100644 index 0000000000000..2918bf592880a --- /dev/null +++ b/mlir/python/mlir/dialects/transform/xegpu.py @@ -0,0 +1,66 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._xegpu_transform_ops_gen import * +from .._xegpu_transform_ops_gen import _Dialect + +try: + from ...ir import * + from .._ods_common import _cext as _ods_cext + from .._ods_common import ( + MixedValues, + get_op_result_or_value as _get_op_result_or_value, + _dispatch_dynamic_index_list, + ) + +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union, Optional + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SetDescLayoutOp(SetDescLayoutOp): + """Specialization for SetDescLayoutOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + sg_layout: MixedValues, + sg_data: MixedValues, + *, + inst_data: Optional[MixedValues] = None, + loc=None, + ip=None, + ): + target_handle = _get_op_result_or_value(target) + inst_data = [] if inst_data is None else inst_data + ( + dynamic_sg_layout, + static_sg_layout, + _, + ) = _dispatch_dynamic_index_list(sg_layout) + ( + dynamic_sg_data, + static_sg_data, + _, + ) = _dispatch_dynamic_index_list(sg_data) + ( + dynamic_inst_data, + static_inst_data, + _, + ) = _dispatch_dynamic_index_list(inst_data) + + super().__init__( + target_handle.type, + target_handle, + dynamic_sg_layout, + dynamic_sg_data, + dynamic_inst_data, + static_sg_layout=static_sg_layout, + static_sg_data=static_sg_data, + static_inst_data=static_inst_data, + loc=loc, + ip=ip, + ) diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir new file mode 100644 index 0000000000000..303584518f9f4 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics + +func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) { + %c32 = arith.constant 32 : index // expected-note {{target op}} + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error@below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir new file mode 100644 index 0000000000000..23e1cd946b4cd --- /dev/null +++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @set_desc_layout +func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.block_tdesc_attr + // CHECK-SAME: #xegpu.layout> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_desc_layout_minimal +func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.layout> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @set_desc_layout_param +func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) { + // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0 + // CHECK-SAME: #xegpu.layout> + %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // CHECK: transform.xegpu.set_desc_layout %{{.*}} + %layout0 = transform.param.constant 8 : i64 -> !transform.param + %1 = transform.xegpu.set_desc_layout %0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op, !transform.param) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py new file mode 100644 index 0000000000000..1c8a2bcc6a2fb --- /dev/null +++ b/mlir/test/python/dialects/transform_xegpu_ext.py @@ -0,0 +1,51 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import xegpu +from mlir.dialects.transform import structured + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def setDescLayoutMinimal(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.create_nd_tdesc"), + ) + with InsertionPoint(sequence.body): + xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16]) + transform.YieldOp() + # CHECK-LABEL: TEST: setDescLayoutMinimal + # CHECK: %0 = transform.xegpu.set_desc_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + + +@run +def setDescLayoutInstData(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("xegpu.create_nd_tdesc"), + ) + with InsertionPoint(sequence.body): + xegpu.SetDescLayoutOp( + sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16] + ) + transform.YieldOp() + # CHECK-LABEL: TEST: setDescLayoutInstData + # CHECK: %0 = transform.xegpu.set_desc_layout % + # CHECK: sg_layout = [6, 4] + # CHECK: sg_data = [32, 16] + # CHECK: inst_data = [8, 16]