Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 248 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ include "mlir/IR/OpAsmInterface.td"

def AlignmentProp : OptionalProp<I64Prop>;

def ContiguityProp : IntArrayProp<I32Prop, "memory access contiguity information">;

//===----------------------------------------------------------------------===//
// Common types
//===----------------------------------------------------------------------===//
Expand All @@ -45,6 +47,15 @@ def Ptr_IntLikeType :AnyTypeOf<[
AnySignlessIntegerOrIndex
]>;

// A shaped pointer type with value semantics.
def Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;

// A shaped mask type with value semantics.
def Ptr_ShapedMaskType : Ptr_ShapedValueType<[I1], [HasRankPred]>;

// A shaped mask type with value semantics.
def Ptr_ShapedAnyType : Ptr_ShapedValueType<[AnyType], [HasRankPred]>;

// A shaped value type of rank 1 of any element type.
def Ptr_Any1DType :
Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
Expand Down Expand Up @@ -472,6 +483,127 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// ReadOp
//===----------------------------------------------------------------------===//

def Ptr_ReadOp : Pointer_Op<"read", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
::llvm::cast<ShapedType>($_self).clone(
IntegerType::get($_self.getContext(), 1))
}]>,
AllTypesMatch<["result", "passthrough"]>,
// Check the shapes are compatible and both use the same shaped container
// type.
AllShapesMatch<["result", "ptr"]>, AllTypeIDsMatch<["result", "ptr"]>
]> {
let summary = "Read operation";
let description = [{
The `read` operation is a high-level operation that performs a read
from multiple memory locations specified by `ptr` based on a mask `mask`.
Elements of the `result`, corresponding to masked-off lanes, are taken from
the `passthrough` operand.

The `mask` operand is a shaped type of `i1` elements that must have the same
shape as the result type.

The `contiguity` property is an integer array with the same rank as `ptr`,
where each element describes memory access contiguity for the corresponding
dimension. The precise semantics of this property are given by:
Let `c1, c2, ..., cn` be the elements of the contiguity array, and
`s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
The following rules and restrictions apply:
1. `ck` must be strictly positive for all k.
2. `ck` must divide `sk` for all k.
3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
given by:
- `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for
`l = 0, 1, ..., sk / ck - 1`
are contiguous for all k.

It is undefined behavior if the pointers in `ptr` do not satisfy the
contiguity constraints specified by `contiguity`.

Depending on the values of `mask` and `contiguity`, the operation can be
lowered to either:
1. A `ptr.load`, if the mask is all ones, and there's a dimension where all
the accesses are contiguous.
2. A `ptr.masked_load`, if the mask is not all ones, and there's a dimension
where all the accesses are contiguous.
3. A `ptr.gather` if the mask is not all ones, and there's no contiguous
dimension.

The alignment property describes the alignment (in bytes) of each contiguous
memory-block being accessed.

Examples:
```mlir
// Read a vector in row-major order
%result = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] :
vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>

// Read a vector in column-major order with alignment
%result = ptr.read %ptr, %mask, %passthrough alignment = 8
contiguity = [4, 1] :
vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>

// Gather a vector from memory
%result = ptr.read %ptr, %mask, %passthrough alignment = 8
contiguity = [1, 1] :
vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
```
}];
let arguments = (ins Ptr_ShapedPtrType:$ptr,
Ptr_ShapedMaskType:$mask,
Ptr_ShapedAnyType:$passthrough,
AlignmentProp:$alignment,
ContiguityProp:$contiguity);
let results = (outs Ptr_ShapedAnyType:$result);
let assemblyFormat = [{
$ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
`contiguity` `=` $contiguity attr-dict `:` type($ptr) `->` type($result)
}];
let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$passthrough,
CArg<"unsigned", "0">:$alignment,
CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
];
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns the ptr type of the operation.
PtrType getPtrType() {
return cast<PtrType>(getPtr().getType().getElementType());
}

/// Returns the rank of the shaped operands and result.
unsigned getRank() { return getType().getRank(); }

/// Returns the shape of the shaped operands and result.
ArrayRef<int64_t> getShape() { return getType().getShape(); }

/// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
/// of the `i`-th dimension.
std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
assert(i < getRank() && "Invalid dimension");
return {getContiguity()[i], getShape()[i]};
}

/// Returns true if the `i`-th dimension is contiguous.
bool isContiguous(unsigned i) {
auto [contiguity, size] = getContiguityInfo(i);
return contiguity == size && size > 1;
}

/// Returns true if the read has gather semantics, ie. there's no dimension
/// where all the accesses are contiguous.
bool hasGatherSemantics() {
return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
[this](unsigned i) { return isContiguous(i); });
}
}];
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -645,4 +777,120 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
}];
}

//===----------------------------------------------------------------------===//
// WriteOp
//===----------------------------------------------------------------------===//

def Ptr_WriteOp : Pointer_Op<"write", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TypesMatchWith<"value and mask must be compatible",
"value", "mask", [{
cast<ShapedType>($_self).clone(IntegerType::get($_self.getContext(), 1))
}]>,
// Check the shapes are compatible and both use the same shaped container
AllShapesMatch<["value", "ptr"]>, AllTypeIDsMatch<["value", "ptr"]>
]> {
let summary = "Write operation";
let description = [{
The `write` operation is a high-level operation that performs a write to
multiple memory locations specified by `ptr` based on a mask `mask`.
Elements of the `value`, corresponding to masked-off lanes, are not written
to memory.

The `mask` operand is a shaped type of `i1` elements that must have the same
shape as the `value` type.

The `contiguity` property is an integer array with the same rank as `ptr`,
where each element describes memory access contiguity for the corresponding
dimension. The precise semantics of this property are given by:
Let `c1, c2, ..., cn` be the elements of the contiguity array, and
`s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
The following rules and restrictions apply:
1. `ck` must be strictly positive for all k.
2. `ck` must divide `sk` for all k.
3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
given by:
- `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for
`l = 0, 1, ..., sk / ck - 1`
are contiguous for all k.

It is undefined behavior if the pointers in `ptr` do not satisfy the
contiguity constraints specified by `contiguity`.

Depending on the values of `mask` and `contiguity`, the operation can be
lowered to either:
1. A `ptr.store`, if the mask is all ones, and there's a dimension where all
the accesses are contiguous.
2. A `ptr.masked_store`, if the mask is not all ones, and there's a dimension
where all the accesses are contiguous.
3. A `ptr.scatter` if the mask is not all ones, and there's no contiguous
dimension.

The alignment property describes the alignment (in bytes) of each contiguous
memory-block being accessed.

Example:
```mlir
// Write a vector in row-major order
ptr.write %value, %ptr, %mask contiguity = [1, 4] :
vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>

// Write a vector in column-major order with alignment
ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] :
vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>

// Scatter a vector to memory
ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] :
vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
```
}];
let arguments = (ins Ptr_ShapedAnyType:$value,
Ptr_ShapedPtrType:$ptr,
Ptr_ShapedMaskType:$mask,
AlignmentProp:$alignment,
ContiguityProp:$contiguity);
let assemblyFormat = [{
$value `,` $ptr `,` $mask (`alignment` `=` $alignment^)?
`contiguity` `=` $contiguity attr-dict `:` type($value) `,` type($ptr)
}];
let builders = [
OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
CArg<"unsigned", "0">:$alignment,
CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
];
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns the ptr type of the operation.
PtrType getPtrType() {
return cast<PtrType>(getPtr().getType().getElementType());
}

/// Returns the rank of the shaped operands.
unsigned getRank() { return getPtr().getType().getRank(); }

/// Returns the shape of the shaped operands.
ArrayRef<int64_t> getShape() { return getPtr().getType().getShape(); }

/// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
/// of the `i`-th dimension.
std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
assert(i < getRank() && "Invalid dimension");
return {getContiguity()[i], getShape()[i]};
}

/// Returns true if the `i`-th dimension is contiguous.
bool isContiguous(unsigned i) {
auto [contiguity, size] = getContiguityInfo(i);
return contiguity == size && size > 1;
}

/// Returns true if the write has scatter semantics, ie. there's no
/// dimension where all the accesses are contiguous.
bool hasScatterSemantics() {
return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
[this](unsigned i) { return isContiguous(i); });
}
}];
}

#endif // PTR_OPS
105 changes: 105 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@ verifyAlignment(std::optional<int64_t> alignment,
return success();
}

/// Verifies that the contiguity array has the right size, all the elements are
/// positive and divide the corresponding shape dimension.
static LogicalResult
verifyContiguityProp(ArrayRef<int32_t> contiguity, ArrayRef<int64_t> shape,
function_ref<InFlightDiagnostic()> emitError) {
if (contiguity.size() != shape.size()) {
return emitError() << "expected contiguity array with " << shape.size()
<< " elements";
}
if (!llvm::all_of(llvm::zip(contiguity, shape), [](auto cs) {
int32_t c = std::get<0>(cs);
return c > 0 && std::get<1>(cs) % c == 0;
})) {
return emitError()
<< "expected contiguity values to be positive and divide the shape";
}
return success();
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -264,6 +283,49 @@ void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
alignment ? std::optional<int64_t>(alignment) : std::nullopt);
}

//===----------------------------------------------------------------------===//
// ReadOp
//===----------------------------------------------------------------------===//

void ReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
}

LogicalResult ReadOp::verify() {
auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };

// Verify that the pointer type's memory space allows loads.
MemorySpaceAttrInterface ms =
cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
DataLayout dataLayout = DataLayout::closest(*this);
if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
getAlignment(), &dataLayout, emitDiag))
return failure();

// Verify the alignment.
if (failed(verifyAlignment(getAlignment(), emitDiag)))
return failure();

// Verify the contiguity array.
return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
}

void ReadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
Value mask, Value passthrough, unsigned alignment,
ArrayRef<int32_t> contiguity) {
if (!contiguity.empty()) {
build(builder, state, ptr, mask, passthrough,
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
contiguity);
return;
}
build(builder, state, ptr, mask, passthrough,
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
}

//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -470,6 +532,49 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
return dl.getTypeSize(getElementType());
}

//===----------------------------------------------------------------------===//
// WriteOp
//===----------------------------------------------------------------------===//

void WriteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
}

LogicalResult WriteOp::verify() {
auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };

// Verify that the pointer type's memory space allows stores.
MemorySpaceAttrInterface ms =
cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
DataLayout dataLayout = DataLayout::closest(*this);
if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
getAlignment(), &dataLayout, emitDiag))
return failure();

// Verify the alignment.
if (failed(verifyAlignment(getAlignment(), emitDiag)))
return failure();

// Verify the contiguity array.
return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
}

void WriteOp::build(OpBuilder &builder, OperationState &state, Value value,
Value ptr, Value mask, unsigned alignment,
ArrayRef<int32_t> contiguity) {
if (!contiguity.empty()) {
build(builder, state, value, ptr, mask,
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
contiguity);
return;
}
build(builder, state, value, ptr, mask,
alignment ? std::optional<int64_t>(alignment) : std::nullopt,
SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
}

//===----------------------------------------------------------------------===//
// Pointer API.
//===----------------------------------------------------------------------===//
Expand Down
Loading