Skip to content

Conversation

@fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Sep 28, 2025

Add ptr.read and ptr.write operations to the pointer dialect. These operations
provide a high-level interface for reading from and writing to memory with:

  • Masked access semantics (conditional loads/stores)
  • Contiguity information for optimized lowering
  • Support for both vector and tensor types
  • Ability to express row-major, column-major, gather/scatter patterns, and other access patterns

Example:

func.func @read(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) {
  // Row-major styled read
  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
  // Column-major styled read
  %1 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [4, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
  // Gather styled read
  %2 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [1, 1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
  return
}

func.func @write(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
  // Row-major styled write
  ptr.write %value, %ptr, %mask contiguity = [1, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
  // Column-major styled write
  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
  // Scatter styled write
  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
  return
}

It's future work to add lowerings to ptr.load/store, ptr.masked_load/store, or
ptr.gather/scatter depending on mask and contiguity properties.

@llvmbot llvmbot added the mlir label Sep 28, 2025
@fabianmcg fabianmcg requested a review from Copilot September 28, 2025 13:46
@fabianmcg fabianmcg added the mlir:ptr MLIR ptr dialect label Sep 28, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2025

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

Changes

Add ptr.read and ptr.write operations to the pointer dialect. These operations
provide a high-level interface for reading from and writing to memory with:

  • Masked access semantics (conditional loads/stores)
  • Contiguity information for optimized lowering
  • Support for both vector and tensor types
  • Ability to express row-major, column-major, and gather/scatter patterns

It's future work to add lowerings to ptr.load/store, ptr.masked_load/store, or
ptr.gather/scatter depending on mask and contiguity properties.

Example:

func.func @<!-- -->read(%ptr: vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, %mask: vector&lt;4x4xi1&gt;, %passthrough: vector&lt;4x4xf32&gt;) {
  // Row-major styled read
  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] : vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt; -&gt; vector&lt;4x4xf32&gt;
  // Column-major styled read
  %1 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [4, 1] : vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt; -&gt; vector&lt;4x4xf32&gt;
  // Gather styled read
  %2 = ptr.read %ptr, %mask, %passthrough alignment = 8 contiguity = [1, 1] : vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt; -&gt; vector&lt;4x4xf32&gt;
  return
}

func.func @<!-- -->write(%value: vector&lt;4x4xf32&gt;, %ptr: vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, %mask: vector&lt;4x4xi1&gt;) {
  // Row-major styled write
  ptr.write %value, %ptr, %mask contiguity = [1, 4] : vector&lt;4x4xf32&gt;, vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;
  // Column-major styled write
  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] : vector&lt;4x4xf32&gt;, vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;
  // Scatter styled write
  ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] : vector&lt;4x4xf32&gt;, vector&lt;4x4x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;
  return
}

Patch is 21.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/161081.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+248)
  • (modified) mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp (+105)
  • (modified) mlir/test/Dialect/Ptr/invalid.mlir (+48)
  • (modified) mlir/test/Dialect/Ptr/ops.mlir (+36)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index e14f64330c294..c3a5415d0cbc8 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -24,6 +24,8 @@ include "mlir/IR/OpAsmInterface.td"
 
 def AlignmentProp : OptionalProp<I64Prop>;
 
+def ContiguityProp : IntArrayProp<I32Prop, "memory access contiguity information">;
+
 //===----------------------------------------------------------------------===//
 // Common types
 //===----------------------------------------------------------------------===//
@@ -45,6 +47,15 @@ def Ptr_IntLikeType :AnyTypeOf<[
   AnySignlessIntegerOrIndex
 ]>;
 
+// A shaped pointer type with value semantics.
+def Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+
+// A shaped mask type with value semantics.
+def Ptr_ShapedMaskType : Ptr_ShapedValueType<[I1], [HasRankPred]>;
+
+// A shaped mask type with value semantics.
+def Ptr_ShapedAnyType : Ptr_ShapedValueType<[AnyType], [HasRankPred]>;
+
 // A shaped value type of rank 1 of any element type.
 def Ptr_Any1DType :
   Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
@@ -472,6 +483,127 @@ def Ptr_PtrDiffOp : Pointer_Op<"ptr_diff", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ReadOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_ReadOp : Pointer_Op<"read", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
+      ::llvm::cast<ShapedType>($_self).clone(
+        IntegerType::get($_self.getContext(), 1))
+    }]>,
+    AllTypesMatch<["result", "passthrough"]>,
+    // Check the shapes are compatible and both use the same shaped container
+    // type.
+    AllShapesMatch<["result", "ptr"]>, AllTypeIDsMatch<["result", "ptr"]>
+  ]> {
+  let summary = "Read operation";
+  let description = [{
+    The `read` operation is a high-level operation that performs a read
+    from multiple memory locations specified by `ptr` based on a mask `mask`.
+    Elements of the `result`, corresponding to masked-off lanes, are taken from
+    the `passthrough` operand.
+
+    The `mask` operand is a shaped type of `i1` elements that must have the same
+    shape as the result type.
+
+    The `contiguity` property is an integer array with the same rank as `ptr`,
+    where each element describes memory access contiguity for the corresponding
+    dimension. The precise semantics of this property are given by:
+    Let `c1, c2, ..., cn` be the elements of the contiguity array, and
+    `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
+    The following rules and restrictions apply:
+      1. `ck` must be strictly positive for all k.
+      2. `ck` must divide `sk` for all k.
+      3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
+         given by:
+           - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, ji+1, ..., jn]` for
+           `l = 0, 1, ..., sk / ck - 1`
+         are contiguous for all k.
+
+    It is undefined behavior if the pointers in `ptr` do not satisfy the
+    contiguity constraints specified by `contiguity`.
+
+    Depending on the values of `mask` and `contiguity`, the operation can be
+    lowered to either:
+    1. A `ptr.load`, if the mask is all ones, and there's a dimension where all
+       the accesses are contiguous.
+    2. A `ptr.masked_load`, if the mask is not all ones, and there's a dimension
+       where all the accesses are contiguous.
+    3. A `ptr.gather` if the mask is not all ones, and there's no contiguous
+       dimension.
+
+    The alignment property describes the alignment (in bytes) of each contiguous
+    memory-block being accessed.
+
+    Examples:
+    ```mlir
+    // Read a vector in row-major order
+    %result = ptr.read %ptr, %mask, %passthrough contiguity = [1, 4] :
+      vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+
+    // Read a vector in column-major order with alignment
+    %result = ptr.read %ptr, %mask, %passthrough alignment = 8
+      contiguity = [4, 1] :
+      vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+
+    // Gather a vector from memory
+    %result = ptr.read %ptr, %mask, %passthrough alignment = 8
+      contiguity = [1, 1] :
+      vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+    ```
+  }];
+  let arguments = (ins Ptr_ShapedPtrType:$ptr,
+                       Ptr_ShapedMaskType:$mask,
+                       Ptr_ShapedAnyType:$passthrough,
+                       AlignmentProp:$alignment,
+                       ContiguityProp:$contiguity);
+  let results = (outs Ptr_ShapedAnyType:$result);
+  let assemblyFormat = [{
+    $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
+    `contiguity` `=` $contiguity attr-dict `:` type($ptr) `->` type($result)
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$passthrough,
+      CArg<"unsigned", "0">:$alignment,
+      CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
+  ];
+  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    /// Returns the ptr type of the operation.
+    PtrType getPtrType()  {
+      return cast<PtrType>(getPtr().getType().getElementType());
+    }
+
+    /// Returns the rank of the shaped operands and result.
+    unsigned getRank() { return getType().getRank(); }
+
+    /// Returns the shape of the shaped operands and result.
+    ArrayRef<int64_t> getShape() { return getType().getShape(); }
+
+    /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
+    /// of the `i`-th dimension.
+    std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
+      assert(i < getRank() && "Invalid dimension");
+      return {getContiguity()[i], getShape()[i]};
+    }
+
+    /// Returns true if the `i`-th dimension is contiguous.
+    bool isContiguous(unsigned i) {
+      auto [contiguity, size] = getContiguityInfo(i);
+      return contiguity == size && size > 1;
+    }
+
+    /// Returns true if the read has gather semantics, ie. there's no dimension
+    /// where all the accesses are contiguous.
+    bool hasGatherSemantics() {
+      return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
+        [this](unsigned i) { return isContiguous(i); });
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -645,4 +777,120 @@ def Ptr_TypeOffsetOp : Pointer_Op<"type_offset", [Pure]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// WriteOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_WriteOp : Pointer_Op<"write", [
+    DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+    TypesMatchWith<"value and mask must be compatible",
+                                          "value", "mask", [{
+      cast<ShapedType>($_self).clone(IntegerType::get($_self.getContext(), 1))
+    }]>,
+    // Check the shapes are compatible and both use the same shaped container
+    AllShapesMatch<["value", "ptr"]>, AllTypeIDsMatch<["value", "ptr"]>
+  ]> {
+  let summary = "Write operation";
+  let description = [{
+    The `write` operation is a high-level operation that performs a write to
+    multiple memory locations specified by `ptr` based on a mask `mask`.
+    Elements of the `value`, corresponding to masked-off lanes, are not written
+    to memory.
+
+    The `mask` operand is a shaped type of `i1` elements that must have the same
+    shape as the `value` type.
+
+    The `contiguity` property is an integer array with the same rank as `ptr`,
+    where each element describes memory access contiguity for the corresponding
+    dimension. The precise semantics of this property are given by:
+    Let `c1, c2, ..., cn` be the elements of the contiguity array, and
+    `s1, s2, ..., sn` be the corresponding elements of the `ptr` shape.
+    The following rules and restrictions apply:
+      1. `ck` must be strictly positive for all k.
+      2. `ck` must divide `sk` for all k.
+      3. Given arbitrary but valid indices `j1, ..., jn`, then the memory ranges
+         given by:
+           - `ptr[j1, ..., jk-1, l * ck : (l + 1) * ck, ji+1, ..., jn]` for
+           `l = 0, 1, ..., sk / ck - 1`
+         are contiguous for all k.
+
+    It is undefined behavior if the pointers in `ptr` do not satisfy the
+    contiguity constraints specified by `contiguity`.
+
+    Depending on the values of `mask` and `contiguity`, the operation can be
+    lowered to either:
+    1. A `ptr.store`, if the mask is all ones, and there's a dimension where all
+       the accesses are contiguous.
+    2. A `ptr.masked_store`, if the mask is not all ones, and there's a dimension
+       where all the accesses are contiguous.
+    3. A `ptr.scatter` if the mask is not all ones, and there's no contiguous
+       dimension.
+
+    The alignment property describes the alignment (in bytes) of each contiguous
+    memory-block being accessed.
+
+    Example:
+    ```mlir
+    // Write a vector in row-major order
+    ptr.write %value, %ptr, %mask contiguity = [1, 4] :
+      vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+
+    // Write a vector in column-major order with alignment
+    ptr.write %value, %ptr, %mask alignment = 8 contiguity = [4, 1] :
+      vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+
+    // Scatter a vector to memory
+    ptr.write %value, %ptr, %mask alignment = 8 contiguity = [1, 1] :
+      vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+    ```
+  }];
+  let arguments = (ins Ptr_ShapedAnyType:$value,
+                       Ptr_ShapedPtrType:$ptr,
+                       Ptr_ShapedMaskType:$mask,
+                       AlignmentProp:$alignment,
+                       ContiguityProp:$contiguity);
+  let assemblyFormat = [{
+    $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)?
+    `contiguity` `=` $contiguity attr-dict `:` type($value) `,` type($ptr)
+  }];
+  let builders = [
+    OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
+      CArg<"unsigned", "0">:$alignment,
+      CArg<"ArrayRef<int32_t>", "{}">:$contiguity)>
+  ];
+  let hasVerifier = 1;
+  let extraClassDeclaration = [{
+    /// Returns the ptr type of the operation.
+    PtrType getPtrType()  {
+      return cast<PtrType>(getPtr().getType().getElementType());
+    }
+
+    /// Returns the rank of the shaped operands.
+    unsigned getRank() { return getPtr().getType().getRank(); }
+
+    /// Returns the shape of the shaped operands.
+    ArrayRef<int64_t> getShape() { return getPtr().getType().getShape(); }
+
+    /// Returns a pair `(c, s)` where `c` is the contiguity and `s` the size
+    /// of the `i`-th dimension.
+    std::pair<int64_t, int64_t> getContiguityInfo(unsigned i) {
+      assert(i < getRank() && "Invalid dimension");
+      return {getContiguity()[i], getShape()[i]};
+    }
+
+    /// Returns true if the `i`-th dimension is contiguous.
+    bool isContiguous(unsigned i) {
+      auto [contiguity, size] = getContiguityInfo(i);
+      return contiguity == size && size > 1;
+    }
+
+    /// Returns true if the write has scatter semantics, ie. there's no
+    /// dimension where all the accesses are contiguous.
+    bool hasScatterSemantics() {
+      return !llvm::any_of(llvm::seq<unsigned>(0, getRank()),
+        [this](unsigned i) { return isContiguous(i); });
+    }
+  }];
+}
+
 #endif // PTR_OPS
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 51f25f755a8a6..ecfbd957bbe24 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -57,6 +57,25 @@ verifyAlignment(std::optional<int64_t> alignment,
   return success();
 }
 
+/// Verifies that the contiguity array has the right size, all the elements are
+/// positive and divide the corresponding shape dimension.
+static LogicalResult
+verifyContiguityProp(ArrayRef<int32_t> contiguity, ArrayRef<int64_t> shape,
+                     function_ref<InFlightDiagnostic()> emitError) {
+  if (contiguity.size() != shape.size()) {
+    return emitError() << "expected contiguity array with " << shape.size()
+                       << " elements";
+  }
+  if (!llvm::all_of(llvm::zip(contiguity, shape), [](auto cs) {
+        int32_t c = std::get<0>(cs);
+        return c > 0 && std::get<1>(cs) % c == 0;
+      })) {
+    return emitError()
+           << "expected contiguity values to be positive and divide the shape";
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -264,6 +283,49 @@ void MaskedStoreOp::build(OpBuilder &builder, OperationState &state,
         alignment ? std::optional<int64_t>(alignment) : std::nullopt);
 }
 
+//===----------------------------------------------------------------------===//
+// ReadOp
+//===----------------------------------------------------------------------===//
+
+void ReadOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable());
+}
+
+LogicalResult ReadOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+  // Verify that the pointer type's memory space allows loads.
+  MemorySpaceAttrInterface ms =
+      cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidLoad(getResult().getType(), AtomicOrdering::not_atomic,
+                      getAlignment(), &dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  if (failed(verifyAlignment(getAlignment(), emitDiag)))
+    return failure();
+
+  // Verify the contiguity array.
+  return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
+}
+
+void ReadOp::build(OpBuilder &builder, OperationState &state, Value ptr,
+                   Value mask, Value passthrough, unsigned alignment,
+                   ArrayRef<int32_t> contiguity) {
+  if (!contiguity.empty()) {
+    build(builder, state, ptr, mask, passthrough,
+          alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+          contiguity);
+    return;
+  }
+  build(builder, state, ptr, mask, passthrough,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+        SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
@@ -470,6 +532,49 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) {
   return dl.getTypeSize(getElementType());
 }
 
+//===----------------------------------------------------------------------===//
+// WriteOp
+//===----------------------------------------------------------------------===//
+
+void WriteOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable());
+}
+
+LogicalResult WriteOp::verify() {
+  auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); };
+
+  // Verify that the pointer type's memory space allows stores.
+  MemorySpaceAttrInterface ms =
+      cast<PtrType>(getPtr().getType().getElementType()).getMemorySpace();
+  DataLayout dataLayout = DataLayout::closest(*this);
+  if (!ms.isValidStore(getValue().getType(), AtomicOrdering::not_atomic,
+                       getAlignment(), &dataLayout, emitDiag))
+    return failure();
+
+  // Verify the alignment.
+  if (failed(verifyAlignment(getAlignment(), emitDiag)))
+    return failure();
+
+  // Verify the contiguity array.
+  return verifyContiguityProp(getContiguity(), getShape(), emitDiag);
+}
+
+void WriteOp::build(OpBuilder &builder, OperationState &state, Value value,
+                    Value ptr, Value mask, unsigned alignment,
+                    ArrayRef<int32_t> contiguity) {
+  if (!contiguity.empty()) {
+    build(builder, state, value, ptr, mask,
+          alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+          contiguity);
+    return;
+  }
+  build(builder, state, value, ptr, mask,
+        alignment ? std::optional<int64_t>(alignment) : std::nullopt,
+        SmallVector<int32_t>(cast<ShapedType>(ptr.getType()).getRank(), 1));
+}
+
 //===----------------------------------------------------------------------===//
 // Pointer API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 83e1c880650c5..54332a5632808 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -78,3 +78,51 @@ func.func @ptr_diff_mismatch(%lhs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %rhs:
   %res = ptr.ptr_diff %lhs, %rhs : tensor<8x!ptr.ptr<#ptr.generic_space>> -> vector<8xi64>
   return %res : vector<8xi64>
 }
+
+// -----
+
+func.func @read_contiguity_does_not_divide(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+  // expected-error@+1 {{expected contiguity values to be positive and divide the shape}}
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, 3] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @read_contiguity_is_not_positive(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+  // expected-error@+1 {{expected contiguity values to be positive and divide the shape}}
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1, -1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @read_invalid_contiguity_size(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) -> vector<4x4xf32> {
+  // expected-error@+1 {{expected contiguity array with 2 elements}}
+  %0 = ptr.read %ptr, %mask, %passthrough contiguity = [1] : vector<4x4x!ptr.ptr<#ptr.generic_space>> -> vector<4x4xf32>
+  return %0 : vector<4x4xf32>
+}
+
+// -----
+
+func.func @write_contiguity_does_not_divide(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+  // expected-error@+1 {{expected contiguity values to be positive and divide the shape}}
+  ptr.write %value, %ptr, %mask contiguity = [1, 7] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  return
+}
+
+// -----
+
+func.func @write_contiguity_is_not_positive(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+  // expected-error@+1 {{expected contiguity values to be positive and divide the shape}}
+  ptr.write %value, %ptr, %mask contiguity = [0, 4] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  return
+}
+
+// -----
+
+func.func @write_invalid_contiguity_size(%value: vector<4x4xf32>, %ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>) {
+  // expected-error@+1 {{expected contiguity array with 2 elements}}
+  ptr.write %value, %ptr, %mask contiguity = [1] : vector<4x4xf32>, vector<4x4x!ptr.ptr<#ptr.generic_space>>
+  return
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index 0a906ad559e21..d0c0390d6932e 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -239,3 +239,39 @@ func.func @ptr_diff_tensor_2d_ops(%ptrs1: tensor<4x8x!ptr.ptr<#ptr.generic_space
   %diff = ptr.ptr_diff %ptrs1, %ptrs2 : tensor<4x8x!ptr.ptr<#ptr.generic_space>> -> tensor<4x8xi64>
   return %diff : tensor<4x8xi64>
 }
+
+/// Check read op assembly.
+func.func @read_ops(%ptr: vector<4x4x!ptr.ptr<#ptr.generic_space>>, %mask: vector<4x4xi1>, %passthrough: vector<4x4xf32>) {
+  // Row-major ...
[truncated]

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces high-level ptr.read and ptr.write operations to the MLIR pointer dialect, providing a unified interface for memory access patterns with masked access semantics and contiguity information.

Key changes:

  • Added ReadOp and WriteOp with support for masked access, alignment, and contiguity properties
  • Implemented comprehensive verification logic for contiguity constraints
  • Added extensive test coverage for both valid operations and error cases

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td Defines the new ReadOp and WriteOp operations with their assembly format, constraints, and helper methods
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp Implements verification logic, memory effects, and builder methods for the new operations
mlir/test/Dialect/Ptr/ops.mlir Adds positive test cases demonstrating valid usage of read and write operations
mlir/test/Dialect/Ptr/invalid.mlir Adds negative test cases for invalid contiguity array configurations

@github-actions
Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

@joker-eph
Copy link
Collaborator

I'm afraid of bloating the ptr dialect and making it more than a low-level/thin thing around memory.
Bundling higher-level concepts in the same dialect as lower-level concepts is something that easily scope creep and leads to intra-dialect lowering: this makes a bunch of things less natural, in particular forming a simple mental model around a dialect (and its canonicalization).

@fabianmcg
Copy link
Contributor Author

I'm afraid of bloating the ptr dialect and making it more than a low-level/thin thing around memory.
Bundling higher-level concepts in the same dialect as lower-level concepts is something that easily scope creep and leads to intra-dialect lowering: this makes a bunch of things less natural, in particular forming a simple mental model around a dialect (and its canonicalization).

On the high-level op side, I was only thinking of adding read, write and make_offsets (a helper to create int offsets from strides and indices make_offsets [%i0, %i1][%stride0, %stride1] : tensor<3x3xi32>). I originally didn't think of this multi-level concepts as an issue, as the vector, memref and other dialect do the same.

A new dialect could be created, but it would only contain those 3 ops.

@joker-eph
Copy link
Collaborator

joker-eph commented Sep 28, 2025

Even the naming looks out of place: how would one figure why we can "load" and "read" from a pointer? The names looks synonymous and just arbitrarily introduced "just because".

This just does not belong to ptr IMO: let's keep things simple and generalizable by avoiding ad-hoc constructions.

@fabianmcg
Copy link
Contributor Author

fabianmcg commented Sep 30, 2025

Even the naming looks out of place: how would one figure why we can "load" and "read" from a pointer? The names looks synonymous and just arbitrarily introduced "just because".

FWIW, I took the naming from vector, where there are load, store and transfer_read, transfer_write ops.
I can change the names of these ops to something like transfer_read, transfer_write, or find another suitable name.

This just does not belong to ptr IMO: let's keep things simple and generalizable by avoiding ad-hoc constructions.

These are still ops acting on pointers with well-defined semantics in the dialect, so I wouldn't call them ad-hoc constructions. However, I can give you that this does introduce 2 different semantic levels in the dialect, and while there are other dialects that do that (eg. vector), it still might not be desirable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:ptr MLIR ptr dialect mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants