diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index 45da8aca..dfa72232 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(TritonToLinalg) add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) +add_subdirectory(TritonToUnstructured) add_subdirectory(StructuredToMemref) diff --git a/include/triton-shared/Conversion/TritonToUnstructured/CMakeLists.txt b/include/triton-shared/Conversion/TritonToUnstructured/CMakeLists.txt new file mode 100644 index 00000000..116a3e3f --- /dev/null +++ b/include/triton-shared/Conversion/TritonToUnstructured/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToUnstructured) +add_public_tablegen_target(TritonToUnstructuredConversionPassIncGen) diff --git a/include/triton-shared/Conversion/TritonToUnstructured/Passes.h b/include/triton-shared/Conversion/TritonToUnstructured/Passes.h new file mode 100644 index 00000000..a2016c7a --- /dev/null +++ b/include/triton-shared/Conversion/TritonToUnstructured/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H +#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/TritonToUnstructured/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/TritonToUnstructured/Passes.td b/include/triton-shared/Conversion/TritonToUnstructured/Passes.td new file mode 100644 index 00000000..542d087c --- /dev/null +++ b/include/triton-shared/Conversion/TritonToUnstructured/Passes.td @@ -0,0 +1,15 @@ +#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES +#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonToUnstructured : Pass<"triton-to-unstructured", "mlir::ModuleOp"> { + let summary = "Transforms tt.addptr ops into offset accumulation ops"; + let constructor = "triton::createTritonToUnstructuredPass()"; + let options = [ + Option<"offsetBitWidth", "offset-bit-width", "size_t", /*default*/"32", + "Bitwidth used for the starting offset of each pointer"> + ]; +} + +#endif diff --git a/include/triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h b/include/triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h new file mode 100644 index 00000000..03ccdcd1 --- /dev/null +++ b/include/triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h @@ -0,0 +1,17 @@ +#ifndef TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H +#define TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createTritonToUnstructuredPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h index bd01afd0..fbead0c5 100644 --- a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h @@ -1,18 +1,20 @@ #ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ #define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_ -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/Types.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" + #include "triton/Dialect/Triton/IR/Dialect.h" -#include "mlir/IR/Dialect.h" +namespace mlir { +namespace tts { +namespace utils { +mlir::Value getScalarValue(mlir::Value operand, mlir::Location loc, + mlir::OpBuilder &builder); +} +} // namespace tts +} // namespace mlir //===----------------------------------------------------------------------===// // TritonStructured Operations diff --git a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td index c0f89bfc..599524ab 100644 --- a/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td +++ b/include/triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.td @@ -120,8 +120,6 @@ def TTS_MakeTensorPtrOp //let hasCanonicalizer = 1; } -// SameVariadicResultSize -// AttrSizedResultSegments def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> { let summary = "Placeholder for the structured pointer states computed during PtrAnalysis."; let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites."; @@ -145,6 +143,51 @@ def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSe let hasVerifier = 1; } +def TTS_GatherOp : TTS_Op<"gather", [ + MemoryEffects<[MemRead]>, + AttrSizedOperandSegments, + OptionalTypesMatchWith<"mask type matches ptr type", "offset", "mask", "triton::getI1SameShape($_self)">, + OptionalTypesMatchWith<"other matches ptr type", "ptr", "other", "triton::getPointeeType($_self)"> +]> { + let summary = "optionally load data from in memory to fill a portion of the tensor"; + + let arguments = ( + ins + TT_Ptr:$ptr, + TT_IntLike:$offset, + Optional:$mask, + Optional:$other + ); + + let results = (outs TT_Type:$result); + + let assemblyFormat = [{ + $ptr `[` $offset `]` (`mask` `=` $mask^)? (`default` `=` $other^)? + attr-dict `:` `(` type($ptr) `,` type($offset) `)` `->` type($result) + }]; +} + +def TTS_ScatterOp : TTS_Op<"scatter", [ + MemoryEffects<[MemWrite]>, + OptionalTypesMatchWith<"mask type matches offset type", "offset", "mask", + "triton::getI1SameShape($_self)"> +]> { + let summary = "optionally store data from in memory to fill a portion of the tensor"; + + let arguments = ( + ins + TT_Ptr:$ptr, + TT_IntLike:$offset, + TT_Type:$value, + Optional:$mask + ); + + let assemblyFormat = [{ + $value `into` $ptr `[` $offset `]` (`mask` `=` $mask^)? + attr-dict `:` type($value) `into` ` ` `(` type($ptr) `,` type($offset) `)` + }]; +} + def TTS_LoadOp : TTS_Op<"load", [ MemoryEffects<[MemRead]>, AttrSizedOperandSegments @@ -170,7 +213,7 @@ def TTS_LoadOp : TTS_Op<"load", [ } bool hasMask() { - return !getStaticMaskDims().empty(); + return !getMixedMaskDims().empty(); } }]; @@ -182,7 +225,7 @@ def TTS_LoadOp : TTS_Op<"load", [ def TTS_StoreOp : TTS_Op<"store", [ MemoryEffects<[MemWrite]> ]> { - let summary = "optionally load data from in memory to fill a portion of the tensor"; + let summary = "optionally store data from in memory to fill a portion of the tensor"; let arguments = (ins TT_PtrLike:$ptr, TT_Tensor:$value, @@ -201,7 +244,7 @@ def TTS_StoreOp : TTS_Op<"store", [ } bool hasMask() { - return !getStaticMaskDims().empty(); + return !getMixedMaskDims().empty(); } }]; diff --git a/lib/AnalysisStructured/PtrAnalysis.cpp b/lib/AnalysisStructured/PtrAnalysis.cpp index e3a2deea..f88770bc 100644 --- a/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/lib/AnalysisStructured/PtrAnalysis.cpp @@ -8,11 +8,8 @@ #include "triton-shared/AnalysisStructured/PtrAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" @@ -33,7 +30,6 @@ #include "llvm/Support/LogicalResult.h" #include #include -#include #include #include #include @@ -42,74 +38,6 @@ namespace mlir { -// Extract a scalar value from v. -// If v is a scalar, return that directly. Otherwise, parse through operations -// (currently only support splat, sitofp, and truncf) that produce it to -// extract the underlying scalar value. We then reconstruct the chain of -// operations that can produce this constant with the original type. If no -// scalar value can be extracted, a nullptr is returned. -static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { - SmallVector ops; - - auto reconstructScalarValue = [&](Value src) { - for (auto op = ops.rbegin(); op != ops.rend(); ++op) { - src = TypeSwitch(*op) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return builder.create(loc, resType, src); - }) - .Case([&](Operation *op) { - auto resType = op->getResults()[0].getType(); - if (auto shapedType = dyn_cast(resType)) { - resType = shapedType.getElementType(); - } - return builder.create(loc, resType, src); - }) - .Default([](Operation *op) { - llvm_unreachable("unsupported op in generating "); - return nullptr; - }); - } - return src; - }; - - while (true) { - if (!dyn_cast(operand.getType())) { - return reconstructScalarValue(operand); - } else if (auto op = operand.getDefiningOp()) { - if (auto attr = dyn_cast(op.getValue())) { - if (!attr.isSplat()) { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load " - "produced by unsupported instruction"; - return nullptr; - } - auto elemValue = attr.getSplatValue(); - auto constOp = arith::ConstantOp::materialize( - builder, elemValue, attr.getElementType(), op.getLoc()); - return reconstructScalarValue(constOp.getResult()); - } - } else if (auto op = operand.getDefiningOp()) { - operand = op.getSrc(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else if (auto op = operand.getDefiningOp()) { - ops.push_back(op.getOperation()); - operand = op.getIn(); - } else { - InFlightDiagnostic diag = emitError(loc) - << "other value used in masked load produced " - "by unsupported instruction"; - return nullptr; - } - } - return nullptr; -} - namespace tts { int32_t PtrState::getRank() const { @@ -1126,7 +1054,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) { if (other) { assert(mask && "other value used while no masks are specified"); - scalarOther = getScalarValue(other, loc, builder); + scalarOther = utils::getScalarValue(other, loc, builder); if (!scalarOther) { op->emitRemark("other value used in masked load produced by " "unsupported instruction"); diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 45da8aca..2e4d1f2f 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(TritonToLinalg) add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) +add_subdirectory(TritonToUnstructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) diff --git a/lib/Conversion/TritonToUnstructured/CMakeLists.txt b/lib/Conversion/TritonToUnstructured/CMakeLists.txt new file mode 100644 index 00000000..57c7b798 --- /dev/null +++ b/lib/Conversion/TritonToUnstructured/CMakeLists.txt @@ -0,0 +1,22 @@ +add_triton_library(TritonToUnstructured + TritonToUnstructuredPass.cpp + + DEPENDS + TritonStructuredTableGen + TritonToUnstructuredConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + MLIRReconcileUnrealizedCasts + TritonIR + TritonTransforms + TritonSharedAnalysisStructured + TritonStructuredIR +) diff --git a/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp new file mode 100644 index 00000000..1de7b1ab --- /dev/null +++ b/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp @@ -0,0 +1,544 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// +// +//////////////////////////////////////////////////////////////////////////////// +// Overview +//////////////////////////////////////////////////////////////////////////////// +// +// This pass attempts to lower all loads and stores of unstructured pointers to +// tts.gather or tts.scatter that take a single base, a tensor of offsets, an +// optional tensor of mask values, and a default value in case of load. +// +// In addition, all pointer-producing ops will be eliminated and replaced by +// offset-producing ops. tts.gather and tts.scatter will use the pointer +// directly from the kernel arguments as opposed to pointer produced by ops such +// as tt.addptr and tt.splat. +// +// Example: +// +// %12 = tts.gather %arg0[%10] : (, tensor<64xi64>) -> tensor<64xf32> +// tts.scatter %12 into %arg1[%arg3] : tensor<64xf32> into (, +// tensor<64xi32>) +// +// Current assumptions and limitations: +// - For simplicity, the pass assumes that gather / scatter operations load / +// store from / to a single base with a tensor of random offsets. As a +// result, the following triton program would not work: +// +// @triton.jit +// def gather_simple(in0, in1, out0): +// offs = tl.arange(0, 8) +// in0_ptrs = in0 + offs +// in1_ptrs = in1 + offs +// ptrs = tl.cat(in0_ptrs, in1_ptrs, can_reorder=True) +// c = tl.load(ptrs) +// out_offs = tl.arange(0, 16) +// tl.store(out0 + out_offs, c) +// +// In the above program, `ptrs` contains 2 bases: `in0` and `in1` after the +// `cat` operation. +// +//////////////////////////////////////////////////////////////////////////////// +// Future work +//////////////////////////////////////////////////////////////////////////////// +// +// Future work may include scaling the algorithm to support such cases -- one +// possible solution is to let tts.gather and tts.scatter take in an additional +// tensor of base pointers corresponding to the tensor of offsets. But because +// we do not want pointer-producing ops to be present after this pass, we can +// use a tensor of index where each element indicates the index of the pointer +// argument to be used. The drawback is a gather or scatter operation now needs +// one extract lookup to get the base which will affect performance. +// +//////////////////////////////////////////////////////////////////////////////// +// Algorithm +//////////////////////////////////////////////////////////////////////////////// +// +// Because the goal of triton-shared is to eventually lower all triton ops and +// types to mlir, we want to transform the IR such that the usages of triton +// pointers are as limited as possible. Doing so will help simplify conversion +// to mlir dialects in subsequent passes. In a familiar fashion to the +// triton-to-structured pass, we want triton pointers to only appear in +// tts.gather and tts.scatter only. +// +// With that goal in mind, we want to revisit the triton pointer type. +// +// Triton pointers are created and manipulated through a sequence of ops such as +// tt.addptr, tt.splat, or tt.broadcast. If a triton pointer is created +// through `tt.addptr %ptr %offset`, the new pointer will contain the same base +// pointer as the original pointer; its offset will also be accumulated. +// +// Triton pointers created through tt.splat and tt.broadcast retain their base +// pointers and offsets. Tensors of pointers, however, may have different bases +// when tl.cat is present. For simplicity, we assume tl.cat isn't present as +// mentioned in the overview section. +// +// Therefore, a single triton pointer (tt.ptr) has two pieces of info that is +// implicit: +// - a base pointer which comes from the kernel arguments +// - an offset which could be either a tensor of offset or a single integer +// offset +// +// Leveraging this insight, in order to limit the usages of triton pointer, we +// can explicitly compute and split the above two pieces of info. So chains of +// tt.addptr, tt.splat, and tt.broadcast which produce triton pointers can be +// transformed to sequences of offset (of integer type) manipulation ops and a +// base pointer which comes from the kernel arguments. With this approach, only +// tts.gather and tts.scatter need to be aware of the pointer type. +// +// In essence, this pass transforms all sequences of tt.addptr into sequences of +// offset accumulation ops which are then fed into a single op +// tts.gather or tts.scatter that takes: +// +// - a base pointer from the kernel arguments +// - a tensor of offsets (or single offset) that indicates the offsets from +// the base pointer +// +// All intermediate tt.addptr ops are converted to arith.addi ops that compute +// the offsets. Offsets start at 0 with the provided bit-width. All pointer +// shape manipulation ops such as tt.splat and tt.broadcast will instead operate +// on the offsets and will be converted to linalg in triton-arith-to-linalg. +// +// By default, the pass uses i32 for the initial offsets of all pointers +// (configurable via offset-bit-width=width). If any intermediate tt.addptr +// introduces a larger bitwidth offset, the offsets will be sign-extended to the +// larger bitwidth. +// +//////////////////////////////////////////////////////////////////////////////// +// Algorithm +//////////////////////////////////////////////////////////////////////////////// +// +// This pass uses a standard worklist-based algorithm to walk the use-def chains +// of all pointer arguments and create replacement ops that operate on offsets +// instead of tt.ptr types. +// +// In cases such as tt.addptr, tt.splat, and tt.broadcast, we create +// corresponding replacement ops which will then be used to map the results +// at the end of the algorithm. We do not want to modify these ops in-place +// because the use-def chains may be changed. In special cases like scf.for, we +// also set the type of the iter-arg and result directly which is usually frown +// upon (but justified). +// +// This approach is used in favor of the traditional ConversionPatternRewriter +// which converts all pointer type into an offset integer type because +// TypeConverter does not support dynamic type based on value. This limitation +// means we have to decide the same bitwidth for all tt.addptr sequences which +// is not ideal. +// +// For instance, assuming we have two sequences of tt.addptr: one operates on +// 32-bit offsets while the other operates on 64-bit offsets. If we set the +// default bitwidth to 64, the 32-bit sequence will require unncessary +// sign-extending when computing the offsets. Contrast this with the manual +// approach, we will only sign-extend where necessary. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/Passes.h" +#include "triton-shared/Analysis/OpFoldResultUtils.h" +#include "triton-shared/AnalysisStructured/PtrAnalysis.h" +#include "triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Pass/PassManager.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include +#include + +#define DEBUG_TYPE "triton-to-unstructured" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/TritonToUnstructured/Passes.h.inc" + +namespace { + +static bool isPtrTypeLike(Type t) { + if (auto tensorType = dyn_cast(t)) { + return isa(tensorType.getElementType()); + } + return isa(t); +} + +// Given a type, return the offset type corresponding to that type with the +// specified width. +// If the type is a tensor, return a tensor of offsets of the same shape. If the +// type is a pointer, return a single offset type. +static Type getPtrOffsetType(Type type, unsigned int bitWidth) { + if (auto tensorType = dyn_cast(type)) { + if (auto ptrType = + dyn_cast(tensorType.getElementType())) { + return RankedTensorType::get( + tensorType.getShape(), IntegerType::get(type.getContext(), bitWidth)); + } + } + + if (auto ptrType = dyn_cast(type)) { + return IntegerType::get(type.getContext(), bitWidth); + } + + llvm_unreachable("unexpected type"); + return nullptr; +} + +static unsigned int getBitWidth(Type type) { + if (auto tensorType = dyn_cast(type)) { + if (auto integerType = dyn_cast(tensorType.getElementType())) { + return integerType.getWidth(); + } + } else if (auto integerType = dyn_cast(type)) { + return integerType.getWidth(); + } + + llvm_unreachable("unexpected type"); + return 0; +} + +class TritonToUnstructuredPass + : public TritonToUnstructuredBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + struct PtrOffset { + // the source pointer which comes from the kernel argument + Value ptr; + // the pointer type that corresponds to this offset; used when + // creating tts.make_unstructured_tptr + Type ptrType; + // bitwidth that is used for this offset, used to track if sign-extension is + // necessary + unsigned int bitWidth; + // the offset value + Value offset; + }; + + LogicalResult processUnstructuredPtrs(unsigned int defaultBitWidth = 32) { + llvm::SmallDenseSet ptrArgs; + llvm::DenseMap offsetMap; + std::queue workList; + + getOperation().walk([&](FunctionOpInterface func) { + for (auto arg : func.getArguments()) { + if (!isPtrTypeLike(arg.getType())) { + continue; + } + + OpBuilder b(func->getRegion(0)); + Value zero = b.create( + arg.getLoc(), + b.getIntegerAttr(IntegerType::get(&getContext(), defaultBitWidth), + 0)); + + ptrArgs.insert(arg); + offsetMap.insert({arg, {arg, arg.getType(), defaultBitWidth, zero}}); + workList.push(arg); + } + }); + + llvm::SmallVector toDelete; + llvm::SmallVector ptrUsers; + + while (!workList.empty()) { + auto val = workList.front(); + workList.pop(); + + for (auto &use : val.getUses()) { + auto user = use.getOwner(); + + auto res = + llvm::TypeSwitch(user) + .Case([&](triton::AddPtrOp addptr) { + OpBuilder b{addptr}; + auto loc = addptr->getLoc(); + + auto offsetInfo = offsetMap.at(addptr.getPtr()); + + auto prevOff = offsetInfo.offset; + auto off = addptr.getOffset(); + + auto lhsWidth = offsetInfo.bitWidth; + auto rhsWidth = getBitWidth(off.getType()); + auto resWidth = std::max(lhsWidth, rhsWidth); + + if (lhsWidth < resWidth) { + prevOff = b.create( + loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), + prevOff); + } + + if (rhsWidth < resWidth) { + off = b.create( + loc, getPtrOffsetType(offsetInfo.ptrType, resWidth), + off); + } + + auto accumulatedOff = b.create( + loc, getPtrOffsetType(addptr.getType(), resWidth), + prevOff, off); + + PtrOffset newOffsetInfo{offsetInfo.ptr, addptr.getType(), + resWidth, accumulatedOff}; + + offsetMap.insert({addptr, newOffsetInfo}); + workList.push(addptr); + toDelete.push_back(addptr); + + return success(); + }) + .Case([&](Operation *op) { + auto res = op->getResult(0); + auto resType = res.getType(); + + if (!isPtrTypeLike(resType)) { + return success(); + } + + auto ptr = op->getOperand(0); + auto offsetInfo = offsetMap.at(ptr); + + OpBuilder b{op}; + auto clone = + b.create(op->getLoc(), op->getName().getIdentifier(), + ValueRange{offsetInfo.offset}, + TypeRange{getPtrOffsetType( + resType, offsetInfo.bitWidth)}); + + PtrOffset newOffsetInfo{offsetInfo.ptr, resType, + offsetInfo.bitWidth, + clone->getResult(0)}; + + offsetMap.insert({ + res, + newOffsetInfo, + }); + workList.push(res); + toDelete.push_back(op); + + return success(); + }) + .Case([&](Operation *op) { + // Special case: + // We do not want to create "unstructured tensor pointer" into + // tts.make_tptr if the base pointer is directly from the + // kernel arguments. + if (auto makeTensorPtr = dyn_cast(op)) { + if (ptrArgs.contains(makeTensorPtr.getBase())) { + return success(); + } + } + + ptrUsers.push_back(op); + return success(); + }) + .Case([&](scf::ForOp forOp) { + // Index of the init-arg corresponding to this use, note that + // we have to subtract by 3 from the operand number because + // scf.for ops always have 3 leading operands for start, end, + // and step. + auto argIndex = use.getOperandNumber() - 3; + auto init = forOp.getInitArgs()[argIndex]; + + auto offsetInfo = offsetMap.at(init); + auto offsetType = + getPtrOffsetType(offsetInfo.ptrType, offsetInfo.bitWidth); + + // We're setting both the types of the iter-arg and the + // corresponding result directly to the offset type. + // At this point, the IR is in an invalid state because the + // init-args still have tt.ptr. But at the end, we will + // replace all uses of the tt.ptr to offset values. + auto iterArg = forOp.getRegionIterArg(argIndex); + iterArg.setType(offsetType); + + auto res = forOp.getResult(argIndex); + res.setType(offsetType); + + // For other ops, we only need to push the result into the + // worklist. But for scf.for, the iter-arg corresponding to + // the init-arg is used in the op's body instead, we have to + // process uses of the iter-arg. + offsetMap.insert({ + iterArg, + offsetInfo, + }); + offsetMap.insert({ + res, + offsetInfo, + }); + workList.push(iterArg); + workList.push(res); + + return success(); + }) + .Case([](auto) { return success(); }) + .Case([](triton::CatOp op) { + op->emitError("Do not support gather / scatter with multiple " + "bases yet"); + return failure(); + }) + .Default([&](Operation *op) { + op->emitError("unexpected op in ptr sequence"); + return failure(); + }); + + if (failed(res)) { + return failure(); + } + } + } + + for (auto op : ptrUsers) { + OpBuilder b{op}; + auto loc = op->getLoc(); + auto res = + llvm::TypeSwitch(op) + .Case([&](triton::LoadOp load) { + auto offsetInfo = offsetMap.at(load.getPtr()); + + auto other = load.getOther(); + + if (other) { + other = tts::utils::getScalarValue(other, loc, b); + if (!other) { + load->emitError("cannot parse `other` value for load"); + return failure(); + } + } + + auto gather = b.create( + loc, load.getType(), offsetInfo.ptr, offsetInfo.offset, + load.getMask(), other); + + load->replaceAllUsesWith(gather->getResults()); + load->erase(); + return success(); + }) + .Case([&](triton::StoreOp store) { + auto offsetInfo = offsetMap.at(store.getPtr()); + auto scatter = b.create( + loc, offsetInfo.ptr, offsetInfo.offset, store.getValue(), + store.getMask()); + + store->erase(); + return success(); + }) + .Case([&](auto makeTensorPtr) { + // For block pointers, the base could come from a sequence of + // `tt.addptr`. Accumulate the target offset with the offset we + // have saved. + auto offsetInfo = offsetMap.at(makeTensorPtr.getBase()); + auto baseOffset = offsetInfo.offset; + + makeTensorPtr.getBaseMutable().set(offsetInfo.ptr); + + // Add the existing offset from the base to the offset + // operand in the ops. + auto &offsetOpnd = makeTensorPtr.getOffsetsMutable()[0]; + auto currOffset = offsetOpnd.get(); + + auto baseOffType = baseOffset.getType(); + auto currOffType = currOffset.getType(); + + if (baseOffType != currOffType) { + if (currOffType.isIndex()) { + baseOffset = b.create( + loc, b.getIndexType(), baseOffset); + } else if (currOffType.isInteger()) { + if (baseOffType.getIntOrFloatBitWidth() < + currOffType.getIntOrFloatBitWidth()) { + baseOffset = b.create(loc, currOffType, + baseOffset); + } else { + // MakeTensorPtrOp only takes i32 offsets, so we need + // to truncate if the offsets were already in i64 + makeTensorPtr.emitWarning( + "truncating offsets which may result in data loss"); + baseOffset = b.create(loc, currOffType, + baseOffset); + } + } + } + + auto accumulatedOffset = b.create( + loc, currOffset.getType(), baseOffset, currOffset); + + offsetOpnd.set(accumulatedOffset); + + return success(); + }) + + .Default([&](Operation *op) { + op->emitError("unexpected op in ptr sequence"); + return failure(); + }); + + if (failed(res)) { + return failure(); + } + } + + for (auto op : toDelete) { + auto ptrInfo = offsetMap.at(op->getResult(0)); + op->replaceAllUsesWith(ValueRange{ptrInfo.offset}); + op->erase(); + } + + return success(); + } + + void runOnOperation() override { + if (failed(processUnstructuredPtrs(offsetBitWidth))) { + signalPassFailure(); + return; + } + + PassManager pm(&getContext(), getOperation().getOperationName()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + if (failed(runPipeline(pm, getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +triton::createTritonToUnstructuredPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp index cf55d834..a91c45de 100644 --- a/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp +++ b/lib/Dialect/TritonStructured/IR/TritonStructuredOps.cpp @@ -1,17 +1,20 @@ -#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LogicalResult.h" -#include "triton/Dialect/Triton/IR/Types.h" + #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/LogicalResult.h" + #include #include #include @@ -25,6 +28,77 @@ using namespace mlir::tts; namespace mlir { namespace tts { +namespace utils { +// Extract a scalar value from v. +// If v is a scalar, return that directly. Otherwise, parse through operations +// (currently only support splat, sitofp, and truncf) that produce it to +// extract the underlying scalar value. We then reconstruct the chain of +// operations that can produce this constant with the original type. If no +// scalar value can be extracted, a nullptr is returned. +Value getScalarValue(Value operand, Location loc, OpBuilder &builder) { + SmallVector ops; + + auto reconstructScalarValue = [&](Value src) { + for (auto op = ops.rbegin(); op != ops.rend(); ++op) { + src = TypeSwitch(*op) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Case([&](Operation *op) { + auto resType = op->getResults()[0].getType(); + if (auto shapedType = dyn_cast(resType)) { + resType = shapedType.getElementType(); + } + return builder.create(loc, resType, src); + }) + .Default([](Operation *op) { + llvm_unreachable("unsupported op in generating "); + return nullptr; + }); + } + return src; + }; + + while (true) { + if (!dyn_cast(operand.getType())) { + return reconstructScalarValue(operand); + } else if (auto op = operand.getDefiningOp()) { + if (auto attr = dyn_cast(op.getValue())) { + if (!attr.isSplat()) { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load " + "produced by unsupported instruction"; + return nullptr; + } + auto elemValue = attr.getSplatValue(); + auto constOp = arith::ConstantOp::materialize( + builder, elemValue, attr.getElementType(), op.getLoc()); + return reconstructScalarValue(constOp.getResult()); + } + } else if (auto op = operand.getDefiningOp()) { + operand = op.getSrc(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else if (auto op = operand.getDefiningOp()) { + ops.push_back(op.getOperation()); + operand = op.getIn(); + } else { + InFlightDiagnostic diag = emitError(loc) + << "other value used in masked load produced " + "by unsupported instruction"; + return nullptr; + } + } + return nullptr; +} + +} // namespace utils + void MakeTensorPtrOp::build(OpBuilder &b, OperationState &state, Value base, ArrayRef sizes, ArrayRef strides, diff --git a/test/Conversion/TritonToUnstructured/gather_nested.mlir b/test/Conversion/TritonToUnstructured/gather_nested.mlir new file mode 100644 index 00000000..8f54ffdf --- /dev/null +++ b/test/Conversion/TritonToUnstructured/gather_nested.mlir @@ -0,0 +1,99 @@ +// RUN: triton-shared-opt --triton-to-unstructured --canonicalize %s | FileCheck %s + +module { + tt.func public @gather(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<4> : tensor<4xi32> + %cst_0 = arith.constant dense<64> : tensor<4xi32> + %cst_1 = arith.constant dense<3> : tensor<4xi32> + %c1_i32 = arith.constant 1 : i32 + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %4 = arith.divsi %arg3, %cst_1 : tensor<4xi32> + %5 = tt.splat %arg2 : i32 -> tensor<4xi32> + %6 = arith.addi %4, %5 : tensor<4xi32> + %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32> + %8 = tt.addptr %1, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + %9 = tt.load %8, %7 : tensor<4x!tt.ptr> + %10 = tt.addptr %2, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %9 : tensor<4x!tt.ptr> + %11 = arith.addi %6, %cst : tensor<4xi32> + %12 = arith.addi %arg4, %cst : tensor<4xi32> + %13 = arith.addi %arg2, %c1_i32 : i32 + %14:2 = scf.for %arg5 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg6 = %11, %arg7 = %12) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %23 = arith.addi %arg5, %c1_i32 : i32 + %24 = arith.muli %13, %23 : i32 + %25 = tt.splat %24 : i32 -> tensor<4xi32> + %26 = arith.divsi %arg6, %25 : tensor<4xi32> + %27 = arith.addi %26, %5 : tensor<4xi32> + %28 = arith.cmpi slt, %27, %cst_0 : tensor<4xi32> + %29 = tt.addptr %1, %27 : tensor<4x!tt.ptr>, tensor<4xi32> + %30 = tt.load %29, %28 : tensor<4x!tt.ptr> + %31 = tt.addptr %2, %27 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %31, %30 : tensor<4x!tt.ptr> + %32 = arith.addi %27, %cst : tensor<4xi32> + %33 = arith.addi %arg7, %cst : tensor<4xi32> + scf.yield %32, %33 : tensor<4xi32>, tensor<4xi32> + } + %15 = arith.divsi %14#0, %cst_1 : tensor<4xi32> + %16 = arith.addi %15, %5 : tensor<4xi32> + %17 = arith.cmpi slt, %16, %cst_0 : tensor<4xi32> + %18 = tt.addptr %1, %16 : tensor<4x!tt.ptr>, tensor<4xi32> + %19 = tt.load %18, %17 : tensor<4x!tt.ptr> + %20 = tt.addptr %2, %16 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %20, %19 : tensor<4x!tt.ptr> + %21 = arith.addi %16, %cst : tensor<4xi32> + %22 = arith.addi %14#1, %cst : tensor<4xi32> + scf.yield %21, %22 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: tt.func public @gather([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<4> : tensor<4xi32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<64> : tensor<4xi32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<3> : tensor<4xi32> +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg3_:%.+]] = [[VAR_0_]], [[VAR_arg4_:%.+]] = [[VAR_0_]]) -> (tensor<4xi32>, tensor<4xi32>) : i32 { +// CHECK-DAG: [[VAR_2_:%.+]] = arith.divsi [[VAR_arg3_]], [[VAR_cst_1_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_3_:%.+]] = tt.splat [[VAR_arg2_]] : i32 -> tensor<4xi32> +// CHECK: [[VAR_4_:%.+]] = arith.addi [[VAR_2_]], [[VAR_3_]] : tensor<4xi32> +// CHECK: [[VAR_5_:%.+]] = arith.cmpi slt, [[VAR_4_]], [[VAR_cst_0_]] : tensor<4xi32> +// CHECK: [[VAR_6_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_4_]]{{.}} mask = [[VAR_5_]] : (, tensor<4xi32>) -> tensor<4xf32> +// CHECK: tts.scatter [[VAR_6_]] into [[PARAM_1_]]{{.}}[[VAR_4_]]{{.}} : tensor<4xf32> into (, tensor<4xi32>) +// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_4_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg4_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_9_:%.+]] = arith.addi [[VAR_arg2_]], [[CST_1_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_10_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_7_]], [[VAR_arg7_:%.+]] = [[VAR_8_]]) -> (tensor<4xi32>, tensor<4xi32>) : i32 { +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_1_]] : i32 +// CHECK: [[VAR_18_:%.+]] = arith.muli [[VAR_9_]], [[VAR_17_]] : i32 +// CHECK: [[VAR_19_:%.+]] = tt.splat [[VAR_18_]] : i32 -> tensor<4xi32> +// CHECK: [[VAR_20_:%.+]] = arith.divsi [[VAR_arg6_]], [[VAR_19_]] : tensor<4xi32> +// CHECK: [[VAR_21_:%.+]] = arith.addi [[VAR_20_]], [[VAR_3_]] : tensor<4xi32> +// CHECK: [[VAR_22_:%.+]] = arith.cmpi slt, [[VAR_21_]], [[VAR_cst_0_]] : tensor<4xi32> +// CHECK: [[VAR_23_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_21_]]{{.}} mask = [[VAR_22_]] : (, tensor<4xi32>) -> tensor<4xf32> +// CHECK: tts.scatter [[VAR_23_]] into [[PARAM_1_]]{{.}}[[VAR_21_]]{{.}} : tensor<4xf32> into (, tensor<4xi32>) +// CHECK-DAG: [[VAR_24_:%.+]] = arith.addi [[VAR_21_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK: scf.yield [[VAR_24_]], [[VAR_25_]] : tensor<4xi32>, tensor<4xi32> +// CHECK: } +// CHECK: [[VAR_11_:%.+]] = arith.divsi [[VAR_10_]]#0, [[VAR_cst_1_]] : tensor<4xi32> +// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_11_]], [[VAR_3_]] : tensor<4xi32> +// CHECK: [[VAR_13_:%.+]] = arith.cmpi slt, [[VAR_12_]], [[VAR_cst_0_]] : tensor<4xi32> +// CHECK: [[VAR_14_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_12_]]{{.}} mask = [[VAR_13_]] : (, tensor<4xi32>) -> tensor<4xf32> +// CHECK: tts.scatter [[VAR_14_]] into [[PARAM_1_]]{{.}}[[VAR_12_]]{{.}} : tensor<4xf32> into (, tensor<4xi32>) +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_12_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_10_]]#1, [[VAR_cst_]] : tensor<4xi32> +// CHECK: scf.yield [[VAR_15_]], [[VAR_16_]] : tensor<4xi32>, tensor<4xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToUnstructured/gather_no_loop.mlir b/test/Conversion/TritonToUnstructured/gather_no_loop.mlir new file mode 100644 index 00000000..77140a93 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/gather_no_loop.mlir @@ -0,0 +1,25 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @gather_simple_no_loop(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<64xi32> + %cst_0 = arith.constant dense<10> : tensor<64xi32> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %1 = arith.divsi %0, %cst_0 : tensor<64xi32> + %2 = arith.addi %1, %cst : tensor<64xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %4 = tt.addptr %3, %2 : tensor<64x!tt.ptr>, tensor<64xi32> + %5 = tt.load %4 : tensor<64x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %7 = tt.addptr %6, %0 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %7, %5 : tensor<64x!tt.ptr> + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store + +// CHECK: [[tensor:%.+]] = tts.gather %arg0 +// CHECK: tts.scatter [[tensor]] into %arg1 diff --git a/test/Conversion/TritonToUnstructured/gather_reuse_loop_results.mlir b/test/Conversion/TritonToUnstructured/gather_reuse_loop_results.mlir new file mode 100644 index 00000000..20eca4db --- /dev/null +++ b/test/Conversion/TritonToUnstructured/gather_reuse_loop_results.mlir @@ -0,0 +1,88 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @nested_use_same_level_loop_result(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<2xi32> -> tensor<2x1xi32> + %2 = tt.splat %arg2 : i32 -> tensor<2x1xi32> + %3 = arith.muli %1, %2 : tensor<2x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1x2xi32> + %6 = arith.muli %4, %5 : tensor<1x2xi32> + %7 = tt.broadcast %3 : tensor<2x1xi32> -> tensor<2x2xi32> + %8 = tt.broadcast %6 : tensor<1x2xi32> -> tensor<2x2xi32> + %9 = arith.addi %7, %8 : tensor<2x2xi32> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<2x2x!tt.ptr> + %11 = tt.addptr %10, %9 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<2x1x!tt.ptr> + %13 = tt.addptr %12, %3 : tensor<2x1x!tt.ptr>, tensor<2x1xi32> + %14 = tt.broadcast %13 : tensor<2x1x!tt.ptr> -> tensor<2x2x!tt.ptr> + %15 = tt.addptr %14, %8 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %16 = arith.muli %arg3, %c2_i32 : i32 + %17 = tt.splat %16 : i32 -> tensor<2x2xi32> + %18 = arith.muli %arg3, %c2_i32 : i32 + %19 = tt.splat %18 : i32 -> tensor<2x2xi32> + %20 = arith.muli %arg3, %c2_i32 : i32 + %21 = tt.splat %20 : i32 -> tensor<2x2xi32> + %22:2 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %11, %arg6 = %15) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %23 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %arg5) -> (tensor<2x2x!tt.ptr>) : i32 { + %26 = tt.addptr %arg8, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %26 : tensor<2x2x!tt.ptr> + } + %24:2 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %23, %arg9 = %arg6) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %26 = tt.load %arg8 : tensor<2x2x!tt.ptr> + %27 = tt.addptr %arg8, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %28 = tt.load %27 : tensor<2x2x!tt.ptr> + tt.store %arg9, %26 : tensor<2x2x!tt.ptr> + %29 = tt.addptr %arg9, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %30 = tt.addptr %29, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %30, %28 : tensor<2x2x!tt.ptr> + %31 = tt.addptr %30, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %32 = tt.addptr %27, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %32, %31 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + %25 = tt.addptr %24#0, %21 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %25, %24#1 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + tt.return + } +} + +// CHECK: tt.func public @nested_use_same_level_loop_result([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32) attributes {noinline = false} { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<2xi32> -> tensor<2x1xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_2_]] : i32 -> tensor<2x1xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_1_]], [[VAR_2_]] : tensor<2x1xi32> +// CHECK-DAG: [[VAR_4_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.splat [[PARAM_3_]] : i32 -> tensor<1x2xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[VAR_4_]], [[VAR_5_]] : tensor<1x2xi32> +// CHECK-DAG: [[VAR_7_:%.+]] = tt.broadcast [[VAR_3_]] : tensor<2x1xi32> -> tensor<2x2xi32> +// CHECK: [[VAR_8_:%.+]] = tt.broadcast [[VAR_6_]] : tensor<1x2xi32> -> tensor<2x2xi32> +// CHECK-DAG: [[VAR_9_:%.+]] = arith.addi [[VAR_7_]], [[VAR_8_]] : tensor<2x2xi32> +// CHECK-DAG: [[VAR_10_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_]] : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_11_:%.+]] = tt.splat [[VAR_10_]] : i32 -> tensor<2x2xi32> +// CHECK-DAG: [[VAR_12_:%.+]] = scf.for [[VAR_arg4_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg5_:%.+]] = [[VAR_9_]]) -> (tensor<2x2xi32>) : i32 { +// CHECK-DAG: [[VAR_13_:%.+]] = scf.for [[VAR_arg6_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg7_:%.+]] = [[VAR_arg5_]]) -> (tensor<2x2xi32>) : i32 { +// CHECK-DAG: [[VAR_14_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_9_]]{{.}} : (, tensor<2x2xi32>) -> tensor<2x2xf32> +// CHECK: [[VAR_15_:%.+]] = arith.addi [[VAR_9_]], [[VAR_11_]] : tensor<2x2xi32> +// CHECK: [[VAR_16_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_15_]]{{.}} : (, tensor<2x2xi32>) -> tensor<2x2xf32> +// CHECK: tts.scatter [[VAR_14_]] into [[PARAM_1_]]{{.}}[[VAR_9_]]{{.}} : tensor<2x2xf32> into (, tensor<2x2xi32>) +// CHECK: [[VAR_17_:%.+]] = arith.addi [[VAR_15_]], [[VAR_11_]] : tensor<2x2xi32> +// CHECK: tts.scatter [[VAR_16_]] into [[PARAM_1_]]{{.}}[[VAR_17_]]{{.}} : tensor<2x2xf32> into (, tensor<2x2xi32>) +// CHECK: [[VAR_18_:%.+]] = arith.addi [[VAR_17_]], [[VAR_11_]] : tensor<2x2xi32> +// CHECK: scf.yield [[VAR_18_]] : tensor<2x2xi32> +// CHECK: } +// CHECK: scf.yield [[VAR_13_]] : tensor<2x2xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToUnstructured/gather_simple_loop.mlir b/test/Conversion/TritonToUnstructured/gather_simple_loop.mlir new file mode 100644 index 00000000..6fea8efe --- /dev/null +++ b/test/Conversion/TritonToUnstructured/gather_simple_loop.mlir @@ -0,0 +1,38 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @gather_simple(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<64> : tensor<64xi32> + %c64_i32 = arith.constant 64 : i32 + %c5_i32 = arith.constant 5 : i32 + %cst_0 = arith.constant dense<10> : tensor<64xi32> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<64xi32>, tensor<64xi32>) : i32 { + %4 = arith.divsi %arg3, %cst_0 : tensor<64xi32> + %5 = arith.addi %arg2, %c5_i32 : i32 + %6 = arith.remsi %5, %c64_i32 : i32 + %7 = tt.splat %6 : i32 -> tensor<64xi32> + %8 = arith.addi %4, %7 : tensor<64xi32> + %9 = tt.addptr %1, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + %10 = tt.load %9 : tensor<64x!tt.ptr> + %11 = tt.addptr %2, %arg4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %11, %10 : tensor<64x!tt.ptr> + %12 = arith.addi %8, %cst : tensor<64xi32> + %13 = arith.addi %arg4, %cst : tensor<64xi32> + scf.yield %12, %13 : tensor<64xi32>, tensor<64xi32> + } + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store + +// CHECK: [[tensor:%.+]] = tts.gather %arg0 +// CHECK: tts.scatter [[tensor]] into %arg1 diff --git a/test/Conversion/TritonToUnstructured/kernel-01-vector-add.mlir b/test/Conversion/TritonToUnstructured/kernel-01-vector-add.mlir new file mode 100644 index 00000000..1086921c --- /dev/null +++ b/test/Conversion/TritonToUnstructured/kernel-01-vector-add.mlir @@ -0,0 +1,34 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @add_kernel_01234(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> + %13 = arith.addf %9, %12 : tensor<1024xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr> + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store + +// CHECK-COUNT-2: tts.gather %arg{{[0-9]+}} +// CHECK-NOT: tts.gather %arg{{[0-9]+}} +// CHECK-COUNT-1: tts.scatter {{.+}} into %arg{{[0-9]+}} +// CHECK-NOT: tts.scatter {{.+}} into %arg{{[0-9]+}} diff --git a/test/Conversion/TritonToUnstructured/kernel-02-fused-softmax.mlir b/test/Conversion/TritonToUnstructured/kernel-02-fused-softmax.mlir new file mode 100644 index 00000000..32b25afd --- /dev/null +++ b/test/Conversion/TritonToUnstructured/kernel-02-fused-softmax.mlir @@ -0,0 +1,48 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @softmax_kernel_012345(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32) { + %cst = arith.constant 0xFF800000 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %4 = tt.splat %2 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.splat %arg4 : i32 -> tensor<128xi32> + %7 = arith.cmpi slt, %3, %6 : tensor<128xi32> + %8 = tt.splat %cst : f32 -> tensor<128xf32> + %9 = tt.load %5, %7, %8 : tensor<128x!tt.ptr> + %10 = "tt.reduce"(%9) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.cmpf ogt, %arg5, %arg6 : f32 + %22 = arith.select %21, %arg5, %arg6 : f32 + tt.reduce.return %22 : f32 + }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 + %11 = tt.splat %10 : f32 -> tensor<128xf32> + %12 = arith.subf %9, %11 : tensor<128xf32> + %13 = math.exp %12 : tensor<128xf32> + %14 = "tt.reduce"(%13) ({ + ^bb0(%arg5: f32, %arg6: f32): + %21 = arith.addf %arg5, %arg6 : f32 + tt.reduce.return %21 : f32 + }) {axis = 0 : i32} : (tensor<128xf32>) -> f32 + %15 = tt.splat %14 : f32 -> tensor<128xf32> + %16 = arith.divf %13, %15 : tensor<128xf32> + %17 = arith.muli %0, %arg3 : i32 + %18 = tt.addptr %arg0, %17 : !tt.ptr, i32 + %19 = tt.splat %18 : !tt.ptr -> tensor<128x!tt.ptr> + %20 = tt.addptr %19, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %20, %16, %7 : tensor<128x!tt.ptr> + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store + +// CHECK-COUNT-1: tts.gather %arg{{[0-9]+}} +// CHECK-NOT: tts.gather %arg{{[0-9]+}} +// CHECK-COUNT-1: tts.scatter {{.+}} into %arg{{[0-9]+}} +// CHECK-NOT: tts.scatter {{.+}} into %arg{{[0-9]+}} diff --git a/test/Conversion/TritonToUnstructured/kernel-03-matrix-multiplication.mlir b/test/Conversion/TritonToUnstructured/kernel-03-matrix-multiplication.mlir new file mode 100644 index 00000000..28cd7d44 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/kernel-03-matrix-multiplication.mlir @@ -0,0 +1,106 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @matmul_kernel_0123456789101112131415(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %c256_i32 = arith.constant 256 : i32 + %c128_i32 = arith.constant 128 : i32 + %c8_i32 = arith.constant 8 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %4, %c8_i32 : i32 + %8 = arith.divsi %0, %7 : i32 + %9 = arith.muli %8, %c8_i32 : i32 + %10 = arith.subi %2, %9 : i32 + %11 = arith.cmpi slt, %10, %c8_i32 : i32 + %12 = arith.select %11, %10, %c8_i32 : i32 + %13 = arith.remsi %0, %12 : i32 + %14 = arith.addi %9, %13 : i32 + %15 = arith.remsi %0, %7 : i32 + %16 = arith.divsi %15, %12 : i32 + %17 = arith.muli %14, %c128_i32 : i32 + %18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %19 = tt.splat %17 : i32 -> tensor<128xi32> + %20 = arith.addi %19, %18 : tensor<128xi32> + %21 = arith.muli %16, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %23 = tt.splat %21 : i32 -> tensor<256xi32> + %24 = arith.addi %23, %22 : tensor<256xi32> + %25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %26 = tt.expand_dims %20 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %27 = tt.splat %arg6 : i32 -> tensor<128x1xi32> + %28 = arith.muli %26, %27 : tensor<128x1xi32> + %29 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> + %30 = tt.splat %arg7 : i32 -> tensor<1x64xi32> + %31 = arith.muli %29, %30 : tensor<1x64xi32> + %32 = tt.broadcast %28 : tensor<128x1xi32> -> tensor<128x64xi32> + %33 = tt.broadcast %31 : tensor<1x64xi32> -> tensor<128x64xi32> + %34 = arith.addi %32, %33 : tensor<128x64xi32> + %35 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %37 = tt.expand_dims %25 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> + %38 = tt.splat %arg8 : i32 -> tensor<64x1xi32> + %39 = arith.muli %37, %38 : tensor<64x1xi32> + %40 = tt.expand_dims %24 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %41 = tt.splat %arg9 : i32 -> tensor<1x256xi32> + %42 = arith.muli %40, %41 : tensor<1x256xi32> + %43 = tt.broadcast %39 : tensor<64x1xi32> -> tensor<64x256xi32> + %44 = tt.broadcast %42 : tensor<1x256xi32> -> tensor<64x256xi32> + %45 = arith.addi %43, %44 : tensor<64x256xi32> + %46 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> + %47 = tt.addptr %46, %45 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + %48 = tt.splat %cst : f32 -> tensor<128x256xf32> + %49 = arith.muli %arg7, %c64_i32 : i32 + %50 = tt.splat %49 : i32 -> tensor<128x64xi32> + %51 = arith.muli %arg8, %c64_i32 : i32 + %52 = tt.splat %51 : i32 -> tensor<64x256xi32> + %53:3 = scf.for %arg12 = %c0_i32 to %6 step %c1_i32 iter_args(%arg13 = %48, %arg14 = %36, %arg15 = %47) -> (tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr>) : i32 { + %71 = tt.load %arg14 : tensor<128x64x!tt.ptr> + %72 = tt.load %arg15 : tensor<64x256x!tt.ptr> + %73 = tt.dot %71, %72, %48 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> + %74 = arith.addf %arg13, %73 : tensor<128x256xf32> + %75 = tt.addptr %arg14, %50 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> + %76 = tt.addptr %arg15, %52 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> + scf.yield %74, %75, %76 : tensor<128x256xf32>, tensor<128x64x!tt.ptr>, tensor<64x256x!tt.ptr> + } + %54 = arith.truncf %53#0 : tensor<128x256xf32> to tensor<128x256xbf16> + %55 = tt.splat %arg10 : i32 -> tensor<128x1xi32> + %56 = arith.muli %55, %26 : tensor<128x1xi32> + %57 = tt.splat %arg2 : !tt.ptr -> tensor<128x1x!tt.ptr> + %58 = tt.addptr %57, %56 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> + %59 = tt.splat %arg11 : i32 -> tensor<1x256xi32> + %60 = arith.muli %59, %40 : tensor<1x256xi32> + %61 = tt.broadcast %58 : tensor<128x1x!tt.ptr> -> tensor<128x256x!tt.ptr> + %62 = tt.broadcast %60 : tensor<1x256xi32> -> tensor<128x256xi32> + %63 = tt.addptr %61, %62 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> + %64 = tt.splat %arg3 : i32 -> tensor<128x1xi32> + %65 = arith.cmpi slt, %26, %64 : tensor<128x1xi32> + %66 = tt.splat %arg4 : i32 -> tensor<1x256xi32> + %67 = arith.cmpi slt, %40, %66 : tensor<1x256xi32> + %68 = tt.broadcast %65 : tensor<128x1xi1> -> tensor<128x256xi1> + %69 = tt.broadcast %67 : tensor<1x256xi1> -> tensor<128x256xi1> + %70 = arith.andi %68, %69 : tensor<128x256xi1> + tt.store %63, %54, %70 : tensor<128x256x!tt.ptr> + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store + +// CHECK-COUNT-2: tts.gather %arg{{[0-9]+}} +// CHECK-NOT: tts.gather %arg{{[0-9]+}} +// CHECK-COUNT-1: tts.scatter {{.+}} into %arg{{[0-9]+}} +// CHECK-NOT: tts.scatter {{.+}} into %arg{{[0-9]+}} diff --git a/test/Conversion/TritonToUnstructured/kernel-05-layer-norm-dwdb.mlir b/test/Conversion/TritonToUnstructured/kernel-05-layer-norm-dwdb.mlir new file mode 100644 index 00000000..42d09434 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/kernel-05-layer-norm-dwdb.mlir @@ -0,0 +1,70 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @_layer_norm_bwd_dwdb_0123456(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: i32, %arg5: i32) { + %c0_i32 = arith.constant 0 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst = arith.constant 0.000000e+00 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c256_i32 : i32 + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %3 = tt.splat %1 : i32 -> tensor<256xi32> + %4 = arith.addi %3, %2 : tensor<256xi32> + %5 = tt.splat %cst : f32 -> tensor<256x256xf32> + %6 = tt.splat %arg4 : i32 -> tensor<256x1xi32> + %7 = tt.expand_dims %4 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + %8 = tt.splat %arg5 : i32 -> tensor<1x256xi32> + %9 = arith.cmpi slt, %7, %8 : tensor<1x256xi32> + %10 = tt.broadcast %9 : tensor<1x256xi1> -> tensor<256x256xi1> + %11 = tt.splat %arg5 : i32 -> tensor<256x1xi32> + %12 = tt.broadcast %7 : tensor<1x256xi32> -> tensor<256x256xi32> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<256x256x!tt.ptr> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<256x256x!tt.ptr> + %15:2 = scf.for %arg6 = %c0_i32 to %arg4 step %c256_i32 iter_args(%arg7 = %5, %arg8 = %5) -> (tensor<256x256xf32>, tensor<256x256xf32>) : i32 { + %24 = tt.splat %arg6 : i32 -> tensor<256xi32> + %25 = arith.addi %24, %2 : tensor<256xi32> + %26 = tt.expand_dims %25 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> + %27 = arith.cmpi slt, %26, %6 : tensor<256x1xi32> + %28 = tt.broadcast %27 : tensor<256x1xi1> -> tensor<256x256xi1> + %29 = arith.andi %28, %10 : tensor<256x256xi1> + %30 = arith.muli %26, %11 : tensor<256x1xi32> + %31 = tt.broadcast %30 : tensor<256x1xi32> -> tensor<256x256xi32> + %32 = arith.addi %31, %12 : tensor<256x256xi32> + %33 = tt.addptr %13, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + %34 = tt.load %33, %29, %5 : tensor<256x256x!tt.ptr> + %35 = arith.addf %arg7, %34 : tensor<256x256xf32> + %36 = tt.addptr %14, %32 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> + %37 = tt.load %36, %29, %5 : tensor<256x256x!tt.ptr> + %38 = arith.addf %arg8, %37 : tensor<256x256xf32> + scf.yield %35, %38 : tensor<256x256xf32>, tensor<256x256xf32> + } + %16 = "tt.reduce"(%15#0) ({ + ^bb0(%arg6: f32, %arg7: f32): + %24 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %24 : f32 + }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> + %17 = "tt.reduce"(%15#1) ({ + ^bb0(%arg6: f32, %arg7: f32): + %24 = arith.addf %arg6, %arg7 : f32 + tt.reduce.return %24 : f32 + }) {axis = 0 : i32} : (tensor<256x256xf32>) -> tensor<256xf32> + %18 = tt.splat %arg5 : i32 -> tensor<256xi32> + %19 = arith.cmpi slt, %4, %18 : tensor<256xi32> + %20 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> + %21 = tt.addptr %20, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %21, %16, %19 : tensor<256x!tt.ptr> + %22 = tt.splat %arg3 : !tt.ptr -> tensor<256x!tt.ptr> + %23 = tt.addptr %22, %4 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %23, %17, %19 : tensor<256x!tt.ptr> + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store + +// CHECK-COUNT-2: tts.gather %arg{{[0-9]+}} +// CHECK-NOT: tts.gather %arg{{[0-9]+}} +// CHECK-COUNT-2: tts.scatter {{.+}} into %arg{{[0-9]+}} +// CHECK-NOT: tts.scatter {{.+}} into %arg{{[0-9]+}} diff --git a/test/Conversion/TritonToUnstructured/kernel-05-layer-norm-fwd.mlir b/test/Conversion/TritonToUnstructured/kernel-05-layer-norm-fwd.mlir new file mode 100644 index 00000000..3e1e192c --- /dev/null +++ b/test/Conversion/TritonToUnstructured/kernel-05-layer-norm-fwd.mlir @@ -0,0 +1,96 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @_layer_norm_fwd_fused_0123456789(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: !tt.ptr, %arg4: !tt.ptr, %arg5: !tt.ptr, %arg6: i32, %arg7: i32, %arg8: f32) { + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg6 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %3 = tt.addptr %arg0, %1 : !tt.ptr, i32 + %4 = tt.splat %cst_0 : f32 -> tensor<256xf32> + %5 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %6 = tt.splat %arg7 : i32 -> tensor<256xi32> + %7 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> + %8 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { + %32 = tt.splat %arg9 : i32 -> tensor<256xi32> + %33 = arith.addi %32, %5 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %6 : tensor<256xi32> + %35 = tt.addptr %7, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34, %4 : tensor<256x!tt.ptr> + %37 = arith.addf %arg10, %36 : tensor<256xf32> + scf.yield %37 : tensor<256xf32> + } + %9 = "tt.reduce"(%8) ({ + ^bb0(%arg9: f32, %arg10: f32): + %32 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %32 : f32 + }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 + %10 = arith.sitofp %arg7 : i32 to f32 + %11 = arith.divf %9, %10 : f32 + %12 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %13 = tt.splat %arg7 : i32 -> tensor<256xi32> + %14 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> + %15 = tt.splat %11 : f32 -> tensor<256xf32> + %16 = scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 iter_args(%arg10 = %4) -> (tensor<256xf32>) : i32 { + %32 = tt.splat %arg9 : i32 -> tensor<256xi32> + %33 = arith.addi %32, %12 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %13 : tensor<256xi32> + %35 = tt.addptr %14, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34, %4 : tensor<256x!tt.ptr> + %37 = arith.subf %36, %15 : tensor<256xf32> + %38 = arith.select %34, %37, %4 : tensor<256xi1>, tensor<256xf32> + %39 = arith.mulf %38, %38 : tensor<256xf32> + %40 = arith.addf %arg10, %39 : tensor<256xf32> + scf.yield %40 : tensor<256xf32> + } + %17 = "tt.reduce"(%16) ({ + ^bb0(%arg9: f32, %arg10: f32): + %32 = arith.addf %arg9, %arg10 : f32 + tt.reduce.return %32 : f32 + }) {axis = 0 : i32} : (tensor<256xf32>) -> f32 + %18 = arith.divf %17, %10 : f32 + %19 = arith.addf %18, %arg8 : f32 + %20 = math.sqrt %19 : f32 + %21 = arith.divf %cst, %20 : f32 + %22 = tt.addptr %arg4, %0 : !tt.ptr, i32 + tt.store %22, %11 : !tt.ptr + %23 = tt.addptr %arg5, %0 : !tt.ptr, i32 + tt.store %23, %21 : !tt.ptr + %24 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> + %25 = tt.splat %arg7 : i32 -> tensor<256xi32> + %26 = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr> + %27 = tt.splat %arg3 : !tt.ptr -> tensor<256x!tt.ptr> + %28 = tt.splat %3 : !tt.ptr -> tensor<256x!tt.ptr> + %29 = tt.splat %11 : f32 -> tensor<256xf32> + %30 = tt.splat %21 : f32 -> tensor<256xf32> + %31 = tt.splat %2 : !tt.ptr -> tensor<256x!tt.ptr> + scf.for %arg9 = %c0_i32 to %arg7 step %c256_i32 : i32 { + %32 = tt.splat %arg9 : i32 -> tensor<256xi32> + %33 = arith.addi %32, %24 : tensor<256xi32> + %34 = arith.cmpi slt, %33, %25 : tensor<256xi32> + %35 = tt.addptr %26, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %36 = tt.load %35, %34 : tensor<256x!tt.ptr> + %37 = tt.addptr %27, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %38 = tt.load %37, %34 : tensor<256x!tt.ptr> + %39 = tt.addptr %28, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + %40 = tt.load %39, %34, %4 : tensor<256x!tt.ptr> + %41 = arith.subf %40, %29 : tensor<256xf32> + %42 = arith.mulf %41, %30 : tensor<256xf32> + %43 = arith.mulf %42, %36 : tensor<256xf32> + %44 = arith.addf %43, %38 : tensor<256xf32> + %45 = tt.addptr %31, %33 : tensor<256x!tt.ptr>, tensor<256xi32> + tt.store %45, %44, %34 : tensor<256x!tt.ptr> + } + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store + +// CHECK-COUNT-5: tts.gather %arg{{[0-9]+}} +// CHECK-NOT: tts.gather %arg{{[0-9]+}} diff --git a/test/Conversion/TritonToUnstructured/make_tensor_ptr.mlir b/test/Conversion/TritonToUnstructured/make_tensor_ptr.mlir new file mode 100644 index 00000000..e7e6b794 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/make_tensor_ptr.mlir @@ -0,0 +1,38 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @add_ptr_into_make_block_ptr(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c32768_i64 = arith.constant 32768 : i64 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c512_i64 = arith.constant 512 : i64 + %0 = tt.get_program_id x : i32 + %1 = arith.extsi %0 : i32 to i64 + %2 = arith.muli %1, %c32768_i64 : i64 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i64 + %4 = tt.make_tensor_ptr %3, [%c512_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %5 = tt.addptr %arg1, %2 : !tt.ptr, i64 + %6 = tt.make_tensor_ptr %5, [%c512_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %7 = tt.load %4 : !tt.ptr> + tt.store %6, %7 : !tt.ptr> + tt.return + } +} + +// CHECK: tt.func public @add_ptr_into_make_block_ptr([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_32768_:%.+]] = arith.constant 32768 : i64 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64 +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i64 +// CHECK-DAG: [[CST_512_:%.+]] = arith.constant 512 : i64 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.extsi [[VAR_0_]] : i32 to i64 +// CHECK: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_32768_]] : i64 +// CHECK: [[VAR_3_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32 +// CHECK-DAG: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{.}}[[CST_512_]], [[CST_64_]]{{.}}, {{.}}[[CST_64_]], [[CST_1_]]{{.}}, {{.}}[[VAR_3_]], [[CST_0_]]{{.}} {order = array} : > +// CHECK-DAG: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{.}}[[CST_512_]], [[CST_64_]]{{.}}, {{.}}[[CST_64_]], [[CST_1_]]{{.}}, {{.}}[[VAR_3_]], [[CST_0_]]{{.}} {order = array} : > +// CHECK: [[LOAD_VAR_4_MEM_:%.+]] = tt.load [[VAR_4_]] : !tt.ptr> +// CHECK: tt.store [[VAR_5_]], [[LOAD_VAR_4_MEM_]] : !tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToUnstructured/make_tensor_ptr_structured.mlir b/test/Conversion/TritonToUnstructured/make_tensor_ptr_structured.mlir new file mode 100644 index 00000000..1742e4c6 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/make_tensor_ptr_structured.mlir @@ -0,0 +1,38 @@ +// RUN: triton-shared-opt --triton-to-structured --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @add_ptr_into_make_block_ptr(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c32768_i64 = arith.constant 32768 : i64 + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c64_i64 = arith.constant 64 : i64 + %c512_i64 = arith.constant 512 : i64 + %0 = tt.get_program_id x : i32 + %1 = arith.extsi %0 : i32 to i64 + %2 = arith.muli %1, %c32768_i64 : i64 + %3 = tt.addptr %arg0, %2 : !tt.ptr, i64 + %4 = tt.make_tensor_ptr %3, [%c512_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %5 = tt.addptr %arg1, %2 : !tt.ptr, i64 + %6 = tt.make_tensor_ptr %5, [%c512_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %7 = tt.load %4 : !tt.ptr> + tt.store %6, %7 : !tt.ptr> + tt.return + } +} + +// CHECK: tt.func public @add_ptr_into_make_block_ptr([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_512_:%.+]] = arith.constant 512 : index +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index +// CHECK-DAG: [[CST_32768_:%.+]] = arith.constant 32768 : i64 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.extsi [[VAR_0_]] : i32 to i64 +// CHECK: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_32768_]] : i64 +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i64 to index +// CHECK-DAG: [[VAR_4_:%.+]] = tts.make_tptr [[PARAM_0_]] to sizes: [512, 64], strides: {{.}}[[CST_64_]], [[CST_1_]]{{.}}, offsets: {{.}}[[VAR_3_]], [[CST_0_]]{{.}}, shape: {{.}}[[CST_512_]], [[CST_64_]]{{.}}, order: [1, 0] : to !tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = tts.make_tptr [[PARAM_1_]] to sizes: [512, 64], strides: {{.}}[[CST_64_]], [[CST_1_]]{{.}}, offsets: {{.}}[[VAR_3_]], [[CST_0_]]{{.}}, shape: {{.}}[[CST_512_]], [[CST_64_]]{{.}}, order: [1, 0] : to !tt.ptr> +// CHECK: [[VAR_6_:%.+]] = "tts.load"([[VAR_4_]]) <{operandSegmentSizes = array, static_mask_dims = array}> : (!tt.ptr>) -> tensor<512x64xbf16> +// CHECK: "tts.store"([[VAR_5_]], [[VAR_6_]]) <{static_mask_dims = array}> : (!tt.ptr>, tensor<512x64xbf16>) -> () +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToUnstructured/nested_ptr_in_iterargs.mlir b/test/Conversion/TritonToUnstructured/nested_ptr_in_iterargs.mlir new file mode 100644 index 00000000..721d6562 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/nested_ptr_in_iterargs.mlir @@ -0,0 +1,74 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @nested2_complex_body(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<3> : tensor<2x2xi32> + %cst_0 = arith.constant dense<1> : tensor<2x2xi32> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<2xi32> -> tensor<2x1xi32> + %2 = tt.splat %arg2 : i32 -> tensor<2x1xi32> + %3 = arith.muli %1, %2 : tensor<2x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1x2xi32> + %6 = arith.muli %4, %5 : tensor<1x2xi32> + %7 = tt.broadcast %3 : tensor<2x1xi32> -> tensor<2x2xi32> + %8 = tt.broadcast %6 : tensor<1x2xi32> -> tensor<2x2xi32> + %9 = arith.addi %7, %8 : tensor<2x2xi32> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<2x2x!tt.ptr> + %11 = tt.addptr %10, %9 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<2x1x!tt.ptr> + %13 = tt.addptr %12, %3 : tensor<2x1x!tt.ptr>, tensor<2x1xi32> + %14 = tt.broadcast %13 : tensor<2x1x!tt.ptr> -> tensor<2x2x!tt.ptr> + %15 = tt.addptr %14, %8 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %16 = arith.muli %arg2, %c2_i32 : i32 + %17 = tt.splat %16 : i32 -> tensor<2x2xi32> + %18:2 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %11, %arg6 = %15) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %19 = tt.addptr %arg5, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %20 = tt.addptr %arg6, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %21:2 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %19, %arg9 = %20) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %26 = tt.load %arg8 : tensor<2x2x!tt.ptr> + tt.store %arg9, %26 : tensor<2x2x!tt.ptr> + %27 = tt.addptr %arg8, %cst : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %28 = tt.addptr %arg9, %cst : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %27, %28 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + %22 = tt.addptr %arg5, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %23 = tt.addptr %22, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %24 = tt.addptr %arg6, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %25 = tt.addptr %24, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %23, %25 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + tt.return + } +} + +// CHECK: tt.func public @nested2_complex_body([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32) attributes {noinline = false} { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1> : tensor<2x2xi32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<2xi32> -> tensor<2x1xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_2_]] : i32 -> tensor<2x1xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_1_]], [[VAR_2_]] : tensor<2x1xi32> +// CHECK-DAG: [[VAR_4_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.splat [[PARAM_3_]] : i32 -> tensor<1x2xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[VAR_4_]], [[VAR_5_]] : tensor<1x2xi32> +// CHECK-DAG: [[VAR_7_:%.+]] = tt.broadcast [[VAR_3_]] : tensor<2x1xi32> -> tensor<2x2xi32> +// CHECK: [[VAR_8_:%.+]] = tt.broadcast [[VAR_6_]] : tensor<1x2xi32> -> tensor<2x2xi32> +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_7_]], [[VAR_8_]] : tensor<2x2xi32> +// CHECK: scf.for [[I_0_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] : i32 { +// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_9_]], [[VAR_cst_]] : tensor<2x2xi32> +// CHECK: scf.for [[I_1_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] : i32 { +// CHECK: [[VAR_11_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_10_]]{{.}} : (, tensor<2x2xi32>) -> tensor<2x2xf32> +// CHECK: tts.scatter [[VAR_11_]] into [[PARAM_1_]]{{.}}[[VAR_10_]]{{.}} : tensor<2x2xf32> into (, tensor<2x2xi32>) +// CHECK: } +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToUnstructured/ridiculously_nested_loops.mlir b/test/Conversion/TritonToUnstructured/ridiculously_nested_loops.mlir new file mode 100644 index 00000000..c8c36d3f --- /dev/null +++ b/test/Conversion/TritonToUnstructured/ridiculously_nested_loops.mlir @@ -0,0 +1,161 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @nested_who_knows_how_many_levels(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<2xi32> -> tensor<2x1xi32> + %2 = tt.splat %arg2 : i32 -> tensor<2x1xi32> + %3 = arith.muli %1, %2 : tensor<2x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1x2xi32> + %6 = arith.muli %4, %5 : tensor<1x2xi32> + %7 = tt.broadcast %3 : tensor<2x1xi32> -> tensor<2x2xi32> + %8 = tt.broadcast %6 : tensor<1x2xi32> -> tensor<2x2xi32> + %9 = arith.addi %7, %8 : tensor<2x2xi32> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<2x2x!tt.ptr> + %11 = tt.addptr %10, %9 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<2x1x!tt.ptr> + %13 = tt.addptr %12, %3 : tensor<2x1x!tt.ptr>, tensor<2x1xi32> + %14 = tt.broadcast %13 : tensor<2x1x!tt.ptr> -> tensor<2x2x!tt.ptr> + %15 = tt.addptr %14, %8 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %16 = arith.muli %arg3, %c2_i32 : i32 + %17 = tt.splat %16 : i32 -> tensor<2x2xi32> + %18 = arith.muli %arg3, %c2_i32 : i32 + %19 = tt.splat %18 : i32 -> tensor<2x2xi32> + %20 = arith.muli %arg3, %c2_i32 : i32 + %21 = tt.splat %20 : i32 -> tensor<2x2xi32> + %22 = arith.muli %arg3, %c2_i32 : i32 + %23 = tt.splat %22 : i32 -> tensor<2x2xi32> + %24:2 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %11, %arg6 = %15) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %25 = tt.load %arg5 : tensor<2x2x!tt.ptr> + %26:3 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %arg5, %arg9 = %arg6, %arg10 = %25) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>) : i32 { + %29 = tt.addptr %arg8, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %30 = tt.load %29 : tensor<2x2x!tt.ptr> + %31:4 = scf.for %arg11 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg12 = %29, %arg13 = %arg9, %arg14 = %arg10, %arg15 = %30) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>) : i32 { + %33 = tt.addptr %arg12, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %34 = tt.load %33 : tensor<2x2x!tt.ptr> + tt.store %arg13, %arg14 : tensor<2x2x!tt.ptr> + %35 = tt.addptr %arg13, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %35, %arg15 : tensor<2x2x!tt.ptr> + %36 = tt.addptr %35, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %36, %34 : tensor<2x2x!tt.ptr> + %37 = tt.addptr %36, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %38:5 = scf.for %arg16 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg17 = %arg14, %arg18 = %33, %arg19 = %arg15, %arg20 = %34, %arg21 = %37) -> (tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %40 = tt.load %arg18 : tensor<2x2x!tt.ptr> + %41:5 = scf.for %arg22 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg23 = %arg18, %arg24 = %arg19, %arg25 = %arg20, %arg26 = %arg21, %arg27 = %40) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>) : i32 { + %42 = tt.addptr %arg23, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %43 = tt.load %42 : tensor<2x2x!tt.ptr> + %44:5 = scf.for %arg28 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg29 = %42, %arg30 = %arg25, %arg31 = %arg26, %arg32 = %arg27, %arg33 = %43) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>) : i32 { + %45 = tt.addptr %arg29, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %46 = tt.load %45 : tensor<2x2x!tt.ptr> + tt.store %arg31, %arg32 : tensor<2x2x!tt.ptr> + %47 = tt.addptr %arg31, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %47, %arg33 : tensor<2x2x!tt.ptr> + %48 = tt.addptr %47, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %48, %46 : tensor<2x2x!tt.ptr> + %49 = tt.addptr %48, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %50:5 = scf.for %arg34 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg35 = %arg32, %arg36 = %45, %arg37 = %arg33, %arg38 = %46, %arg39 = %49) -> (tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %51 = tt.load %arg36 : tensor<2x2x!tt.ptr> + %52:5 = scf.for %arg40 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg41 = %arg36, %arg42 = %arg37, %arg43 = %arg38, %arg44 = %arg39, %arg45 = %51) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>) : i32 { + %53 = tt.addptr %arg41, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %54 = tt.load %53 : tensor<2x2x!tt.ptr> + %55:5 = scf.for %arg46 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg47 = %53, %arg48 = %arg43, %arg49 = %arg44, %arg50 = %arg45, %arg51 = %54) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>) : i32 { + %56 = tt.addptr %arg47, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %57 = tt.load %56 : tensor<2x2x!tt.ptr> + tt.store %arg49, %arg50 : tensor<2x2x!tt.ptr> + %58 = tt.addptr %arg49, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %58, %arg51 : tensor<2x2x!tt.ptr> + %59 = tt.addptr %58, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %59, %57 : tensor<2x2x!tt.ptr> + %60 = tt.addptr %59, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %61:5 = scf.for %arg52 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg53 = %arg50, %arg54 = %56, %arg55 = %arg51, %arg56 = %57, %arg57 = %60) -> (tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %62 = tt.load %arg54 : tensor<2x2x!tt.ptr> + %63:4 = scf.for %arg58 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg59 = %arg54, %arg60 = %arg55, %arg61 = %arg56, %arg62 = %arg57) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %64 = tt.addptr %arg59, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %65 = tt.load %64 : tensor<2x2x!tt.ptr> + %66:3 = scf.for %arg63 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg64 = %64, %arg65 = %arg61, %arg66 = %arg62) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %67 = tt.addptr %arg64, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %68 = tt.load %67 : tensor<2x2x!tt.ptr> + tt.store %arg66, %62 : tensor<2x2x!tt.ptr> + %69 = tt.addptr %arg66, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %69, %65 : tensor<2x2x!tt.ptr> + %70 = tt.addptr %69, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %70, %68 : tensor<2x2x!tt.ptr> + %71 = tt.addptr %70, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %67, %68, %71 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + scf.yield %66#0, %65, %66#1, %66#2 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + scf.yield %62, %63#0, %63#1, %63#2, %63#3 : tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + scf.yield %61#1, %61#3, %61#4, %61#0, %61#2 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32> + } + scf.yield %55#0, %55#4, %55#1, %55#2, %55#3 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32> + } + scf.yield %52#4, %52#0, %52#1, %52#2, %52#3 : tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + scf.yield %50#1, %50#3, %50#4, %50#0, %50#2 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32> + } + scf.yield %44#0, %44#4, %44#1, %44#2, %44#3 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32> + } + scf.yield %41#4, %41#0, %41#1, %41#2, %41#3 : tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + %39:5 = scf.for %arg16 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg17 = %38#0, %arg18 = %38#1, %arg19 = %38#2, %arg20 = %38#3, %arg21 = %38#4) -> (tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %40 = tt.load %arg18 : tensor<2x2x!tt.ptr> + %41:4 = scf.for %arg22 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg23 = %arg18, %arg24 = %arg19, %arg25 = %arg20, %arg26 = %arg21) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %42 = tt.addptr %arg23, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %43 = tt.load %42 : tensor<2x2x!tt.ptr> + %44:3 = scf.for %arg27 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg28 = %42, %arg29 = %arg25, %arg30 = %arg26) -> (tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr>) : i32 { + %45 = tt.addptr %arg28, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %46 = tt.load %45 : tensor<2x2x!tt.ptr> + tt.store %arg30, %40 : tensor<2x2x!tt.ptr> + %47 = tt.addptr %arg30, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %47, %43 : tensor<2x2x!tt.ptr> + %48 = tt.addptr %47, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %48, %46 : tensor<2x2x!tt.ptr> + %49 = tt.addptr %48, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %45, %46, %49 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + scf.yield %44#0, %43, %44#1, %44#2 : tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + scf.yield %40, %41#0, %41#1, %41#2, %41#3 : tensor<2x2xf32>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2x!tt.ptr> + } + scf.yield %39#1, %39#4, %39#0, %39#2 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>, tensor<2x2xf32>, tensor<2x2xf32> + } + %32 = tt.addptr %31#0, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %32, %31#1, %31#2 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>, tensor<2x2xf32> + } + %27:2 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %26#0, %arg9 = %26#1) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %29 = tt.load %arg8 : tensor<2x2x!tt.ptr> + %30:2 = scf.for %arg10 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %32 = tt.addptr %arg11, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %33 = tt.load %32 : tensor<2x2x!tt.ptr> + %34:2 = scf.for %arg13 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg14 = %32, %arg15 = %arg12) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %35 = tt.addptr %arg14, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %36 = tt.load %35 : tensor<2x2x!tt.ptr> + tt.store %arg15, %29 : tensor<2x2x!tt.ptr> + %37 = tt.addptr %arg15, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %37, %33 : tensor<2x2x!tt.ptr> + %38 = tt.addptr %37, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + tt.store %38, %36 : tensor<2x2x!tt.ptr> + %39 = tt.addptr %38, %19 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %35, %39 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + scf.yield %34#0, %34#1 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + %31 = tt.addptr %30#0, %21 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %31, %30#1 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + %28 = tt.addptr %27#0, %23 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %28, %27#1 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + tt.return + } +} + +// CHECK-NOT: tt.addptr +// CHECK-NOT: tt.load +// CHECK-NOT: tt.store diff --git a/test/Conversion/TritonToUnstructured/scf_for_ptr_in_iterargs.mlir b/test/Conversion/TritonToUnstructured/scf_for_ptr_in_iterargs.mlir new file mode 100644 index 00000000..721d6562 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/scf_for_ptr_in_iterargs.mlir @@ -0,0 +1,74 @@ +// RUN: triton-shared-opt --triton-to-unstructured %s | FileCheck %s + +module { + tt.func public @nested2_complex_body(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<3> : tensor<2x2xi32> + %cst_0 = arith.constant dense<1> : tensor<2x2xi32> + %c2_i32 = arith.constant 2 : i32 + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<2xi32> -> tensor<2x1xi32> + %2 = tt.splat %arg2 : i32 -> tensor<2x1xi32> + %3 = arith.muli %1, %2 : tensor<2x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1x2xi32> + %6 = arith.muli %4, %5 : tensor<1x2xi32> + %7 = tt.broadcast %3 : tensor<2x1xi32> -> tensor<2x2xi32> + %8 = tt.broadcast %6 : tensor<1x2xi32> -> tensor<2x2xi32> + %9 = arith.addi %7, %8 : tensor<2x2xi32> + %10 = tt.splat %arg0 : !tt.ptr -> tensor<2x2x!tt.ptr> + %11 = tt.addptr %10, %9 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<2x1x!tt.ptr> + %13 = tt.addptr %12, %3 : tensor<2x1x!tt.ptr>, tensor<2x1xi32> + %14 = tt.broadcast %13 : tensor<2x1x!tt.ptr> -> tensor<2x2x!tt.ptr> + %15 = tt.addptr %14, %8 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %16 = arith.muli %arg2, %c2_i32 : i32 + %17 = tt.splat %16 : i32 -> tensor<2x2xi32> + %18:2 = scf.for %arg4 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg5 = %11, %arg6 = %15) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %19 = tt.addptr %arg5, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %20 = tt.addptr %arg6, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %21:2 = scf.for %arg7 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg8 = %19, %arg9 = %20) -> (tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr>) : i32 { + %26 = tt.load %arg8 : tensor<2x2x!tt.ptr> + tt.store %arg9, %26 : tensor<2x2x!tt.ptr> + %27 = tt.addptr %arg8, %cst : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %28 = tt.addptr %arg9, %cst : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %27, %28 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + %22 = tt.addptr %arg5, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %23 = tt.addptr %22, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %24 = tt.addptr %arg6, %17 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + %25 = tt.addptr %24, %cst_0 : tensor<2x2x!tt.ptr>, tensor<2x2xi32> + scf.yield %23, %25 : tensor<2x2x!tt.ptr>, tensor<2x2x!tt.ptr> + } + tt.return + } +} + +// CHECK: tt.func public @nested2_complex_body([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32) attributes {noinline = false} { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1> : tensor<2x2xi32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<2xi32> -> tensor<2x1xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_2_]] : i32 -> tensor<2x1xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_1_]], [[VAR_2_]] : tensor<2x1xi32> +// CHECK-DAG: [[VAR_4_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 0 : i32} : tensor<2xi32> -> tensor<1x2xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.splat [[PARAM_3_]] : i32 -> tensor<1x2xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[VAR_4_]], [[VAR_5_]] : tensor<1x2xi32> +// CHECK-DAG: [[VAR_7_:%.+]] = tt.broadcast [[VAR_3_]] : tensor<2x1xi32> -> tensor<2x2xi32> +// CHECK: [[VAR_8_:%.+]] = tt.broadcast [[VAR_6_]] : tensor<1x2xi32> -> tensor<2x2xi32> +// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_7_]], [[VAR_8_]] : tensor<2x2xi32> +// CHECK: scf.for [[I_0_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] : i32 { +// CHECK: [[VAR_10_:%.+]] = arith.addi [[VAR_9_]], [[VAR_cst_]] : tensor<2x2xi32> +// CHECK: scf.for [[I_1_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] : i32 { +// CHECK: [[VAR_11_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_10_]]{{.}} : (, tensor<2x2xi32>) -> tensor<2x2xf32> +// CHECK: tts.scatter [[VAR_11_]] into [[PARAM_1_]]{{.}}[[VAR_10_]]{{.}} : tensor<2x2xf32> into (, tensor<2x2xi32>) +// CHECK: } +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/test/Conversion/TritonToUnstructured/test_offset_bitwidth.mlir b/test/Conversion/TritonToUnstructured/test_offset_bitwidth.mlir new file mode 100644 index 00000000..a7ea9015 --- /dev/null +++ b/test/Conversion/TritonToUnstructured/test_offset_bitwidth.mlir @@ -0,0 +1,64 @@ +// RUN: triton-shared-opt --triton-to-unstructured="offset-bit-width=64" %s | FileCheck %s + +module { + tt.func public @gather_simple(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<64> : tensor<64xi32> + %cst_0 = arith.constant dense<10> : tensor<64xi32> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %1 = tt.get_program_id x : i32 + %2 = arith.divsi %0, %cst_0 : tensor<64xi32> + %3 = arith.extsi %1 : i32 to i64 + %4 = arith.extsi %2 : tensor<64xi32> to tensor<64xi64> + %5 = tt.splat %3 : i64 -> tensor<64xi64> + %6 = arith.addi %4, %5 : tensor<64xi64> + %7 = tt.splat %1 : i32 -> tensor<64xi32> + %8 = arith.addi %2, %7 : tensor<64xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %10 = tt.addptr %9, %6 : tensor<64x!tt.ptr>, tensor<64xi64> + %11 = tt.addptr %10, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %13:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %11, %arg4 = %0) -> (tensor<64x!tt.ptr>, tensor<64xi32>) : i32 { + %14 = tt.load %arg3 : tensor<64x!tt.ptr> + %15 = tt.addptr %12, %arg4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %15, %14 : tensor<64x!tt.ptr> + %16 = tt.addptr %arg3, %7 : tensor<64x!tt.ptr>, tensor<64xi32> + %17 = arith.addi %arg4, %cst : tensor<64xi32> + scf.yield %16, %17 : tensor<64x!tt.ptr>, tensor<64xi32> + } + tt.return + } +} + +// CHECK: tt.func public @gather_simple([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<64> : tensor<64xi32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<10> : tensor<64xi32> +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tt.get_program_id x : i32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = arith.divsi [[VAR_0_]], [[VAR_cst_0_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_3_:%.+]] = arith.extsi [[VAR_1_]] : i32 to i64 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.extsi [[VAR_2_]] : tensor<64xi32> to tensor<64xi64> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.splat [[VAR_3_]] : i64 -> tensor<64xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_4_]], [[VAR_5_]] : tensor<64xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = tt.splat [[VAR_1_]] : i32 -> tensor<64xi32> +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_2_]], [[VAR_7_]] : tensor<64xi32> +// CHECK: [[VAR_9_:%.+]] = arith.extsi [[VAR_8_]] : tensor<64xi32> to tensor<64xi64> +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_6_]], [[VAR_9_]] : tensor<64xi64> +// CHECK-DAG: [[VAR_11_:%.+]] = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg3_:%.+]] = [[VAR_0_]]) -> (tensor<64xi32>) : i32 { +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_12_:%.+]] = tts.gather [[PARAM_0_]]{{.}}[[VAR_10_]]{{.}} : (, tensor<64xi64>) -> tensor<64xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.extsi [[VAR_arg3_]] : tensor<64xi32> to tensor<64xi64> +// CHECK: tts.scatter [[VAR_12_]] into [[PARAM_1_]]{{.}}[[VAR_13_]]{{.}} : tensor<64xf32> into (, tensor<64xi64>) +// CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_arg3_]], [[VAR_cst_]] : tensor<64xi32> +// CHECK: scf.yield [[VAR_14_]] : tensor<64xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } diff --git a/tools/RegisterTritonSharedDialects.h b/tools/RegisterTritonSharedDialects.h index 82ba4f39..1d58f6a9 100644 --- a/tools/RegisterTritonSharedDialects.h +++ b/tools/RegisterTritonSharedDialects.h @@ -19,6 +19,7 @@ #include "triton-shared/Conversion/TritonToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h" #include "triton-shared/Conversion/TritonToStructured/Passes.h" +#include "triton-shared/Conversion/TritonToUnstructured/Passes.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" @@ -45,6 +46,7 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonToLinalgPass(); mlir::triton::registerTritonToLinalgExperimentalPass(); mlir::triton::registerTritonToStructuredPass(); + mlir::triton::registerTritonToUnstructuredPasses(); mlir::triton::registerTritonArithToLinalgPasses(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerStructuredToMemrefPasses();