-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][ptr] Extend ptr_add operation to support shaped operands
#156374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir Author: Fabian Mora (fabianmcg) ChangesThis patch extends Concretely this patch makes the following changes:
Example: func.func @<!-- -->ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
%res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
}The motivation behind this patch is to lay the groundwork for enabling Full diff: https://github.com/llvm/llvm-project/pull/156374.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
index 8686cc7d316d4..eaf1e6243a74d 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 59eaaf7c55cce..43e19d0e2917c 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
include "mlir/Dialect/Ptr/IR/PtrEnums.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
@@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
/*descr=*/[{A shaped type with value semantics and rank.}],
/*cppType=*/"::mlir::ShapedType">;
-// A shaped pointer type with value semantics and rank.
-class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+// A ptr-like type, either scalar or shaped type with value semantics.
+def Ptr_PtrLikeType :
+ AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
+
+// An int-like type, either scalar or shaped type with value semantics.
+def Ptr_IntLikeType :AnyTypeOf<[
+ Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
+ AnySignlessIntegerOrIndex
+]>;
// A shaped value type of rank 1 of any element type.
def Ptr_Any1DType :
@@ -175,41 +183,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
}];
}
-//===----------------------------------------------------------------------===//
-// PtrAddOp
-//===----------------------------------------------------------------------===//
-
-def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
- Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
- ]> {
- let summary = "Pointer add operation";
- let description = [{
- The `ptr_add` operation adds an integer offset to a pointer to produce a new
- pointer. The input and output pointer types are always the same.
-
- Example:
-
- ```mlir
- %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
- %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
- ```
- }];
-
- let arguments = (ins
- Ptr_PtrType:$base,
- AnySignlessIntegerOrIndex:$offset,
- DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
- let results = (outs Ptr_PtrType:$result);
- let assemblyFormat = [{
- ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
- }];
- let hasFolder = 1;
- let extraClassDeclaration = [{
- /// `ViewLikeOp::getViewSource` method.
- Value getViewSource() { return getBase(); }
- }];
-}
-
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
@@ -369,6 +342,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// PtrAddOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
+ Pure, ViewLikeOpInterface,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>
+ ]> {
+ let summary = "Pointer add operation";
+ let description = [{
+ The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.
+
+ The operation supports both scalar and shaped types with value semantics:
+ - When both base and offset are scalar: produces a single new pointer
+ - When base is shaped and offset is scalar: adds the same offset to each
+ pointer in the base
+ - When base is scalar and offset is shaped: adds the single pointer to each
+ offset in the shaped value
+ - When both are shaped: performs element-wise addition (shapes must be
+ compatible)
+
+ Example:
+
+ ```mlir
+ // Scalar base and offset
+ %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+ %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+
+ // Shaped base with scalar offset
+ %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
+
+ // Scalar base with shaped offset
+ %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
+
+ // Both base and offset are shaped
+ %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
+ ```
+ }];
+ let arguments = (ins
+ Ptr_PtrLikeType:$base,
+ Ptr_IntLikeType:$offset,
+ DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
+ let results = (outs Ptr_PtrLikeType:$result);
+ let assemblyFormat = [{
+ ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
+ }];
+ let hasFolder = 1;
+ let extraClassDeclaration = [{
+ /// `ViewLikeOp::getViewSource` method.
+ Value getViewSource() { return getBase(); }
+
+ /// Returns the ptr type of the operation.
+ ptr::PtrType getPtrType();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index bd1e655fc6b5e..a6b0d416a4165 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(
MLIRIR
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
+ MLIRInferTypeOpInterface
MLIRViewLikeInterface
MLIRPtrMemorySpaceInterfaces
)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 81ae4efd8ec87..2b731aa54df05 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -346,6 +346,46 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
+LogicalResult PtrAddOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ // Get the base pointer and offset types.
+ Type baseType = operands[0].getType();
+ Type offsetType = operands[1].getType();
+
+ // If neither are shaped types, result is same as base type.
+ if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
+ inferredReturnTypes.push_back(baseType);
+ return success();
+ }
+
+ // Handle cases with shaped types.
+ if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
+ // If both shaped, they must have the same shape.
+ if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
+ if (offTy.getShape() != baseTy.getShape()) {
+ if (location)
+ mlir::emitError(*location) << "shapes of base and offset must match";
+ return failure();
+ }
+ // Make sure they are the same kind of shaped type.
+ if (baseType.getTypeID() != offsetType.getTypeID()) {
+ if (location)
+ mlir::emitError(*location) << "the shaped containers type must match";
+ return failure();
+ }
+ }
+ inferredReturnTypes.push_back(baseType);
+ return success();
+ }
+
+ // Base is scalar, offset is shaped.
+ auto offsetShapedType = cast<ShapedType>(offsetType);
+ inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
index dc645fe0480fa..5128fd8ccb265 100644
--- a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
+++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
@@ -16,10 +16,10 @@
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
// CHECK: }
func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
- %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
- %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
- %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
- %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
+ %0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+ %1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+ %2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+ %3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
}
@@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
%0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
%1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
%2 = ptr.type_offset f32 : index
- %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
+ %3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
%4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
}
@@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
%0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
%1 = ptr.type_offset f32 : index
%2 = arith.muli %1, %arg1 : index
- %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
+ %3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
return %3 : !ptr.ptr<#ptr.generic_space>
}
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 0c34ae43bf6be..cc1eeb3cb5744 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
return
}
+
+// -----
+
+func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ // expected-error@+1 {{the shaped containers type must match}}
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+// -----
+
+func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ // expected-error@+1 {{shapes of base and offset must match}}
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index bde2fb22b6aa0..c008b858af0d7 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -11,6 +11,8 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
return %res : !ptr.ptr<#ptr.generic_space>
}
+
+
/// Check cast ops assembly.
func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
@@ -126,3 +128,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
return %0 : vector<4xf32>
}
+
+/// Test ptr_add with shaped operands (vectors)
+func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped operands (tensors)
+func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with 2D tensors
+func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+ %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+ return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (vectors)
+func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (tensors)
+func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (vectors)
+func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (tensors)
+func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 545bec5979b2d..4b29be813fa81 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
llvm.return
}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT: ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+ %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
+ llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
+// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT: ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+ %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
+ llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
+// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
+// CHECK-NEXT: ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+ %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
+ llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
|
|
@llvm/pr-subscribers-mlir-llvm Author: Fabian Mora (fabianmcg) ChangesThis patch extends Concretely this patch makes the following changes:
Example: func.func @<!-- -->ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
%res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
%res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
}The motivation behind this patch is to lay the groundwork for enabling Full diff: https://github.com/llvm/llvm-project/pull/156374.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
index 8686cc7d316d4..eaf1e6243a74d 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
index 59eaaf7c55cce..43e19d0e2917c 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
include "mlir/Dialect/Ptr/IR/PtrEnums.td"
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"
@@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
/*descr=*/[{A shaped type with value semantics and rank.}],
/*cppType=*/"::mlir::ShapedType">;
-// A shaped pointer type with value semantics and rank.
-class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
+// A ptr-like type, either scalar or shaped type with value semantics.
+def Ptr_PtrLikeType :
+ AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
+
+// An int-like type, either scalar or shaped type with value semantics.
+def Ptr_IntLikeType :AnyTypeOf<[
+ Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
+ AnySignlessIntegerOrIndex
+]>;
// A shaped value type of rank 1 of any element type.
def Ptr_Any1DType :
@@ -175,41 +183,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
}];
}
-//===----------------------------------------------------------------------===//
-// PtrAddOp
-//===----------------------------------------------------------------------===//
-
-def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
- Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
- ]> {
- let summary = "Pointer add operation";
- let description = [{
- The `ptr_add` operation adds an integer offset to a pointer to produce a new
- pointer. The input and output pointer types are always the same.
-
- Example:
-
- ```mlir
- %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
- %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
- ```
- }];
-
- let arguments = (ins
- Ptr_PtrType:$base,
- AnySignlessIntegerOrIndex:$offset,
- DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
- let results = (outs Ptr_PtrType:$result);
- let assemblyFormat = [{
- ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
- }];
- let hasFolder = 1;
- let extraClassDeclaration = [{
- /// `ViewLikeOp::getViewSource` method.
- Value getViewSource() { return getBase(); }
- }];
-}
-
//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
@@ -369,6 +342,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// PtrAddOp
+//===----------------------------------------------------------------------===//
+
+def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
+ Pure, ViewLikeOpInterface,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>
+ ]> {
+ let summary = "Pointer add operation";
+ let description = [{
+ The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.
+
+ The operation supports both scalar and shaped types with value semantics:
+ - When both base and offset are scalar: produces a single new pointer
+ - When base is shaped and offset is scalar: adds the same offset to each
+ pointer in the base
+ - When base is scalar and offset is shaped: adds the single pointer to each
+ offset in the shaped value
+ - When both are shaped: performs element-wise addition (shapes must be
+ compatible)
+
+ Example:
+
+ ```mlir
+ // Scalar base and offset
+ %x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+ %x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
+
+ // Shaped base with scalar offset
+ %ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
+
+ // Scalar base with shaped offset
+ %x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
+
+ // Both base and offset are shaped
+ %ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
+ ```
+ }];
+ let arguments = (ins
+ Ptr_PtrLikeType:$base,
+ Ptr_IntLikeType:$offset,
+ DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
+ let results = (outs Ptr_PtrLikeType:$result);
+ let assemblyFormat = [{
+ ($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
+ }];
+ let hasFolder = 1;
+ let extraClassDeclaration = [{
+ /// `ViewLikeOp::getViewSource` method.
+ Value getViewSource() { return getBase(); }
+
+ /// Returns the ptr type of the operation.
+ ptr::PtrType getPtrType();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
index bd1e655fc6b5e..a6b0d416a4165 100644
--- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(
MLIRIR
MLIRDataLayoutInterfaces
MLIRMemorySlotInterfaces
+ MLIRInferTypeOpInterface
MLIRViewLikeInterface
MLIRPtrMemorySpaceInterfaces
)
diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
index 81ae4efd8ec87..2b731aa54df05 100644
--- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
+++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
@@ -346,6 +346,46 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
+LogicalResult PtrAddOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location, ValueRange operands,
+ DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ // Get the base pointer and offset types.
+ Type baseType = operands[0].getType();
+ Type offsetType = operands[1].getType();
+
+ // If neither are shaped types, result is same as base type.
+ if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
+ inferredReturnTypes.push_back(baseType);
+ return success();
+ }
+
+ // Handle cases with shaped types.
+ if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
+ // If both shaped, they must have the same shape.
+ if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
+ if (offTy.getShape() != baseTy.getShape()) {
+ if (location)
+ mlir::emitError(*location) << "shapes of base and offset must match";
+ return failure();
+ }
+ // Make sure they are the same kind of shaped type.
+ if (baseType.getTypeID() != offsetType.getTypeID()) {
+ if (location)
+ mlir::emitError(*location) << "the shaped containers type must match";
+ return failure();
+ }
+ }
+ inferredReturnTypes.push_back(baseType);
+ return success();
+ }
+
+ // Base is scalar, offset is shaped.
+ auto offsetShapedType = cast<ShapedType>(offsetType);
+ inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
index dc645fe0480fa..5128fd8ccb265 100644
--- a/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
+++ b/mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir
@@ -16,10 +16,10 @@
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
// CHECK: }
func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
- %0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
- %1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
- %2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
- %3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
+ %0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+ %1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+ %2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
+ %3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
}
@@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
%0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
%1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
%2 = ptr.type_offset f32 : index
- %3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
+ %3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
%4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
}
@@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
%0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
%1 = ptr.type_offset f32 : index
%2 = arith.muli %1, %arg1 : index
- %3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
+ %3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
return %3 : !ptr.ptr<#ptr.generic_space>
}
diff --git a/mlir/test/Dialect/Ptr/invalid.mlir b/mlir/test/Dialect/Ptr/invalid.mlir
index 0c34ae43bf6be..cc1eeb3cb5744 100644
--- a/mlir/test/Dialect/Ptr/invalid.mlir
+++ b/mlir/test/Dialect/Ptr/invalid.mlir
@@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
return
}
+
+// -----
+
+func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ // expected-error@+1 {{the shaped containers type must match}}
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+// -----
+
+func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ // expected-error@+1 {{shapes of base and offset must match}}
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Dialect/Ptr/ops.mlir b/mlir/test/Dialect/Ptr/ops.mlir
index bde2fb22b6aa0..c008b858af0d7 100644
--- a/mlir/test/Dialect/Ptr/ops.mlir
+++ b/mlir/test/Dialect/Ptr/ops.mlir
@@ -11,6 +11,8 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
return %res : !ptr.ptr<#ptr.generic_space>
}
+
+
/// Check cast ops assembly.
func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
@@ -126,3 +128,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
return %0 : vector<4xf32>
}
+
+/// Test ptr_add with shaped operands (vectors)
+func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ %res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
+ return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped operands (tensors)
+func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with 2D tensors
+func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+ %res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
+ return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (vectors)
+func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
+ return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with scalar base and shaped offsets (tensors)
+func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ %res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (vectors)
+func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ %res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
+ return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
+}
+
+/// Test ptr_add with shaped base and scalar offset (tensors)
+func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
+ %res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ %res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
+ return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
+}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 545bec5979b2d..4b29be813fa81 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
llvm.return
}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT: ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+ %res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
+ llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
+// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
+// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
+// CHECK-NEXT: ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+ %res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
+ llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
+
+// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
+// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
+// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
+// CHECK-NEXT: ret <4 x ptr> %[[RES]]
+// CHECK-NEXT: }
+llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
+ %res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
+ llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends the ptr_add operation to support shaped operands (vectors and tensors) with value semantics, enabling more flexible pointer arithmetic operations that work element-wise on collections of pointers and offsets.
- Adds support for scalar-to-scalar, scalar-to-shaped, shaped-to-scalar, and shaped-to-shaped combinations
- Implements InferTypeOpInterface for automatic result type deduction
- Adds comprehensive test coverage for LLVM IR translation with vector operands
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp | Implements type inference logic for shaped operands |
| mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td | Updates operation definition to support shaped types |
| mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h | Adds InferTypeOpInterface include |
| mlir/lib/Dialect/Ptr/IR/CMakeLists.txt | Adds dependency on InferTypeOpInterface |
| mlir/test/Dialect/Ptr/ops.mlir | Adds tests for shaped operand combinations |
| mlir/test/Dialect/Ptr/invalid.mlir | Adds negative tests for type/shape mismatches |
| mlir/test/Target/LLVMIR/ptr.mlir | Adds LLVM translation tests for vector operands |
| mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir | Fixes type syntax in existing tests |
|
|
e9301c2 to
16785ab
Compare
6b32155 to
af522ed
Compare
af522ed to
785d21f
Compare
Co-authored-by: Mehdi Amini <[email protected]>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This patch extends
ptr_addto work with shaped types with value semantics, both for the offsets and base.Concretely this patch makes the following changes:
Example:
The motivation behind this patch is to lay the groundwork for enabling
tritonstyled loads and stores, and their variants.