Skip to content

Conversation

@fabianmcg
Copy link
Contributor

This patch extends ptr_add to work with shaped types with value semantics, both for the offsets and base.

Concretely this patch makes the following changes:

  • Supports scalar-to-scalar, scalar-to-shaped, shaped-to-scalar, and shaped-to-shaped combinations
  • Adds InferTypeOpInterface for automatic result type deduction
  • Adds tests for LLVM IR translation with vector operands

Example:

func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
  %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
  %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
  return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
}

The motivation behind this patch is to lay the groundwork for enabling triton styled loads and stores, and their variants.

@llvmbot
Copy link
Member

llvmbot commented Sep 1, 2025

@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir

Author: Fabian Mora (fabianmcg)

Changes

This patch extends ptr_add to work with shaped types with value semantics, both for the offsets and base.

Concretely this patch makes the following changes:

  • Supports scalar-to-scalar, scalar-to-shaped, shaped-to-scalar, and shaped-to-shaped combinations
  • Adds InferTypeOpInterface for automatic result type deduction
  • Adds tests for LLVM IR translation with vector operands

Example:

func.func @<!-- -->ptr_add_tensor_2d(%ptrs: tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, %offsets: tensor&lt;4x8xindex&gt;) -&gt; tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt; {
  %res = ptr.ptr_add %ptrs, %offsets : tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, tensor&lt;4x8xindex&gt;
  %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, tensor&lt;4x8xindex&gt;
  return %res : tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;
}

The motivation behind this patch is to lay the groundwork for enabling triton styled loads and stores, and their variants.


Full diff: https://github.com/llvm/llvm-project/pull/156374.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+66-37)
  • (modified) mlir/lib/Dialect/Ptr/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp (+40)
  • (modified) mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir (+6-6)
  • (modified) mlir/test/Dialect/Ptr/invalid.mlir (+16)
  • (modified) mlir/test/Dialect/Ptr/ops.mlir (+65)
  • (modified) mlir/test/Target/LLVMIR/ptr.mlir (+30)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
index 8686cc7d316d4..eaf1e6243a74d 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 59eaaf7c55cce..43e19d0e2917c 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
 include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
 include "mlir/Dialect/Ptr/IR/PtrEnums.td"
 include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/OpAsmInterface.td"
@@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
     /*descr=*/[{A shaped type with value semantics and rank.}],
     /*cppType=*/"::mlir::ShapedType">;
 
-// A shaped pointer type with value semantics and rank.
-class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+// A ptr-like type, either scalar or shaped type with value semantics.
+def Ptr_PtrLikeType : 
+  AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
+
+// An int-like type, either scalar or shaped type with value semantics.
+def Ptr_IntLikeType :AnyTypeOf<[
+  Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
+  AnySignlessIntegerOrIndex
+]>;
 
 // A shaped value type of rank 1 of any element type.
 def Ptr_Any1DType :
@@ -175,41 +183,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// PtrAddOp
-//===----------------------------------------------------------------------===//
-
-def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
-    Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
-  ]> {
-  let summary = "Pointer add operation";
-  let description = [{
-    The `ptr_add` operation adds an integer offset to a pointer to produce a new
-    pointer. The input and output pointer types are always the same.
-
-    Example:
-
-    ```mlir
-    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
-    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
-    ```
-  }];
-
-  let arguments = (ins
-    Ptr_PtrType:$base,
-    AnySignlessIntegerOrIndex:$offset,
-    DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
-  let results = (outs Ptr_PtrType:$result);
-  let assemblyFormat = [{
-    ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
-  }];
-  let hasFolder = 1;
-  let extraClassDeclaration = [{
-    /// `ViewLikeOp::getViewSource` method. 
-    Value getViewSource() { return getBase(); }
-  }];
-}
-
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
@@ -369,6 +342,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// PtrAddOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
+    Pure, ViewLikeOpInterface,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>
+  ]> {
+  let summary = "Pointer add operation";
+  let description = [{
+    The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.
+
+    The operation supports both scalar and shaped types with value semantics:
+    - When both base and offset are scalar: produces a single new pointer
+    - When base is shaped and offset is scalar: adds the same offset to each
+    pointer in the base
+    - When base is scalar and offset is shaped: adds the single pointer to each
+    offset in the shaped value
+    - When both are shaped: performs element-wise addition (shapes must be
+    compatible)
+
+    Example:
+
+    ```mlir
+    // Scalar base and offset
+    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+    
+    // Shaped base with scalar offset
+    %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
+    
+    // Scalar base with shaped offset
+    %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
+    
+    // Both base and offset are shaped
+    %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
+    ```
+  }];
+  let arguments = (ins
+    Ptr_PtrLikeType:$base,
+    Ptr_IntLikeType:$offset,
+    DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
+  let results = (outs Ptr_PtrLikeType:$result);
+  let assemblyFormat = [{
+    ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
+  }];
+  let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// `ViewLikeOp::getViewSource` method. 
+    Value getViewSource() { return getBase(); }
+
+    /// Returns the ptr type of the operation.
+    ptr::PtrType getPtrType();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index bd1e655fc6b5e..a6b0d416a4165 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(
   MLIRIR
   MLIRDataLayoutInterfaces
   MLIRMemorySlotInterfaces
+  MLIRInferTypeOpInterface
   MLIRViewLikeInterface
   MLIRPtrMemorySpaceInterfaces
 )
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 81ae4efd8ec87..2b731aa54df05 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -346,6 +346,46 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
   return nullptr;
 }
 
+LogicalResult PtrAddOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  // Get the base pointer and offset types.
+  Type baseType = operands[0].getType();
+  Type offsetType = operands[1].getType();
+
+  // If neither are shaped types, result is same as base type.
+  if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
+    inferredReturnTypes.push_back(baseType);
+    return success();
+  }
+
+  // Handle cases with shaped types.
+  if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
+    // If both shaped, they must have the same shape.
+    if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
+      if (offTy.getShape() != baseTy.getShape()) {
+        if (location)
+          mlir::emitError(*location) << "shapes of base and offset must match";
+        return failure();
+      }
+      // Make sure they are the same kind of shaped type.
+      if (baseType.getTypeID() != offsetType.getTypeID()) {
+        if (location)
+          mlir::emitError(*location) << "the shaped containers type must match";
+        return failure();
+      }
+    }
+    inferredReturnTypes.push_back(baseType);
+    return success();
+  }
+
+  // Base is scalar, offset is shaped.
+  auto offsetShapedType = cast<ShapedType>(offsetType);
+  inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ToPtrOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
index dc645fe0480fa..5128fd8ccb265 100644
--- a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
+++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
@@ -16,10 +16,10 @@
 // CHECK:           llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
 // CHECK:         }
 func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
-  %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
-  %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
-  %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
-  %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
+  %0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
   return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
 }
 
@@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
   %0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
   %1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
   %2 = ptr.type_offset f32 : index
-  %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
+  %3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
   %4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
   return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
 }
@@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
   %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
   %1 = ptr.type_offset f32 : index
   %2 = arith.muli %1, %arg1 : index
-  %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
+  %3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
   return %3 : !ptr.ptr<#ptr.generic_space>
 }
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 0c34ae43bf6be..cc1eeb3cb5744 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
   ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
   return
 }
+
+// -----
+
+func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  // expected-error@+1 {{the shaped containers type must match}}
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+// -----
+
+func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  // expected-error@+1 {{shapes of base and offset must match}}
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index bde2fb22b6aa0..c008b858af0d7 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -11,6 +11,8 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
   return %res : !ptr.ptr<#ptr.generic_space>
 }
 
+
+
 /// Check cast ops assembly.
 func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
   %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
@@ -126,3 +128,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
   ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
   return %0 : vector<4xf32>
 }
+
+/// Test ptr_add with shaped operands (vectors)
+func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped operands (tensors)
+func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with 2D tensors
+func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+  %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+  return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (vectors)
+func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (tensors)
+func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (vectors)
+func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (tensors)
+func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 545bec5979b2d..4b29be813fa81 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
   ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
   llvm.return
 }
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
+// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}

@llvmbot
Copy link
Member

llvmbot commented Sep 1, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Fabian Mora (fabianmcg)

Changes

This patch extends ptr_add to work with shaped types with value semantics, both for the offsets and base.

Concretely this patch makes the following changes:

  • Supports scalar-to-scalar, scalar-to-shaped, shaped-to-scalar, and shaped-to-shaped combinations
  • Adds InferTypeOpInterface for automatic result type deduction
  • Adds tests for LLVM IR translation with vector operands

Example:

func.func @<!-- -->ptr_add_tensor_2d(%ptrs: tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, %offsets: tensor&lt;4x8xindex&gt;) -&gt; tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt; {
  %res = ptr.ptr_add %ptrs, %offsets : tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, tensor&lt;4x8xindex&gt;
  %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;, tensor&lt;4x8xindex&gt;
  return %res : tensor&lt;4x8x!ptr.ptr&lt;#ptr.generic_space&gt;&gt;
}

The motivation behind this patch is to lay the groundwork for enabling triton styled loads and stores, and their variants.


Full diff: https://github.com/llvm/llvm-project/pull/156374.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td (+66-37)
  • (modified) mlir/lib/Dialect/Ptr/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp (+40)
  • (modified) mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir (+6-6)
  • (modified) mlir/test/Dialect/Ptr/invalid.mlir (+16)
  • (modified) mlir/test/Dialect/Ptr/ops.mlir (+65)
  • (modified) mlir/test/Target/LLVMIR/ptr.mlir (+30)
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
index 8686cc7d316d4..eaf1e6243a74d 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Ptr/IR/PtrDialect.h"
 #include "mlir/Dialect/Ptr/IR/PtrTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 59eaaf7c55cce..43e19d0e2917c 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
 include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
 include "mlir/Dialect/Ptr/IR/PtrEnums.td"
 include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
 include "mlir/IR/OpAsmInterface.td"
@@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
     /*descr=*/[{A shaped type with value semantics and rank.}],
     /*cppType=*/"::mlir::ShapedType">;
 
-// A shaped pointer type with value semantics and rank.
-class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+// A ptr-like type, either scalar or shaped type with value semantics.
+def Ptr_PtrLikeType : 
+  AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
+
+// An int-like type, either scalar or shaped type with value semantics.
+def Ptr_IntLikeType :AnyTypeOf<[
+  Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
+  AnySignlessIntegerOrIndex
+]>;
 
 // A shaped value type of rank 1 of any element type.
 def Ptr_Any1DType :
@@ -175,41 +183,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
   }];
 }
 
-//===----------------------------------------------------------------------===//
-// PtrAddOp
-//===----------------------------------------------------------------------===//
-
-def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
-    Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
-  ]> {
-  let summary = "Pointer add operation";
-  let description = [{
-    The `ptr_add` operation adds an integer offset to a pointer to produce a new
-    pointer. The input and output pointer types are always the same.
-
-    Example:
-
-    ```mlir
-    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
-    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
-    ```
-  }];
-
-  let arguments = (ins
-    Ptr_PtrType:$base,
-    AnySignlessIntegerOrIndex:$offset,
-    DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
-  let results = (outs Ptr_PtrType:$result);
-  let assemblyFormat = [{
-    ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
-  }];
-  let hasFolder = 1;
-  let extraClassDeclaration = [{
-    /// `ViewLikeOp::getViewSource` method. 
-    Value getViewSource() { return getBase(); }
-  }];
-}
-
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
@@ -369,6 +342,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// PtrAddOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
+    Pure, ViewLikeOpInterface,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>
+  ]> {
+  let summary = "Pointer add operation";
+  let description = [{
+    The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.
+
+    The operation supports both scalar and shaped types with value semantics:
+    - When both base and offset are scalar: produces a single new pointer
+    - When base is shaped and offset is scalar: adds the same offset to each
+    pointer in the base
+    - When base is scalar and offset is shaped: adds the single pointer to each
+    offset in the shaped value
+    - When both are shaped: performs element-wise addition (shapes must be
+    compatible)
+
+    Example:
+
+    ```mlir
+    // Scalar base and offset
+    %x_off  = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+    %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+    
+    // Shaped base with scalar offset
+    %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
+    
+    // Scalar base with shaped offset
+    %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
+    
+    // Both base and offset are shaped
+    %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
+    ```
+  }];
+  let arguments = (ins
+    Ptr_PtrLikeType:$base,
+    Ptr_IntLikeType:$offset,
+    DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
+  let results = (outs Ptr_PtrLikeType:$result);
+  let assemblyFormat = [{
+    ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
+  }];
+  let hasFolder = 1;
+  let extraClassDeclaration = [{
+    /// `ViewLikeOp::getViewSource` method. 
+    Value getViewSource() { return getBase(); }
+
+    /// Returns the ptr type of the operation.
+    ptr::PtrType getPtrType();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index bd1e655fc6b5e..a6b0d416a4165 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(
   MLIRIR
   MLIRDataLayoutInterfaces
   MLIRMemorySlotInterfaces
+  MLIRInferTypeOpInterface
   MLIRViewLikeInterface
   MLIRPtrMemorySpaceInterfaces
 )
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 81ae4efd8ec87..2b731aa54df05 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -346,6 +346,46 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
   return nullptr;
 }
 
+LogicalResult PtrAddOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  // Get the base pointer and offset types.
+  Type baseType = operands[0].getType();
+  Type offsetType = operands[1].getType();
+
+  // If neither are shaped types, result is same as base type.
+  if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
+    inferredReturnTypes.push_back(baseType);
+    return success();
+  }
+
+  // Handle cases with shaped types.
+  if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
+    // If both shaped, they must have the same shape.
+    if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
+      if (offTy.getShape() != baseTy.getShape()) {
+        if (location)
+          mlir::emitError(*location) << "shapes of base and offset must match";
+        return failure();
+      }
+      // Make sure they are the same kind of shaped type.
+      if (baseType.getTypeID() != offsetType.getTypeID()) {
+        if (location)
+          mlir::emitError(*location) << "the shaped containers type must match";
+        return failure();
+      }
+    }
+    inferredReturnTypes.push_back(baseType);
+    return success();
+  }
+
+  // Base is scalar, offset is shaped.
+  auto offsetShapedType = cast<ShapedType>(offsetType);
+  inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ToPtrOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
index dc645fe0480fa..5128fd8ccb265 100644
--- a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
+++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
@@ -16,10 +16,10 @@
 // CHECK:           llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
 // CHECK:         }
 func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
-  %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
-  %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
-  %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
-  %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
+  %0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+  %3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
   return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
 }
 
@@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
   %0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
   %1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
   %2 = ptr.type_offset f32 : index
-  %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
+  %3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
   %4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
   return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
 }
@@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
   %0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
   %1 = ptr.type_offset f32 : index
   %2 = arith.muli %1, %arg1 : index
-  %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
+  %3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
   return %3 : !ptr.ptr<#ptr.generic_space>
 }
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 0c34ae43bf6be..cc1eeb3cb5744 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
   ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
   return
 }
+
+// -----
+
+func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  // expected-error@+1 {{the shaped containers type must match}}
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+// -----
+
+func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  // expected-error@+1 {{shapes of base and offset must match}}
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index bde2fb22b6aa0..c008b858af0d7 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -11,6 +11,8 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
   return %res : !ptr.ptr<#ptr.generic_space>
 }
 
+
+
 /// Check cast ops assembly.
 func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
   %ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
@@ -126,3 +128,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
   ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
   return %0 : vector<4xf32>
 }
+
+/// Test ptr_add with shaped operands (vectors)
+func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  %res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped operands (tensors)
+func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with 2D tensors
+func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+  %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+  return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (vectors)
+func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (tensors)
+func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (vectors)
+func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  %res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+  return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (tensors)
+func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+  %res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+  return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 545bec5979b2d..4b29be813fa81 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
   ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
   llvm.return
 }
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
+// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
+// CHECK-NEXT:   %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
+// CHECK-NEXT:   ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+  %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
+  llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR extends the ptr_add operation to support shaped operands (vectors and tensors) with value semantics, enabling more flexible pointer arithmetic operations that work element-wise on collections of pointers and offsets.

  • Adds support for scalar-to-scalar, scalar-to-shaped, shaped-to-scalar, and shaped-to-shaped combinations
  • Implements InferTypeOpInterface for automatic result type deduction
  • Adds comprehensive test coverage for LLVM IR translation with vector operands

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp Implements type inference logic for shaped operands
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td Updates operation definition to support shaped types
mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h Adds InferTypeOpInterface include
mlir/lib/Dialect/Ptr/IR/CMakeLists.txt Adds dependency on InferTypeOpInterface
mlir/test/Dialect/Ptr/ops.mlir Adds tests for shaped operand combinations
mlir/test/Dialect/Ptr/invalid.mlir Adds negative tests for type/shape mismatches
mlir/test/Target/LLVMIR/ptr.mlir Adds LLVM translation tests for vector operands
mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir Fixes type syntax in existing tests

@github-actions
Copy link

github-actions bot commented Sep 1, 2025

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

@fabianmcg fabianmcg force-pushed the users/fabianmcg/ptr-op-variants branch from e9301c2 to 16785ab Compare September 3, 2025 14:30
@fabianmcg fabianmcg force-pushed the users/fabianmcg/ptr-ptradd-nd branch from 6b32155 to af522ed Compare September 3, 2025 14:45
Base automatically changed from users/fabianmcg/ptr-op-variants to main September 3, 2025 14:45
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Sep 3, 2025
@fabianmcg fabianmcg requested a review from joker-eph September 3, 2025 14:45
@fabianmcg fabianmcg force-pushed the users/fabianmcg/ptr-ptradd-nd branch from af522ed to 785d21f Compare September 3, 2025 14:50
@github-actions
Copy link

github-actions bot commented Sep 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@fabianmcg fabianmcg enabled auto-merge (squash) September 3, 2025 15:47
@fabianmcg fabianmcg merged commit bad2036 into main Sep 3, 2025
9 checks passed
@fabianmcg fabianmcg deleted the users/fabianmcg/ptr-ptradd-nd branch September 3, 2025 15:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants