@@ -24,6 +24,8 @@ include "mlir/IR/OpAsmInterface.td"
2424
2525def AlignmentProp : OptionalProp<I64Prop>;
2626
27+ def ContiguityProp : IntArrayProp<I32Prop, "memory access contiguity information">;
28+
2729//===----------------------------------------------------------------------===//
2830// Common types
2931//===----------------------------------------------------------------------===//
@@ -45,6 +47,15 @@ def Ptr_IntLikeType :AnyTypeOf<[
4547 AnySignlessIntegerOrIndex
4648]>;
4749
50+ // A shaped pointer type with value semantics.
51+ def Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
52+
53+ // A shaped mask type with value semantics.
54+ def Ptr_ShapedMaskType : Ptr_ShapedValueType<[I1], [HasRankPred]>;
55+
56+ // A shaped mask type with value semantics.
57+ def Ptr_ShapedAnyType : Ptr_ShapedValueType<[AnyType], [HasRankPred]>;
58+
4859// A shaped value type of rank 1 of any element type.
4960def Ptr_Any1DType :
5061 Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
@@ -472,6 +483,127 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
472483 let hasVerifier = 1;
473484}
474485
486+ //===----------------------------------------------------------------------===//
487+ // ReadOp
488+ //===----------------------------------------------------------------------===//
489+
490+ def Ptr_ReadOp : Pointer_Op<"read", [
491+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
492+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
493+ ::llvm::cast<ShapedType>($_self).clone(
494+ IntegerType::get($_self.getContext(), 1))
495+ }]>,
496+ AllTypesMatch<["result", "passthrough"]>,
497+ // Check the shapes are compatible and both use the same shaped container
498+ // type.
499+ AllShapesMatch<["result", "ptr"]>, AllTypeIDsMatch<["result", "ptr"]>
500+ ]> {
501+ let summary = "Read operation";
502+ let description = [{
503+ The `read` operation is a high-level operation that performs a read
504+ from multiple memory locations specified by `ptr` based on a mask `mask`.
505+ Elements of the `result`, corresponding to masked-off lanes, are taken from
506+ the `passthrough` operand.
507+
508+ The `mask` operand is a shaped type of `i1` elements that must have the same
509+ shape as the result type.
510+
511+ The `contiguity` property is an integer array with the same rank as `ptr`,
512+ where each element describes memory access contiguity for the corresponding
513+ dimension. The precise semantics of this property are given by:
514+ Let `c1, c2, ..., cn` be the elements of the contiguity array, and
515+ `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
516+ The following rules and restrictions apply:
517+ 1. `ck` must be strictly positive for all k.
518+ 2. `ck` must divide `sk` for all k.
519+ 3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
520+ given by:
521+ - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for
522+ `l = 0, 1, ..., sk / ck - 1`
523+ are contiguous for all k.
524+
525+ It is undefined behavior if the pointers in `ptr` do not satisfy the
526+ contiguity constraints specified by `contiguity`.
527+
528+ Depending on the values of `mask` and `contiguity`, the operation can be
529+ lowered to either:
530+ 1. A `ptr.load`, if the mask is all ones, and there's a dimension where all
531+ the accesses are contiguous.
532+ 2. A `ptr.masked_load`, if the mask is not all ones, and there's a dimension
533+ where all the accesses are contiguous.
534+ 3. A `ptr.gather` if the mask is not all ones, and there's no contiguous
535+ dimension.
536+
537+ The alignment property describes the alignment (in bytes) of each contiguous
538+ memory-block being accessed.
539+
540+ Examples:
541+ ```mlir
542+ // Read a vector in row-major order
543+ %result = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] :
544+ vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
545+
546+ // Read a vector in column-major order with alignment
547+ %result = ptr.read %ptr, %mask, %passthrough alignment = 8
548+ contiguity = [4, 1] :
549+ vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
550+
551+ // Gather a vector from memory
552+ %result = ptr.read %ptr, %mask, %passthrough alignment = 8
553+ contiguity = [1, 1] :
554+ vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
555+ ```
556+ }];
557+ let arguments = (ins Ptr_ShapedPtrType:$ptr,
558+ Ptr_ShapedMaskType:$mask,
559+ Ptr_ShapedAnyType:$passthrough,
560+ AlignmentProp:$alignment,
561+ ContiguityProp:$contiguity);
562+ let results = (outs Ptr_ShapedAnyType:$result);
563+ let assemblyFormat = [{
564+ $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
565+ `contiguity` `=` $contiguity attr-dict `:` type($ptr) `->` type($result)
566+ }];
567+ let builders = [
568+ OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$passthrough,
569+ CArg<"unsigned", "0">:$alignment,
570+ CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
571+ ];
572+ let hasVerifier = 1;
573+ let extraClassDeclaration = [{
574+ /// Returns the ptr type of the operation.
575+ PtrType getPtrType() {
576+ return cast<PtrType>(getPtr().getType().getElementType());
577+ }
578+
579+ /// Returns the rank of the shaped operands and result.
580+ unsigned getRank() { return getType().getRank(); }
581+
582+ /// Returns the shape of the shaped operands and result.
583+ ArrayRef<int64_t> getShape() { return getType().getShape(); }
584+
585+ /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
586+ /// of the `i`-th dimension.
587+ std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
588+ assert(i < getRank() && "Invalid dimension");
589+ return {getContiguity()[i], getShape()[i]};
590+ }
591+
592+ /// Returns true if the `i`-th dimension is contiguous.
593+ bool isContiguous(unsigned i) {
594+ auto [contiguity, size] = getContiguityInfo(i);
595+ return contiguity == size && size > 1;
596+ }
597+
598+ /// Returns true if the read has gather semantics, ie. there's no dimension
599+ /// where all the accesses are contiguous.
600+ bool hasGatherSemantics() {
601+ return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
602+ [this](unsigned i) { return isContiguous(i); });
603+ }
604+ }];
605+ }
606+
475607//===----------------------------------------------------------------------===//
476608// ScatterOp
477609//===----------------------------------------------------------------------===//
@@ -645,4 +777,120 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
645777 }];
646778}
647779
780+ //===----------------------------------------------------------------------===//
781+ // WriteOp
782+ //===----------------------------------------------------------------------===//
783+
784+ def Ptr_WriteOp : Pointer_Op<"write", [
785+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
786+ TypesMatchWith<"value and mask must be compatible",
787+ "value", "mask", [{
788+ cast<ShapedType>($_self).clone(IntegerType::get($_self.getContext(), 1))
789+ }]>,
790+ // Check the shapes are compatible and both use the same shaped container
791+ AllShapesMatch<["value", "ptr"]>, AllTypeIDsMatch<["value", "ptr"]>
792+ ]> {
793+ let summary = "Write operation";
794+ let description = [{
795+ The `write` operation is a high-level operation that performs a write to
796+ multiple memory locations specified by `ptr` based on a mask `mask`.
797+ Elements of the `value`, corresponding to masked-off lanes, are not written
798+ to memory.
799+
800+ The `mask` operand is a shaped type of `i1` elements that must have the same
801+ shape as the `value` type.
802+
803+ The `contiguity` property is an integer array with the same rank as `ptr`,
804+ where each element describes memory access contiguity for the corresponding
805+ dimension. The precise semantics of this property are given by:
806+ Let `c1, c2, ..., cn` be the elements of the contiguity array, and
807+ `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
808+ The following rules and restrictions apply:
809+ 1. `ck` must be strictly positive for all k.
810+ 2. `ck` must divide `sk` for all k.
811+ 3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
812+ given by:
813+ - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, jk+1, ..., jn]` for
814+ `l = 0, 1, ..., sk / ck - 1`
815+ are contiguous for all k.
816+
817+ It is undefined behavior if the pointers in `ptr` do not satisfy the
818+ contiguity constraints specified by `contiguity`.
819+
820+ Depending on the values of `mask` and `contiguity`, the operation can be
821+ lowered to either:
822+ 1. A `ptr.store`, if the mask is all ones, and there's a dimension where all
823+ the accesses are contiguous.
824+ 2. A `ptr.masked_store`, if the mask is not all ones, and there's a dimension
825+ where all the accesses are contiguous.
826+ 3. A `ptr.scatter` if the mask is not all ones, and there's no contiguous
827+ dimension.
828+
829+ The alignment property describes the alignment (in bytes) of each contiguous
830+ memory-block being accessed.
831+
832+ Example:
833+ ```mlir
834+ // Write a vector in row-major order
835+ ptr.write %value, %ptr, %mask contiguity = [1, 4] :
836+ vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
837+
838+ // Write a vector in column-major order with alignment
839+ ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] :
840+ vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
841+
842+ // Scatter a vector to memory
843+ ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] :
844+ vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
845+ ```
846+ }];
847+ let arguments = (ins Ptr_ShapedAnyType:$value,
848+ Ptr_ShapedPtrType:$ptr,
849+ Ptr_ShapedMaskType:$mask,
850+ AlignmentProp:$alignment,
851+ ContiguityProp:$contiguity);
852+ let assemblyFormat = [{
853+ $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)?
854+ `contiguity` `=` $contiguity attr-dict `:` type($value) `,` type($ptr)
855+ }];
856+ let builders = [
857+ OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
858+ CArg<"unsigned", "0">:$alignment,
859+ CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
860+ ];
861+ let hasVerifier = 1;
862+ let extraClassDeclaration = [{
863+ /// Returns the ptr type of the operation.
864+ PtrType getPtrType() {
865+ return cast<PtrType>(getPtr().getType().getElementType());
866+ }
867+
868+ /// Returns the rank of the shaped operands.
869+ unsigned getRank() { return getPtr().getType().getRank(); }
870+
871+ /// Returns the shape of the shaped operands.
872+ ArrayRef<int64_t> getShape() { return getPtr().getType().getShape(); }
873+
874+ /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
875+ /// of the `i`-th dimension.
876+ std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
877+ assert(i < getRank() && "Invalid dimension");
878+ return {getContiguity()[i], getShape()[i]};
879+ }
880+
881+ /// Returns true if the `i`-th dimension is contiguous.
882+ bool isContiguous(unsigned i) {
883+ auto [contiguity, size] = getContiguityInfo(i);
884+ return contiguity == size && size > 1;
885+ }
886+
887+ /// Returns true if the write has scatter semantics, ie. there's no
888+ /// dimension where all the accesses are contiguous.
889+ bool hasScatterSemantics() {
890+ return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
891+ [this](unsigned i) { return isContiguous(i); });
892+ }
893+ }];
894+ }
895+
648896#endif // PTR_OPS
0 commit comments