Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton-shared/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(TritonToUnstructured)
add_subdirectory(StructuredToMemref)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToUnstructured)
add_public_tablegen_target(TritonToUnstructuredConversionPassIncGen)
15 changes: 15 additions & 0 deletions include/triton-shared/Conversion/TritonToUnstructured/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H
#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES_H

#include "triton-shared/Conversion/TritonToUnstructured/TritonToUnstructured.h"

namespace mlir {
namespace triton {

#define GEN_PASS_REGISTRATION
#include "triton-shared/Conversion/TritonToUnstructured/Passes.h.inc"

} // namespace triton
} // namespace mlir

#endif
15 changes: 15 additions & 0 deletions include/triton-shared/Conversion/TritonToUnstructured/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES
#define TRITON_TO_UNSTRUCTURED_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def TritonToUnstructured : Pass<"triton-to-unstructured", "mlir::ModuleOp"> {
let summary = "Transforms tt.addptr ops into offset accumulation ops";
let constructor = "triton::createTritonToUnstructuredPass()";
let options = [
Option<"offsetBitWidth", "offset-bit-width", "size_t", /*default*/"32",
"Bitwidth used for the starting offset of each pointer">
];
}

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H
#define TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

namespace mlir {
namespace triton {

std::unique_ptr<OperationPass<ModuleOp>> createTritonToUnstructuredPass();

} // namespace triton
} // namespace mlir

#endif // TRITON_CONVERSION_TRITON_TO_UNSTRUCTURED_TRITON_TO_UNSTRUCTURED_H
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#ifndef MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_
#define MLIR_DIALECT_TRITON_STRUCTURED_IR_TRITON_STRUCTURED_DIALECT_H_

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "triton/Dialect/Triton/IR/Dialect.h"

#include "mlir/IR/Dialect.h"
namespace mlir {
namespace tts {
namespace utils {
mlir::Value getScalarValue(mlir::Value operand, mlir::Location loc,
mlir::OpBuilder &builder);
}
} // namespace tts
} // namespace mlir

//===----------------------------------------------------------------------===//
// TritonStructured Operations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ def TTS_MakeTensorPtrOp
//let hasCanonicalizer = 1;
}

// SameVariadicResultSize
// AttrSizedResultSegments
def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSegments, Pure]> {
let summary = "Placeholder for the structured pointer states computed during PtrAnalysis.";
let description = "Used to pass the offsets and strides to scf.for op to simplify IR rewrites.";
Expand All @@ -145,6 +143,51 @@ def TTS_GetStructuredStateOp : TTS_Op<"get_structured_state", [AttrSizedResultSe
let hasVerifier = 1;
}

def TTS_GatherOp : TTS_Op<"gather", [
MemoryEffects<[MemRead]>,
AttrSizedOperandSegments,
OptionalTypesMatchWith<"mask type matches ptr type", "offset", "mask", "triton::getI1SameShape($_self)">,
OptionalTypesMatchWith<"other matches ptr type", "ptr", "other", "triton::getPointeeType($_self)">
]> {
let summary = "optionally load data from in memory to fill a portion of the tensor";

let arguments = (
ins
TT_Ptr:$ptr,
TT_IntLike:$offset,
Optional<TT_BoolLike>:$mask,
Optional<TT_Type>:$other
);

let results = (outs TT_Type:$result);

let assemblyFormat = [{
$ptr `[` $offset `]` (`mask` `=` $mask^)? (`default` `=` $other^)?
attr-dict `:` `(` type($ptr) `,` type($offset) `)` `->` type($result)
}];
}

def TTS_ScatterOp : TTS_Op<"scatter", [
MemoryEffects<[MemWrite]>,
OptionalTypesMatchWith<"mask type matches offset type", "offset", "mask",
"triton::getI1SameShape($_self)">
]> {
let summary = "optionally store data from in memory to fill a portion of the tensor";

let arguments = (
ins
TT_Ptr:$ptr,
TT_IntLike:$offset,
TT_Type:$value,
Optional<TT_BoolLike>:$mask
);

let assemblyFormat = [{
$value `into` $ptr `[` $offset `]` (`mask` `=` $mask^)?
attr-dict `:` type($value) `into` ` ` `(` type($ptr) `,` type($offset) `)`
}];
}

def TTS_LoadOp : TTS_Op<"load", [
MemoryEffects<[MemRead]>,
AttrSizedOperandSegments
Expand All @@ -170,7 +213,7 @@ def TTS_LoadOp : TTS_Op<"load", [
}

bool hasMask() {
return !getStaticMaskDims().empty();
return !getMixedMaskDims().empty();
}
}];

Expand All @@ -182,7 +225,7 @@ def TTS_LoadOp : TTS_Op<"load", [
def TTS_StoreOp : TTS_Op<"store", [
MemoryEffects<[MemWrite]>
]> {
let summary = "optionally load data from in memory to fill a portion of the tensor";
let summary = "optionally store data from in memory to fill a portion of the tensor";

let arguments = (ins TT_PtrLike:$ptr,
TT_Tensor:$value,
Expand All @@ -201,7 +244,7 @@ def TTS_StoreOp : TTS_Op<"store", [
}

bool hasMask() {
return !getStaticMaskDims().empty();
return !getMixedMaskDims().empty();
}
}];

Expand Down
74 changes: 1 addition & 73 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
#include "triton-shared/AnalysisStructured/PtrAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -33,7 +30,6 @@
#include "llvm/Support/LogicalResult.h"
#include <cassert>
#include <cstddef>
#include <functional>
#include <optional>
#include <queue>
#include <string>
Expand All @@ -42,74 +38,6 @@

namespace mlir {

// Extract a scalar value from v.
// If v is a scalar, return that directly. Otherwise, parse through operations
// (currently only support splat, sitofp, and truncf) that produce it to
// extract the underlying scalar value. We then reconstruct the chain of
// operations that can produce this constant with the original type. If no
// scalar value can be extracted, a nullptr is returned.
static Value getScalarValue(Value operand, Location loc, OpBuilder &builder) {
SmallVector<Operation *> ops;

auto reconstructScalarValue = [&](Value src) {
for (auto op = ops.rbegin(); op != ops.rend(); ++op) {
src = TypeSwitch<Operation *, Value>(*op)
.Case<arith::SIToFPOp>([&](Operation *op) {
auto resType = op->getResults()[0].getType();
if (auto shapedType = dyn_cast<ShapedType>(resType)) {
resType = shapedType.getElementType();
}
return builder.create<arith::SIToFPOp>(loc, resType, src);
})
.Case<arith::TruncFOp>([&](Operation *op) {
auto resType = op->getResults()[0].getType();
if (auto shapedType = dyn_cast<ShapedType>(resType)) {
resType = shapedType.getElementType();
}
return builder.create<arith::TruncFOp>(loc, resType, src);
})
.Default([](Operation *op) {
llvm_unreachable("unsupported op in generating ");
return nullptr;
});
}
return src;
};

while (true) {
if (!dyn_cast<ShapedType>(operand.getType())) {
return reconstructScalarValue(operand);
} else if (auto op = operand.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<DenseElementsAttr>(op.getValue())) {
if (!attr.isSplat()) {
InFlightDiagnostic diag = emitError(loc)
<< "other value used in masked load "
"produced by unsupported instruction";
return nullptr;
}
auto elemValue = attr.getSplatValue<Attribute>();
auto constOp = arith::ConstantOp::materialize(
builder, elemValue, attr.getElementType(), op.getLoc());
return reconstructScalarValue(constOp.getResult());
}
} else if (auto op = operand.getDefiningOp<triton::SplatOp>()) {
operand = op.getSrc();
} else if (auto op = operand.getDefiningOp<arith::SIToFPOp>()) {
ops.push_back(op.getOperation());
operand = op.getIn();
} else if (auto op = operand.getDefiningOp<arith::TruncFOp>()) {
ops.push_back(op.getOperation());
operand = op.getIn();
} else {
InFlightDiagnostic diag = emitError(loc)
<< "other value used in masked load produced "
"by unsupported instruction";
return nullptr;
}
}
return nullptr;
}

namespace tts {

int32_t PtrState::getRank() const {
Expand Down Expand Up @@ -1126,7 +1054,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op) {
if (other) {
assert(mask && "other value used while no masks are specified");

scalarOther = getScalarValue(other, loc, builder);
scalarOther = utils::getScalarValue(other, loc, builder);
if (!scalarOther) {
op->emitRemark("other value used in masked load produced by "
"unsupported instruction");
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_subdirectory(TritonToLinalg)
add_subdirectory(TritonToLinalgExperimental)
add_subdirectory(TritonToStructured)
add_subdirectory(TritonToUnstructured)
add_subdirectory(TritonArithToLinalg)
add_subdirectory(StructuredToMemref)
22 changes: 22 additions & 0 deletions lib/Conversion/TritonToUnstructured/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
add_triton_library(TritonToUnstructured
TritonToUnstructuredPass.cpp

DEPENDS
TritonStructuredTableGen
TritonToUnstructuredConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRDialectUtils
MLIRIR
MLIRMathDialect
MLIRPass
MLIRTensorDialect
MLIRTransforms
MLIRSupport
MLIRReconcileUnrealizedCasts
TritonIR
TritonTransforms
TritonSharedAnalysisStructured
TritonStructuredIR
)
Loading
Loading