From 83340b269d8a851b3b99281af961539ee2a97794 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 7 Jan 2025 17:57:54 -0500 Subject: [PATCH 1/7] Add UnstructuredToMemref pass --- .../triton-shared/Conversion/CMakeLists.txt | 1 + .../UnstructuredToMemref/CMakeLists.txt | 9 + .../Conversion/UnstructuredToMemref/Passes.h | 22 + .../Conversion/UnstructuredToMemref/Passes.td | 18 + .../UnstructuredToMemref.h | 21 + lib/Conversion/CMakeLists.txt | 1 + .../UnstructuredToMemref/CMakeLists.txt | 27 + .../UnstructuredToMemrefPass.cpp | 463 ++++++++++++++++++ .../gather_mask_no_other.mlir | 78 +++ .../gather_mask_with_other.mlir | 79 +++ .../UnstructuredToMemref/gather_no_mask.mlir | 74 +++ .../gather_scatter_all_mask.mlir | 80 +++ tools/RegisterTritonSharedDialects.h | 2 + 13 files changed, 875 insertions(+) create mode 100644 include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt create mode 100644 include/triton-shared/Conversion/UnstructuredToMemref/Passes.h create mode 100644 include/triton-shared/Conversion/UnstructuredToMemref/Passes.td create mode 100644 include/triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h create mode 100644 lib/Conversion/UnstructuredToMemref/CMakeLists.txt create mode 100644 lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp create mode 100644 test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir create mode 100644 test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir create mode 100644 test/Conversion/UnstructuredToMemref/gather_no_mask.mlir create mode 100644 test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index 45da8aca..ed60f6ca 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) +add_subdirectory(UnstructuredToMemref) diff --git a/include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt b/include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt new file mode 100644 index 00000000..291347c4 --- /dev/null +++ b/include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt @@ -0,0 +1,9 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name UnstructuredToMemref) +add_public_tablegen_target(UnstructuredToMemrefConversionPassIncGen) diff --git a/include/triton-shared/Conversion/UnstructuredToMemref/Passes.h b/include/triton-shared/Conversion/UnstructuredToMemref/Passes.h new file mode 100644 index 00000000..f2d71174 --- /dev/null +++ b/include/triton-shared/Conversion/UnstructuredToMemref/Passes.h @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES_H +#define UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES_H + +#include "triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton-shared/Conversion/UnstructuredToMemref/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton-shared/Conversion/UnstructuredToMemref/Passes.td b/include/triton-shared/Conversion/UnstructuredToMemref/Passes.td new file mode 100644 index 00000000..a0bf316d --- /dev/null +++ b/include/triton-shared/Conversion/UnstructuredToMemref/Passes.td @@ -0,0 +1,18 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES +#define UNSTRUCTURED_TO_MEMREF_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def UnstructuredToMemref : Pass<"unstructured-to-memref", "mlir::ModuleOp"> { + let summary = "Convert unstructured triton ptr (gather / scatter) to memref"; + let constructor = "triton::createUnstructuredToMemrefPass()"; +} + +#endif diff --git a/include/triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h b/include/triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h new file mode 100644 index 00000000..ad0f5c46 --- /dev/null +++ b/include/triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_CONVERSION_UNSTRUCTUREDTOMEMREF_UNSTRUCTUREDTOMEMREF_H +#define TRITON_CONVERSION_UNSTRUCTUREDTOMEMREF_UNSTRUCTUREDTOMEMREF_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { + +std::unique_ptr> createUnstructuredToMemrefPass(); + +} // namespace triton +} // namespace mlir + +#endif // TRITON_CONVERSION_UNSTRUCTUREDTOMEMREF_UNSTRUCTUREDTOMEMREF_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 45da8aca..ed60f6ca 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) add_subdirectory(StructuredToMemref) +add_subdirectory(UnstructuredToMemref) diff --git a/lib/Conversion/UnstructuredToMemref/CMakeLists.txt b/lib/Conversion/UnstructuredToMemref/CMakeLists.txt new file mode 100644 index 00000000..4c08b2cd --- /dev/null +++ b/lib/Conversion/UnstructuredToMemref/CMakeLists.txt @@ -0,0 +1,27 @@ +#===------------------------------------------------------------------------===# +# +# Copyright (c) Triton Project Contributors. +# +#===------------------------------------------------------------------------===# + +add_triton_library(UnstructuredToMemref + UnstructuredToMemrefPass.cpp + + DEPENDS + UnstructuredToMemrefConversionPassIncGen + + LINK_LIBS PUBLIC + TritonTilingExtIR + MLIRArithDialect + MLIRDialectUtils + MLIRIR + MLIRMathDialect + MLIRPass + MLIRTensorDialect + MLIRTransforms + MLIRSupport + TritonAnalysis + TritonIR + TritonTransforms + TritonSharedAnalysis +) diff --git a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp new file mode 100644 index 00000000..4f03e6f0 --- /dev/null +++ b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp @@ -0,0 +1,463 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +//===----------------------------------------------------------------------===// + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include "triton-shared/Conversion/UnstructuredToMemref/UnstructuredToMemref.h" +#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" +#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include + +#define DEBUG_TYPE "unstructured-to-memref" + +using namespace mlir; +using namespace triton; + +#define GEN_PASS_CLASSES +#include "triton-shared/Conversion/UnstructuredToMemref/Passes.h.inc" + +namespace { + +class PtrToUnrankedMemrefConverter : public TypeConverter { +public: + PtrToUnrankedMemrefConverter() { + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrType) { + return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); + }); + addTargetMaterialization([&](OpBuilder &builder, + UnrankedMemRefType resultType, + ValueRange inputs, + Location loc) -> std::optional { + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + } +}; + +static MemRefType getMemrefTypeForScalarPtr(triton::PointerType ptrType, + MLIRContext *context) { + SmallVector strides{1}; + auto layout = StridedLayoutAttr::get(context, ShapedType::kDynamic, strides); + auto elemType = ptrType.getPointeeType(); + auto memrefType = MemRefType::get({1}, elemType, layout); + return memrefType; +} + +struct ScalarLoadConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + ScalarLoadConverter(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + ScalarLoadConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!loadOp.getType().isIntOrIndexOrFloat()) { + return failure(); + } + + auto loc = loadOp->getLoc(); + + auto makePtrOp = + loadOp.getPtr().getDefiningOp(); + + auto basePtr = adaptor.getPtr(); + auto offset = makePtrOp.getOffset(); + + Value loadIndex = rewriter.create( + loc, rewriter.getIndexType(), offset); + + auto memref = rewriter.create( + loc, + getMemrefTypeForScalarPtr( + cast(loadOp.getPtr().getType()), + rewriter.getContext()), + basePtr, getAsOpFoldResult(loadIndex) /*offset*/, + ArrayRef{rewriter.getIndexAttr(1)} /*sizes*/, + ArrayRef{rewriter.getIndexAttr(1)} /*strides*/); + + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + + auto scalarLoadOp = rewriter.create( + loc, memref, zeroMap, std::nullopt); + + rewriter.replaceOp(loadOp, scalarLoadOp.getResult()); + + return success(); + } +}; + +struct ScalarStoreConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + ScalarStoreConverter(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + ScalarStoreConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!storeOp.getValue().getType().isIntOrIndexOrFloat()) { + return failure(); + } + + auto loc = storeOp->getLoc(); + + auto makePtrOp = + storeOp.getPtr().getDefiningOp(); + + auto basePtr = adaptor.getPtr(); + auto offset = makePtrOp.getOffset(); + + Value storeIndex = rewriter.create( + loc, rewriter.getIndexType(), offset); + + auto memref = rewriter.create( + loc, + getMemrefTypeForScalarPtr( + cast(storeOp.getPtr().getType()), + rewriter.getContext()), + basePtr, getAsOpFoldResult(storeIndex) /*offset*/, + ArrayRef{rewriter.getIndexAttr(1)} /*sizes*/, + ArrayRef{rewriter.getIndexAttr(1)} /*strides*/); + + auto storeVal = storeOp.getValue(); + auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); + + rewriter.create(loc, storeVal, memref, zeroMap, + std::nullopt); + rewriter.eraseOp(storeOp); + + return success(); + } +}; + +// Lowering an unstructured load op (gather) into a linalg.generic op +struct LoadOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LoadOpConverter(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LoadOpConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = loadOp->getLoc(); + + auto makePtrOp = + loadOp.getPtr().getDefiningOp(); + + auto ptr = adaptor.getPtr(); + auto offsetTensor = makePtrOp.getOffset(); + auto offsetType = dyn_cast(offsetTensor.getType()); + + // This must be a scalar load, skip processing + if (!offsetType) { + return failure(); + } + + auto loadResultType = + dyn_cast(loadOp.getResult().getType()); + + // Treat the base pointer (memref) as 1D because the offsets are all + // relative to a single base pointer (already collapsed). + auto baseMemref = rewriter.create( + loc, + MemRefType::get({ShapedType::kDynamic}, + loadResultType.getElementType()), + ptr); + + auto baseTensor = + rewriter + .create( + loc, + RankedTensorType::get( + SmallVector(1, ShapedType::kDynamic), + loadResultType.getElementType()), + baseMemref, true /* restrict */, false /* writable */) + .getResult(); + + // The linalg.generic op should have the following inputs: + // - the offset tensor + // - an optional mask tensor if the load op contains mask + SmallVector inputs{offsetTensor}; + + if (loadOp.getMask()) { + inputs.push_back(loadOp.getMask()); + } + + auto emptyTensor = + rewriter + .create(loc, loadResultType.getShape(), + loadResultType.getElementType()) + .getResult(); + + // Affine maps for the inputs and output + // If no mask is used, 2 affine maps are generated; one for the input offset + // tensor, the other for the output tensor. + // If mask is used, the first 2 maps are for the offset and mask tensors + // while the last map is for the output tensor. + SmallVector affineMaps( + loadOp.getMask() ? 3 : 2, + rewriter.getMultiDimIdentityMap(loadResultType.getRank())); + + auto genericOp = rewriter.create( + loc, SmallVector({loadResultType}), inputs, + ValueRange{emptyTensor}, affineMaps, + SmallVector(loadResultType.getRank(), + utils::IteratorType::parallel), + [&](OpBuilder &b, Location loc, ValueRange args) { + auto getValueAtIndex = [baseTensor](Value indexValue, Location loc, + OpBuilder &b) -> Value { + Value index0 = + b.create(loc, b.getIndexType(), indexValue); + + return b.create(loc, baseTensor, + ValueRange{index0}); + }; + + if (!loadOp.getMask()) { + // If there is no mask, simply extract the current element from the + // base tensor and use it as the yield value. + auto loadValue = getValueAtIndex(args[0], loc, rewriter); + rewriter.create(loc, loadValue); + } else { + // If the mask value is truthy, the current element is loaded from + // the base tensor using its offset. Otherwise, if `other` is + // present, yield `other`. If `other` is not present, a default + // value of 0 is used. + auto mask = args[1]; + auto ifOp = rewriter.create( + loc, mask, + [&](OpBuilder &b, Location loc) { + // Truthy case, load from the index + auto loadValue = getValueAtIndex(args[0], loc, b); + b.create(loc, loadValue); + }, + [&](OpBuilder &b, Location loc) { + // Falsy case, yield `other` or 0 as the default value + if (loadOp.getOther()) { + auto definingOp = loadOp.getOther().getDefiningOp(); + if (auto constOp = + dyn_cast(definingOp)) { + if (auto attr = + dyn_cast(constOp.getValue())) { + assert(attr.isSplat()); + auto elemValue = attr.getSplatValue(); + auto otherValue = arith::ConstantOp::materialize( + b, elemValue, attr.getElementType(), loc); + b.create(loc, otherValue.getResult()); + } else { + llvm_unreachable("unexpected constant op"); + } + } else if (auto fillOp = + dyn_cast(definingOp)) { + b.create(loc, fillOp.value()); + } else { + definingOp->dump(); + llvm_unreachable("unexpected defining op"); + } + } else { + auto elemType = baseTensor.getType().getElementType(); + Value extract; + if (isa(elemType)) { + extract = rewriter.create( + loc, b.getIntegerAttr(elemType, 0)); + } else if (isa(elemType)) { + extract = rewriter.create( + loc, b.getFloatAttr(elemType, 0)); + } else { + elemType.dump(); + llvm_unreachable("unexpected type"); + } + b.create(loc, extract); + } + }); + + rewriter.create(loc, ifOp->getResult(0)); + } + }); + + rewriter.replaceOp(loadOp, genericOp); + + return success(); + } +}; + +// Lowering an unstructured store op (scatter) into an affine loop nest +struct StoreOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + StoreOpConverter(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + StoreOpConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp->getLoc(); + + auto makePtrOp = + storeOp.getPtr().getDefiningOp(); + + auto ptr = adaptor.getPtr(); + auto offsetTensor = makePtrOp.getOffset(); + auto offsetType = dyn_cast(offsetTensor.getType()); + + // This must be a scalar store, skip processing + if (!offsetType) { + return failure(); + } + + auto resultType = dyn_cast(storeOp.getValue().getType()); + + auto storeMemref = rewriter.create( + loc, + MemRefType::get({ShapedType::kDynamic}, resultType.getElementType()), + ptr); + + auto ip = rewriter.saveInsertionPoint(); + + SmallVector ivs; + for (auto dim : resultType.getShape()) { + auto ub = + rewriter.create(loc, rewriter.getIndexAttr(dim)); + auto forOp = rewriter.create(loc, 0, dim); + ivs.push_back(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + if (storeOp.getMask()) { + // Mask case, only store the value if the mask value at `ivs` is truthy + auto maskValue = + rewriter.create(loc, storeOp.getMask(), ivs); + + auto ifOp = rewriter.create(loc, maskValue, + false /* withElseRegion */); + + rewriter.setInsertionPointToStart( + &ifOp.getThenRegion().getBlocks().front()); + } + + // Generate ops to store the value at each index. Note that with masking, + // these ops are created in the `if` block generated above. + auto offsetValue = + rewriter.create(loc, offsetTensor, ivs); + auto storeValue = + rewriter.create(loc, storeOp.getValue(), ivs); + Value storeIndex = rewriter.create( + loc, rewriter.getIndexType(), offsetValue); + rewriter.create(loc, storeValue, storeMemref, storeIndex); + + // Finalize + rewriter.eraseOp(storeOp); + rewriter.restoreInsertionPoint(ip); + return success(); + } +}; + +struct MakePtrConverter + : public OpConversionPattern { + using OpConversionPattern< + tts::MakeUnstructuredTensorPtrOp>::OpConversionPattern; + + MakePtrConverter(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, + context) {} + + MakePtrConverter(MLIRContext *context) + : OpConversionPattern(context) {} + + LogicalResult + matchAndRewrite(tts::MakeUnstructuredTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // The base pointer that is used in load/store comes from + // tts.make_unstructured_tptr's input. Simply replace the op with the base + // pointer. + rewriter.replaceOp(op, adaptor.getInput()); + return success(); + } +}; + +class UnstructuredToMemrefPass + : public UnstructuredToMemrefBase { + +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + auto moduleOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + + target.addLegalDialect< + func::FuncDialect, arith::ArithDialect, math::MathDialect, + linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, + cf::ControlFlowDialect, tensor::TensorDialect, + bufferization::BufferizationDialect, memref::MemRefDialect, + ttx::TritonTilingExtDialect>(); + + target.addIllegalOp(); + + PtrToUnrankedMemrefConverter typeConverter; + patterns.add(typeConverter, patterns.getContext()); + + patterns.add(patterns.getContext()); + + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +triton::createUnstructuredToMemrefPass() { + return std::make_unique(); +} diff --git a/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir b/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir new file mode 100644 index 00000000..6a0ca23b --- /dev/null +++ b/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir @@ -0,0 +1,78 @@ +// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s + +module { + tt.func public @gather_simple_mask_no_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<64> : tensor<64xi32> + %c16_i32 = arith.constant 16 : i32 + %cst_0 = arith.constant dense<4> : tensor<64xi32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %3:3 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %c8_i32, %arg4 = %0, %arg5 = %0) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { + %4 = arith.divsi %arg4, %cst_0 : tensor<64xi32> + %5 = tt.splat %arg3 : i32 -> tensor<64xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<64xi32> + %7 = tt.addptr %1, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %8 = tt.load %7, %6 : tensor<64x!tt.ptr> + %9 = tt.addptr %2, %arg5 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %9, %8 : tensor<64x!tt.ptr> + %10 = arith.addi %arg3, %c16_i32 : i32 + %11 = arith.addi %arg4, %cst : tensor<64xi32> + %12 = arith.addi %arg5, %cst : tensor<64xi32> + scf.yield %10, %11, %12 : i32, tensor<64xi32>, tensor<64xi32> + } + tt.return + } +} + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: tt.func public @gather_simple_mask_no_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { +// CHECK: %cst = arith.constant 0.000000e+00 : f32 +// CHECK: %c8_i32 = arith.constant 8 : i32 +// CHECK: %cst_0 = arith.constant dense<4> : tensor<64xi32> +// CHECK: %c16_i32 = arith.constant 16 : i32 +// CHECK: %cst_1 = arith.constant dense<64> : tensor<64xi32> +// CHECK: %c2_i32 = arith.constant 2 : i32 +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> +// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> +// CHECK: %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> +// CHECK: %3:3 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %c8_i32, %arg4 = %2, %arg5 = %2) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { +// CHECK: %4 = arith.divsi %arg4, %cst_0 : tensor<64xi32> +// CHECK: %5 = tt.splat %arg3 : i32 -> tensor<64xi32> +// CHECK: %6 = arith.cmpi slt, %4, %5 : tensor<64xi32> +// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref +// CHECK: %7 = bufferization.to_tensor %cast restrict : memref +// CHECK: %8 = tensor.empty() : tensor<64xf32> +// CHECK: %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%4, %6 : tensor<64xi32>, tensor<64xi1>) outs(%8 : tensor<64xf32>) { +// CHECK: ^bb0(%in: i32, %in_3: i1, %out: f32): +// CHECK: %13 = scf.if %in_3 -> (f32) { +// CHECK: %14 = arith.index_cast %in : i32 to index +// CHECK: %extracted = tensor.extract %7[%14] : tensor +// CHECK: scf.yield %extracted : f32 +// CHECK: } else { +// CHECK: scf.yield %cst : f32 +// CHECK: } +// CHECK: linalg.yield %13 : f32 +// CHECK: } -> tensor<64xf32> +// CHECK: %cast_2 = memref.cast %0 : memref<*xf32> to memref +// CHECK: affine.for %arg6 = 0 to 64 { +// CHECK: %extracted = tensor.extract %arg5[%arg6] : tensor<64xi32> +// CHECK: %extracted_3 = tensor.extract %9[%arg6] : tensor<64xf32> +// CHECK: %13 = arith.index_cast %extracted : i32 to index +// CHECK: memref.store %extracted_3, %cast_2[%13] : memref +// CHECK: } +// CHECK: %10 = arith.addi %arg3, %c16_i32 : i32 +// CHECK: %11 = arith.addi %arg4, %cst_1 : tensor<64xi32> +// CHECK: %12 = arith.addi %arg5, %cst_1 : tensor<64xi32> +// CHECK: scf.yield %10, %11, %12 : i32, tensor<64xi32>, tensor<64xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } +// CHECK: } diff --git a/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir new file mode 100644 index 00000000..ffc1bb36 --- /dev/null +++ b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir @@ -0,0 +1,79 @@ +// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s + +module { + tt.func public @gather_simple_mask_with_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<-1.000000e+00> : tensor<64xf32> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<64> : tensor<64xi32> + %c16_i32 = arith.constant 16 : i32 + %cst_1 = arith.constant dense<4> : tensor<64xi32> + %c8_i32 = arith.constant 8 : i32 + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %3:3 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %c8_i32, %arg4 = %0, %arg5 = %0) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { + %4 = arith.divsi %arg4, %cst_1 : tensor<64xi32> + %5 = tt.splat %arg3 : i32 -> tensor<64xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<64xi32> + %7 = tt.addptr %1, %4 : tensor<64x!tt.ptr>, tensor<64xi32> + %8 = tt.load %7, %6, %cst : tensor<64x!tt.ptr> + %9 = tt.addptr %2, %arg5 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %9, %8 : tensor<64x!tt.ptr> + %10 = arith.addi %arg3, %c16_i32 : i32 + %11 = arith.addi %arg4, %cst_0 : tensor<64xi32> + %12 = arith.addi %arg5, %cst_0 : tensor<64xi32> + scf.yield %10, %11, %12 : i32, tensor<64xi32>, tensor<64xi32> + } + tt.return + } +} + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: tt.func public @gather_simple_mask_with_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { +// CHECK: %cst = arith.constant -1.000000e+00 : f32 +// CHECK: %c8_i32 = arith.constant 8 : i32 +// CHECK: %cst_0 = arith.constant dense<4> : tensor<64xi32> +// CHECK: %c16_i32 = arith.constant 16 : i32 +// CHECK: %cst_1 = arith.constant dense<64> : tensor<64xi32> +// CHECK: %c2_i32 = arith.constant 2 : i32 +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> +// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> +// CHECK: %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> +// CHECK: %3:3 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %c8_i32, %arg4 = %2, %arg5 = %2) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { +// CHECK: %4 = arith.divsi %arg4, %cst_0 : tensor<64xi32> +// CHECK: %5 = tt.splat %arg3 : i32 -> tensor<64xi32> +// CHECK: %6 = arith.cmpi slt, %4, %5 : tensor<64xi32> +// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref +// CHECK: %7 = bufferization.to_tensor %cast restrict : memref +// CHECK: %8 = tensor.empty() : tensor<64xf32> +// CHECK: %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%4, %6 : tensor<64xi32>, tensor<64xi1>) outs(%8 : tensor<64xf32>) { +// CHECK: ^bb0(%in: i32, %in_3: i1, %out: f32): +// CHECK: %13 = scf.if %in_3 -> (f32) { +// CHECK: %14 = arith.index_cast %in : i32 to index +// CHECK: %extracted = tensor.extract %7[%14] : tensor +// CHECK: scf.yield %extracted : f32 +// CHECK: } else { +// CHECK: scf.yield %cst : f32 +// CHECK: } +// CHECK: linalg.yield %13 : f32 +// CHECK: } -> tensor<64xf32> +// CHECK: %cast_2 = memref.cast %0 : memref<*xf32> to memref +// CHECK: affine.for %arg6 = 0 to 64 { +// CHECK: %extracted = tensor.extract %arg5[%arg6] : tensor<64xi32> +// CHECK: %extracted_3 = tensor.extract %9[%arg6] : tensor<64xf32> +// CHECK: %13 = arith.index_cast %extracted : i32 to index +// CHECK: memref.store %extracted_3, %cast_2[%13] : memref +// CHECK: } +// CHECK: %10 = arith.addi %arg3, %c16_i32 : i32 +// CHECK: %11 = arith.addi %arg4, %cst_1 : tensor<64xi32> +// CHECK: %12 = arith.addi %arg5, %cst_1 : tensor<64xi32> +// CHECK: scf.yield %10, %11, %12 : i32, tensor<64xi32>, tensor<64xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } +// CHECK: } diff --git a/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir b/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir new file mode 100644 index 00000000..5f535cc2 --- /dev/null +++ b/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir @@ -0,0 +1,74 @@ +// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s + +module { + tt.func public @gather_simple_no_mask(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<64> : tensor<64xi32> + %c64_i32 = arith.constant 64 : i32 + %c5_i32 = arith.constant 5 : i32 + %cst_0 = arith.constant dense<10> : tensor<64xi32> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<64xi32>, tensor<64xi32>) : i32 { + %4 = arith.divsi %arg3, %cst_0 : tensor<64xi32> + %5 = arith.addi %arg2, %c5_i32 : i32 + %6 = arith.remsi %5, %c64_i32 : i32 + %7 = tt.splat %6 : i32 -> tensor<64xi32> + %8 = arith.addi %4, %7 : tensor<64xi32> + %9 = tt.addptr %1, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + %10 = tt.load %9 : tensor<64x!tt.ptr> + %11 = tt.addptr %2, %arg4 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %11, %10 : tensor<64x!tt.ptr> + %12 = arith.addi %8, %cst : tensor<64xi32> + %13 = arith.addi %arg4, %cst : tensor<64xi32> + scf.yield %12, %13 : tensor<64xi32>, tensor<64xi32> + } + tt.return + } +} + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: tt.func public @gather_simple_no_mask(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { +// CHECK: %cst = arith.constant dense<10> : tensor<64xi32> +// CHECK: %c5_i32 = arith.constant 5 : i32 +// CHECK: %c64_i32 = arith.constant 64 : i32 +// CHECK: %cst_0 = arith.constant dense<64> : tensor<64xi32> +// CHECK: %c2_i32 = arith.constant 2 : i32 +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> +// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> +// CHECK: %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> +// CHECK: %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<64xi32>, tensor<64xi32>) : i32 { +// CHECK: %4 = arith.divsi %arg3, %cst : tensor<64xi32> +// CHECK: %5 = arith.addi %arg2, %c5_i32 : i32 +// CHECK: %6 = arith.remsi %5, %c64_i32 : i32 +// CHECK: %7 = tt.splat %6 : i32 -> tensor<64xi32> +// CHECK: %8 = arith.addi %4, %7 : tensor<64xi32> +// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref +// CHECK: %9 = bufferization.to_tensor %cast restrict : memref +// CHECK: %10 = tensor.empty() : tensor<64xf32> +// CHECK: %11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%8 : tensor<64xi32>) outs(%10 : tensor<64xf32>) { +// CHECK: ^bb0(%in: i32, %out: f32): +// CHECK: %14 = arith.index_cast %in : i32 to index +// CHECK: %extracted = tensor.extract %9[%14] : tensor +// CHECK: linalg.yield %extracted : f32 +// CHECK: } -> tensor<64xf32> +// CHECK: %cast_1 = memref.cast %0 : memref<*xf32> to memref +// CHECK: affine.for %arg5 = 0 to 64 { +// CHECK: %extracted = tensor.extract %arg4[%arg5] : tensor<64xi32> +// CHECK: %extracted_2 = tensor.extract %11[%arg5] : tensor<64xf32> +// CHECK: %14 = arith.index_cast %extracted : i32 to index +// CHECK: memref.store %extracted_2, %cast_1[%14] : memref +// CHECK: } +// CHECK: %12 = arith.addi %8, %cst_0 : tensor<64xi32> +// CHECK: %13 = arith.addi %arg4, %cst_0 : tensor<64xi32> +// CHECK: scf.yield %12, %13 : tensor<64xi32>, tensor<64xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } +// CHECK: } diff --git a/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir new file mode 100644 index 00000000..8e168fe7 --- /dev/null +++ b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir @@ -0,0 +1,80 @@ +// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s + +module { + tt.func public @masked_gather_scatter(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<9.900000e+01> : tensor<4xf32> + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<4> : tensor<4xi32> + %cst_1 = arith.constant dense<64> : tensor<4xi32> + %cst_2 = arith.constant dense<3> : tensor<4xi32> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %2 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %0, %arg4 = %0) -> (tensor<4xi32>, tensor<4xi32>) : i32 { + %4 = arith.divsi %arg3, %cst_2 : tensor<4xi32> + %5 = tt.splat %arg2 : i32 -> tensor<4xi32> + %6 = arith.addi %4, %5 : tensor<4xi32> + %7 = arith.cmpi slt, %6, %cst_1 : tensor<4xi32> + %8 = tt.addptr %1, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + %9 = tt.load %8, %7, %cst : tensor<4x!tt.ptr> + %10 = tt.addptr %2, %6 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %9, %7 : tensor<4x!tt.ptr> + %11 = arith.addi %6, %cst_0 : tensor<4xi32> + %12 = arith.addi %arg4, %cst_0 : tensor<4xi32> + scf.yield %11, %12 : tensor<4xi32>, tensor<4xi32> + } + tt.return + } +} + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: tt.func public @masked_gather_scatter(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { +// CHECK: %cst = arith.constant 9.900000e+01 : f32 +// CHECK: %cst_0 = arith.constant dense<3> : tensor<4xi32> +// CHECK: %cst_1 = arith.constant dense<64> : tensor<4xi32> +// CHECK: %cst_2 = arith.constant dense<4> : tensor<4xi32> +// CHECK: %c2_i32 = arith.constant 2 : i32 +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> +// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> +// CHECK: %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK: %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 { +// CHECK: %4 = arith.divsi %arg3, %cst_0 : tensor<4xi32> +// CHECK: %5 = tt.splat %arg2 : i32 -> tensor<4xi32> +// CHECK: %6 = arith.addi %4, %5 : tensor<4xi32> +// CHECK: %7 = arith.cmpi slt, %6, %cst_1 : tensor<4xi32> +// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref +// CHECK: %8 = bufferization.to_tensor %cast restrict : memref +// CHECK: %9 = tensor.empty() : tensor<4xf32> +// CHECK: %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) { +// CHECK: ^bb0(%in: i32, %in_4: i1, %out: f32): +// CHECK: %13 = scf.if %in_4 -> (f32) { +// CHECK: %14 = arith.index_cast %in : i32 to index +// CHECK: %extracted = tensor.extract %8[%14] : tensor +// CHECK: scf.yield %extracted : f32 +// CHECK: } else { +// CHECK: scf.yield %cst : f32 +// CHECK: } +// CHECK: linalg.yield %13 : f32 +// CHECK: } -> tensor<4xf32> +// CHECK: %cast_3 = memref.cast %0 : memref<*xf32> to memref +// CHECK: affine.for %arg5 = 0 to 4 { +// CHECK: %extracted = tensor.extract %7[%arg5] : tensor<4xi1> +// CHECK: scf.if %extracted { +// CHECK: %extracted_4 = tensor.extract %6[%arg5] : tensor<4xi32> +// CHECK: %extracted_5 = tensor.extract %10[%arg5] : tensor<4xf32> +// CHECK: %13 = arith.index_cast %extracted_4 : i32 to index +// CHECK: memref.store %extracted_5, %cast_3[%13] : memref +// CHECK: } +// CHECK: } +// CHECK: %11 = arith.addi %6, %cst_2 : tensor<4xi32> +// CHECK: %12 = arith.addi %arg4, %cst_2 : tensor<4xi32> +// CHECK: scf.yield %11, %12 : tensor<4xi32>, tensor<4xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } +// CHECK: } diff --git a/tools/RegisterTritonSharedDialects.h b/tools/RegisterTritonSharedDialects.h index 82ba4f39..af522184 100644 --- a/tools/RegisterTritonSharedDialects.h +++ b/tools/RegisterTritonSharedDialects.h @@ -19,6 +19,7 @@ #include "triton-shared/Conversion/TritonToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h" #include "triton-shared/Conversion/TritonToStructured/Passes.h" +#include "triton-shared/Conversion/UnstructuredToMemref/Passes.h" #include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h" #include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" @@ -45,6 +46,7 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonToLinalgPass(); mlir::triton::registerTritonToLinalgExperimentalPass(); mlir::triton::registerTritonToStructuredPass(); + mlir::triton::registerUnstructuredToMemref(); mlir::triton::registerTritonArithToLinalgPasses(); mlir::triton::registerConvertTritonToTritonGPUPass(); mlir::triton::registerStructuredToMemrefPasses(); From cbc902c85ff854dae996070dd9cd57048447884f Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 7 Jan 2025 17:59:19 -0500 Subject: [PATCH 2/7] Update copyright and add MIT license --- .../Conversion/UnstructuredToMemref/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt b/include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt index 291347c4..f988ac9f 100644 --- a/include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt +++ b/include/triton-shared/Conversion/UnstructuredToMemref/CMakeLists.txt @@ -1,6 +1,7 @@ #===------------------------------------------------------------------------===# # -# Copyright (c) Triton Project Contributors. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. # #===------------------------------------------------------------------------===# From 58d217c6ccab46835383688858f21375fa18443b Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 7 Jan 2025 17:59:37 -0500 Subject: [PATCH 3/7] Update copyright and add MIT license --- lib/Conversion/UnstructuredToMemref/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/UnstructuredToMemref/CMakeLists.txt b/lib/Conversion/UnstructuredToMemref/CMakeLists.txt index 4c08b2cd..89905dda 100644 --- a/lib/Conversion/UnstructuredToMemref/CMakeLists.txt +++ b/lib/Conversion/UnstructuredToMemref/CMakeLists.txt @@ -1,6 +1,7 @@ #===------------------------------------------------------------------------===# # -# Copyright (c) Triton Project Contributors. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. # #===------------------------------------------------------------------------===# From 088f30c1ed98d9b06c7b5fcb67aac1c73149e341 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 7 Jan 2025 18:53:52 -0500 Subject: [PATCH 4/7] Fix replacement in matchAndRewrite function --- .../UnstructuredToMemref/UnstructuredToMemrefPass.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp index 4f03e6f0..6405a906 100644 --- a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp +++ b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp @@ -410,9 +410,9 @@ struct MakePtrConverter matchAndRewrite(tts::MakeUnstructuredTensorPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // The base pointer that is used in load/store comes from - // tts.make_unstructured_tptr's input. Simply replace the op with the base + // tts.make_unstructured_tptr. Simply replace the op with the base // pointer. - rewriter.replaceOp(op, adaptor.getInput()); + rewriter.replaceOp(op, adaptor.getBase()); return success(); } }; From 2a348d0bbc2c1c3c05f5c8f25ed08724b701a31b Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Fri, 10 Jan 2025 18:14:56 -0500 Subject: [PATCH 5/7] Update --- .../UnstructuredToMemrefPass.cpp | 167 ++++++------------ 1 file changed, 56 insertions(+), 111 deletions(-) diff --git a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp index 6405a906..4a899cb5 100644 --- a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp +++ b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp @@ -69,29 +69,26 @@ static MemRefType getMemrefTypeForScalarPtr(triton::PointerType ptrType, return memrefType; } -struct ScalarLoadConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ScalarLoadConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; ScalarLoadConverter(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} + : OpConversionPattern(typeConverter, context) {} ScalarLoadConverter(MLIRContext *context) - : OpConversionPattern(context) {} + : OpConversionPattern(context) {} LogicalResult - matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + matchAndRewrite(tts::GatherOp gatherOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!loadOp.getType().isIntOrIndexOrFloat()) { + if (!gatherOp.getType().isIntOrIndexOrFloat()) { return failure(); } - auto loc = loadOp->getLoc(); - - auto makePtrOp = - loadOp.getPtr().getDefiningOp(); + auto loc = gatherOp->getLoc(); auto basePtr = adaptor.getPtr(); - auto offset = makePtrOp.getOffset(); + auto offset = adaptor.getOffset(); Value loadIndex = rewriter.create( loc, rewriter.getIndexType(), offset); @@ -99,7 +96,7 @@ struct ScalarLoadConverter : public OpConversionPattern { auto memref = rewriter.create( loc, getMemrefTypeForScalarPtr( - cast(loadOp.getPtr().getType()), + cast(gatherOp.getPtr().getType()), rewriter.getContext()), basePtr, getAsOpFoldResult(loadIndex) /*offset*/, ArrayRef{rewriter.getIndexAttr(1)} /*sizes*/, @@ -110,36 +107,33 @@ struct ScalarLoadConverter : public OpConversionPattern { auto scalarLoadOp = rewriter.create( loc, memref, zeroMap, std::nullopt); - rewriter.replaceOp(loadOp, scalarLoadOp.getResult()); + rewriter.replaceOp(gatherOp, scalarLoadOp.getResult()); return success(); } }; -struct ScalarStoreConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ScalarStoreConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; ScalarStoreConverter(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} + : OpConversionPattern(typeConverter, context) {} ScalarStoreConverter(MLIRContext *context) - : OpConversionPattern(context) {} + : OpConversionPattern(context) {} LogicalResult - matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + matchAndRewrite(tts::ScatterOp scatterOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!storeOp.getValue().getType().isIntOrIndexOrFloat()) { + if (!scatterOp.getValue().getType().isIntOrIndexOrFloat()) { return failure(); } - auto loc = storeOp->getLoc(); - - auto makePtrOp = - storeOp.getPtr().getDefiningOp(); + auto loc = scatterOp->getLoc(); auto basePtr = adaptor.getPtr(); - auto offset = makePtrOp.getOffset(); + auto offset = adaptor.getOffset(); Value storeIndex = rewriter.create( loc, rewriter.getIndexType(), offset); @@ -147,44 +141,41 @@ struct ScalarStoreConverter : public OpConversionPattern { auto memref = rewriter.create( loc, getMemrefTypeForScalarPtr( - cast(storeOp.getPtr().getType()), + cast(scatterOp.getPtr().getType()), rewriter.getContext()), basePtr, getAsOpFoldResult(storeIndex) /*offset*/, ArrayRef{rewriter.getIndexAttr(1)} /*sizes*/, ArrayRef{rewriter.getIndexAttr(1)} /*strides*/); - auto storeVal = storeOp.getValue(); + auto storeVal = scatterOp.getValue(); auto zeroMap = AffineMap::getConstantMap(0, rewriter.getContext()); rewriter.create(loc, storeVal, memref, zeroMap, std::nullopt); - rewriter.eraseOp(storeOp); + rewriter.eraseOp(scatterOp); return success(); } }; // Lowering an unstructured load op (gather) into a linalg.generic op -struct LoadOpConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct GatherConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LoadOpConverter(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} + GatherConverter(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} - LoadOpConverter(MLIRContext *context) - : OpConversionPattern(context) {} + GatherConverter(MLIRContext *context) + : OpConversionPattern(context) {} LogicalResult - matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + matchAndRewrite(tts::GatherOp gatherOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = loadOp->getLoc(); - - auto makePtrOp = - loadOp.getPtr().getDefiningOp(); + auto loc = gatherOp->getLoc(); auto ptr = adaptor.getPtr(); - auto offsetTensor = makePtrOp.getOffset(); + auto offsetTensor = adaptor.getOffset(); auto offsetType = dyn_cast(offsetTensor.getType()); // This must be a scalar load, skip processing @@ -193,7 +184,7 @@ struct LoadOpConverter : public OpConversionPattern { } auto loadResultType = - dyn_cast(loadOp.getResult().getType()); + dyn_cast(gatherOp.getResult().getType()); // Treat the base pointer (memref) as 1D because the offsets are all // relative to a single base pointer (already collapsed). @@ -218,8 +209,8 @@ struct LoadOpConverter : public OpConversionPattern { // - an optional mask tensor if the load op contains mask SmallVector inputs{offsetTensor}; - if (loadOp.getMask()) { - inputs.push_back(loadOp.getMask()); + if (gatherOp.getMask()) { + inputs.push_back(gatherOp.getMask()); } auto emptyTensor = @@ -234,7 +225,7 @@ struct LoadOpConverter : public OpConversionPattern { // If mask is used, the first 2 maps are for the offset and mask tensors // while the last map is for the output tensor. SmallVector affineMaps( - loadOp.getMask() ? 3 : 2, + gatherOp.getMask() ? 3 : 2, rewriter.getMultiDimIdentityMap(loadResultType.getRank())); auto genericOp = rewriter.create( @@ -252,7 +243,7 @@ struct LoadOpConverter : public OpConversionPattern { ValueRange{index0}); }; - if (!loadOp.getMask()) { + if (!gatherOp.getMask()) { // If there is no mask, simply extract the current element from the // base tensor and use it as the yield value. auto loadValue = getValueAtIndex(args[0], loc, rewriter); @@ -272,27 +263,8 @@ struct LoadOpConverter : public OpConversionPattern { }, [&](OpBuilder &b, Location loc) { // Falsy case, yield `other` or 0 as the default value - if (loadOp.getOther()) { - auto definingOp = loadOp.getOther().getDefiningOp(); - if (auto constOp = - dyn_cast(definingOp)) { - if (auto attr = - dyn_cast(constOp.getValue())) { - assert(attr.isSplat()); - auto elemValue = attr.getSplatValue(); - auto otherValue = arith::ConstantOp::materialize( - b, elemValue, attr.getElementType(), loc); - b.create(loc, otherValue.getResult()); - } else { - llvm_unreachable("unexpected constant op"); - } - } else if (auto fillOp = - dyn_cast(definingOp)) { - b.create(loc, fillOp.value()); - } else { - definingOp->dump(); - llvm_unreachable("unexpected defining op"); - } + if (gatherOp.getOther()) { + b.create(loc, gatherOp.getOther()); } else { auto elemType = baseTensor.getType().getElementType(); Value extract; @@ -314,32 +286,29 @@ struct LoadOpConverter : public OpConversionPattern { } }); - rewriter.replaceOp(loadOp, genericOp); + rewriter.replaceOp(gatherOp, genericOp); return success(); } }; // Lowering an unstructured store op (scatter) into an affine loop nest -struct StoreOpConverter : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ScatterConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - StoreOpConverter(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} + ScatterConverter(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} - StoreOpConverter(MLIRContext *context) - : OpConversionPattern(context) {} + ScatterConverter(MLIRContext *context) + : OpConversionPattern(context) {} LogicalResult - matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + matchAndRewrite(tts::ScatterOp scatterOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = storeOp->getLoc(); - - auto makePtrOp = - storeOp.getPtr().getDefiningOp(); + auto loc = scatterOp->getLoc(); auto ptr = adaptor.getPtr(); - auto offsetTensor = makePtrOp.getOffset(); + auto offsetTensor = adaptor.getOffset(); auto offsetType = dyn_cast(offsetTensor.getType()); // This must be a scalar store, skip processing @@ -347,7 +316,8 @@ struct StoreOpConverter : public OpConversionPattern { return failure(); } - auto resultType = dyn_cast(storeOp.getValue().getType()); + auto resultType = + dyn_cast(scatterOp.getValue().getType()); auto storeMemref = rewriter.create( loc, @@ -365,10 +335,10 @@ struct StoreOpConverter : public OpConversionPattern { rewriter.setInsertionPointToStart(forOp.getBody()); } - if (storeOp.getMask()) { + if (scatterOp.getMask()) { // Mask case, only store the value if the mask value at `ivs` is truthy auto maskValue = - rewriter.create(loc, storeOp.getMask(), ivs); + rewriter.create(loc, scatterOp.getMask(), ivs); auto ifOp = rewriter.create(loc, maskValue, false /* withElseRegion */); @@ -382,41 +352,18 @@ struct StoreOpConverter : public OpConversionPattern { auto offsetValue = rewriter.create(loc, offsetTensor, ivs); auto storeValue = - rewriter.create(loc, storeOp.getValue(), ivs); + rewriter.create(loc, scatterOp.getValue(), ivs); Value storeIndex = rewriter.create( loc, rewriter.getIndexType(), offsetValue); rewriter.create(loc, storeValue, storeMemref, storeIndex); // Finalize - rewriter.eraseOp(storeOp); + rewriter.eraseOp(scatterOp); rewriter.restoreInsertionPoint(ip); return success(); } }; -struct MakePtrConverter - : public OpConversionPattern { - using OpConversionPattern< - tts::MakeUnstructuredTensorPtrOp>::OpConversionPattern; - - MakePtrConverter(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, - context) {} - - MakePtrConverter(MLIRContext *context) - : OpConversionPattern(context) {} - - LogicalResult - matchAndRewrite(tts::MakeUnstructuredTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // The base pointer that is used in load/store comes from - // tts.make_unstructured_tptr. Simply replace the op with the base - // pointer. - rewriter.replaceOp(op, adaptor.getBase()); - return success(); - } -}; - class UnstructuredToMemrefPass : public UnstructuredToMemrefBase { @@ -442,14 +389,12 @@ class UnstructuredToMemrefPass bufferization::BufferizationDialect, memref::MemRefDialect, ttx::TritonTilingExtDialect>(); - target.addIllegalOp(); + target.addIllegalOp(); PtrToUnrankedMemrefConverter typeConverter; - patterns.add(typeConverter, patterns.getContext()); - patterns.add(patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) signalPassFailure(); From eef069e6e54c5241ecd22b3380d98e6a35342ef0 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 14 Jan 2025 14:37:57 -0500 Subject: [PATCH 6/7] Fix tests --- test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir | 2 +- .../Conversion/UnstructuredToMemref/gather_mask_with_other.mlir | 2 +- test/Conversion/UnstructuredToMemref/gather_no_mask.mlir | 2 +- .../UnstructuredToMemref/gather_scatter_all_mask.mlir | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir b/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir index 6a0ca23b..c8d5127b 100644 --- a/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir @@ -1,4 +1,4 @@ -// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s +// RUN: triton-shared-opt --triton-to-unstructured --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s module { tt.func public @gather_simple_mask_no_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { diff --git a/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir index ffc1bb36..80198ce6 100644 --- a/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir @@ -1,4 +1,4 @@ -// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s +// RUN: triton-shared-opt --triton-to-unstructured --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s module { tt.func public @gather_simple_mask_with_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { diff --git a/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir b/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir index 5f535cc2..32ba5ce3 100644 --- a/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir @@ -1,4 +1,4 @@ -// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s +// RUN: triton-shared-opt --triton-to-unstructured --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s module { tt.func public @gather_simple_no_mask(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { diff --git a/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir index 8e168fe7..3cfcfb9c 100644 --- a/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir @@ -1,4 +1,4 @@ -// RUN: triton-shared-opt --fold-unstructured-ptr --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s +// RUN: triton-shared-opt --triton-to-unstructured --canonicalize --unstructured-to-memref --canonicalize %s | FileCheck %s module { tt.func public @masked_gather_scatter(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { From 06c0857b11532c0f520cd50d076be3489efe3999 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Tue, 14 Jan 2025 15:02:13 -0500 Subject: [PATCH 7/7] Fix lit tests --- .../gather_mask_with_other.mlir | 93 +++++++++--------- .../gather_scatter_all_mask.mlir | 96 +++++++++---------- 2 files changed, 95 insertions(+), 94 deletions(-) diff --git a/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir index 80198ce6..a6eb5cdf 100644 --- a/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir @@ -30,50 +30,51 @@ module { } } -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: module { -// CHECK: tt.func public @gather_simple_mask_with_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { -// CHECK: %cst = arith.constant -1.000000e+00 : f32 -// CHECK: %c8_i32 = arith.constant 8 : i32 -// CHECK: %cst_0 = arith.constant dense<4> : tensor<64xi32> -// CHECK: %c16_i32 = arith.constant 16 : i32 -// CHECK: %cst_1 = arith.constant dense<64> : tensor<64xi32> -// CHECK: %c2_i32 = arith.constant 2 : i32 -// CHECK: %c1_i32 = arith.constant 1 : i32 -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> -// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> -// CHECK: %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> -// CHECK: %3:3 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %c8_i32, %arg4 = %2, %arg5 = %2) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { -// CHECK: %4 = arith.divsi %arg4, %cst_0 : tensor<64xi32> -// CHECK: %5 = tt.splat %arg3 : i32 -> tensor<64xi32> -// CHECK: %6 = arith.cmpi slt, %4, %5 : tensor<64xi32> -// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref -// CHECK: %7 = bufferization.to_tensor %cast restrict : memref -// CHECK: %8 = tensor.empty() : tensor<64xf32> -// CHECK: %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%4, %6 : tensor<64xi32>, tensor<64xi1>) outs(%8 : tensor<64xf32>) { -// CHECK: ^bb0(%in: i32, %in_3: i1, %out: f32): -// CHECK: %13 = scf.if %in_3 -> (f32) { -// CHECK: %14 = arith.index_cast %in : i32 to index -// CHECK: %extracted = tensor.extract %7[%14] : tensor -// CHECK: scf.yield %extracted : f32 -// CHECK: } else { -// CHECK: scf.yield %cst : f32 +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK: tt.func public @gather_simple_mask_with_other([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<4> : tensor<64xi32> +// CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : i32 +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<64> : tensor<64xi32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_minus_1_dot_000000_:%.+]] = arith.constant -1.000000e+00 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]]:3 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg3_:%.+]] = [[CST_8_]], [[VAR_arg4_:%.+]] = [[VAR_2_]], [[VAR_arg5_:%.+]] = [[VAR_2_]]) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { +// CHECK-DAG: [[VAR_4_:%.+]] = arith.divsi [[VAR_arg4_]], [[VAR_cst_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.splat [[VAR_arg3_]] : i32 -> tensor<64xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi slt, [[VAR_4_]], [[VAR_5_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[VAR_1_]] : memref<*xf32> to memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.to_tensor [[VAR_cast_]] restrict : memref +// CHECK-DAG: [[VAR_8_:%.+]] = tensor.empty() : tensor<64xf32> +// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_6_]] : tensor<64xi32>, tensor<64xi1>) outs([[VAR_8_]] : tensor<64xf32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: i1, [[IN_2_:%.+]]: f32): +// CHECK-DAG: [[VAR_13_:%.+]] = scf.if [[IN_1_]] -> (f32) { +// CHECK-DAG: [[VAR_14_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_7_]]{{.}}[[VAR_14_]]{{.}} : tensor +// CHECK: scf.yield [[VAR_extracted_]] : f32 +// CHECK: } else { +// CHECK: scf.yield [[CST_minus_1_dot_000000_]] : f32 +// CHECK: } +// CHECK: linalg.yield [[VAR_13_]] : f32 +// CHECK: } -> tensor<64xf32> +// CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[VAR_0_]] : memref<*xf32> to memref +// CHECK: affine.for [[I_0_:%.+]] = 0 to 64 { +// CHECK-DAG: [[VAR_extracted_1_:%.+]] = tensor.extract [[VAR_arg5_]]{{.}}[[I_0_]]{{.}} : tensor<64xi32> +// CHECK-DAG: [[VAR_extracted_3_:%.+]] = tensor.extract [[VAR_9_]]{{.}}[[I_0_]]{{.}} : tensor<64xf32> +// CHECK: [[VAR_13_1_:%.+]] = arith.index_cast [[VAR_extracted_1_]] : i32 to index +// CHECK: memref.store [[VAR_extracted_3_]], [[VAR_cast_2_]]{{.}}[[VAR_13_1_]]{{.}} : memref +// CHECK: } +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_16_]] : i32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg4_]], [[VAR_cst_0_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addi [[VAR_arg5_]], [[VAR_cst_0_]] : tensor<64xi32> +// CHECK: scf.yield [[VAR_10_]], [[VAR_11_]], [[VAR_12_]] : i32, tensor<64xi32>, tensor<64xi32> +// CHECK: } +// CHECK: tt.return // CHECK: } -// CHECK: linalg.yield %13 : f32 -// CHECK: } -> tensor<64xf32> -// CHECK: %cast_2 = memref.cast %0 : memref<*xf32> to memref -// CHECK: affine.for %arg6 = 0 to 64 { -// CHECK: %extracted = tensor.extract %arg5[%arg6] : tensor<64xi32> -// CHECK: %extracted_3 = tensor.extract %9[%arg6] : tensor<64xf32> -// CHECK: %13 = arith.index_cast %extracted : i32 to index -// CHECK: memref.store %extracted_3, %cast_2[%13] : memref -// CHECK: } -// CHECK: %10 = arith.addi %arg3, %c16_i32 : i32 -// CHECK: %11 = arith.addi %arg4, %cst_1 : tensor<64xi32> -// CHECK: %12 = arith.addi %arg5, %cst_1 : tensor<64xi32> -// CHECK: scf.yield %10, %11, %12 : i32, tensor<64xi32>, tensor<64xi32> -// CHECK: } -// CHECK: tt.return -// CHECK: } -// CHECK: } diff --git a/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir index 3cfcfb9c..864503f0 100644 --- a/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir @@ -29,52 +29,52 @@ module { } } -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: module { -// CHECK: tt.func public @masked_gather_scatter(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { -// CHECK: %cst = arith.constant 9.900000e+01 : f32 -// CHECK: %cst_0 = arith.constant dense<3> : tensor<4xi32> -// CHECK: %cst_1 = arith.constant dense<64> : tensor<4xi32> -// CHECK: %cst_2 = arith.constant dense<4> : tensor<4xi32> -// CHECK: %c2_i32 = arith.constant 2 : i32 -// CHECK: %c1_i32 = arith.constant 1 : i32 -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> -// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> -// CHECK: %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> -// CHECK: %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 { -// CHECK: %4 = arith.divsi %arg3, %cst_0 : tensor<4xi32> -// CHECK: %5 = tt.splat %arg2 : i32 -> tensor<4xi32> -// CHECK: %6 = arith.addi %4, %5 : tensor<4xi32> -// CHECK: %7 = arith.cmpi slt, %6, %cst_1 : tensor<4xi32> -// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref -// CHECK: %8 = bufferization.to_tensor %cast restrict : memref -// CHECK: %9 = tensor.empty() : tensor<4xf32> -// CHECK: %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) { -// CHECK: ^bb0(%in: i32, %in_4: i1, %out: f32): -// CHECK: %13 = scf.if %in_4 -> (f32) { -// CHECK: %14 = arith.index_cast %in : i32 to index -// CHECK: %extracted = tensor.extract %8[%14] : tensor -// CHECK: scf.yield %extracted : f32 -// CHECK: } else { -// CHECK: scf.yield %cst : f32 +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK: tt.func public @masked_gather_scatter([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<3> : tensor<4xi32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<64> : tensor<4xi32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<4> : tensor<4xi32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_9_dot_900000_:%.+]] = arith.constant 9.900000e+01 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg3_:%.+]] = [[VAR_2_]], [[VAR_arg4_:%.+]] = [[VAR_2_]]) -> (tensor<4xi32>, tensor<4xi32>) : i32 { +// CHECK-DAG: [[VAR_4_:%.+]] = arith.divsi [[VAR_arg3_]], [[VAR_cst_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.splat [[VAR_arg2_]] : i32 -> tensor<4xi32> +// CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_4_]], [[VAR_5_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpi slt, [[VAR_6_]], [[VAR_cst_0_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[VAR_1_]] : memref<*xf32> to memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_8_:%.+]] = bufferization.to_tensor [[VAR_cast_]] restrict : memref +// CHECK-DAG: [[VAR_9_:%.+]] = tensor.empty() : tensor<4xf32> +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_6_]], [[VAR_7_]] : tensor<4xi32>, tensor<4xi1>) outs([[VAR_9_]] : tensor<4xf32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: i1, [[IN_2_:%.+]]: f32): +// CHECK-DAG: [[VAR_13_:%.+]] = scf.if [[IN_1_]] -> (f32) { +// CHECK-DAG: [[VAR_14_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_8_]]{{.}}[[VAR_14_]]{{.}} : tensor +// CHECK: scf.yield [[VAR_extracted_]] : f32 +// CHECK: } else { +// CHECK: scf.yield [[CST_9_dot_900000_]] : f32 +// CHECK: } +// CHECK: linalg.yield [[VAR_13_]] : f32 +// CHECK: } -> tensor<4xf32> +// CHECK: [[VAR_cast_3_:%.+]] = memref.cast [[VAR_0_]] : memref<*xf32> to memref +// CHECK: affine.for [[I_0_:%.+]] = 0 to 4 { +// CHECK: [[VAR_extracted_1_:%.+]] = tensor.extract [[VAR_7_]]{{.}}[[I_0_]]{{.}} : tensor<4xi1> +// CHECK: scf.if [[VAR_extracted_1_]] { +// CHECK-DAG: [[VAR_extracted_4_:%.+]] = tensor.extract [[VAR_6_]]{{.}}[[I_0_]]{{.}} : tensor<4xi32> +// CHECK-DAG: [[VAR_extracted_5_:%.+]] = tensor.extract [[VAR_10_]]{{.}}[[I_0_]]{{.}} : tensor<4xf32> +// CHECK: [[VAR_13_1_:%.+]] = arith.index_cast [[VAR_extracted_4_]] : i32 to index +// CHECK: memref.store [[VAR_extracted_5_]], [[VAR_cast_3_]]{{.}}[[VAR_13_1_]]{{.}} : memref +// CHECK: } +// CHECK: } +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_6_]], [[VAR_cst_1_]] : tensor<4xi32> +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addi [[VAR_arg4_]], [[VAR_cst_1_]] : tensor<4xi32> +// CHECK: scf.yield [[VAR_11_]], [[VAR_12_]] : tensor<4xi32>, tensor<4xi32> +// CHECK: } +// CHECK: tt.return // CHECK: } -// CHECK: linalg.yield %13 : f32 -// CHECK: } -> tensor<4xf32> -// CHECK: %cast_3 = memref.cast %0 : memref<*xf32> to memref -// CHECK: affine.for %arg5 = 0 to 4 { -// CHECK: %extracted = tensor.extract %7[%arg5] : tensor<4xi1> -// CHECK: scf.if %extracted { -// CHECK: %extracted_4 = tensor.extract %6[%arg5] : tensor<4xi32> -// CHECK: %extracted_5 = tensor.extract %10[%arg5] : tensor<4xf32> -// CHECK: %13 = arith.index_cast %extracted_4 : i32 to index -// CHECK: memref.store %extracted_5, %cast_3[%13] : memref -// CHECK: } -// CHECK: } -// CHECK: %11 = arith.addi %6, %cst_2 : tensor<4xi32> -// CHECK: %12 = arith.addi %arg4, %cst_2 : tensor<4xi32> -// CHECK: scf.yield %11, %12 : tensor<4xi32>, tensor<4xi32> -// CHECK: } -// CHECK: tt.return -// CHECK: } -// CHECK: }