diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td index e14f64330c294..05edc5966975d 100644 --- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td +++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td @@ -24,6 +24,8 @@ include "mlir/IR/OpAsmInterface.td" def AlignmentProp : OptionalProp; +def ContiguityProp : IntArrayProp; + //===----------------------------------------------------------------------===// // Common types //===----------------------------------------------------------------------===// @@ -45,6 +47,15 @@ def Ptr_IntLikeType :AnyTypeOf<[ AnySignlessIntegerOrIndex ]>; +// A shaped pointer type with value semantics. +def Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>; + +// A shaped mask type with value semantics. +def Ptr_ShapedMaskType : Ptr_ShapedValueType<[I1], [HasRankPred]>; + +// A shaped mask type with value semantics. +def Ptr_ShapedAnyType : Ptr_ShapedValueType<[AnyType], [HasRankPred]>; + // A shaped value type of rank 1 of any element type. def Ptr_Any1DType : Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>; @@ -472,6 +483,127 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// ReadOp +//===----------------------------------------------------------------------===// + +def Ptr_ReadOp : Pointer_Op<"read", [ + DeclareOpInterfaceMethods, + TypesMatchWith<"result and mask must be compatible", "result", "mask", [{ + ::llvm::cast($_self).clone( + IntegerType::get($_self.getContext(), 1)) + }]>, + AllTypesMatch<["result", "passthrough"]>, + // Check the shapes are compatible and both use the same shaped container + // type. + AllShapesMatch<["result", "ptr"]>, AllTypeIDsMatch<["result", "ptr"]> + ]> { + let summary = "Read operation"; + let description = [{ + The `read` operation is a high-level operation that performs a read + from multiple memory locations specified by `ptr` based on a mask `mask`. + Elements of the `result`, corresponding to masked-off lanes, are taken from + the `passthrough` operand. + + The `mask` operand is a shaped type of `i1` elements that must have the same + shape as the result type. + + The `contiguity` property is an integer array with the same rank as `ptr`, + where each element describes memory access contiguity for the corresponding + dimension. The precise semantics of this property are given by: + Let `c1, c2, ..., cn` be the elements of the contiguity array, and + `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape. + The following rules and restrictions apply: + 1. `ck` must be strictly positive for all k. + 2. `ck` must divide `sk` for all k. + 3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges + given by: + - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for + `l = 0, 1, ..., sk / ck - 1` + are contiguous for all k. + + It is undefined behavior if the pointers in `ptr` do not satisfy the + contiguity constraints specified by `contiguity`. + + Depending on the values of `mask` and `contiguity`, the operation can be + lowered to either: + 1. A `ptr.load`, if the mask is all ones, and there's a dimension where all + the accesses are contiguous. + 2. A `ptr.masked_load`, if the mask is not all ones, and there's a dimension + where all the accesses are contiguous. + 3. A `ptr.gather` if the mask is not all ones, and there's no contiguous + dimension. + + The alignment property describes the alignment (in bytes) of each contiguous + memory-block being accessed. + + Examples: + ```mlir + // Read a vector in row-major order + %result = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : + vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + + // Read a vector in column-major order with alignment + %result = ptr.read %ptr, %mask, %passthrough alignment = 8 + contiguity = [4, 1] : + vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + + // Gather a vector from memory + %result = ptr.read %ptr, %mask, %passthrough alignment = 8 + contiguity = [1, 1] : + vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + ``` + }]; + let arguments = (ins Ptr_ShapedPtrType:$ptr, + Ptr_ShapedMaskType:$mask, + Ptr_ShapedAnyType:$passthrough, + AlignmentProp:$alignment, + ContiguityProp:$contiguity); + let results = (outs Ptr_ShapedAnyType:$result); + let assemblyFormat = [{ + $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)? + `contiguity` `=` $contiguity attr-dict `:` type($ptr) `->` type($result) + }]; + let builders = [ + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$passthrough, + CArg<"unsigned", "0">:$alignment, + CArg<"ArrayRef", "{}">:$contiguity)> + ]; + let hasVerifier = 1; + let extraClassDeclaration = [{ + /// Returns the ptr type of the operation. + PtrType getPtrType() { + return cast(getPtr().getType().getElementType()); + } + + /// Returns the rank of the shaped operands and result. + unsigned getRank() { return getType().getRank(); } + + /// Returns the shape of the shaped operands and result. + ArrayRef getShape() { return getType().getShape(); } + + /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size + /// of the `i`-th dimension. + std::pair getContiguityInfo(unsigned i) { + assert(i < getRank() && "Invalid dimension"); + return {getContiguity()[i], getShape()[i]}; + } + + /// Returns true if the `i`-th dimension is contiguous. + bool isContiguous(unsigned i) { + auto [contiguity, size] = getContiguityInfo(i); + return contiguity == size && size > 1; + } + + /// Returns true if the read has gather semantics, ie. there's no dimension + /// where all the accesses are contiguous. + bool hasGatherSemantics() { + return !llvm::any_of(llvm::seq(0, getRank()), + [this](unsigned i) { return isContiguous(i); }); + } + }]; +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// @@ -645,4 +777,120 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> { }]; } +//===----------------------------------------------------------------------===// +// WriteOp +//===----------------------------------------------------------------------===// + +def Ptr_WriteOp : Pointer_Op<"write", [ + DeclareOpInterfaceMethods, + TypesMatchWith<"value and mask must be compatible", + "value", "mask", [{ + cast($_self).clone(IntegerType::get($_self.getContext(), 1)) + }]>, + // Check the shapes are compatible and both use the same shaped container + AllShapesMatch<["value", "ptr"]>, AllTypeIDsMatch<["value", "ptr"]> + ]> { + let summary = "Write operation"; + let description = [{ + The `write` operation is a high-level operation that performs a write to + multiple memory locations specified by `ptr` based on a mask `mask`. + Elements of the `value`, corresponding to masked-off lanes, are not written + to memory. + + The `mask` operand is a shaped type of `i1` elements that must have the same + shape as the `value` type. + + The `contiguity` property is an integer array with the same rank as `ptr`, + where each element describes memory access contiguity for the corresponding + dimension. The precise semantics of this property are given by: + Let `c1, c2, ..., cn` be the elements of the contiguity array, and + `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape. + The following rules and restrictions apply: + 1. `ck` must be strictly positive for all k. + 2. `ck` must divide `sk` for all k. + 3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges + given by: + - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for + `l = 0, 1, ..., sk / ck - 1` + are contiguous for all k. + + It is undefined behavior if the pointers in `ptr` do not satisfy the + contiguity constraints specified by `contiguity`. + + Depending on the values of `mask` and `contiguity`, the operation can be + lowered to either: + 1. A `ptr.store`, if the mask is all ones, and there's a dimension where all + the accesses are contiguous. + 2. A `ptr.masked_store`, if the mask is not all ones, and there's a dimension + where all the accesses are contiguous. + 3. A `ptr.scatter` if the mask is not all ones, and there's no contiguous + dimension. + + The alignment property describes the alignment (in bytes) of each contiguous + memory-block being accessed. + + Example: + ```mlir + // Write a vector in row-major order + ptr.write %value, %ptr, %mask contiguity = [1, 4] : + vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + + // Write a vector in column-major order with alignment + ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : + vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + + // Scatter a vector to memory + ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : + vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + ``` + }]; + let arguments = (ins Ptr_ShapedAnyType:$value, + Ptr_ShapedPtrType:$ptr, + Ptr_ShapedMaskType:$mask, + AlignmentProp:$alignment, + ContiguityProp:$contiguity); + let assemblyFormat = [{ + $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)? + `contiguity` `=` $contiguity attr-dict `:` type($value) `,` type($ptr) + }]; + let builders = [ + OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask, + CArg<"unsigned", "0">:$alignment, + CArg<"ArrayRef", "{}">:$contiguity)> + ]; + let hasVerifier = 1; + let extraClassDeclaration = [{ + /// Returns the ptr type of the operation. + PtrType getPtrType() { + return cast(getPtr().getType().getElementType()); + } + + /// Returns the rank of the shaped operands. + unsigned getRank() { return getPtr().getType().getRank(); } + + /// Returns the shape of the shaped operands. + ArrayRef getShape() { return getPtr().getType().getShape(); } + + /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size + /// of the `i`-th dimension. + std::pair getContiguityInfo(unsigned i) { + assert(i < getRank() && "Invalid dimension"); + return {getContiguity()[i], getShape()[i]}; + } + + /// Returns true if the `i`-th dimension is contiguous. + bool isContiguous(unsigned i) { + auto [contiguity, size] = getContiguityInfo(i); + return contiguity == size && size > 1; + } + + /// Returns true if the write has scatter semantics, ie. there's no + /// dimension where all the accesses are contiguous. + bool hasScatterSemantics() { + return !llvm::any_of(llvm::seq(0, getRank()), + [this](unsigned i) { return isContiguous(i); }); + } + }]; +} + #endif // PTR_OPS diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index 51f25f755a8a6..ecfbd957bbe24 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -57,6 +57,25 @@ verifyAlignment(std::optional alignment, return success(); } +/// Verifies that the contiguity array has the right size, all the elements are +/// positive and divide the corresponding shape dimension. +static LogicalResult +verifyContiguityProp(ArrayRef contiguity, ArrayRef shape, + function_ref emitError) { + if (contiguity.size() != shape.size()) { + return emitError() << "expected contiguity array with " << shape.size() + << " elements"; + } + if (!llvm::all_of(llvm::zip(contiguity, shape), [](auto cs) { + int32_t c = std::get<0>(cs); + return c > 0 && std::get<1>(cs) % c == 0; + })) { + return emitError() + << "expected contiguity values to be positive and divide the shape"; + } + return success(); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -264,6 +283,49 @@ void MaskedStoreOp::build(OpBuilder &builder, OperationState &state, alignment ? std::optional(alignment) : std::nullopt); } +//===----------------------------------------------------------------------===// +// ReadOp +//===----------------------------------------------------------------------===// + +void ReadOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable()); +} + +LogicalResult ReadOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + + // Verify that the pointer type's memory space allows loads. + MemorySpaceAttrInterface ms = + cast(getPtr().getType().getElementType()).getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic, + getAlignment(), &dataLayout, emitDiag)) + return failure(); + + // Verify the alignment. + if (failed(verifyAlignment(getAlignment(), emitDiag))) + return failure(); + + // Verify the contiguity array. + return verifyContiguityProp(getContiguity(), getShape(), emitDiag); +} + +void ReadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value passthrough, unsigned alignment, + ArrayRef contiguity) { + if (!contiguity.empty()) { + build(builder, state, ptr, mask, passthrough, + alignment ? std::optional(alignment) : std::nullopt, + contiguity); + return; + } + build(builder, state, ptr, mask, passthrough, + alignment ? std::optional(alignment) : std::nullopt, + SmallVector(cast(ptr.getType()).getRank(), 1)); +} + //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// @@ -470,6 +532,49 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional layout) { return dl.getTypeSize(getElementType()); } +//===----------------------------------------------------------------------===// +// WriteOp +//===----------------------------------------------------------------------===// + +void WriteOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable()); +} + +LogicalResult WriteOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + + // Verify that the pointer type's memory space allows stores. + MemorySpaceAttrInterface ms = + cast(getPtr().getType().getElementType()).getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic, + getAlignment(), &dataLayout, emitDiag)) + return failure(); + + // Verify the alignment. + if (failed(verifyAlignment(getAlignment(), emitDiag))) + return failure(); + + // Verify the contiguity array. + return verifyContiguityProp(getContiguity(), getShape(), emitDiag); +} + +void WriteOp::build(OpBuilder &builder, OperationState &state, Value value, + Value ptr, Value mask, unsigned alignment, + ArrayRef contiguity) { + if (!contiguity.empty()) { + build(builder, state, value, ptr, mask, + alignment ? std::optional(alignment) : std::nullopt, + contiguity); + return; + } + build(builder, state, value, ptr, mask, + alignment ? std::optional(alignment) : std::nullopt, + SmallVector(cast(ptr.getType()).getRank(), 1)); +} + //===----------------------------------------------------------------------===// // Pointer API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir index 83e1c880650c5..54332a5632808 100644 --- a/mlir/test/Dialect/Ptr/invalid.mlir +++ b/mlir/test/Dialect/Ptr/invalid.mlir @@ -78,3 +78,51 @@ func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs: %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64> return %res : vector<8xi64> } + +// ----- + +func.func @read_contiguity_does_not_divide(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> { + // expected-error@+1 {{expected contiguity values to be positive and divide the shape}} + %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 3] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// ----- + +func.func @read_contiguity_is_not_positive(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> { + // expected-error@+1 {{expected contiguity values to be positive and divide the shape}} + %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, -1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// ----- + +func.func @read_invalid_contiguity_size(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> { + // expected-error@+1 {{expected contiguity array with 2 elements}} + %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// ----- + +func.func @write_contiguity_does_not_divide(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) { + // expected-error@+1 {{expected contiguity values to be positive and divide the shape}} + ptr.write %value, %ptr, %mask contiguity = [1, 7] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + return +} + +// ----- + +func.func @write_contiguity_is_not_positive(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) { + // expected-error@+1 {{expected contiguity values to be positive and divide the shape}} + ptr.write %value, %ptr, %mask contiguity = [0, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + return +} + +// ----- + +func.func @write_invalid_contiguity_size(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) { + // expected-error@+1 {{expected contiguity array with 2 elements}} + ptr.write %value, %ptr, %mask contiguity = [1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + return +} diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir index 0a906ad559e21..d0c0390d6932e 100644 --- a/mlir/test/Dialect/Ptr/ops.mlir +++ b/mlir/test/Dialect/Ptr/ops.mlir @@ -239,3 +239,39 @@ func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64> return %diff : tensor<4x8xi64> } + +/// Check read op assembly. +func.func @read_ops(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) { + // Row-major styled read + %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + // Column-major styled read + %1 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [4, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + // Gather styled read + %2 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [1, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32> + return +} + +/// Check read op assembly with tensors +func.func @read_ops_tensor(%ptr: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>, %passthrough: tensor<8xf32>) { + %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1] : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xf32> + %1 = ptr.read %ptr, %mask, %passthrough alignment = 4 contiguity = [8] : tensor<8x!ptr.ptr<#ptr.generic_space>> -> tensor<8xf32> + return +} + +/// Check write op assembly. +func.func @write_ops(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) { + // Row-major styled write + ptr.write %value, %ptr, %mask contiguity = [1, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + // Column-major styled write + ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + // Scatter styled write + ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>> + return +} + +/// Check write op assembly with tensors +func.func @write_ops_tensor(%value: tensor<8xf32>, %ptr: tensor<8x!ptr.ptr<#ptr.generic_space>>, %mask: tensor<8xi1>) { + ptr.write %value, %ptr, %mask contiguity = [1] : tensor<8xf32>, tensor<8x!ptr.ptr<#ptr.generic_space>> + ptr.write %value, %ptr, %mask alignment = 4 contiguity = [8] : tensor<8xf32>, tensor<8x!ptr.ptr<#ptr.generic_space>> + return +}