Skip to content

Commit 785d21f

Browse files
committed
extend ptr_add op
1 parent 34e9f3d commit 785d21f

File tree

8 files changed

+225
-43
lines changed

8 files changed

+225
-43
lines changed

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
1919
#include "mlir/Dialect/Ptr/IR/PtrTypes.h"
2020
#include "mlir/IR/OpDefinition.h"
21+
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/SideEffectInterfaces.h"
2223
#include "mlir/Interfaces/ViewLikeInterface.h"
2324

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include "mlir/Dialect/Ptr/IR/PtrDialect.td"
1313
include "mlir/Dialect/Ptr/IR/PtrAttrDefs.td"
1414
include "mlir/Dialect/Ptr/IR/PtrEnums.td"
1515
include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td"
16+
include "mlir/Interfaces/InferTypeOpInterface.td"
1617
include "mlir/Interfaces/SideEffectInterfaces.td"
1718
include "mlir/Interfaces/ViewLikeInterface.td"
1819
include "mlir/IR/OpAsmInterface.td"
@@ -34,8 +35,15 @@ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
3435
/*descr=*/[{A shaped type with value semantics and rank.}],
3536
/*cppType=*/"::mlir::ShapedType">;
3637

37-
// A shaped pointer type with value semantics and rank.
38-
class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
38+
// A ptr-like type, either scalar or shaped type with value semantics.
39+
def Ptr_PtrLikeType :
40+
AnyTypeOf<[Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>, Ptr_PtrType]>;
41+
42+
// An int-like type, either scalar or shaped type with value semantics.
43+
def Ptr_IntLikeType :AnyTypeOf<[
44+
Ptr_ShapedValueType<[AnySignlessIntegerOrIndex], [HasRankPred]>,
45+
AnySignlessIntegerOrIndex
46+
]>;
3947

4048
// A shaped value type of rank 1 of any element type.
4149
def Ptr_Any1DType :
@@ -167,41 +175,6 @@ def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
167175
}];
168176
}
169177

170-
//===----------------------------------------------------------------------===//
171-
// PtrAddOp
172-
//===----------------------------------------------------------------------===//
173-
174-
def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
175-
Pure, AllTypesMatch<["base", "result"]>, ViewLikeOpInterface
176-
]> {
177-
let summary = "Pointer add operation";
178-
let description = [{
179-
The `ptr_add` operation adds an integer offset to a pointer to produce a new
180-
pointer. The input and output pointer types are always the same.
181-
182-
Example:
183-
184-
```mlir
185-
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
186-
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
187-
```
188-
}];
189-
190-
let arguments = (ins
191-
Ptr_PtrType:$base,
192-
AnySignlessIntegerOrIndex:$offset,
193-
DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
194-
let results = (outs Ptr_PtrType:$result);
195-
let assemblyFormat = [{
196-
($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
197-
}];
198-
let hasFolder = 1;
199-
let extraClassDeclaration = [{
200-
/// `ViewLikeOp::getViewSource` method.
201-
Value getViewSource() { return getBase(); }
202-
}];
203-
}
204-
205178
//===----------------------------------------------------------------------===//
206179
// LoadOp
207180
//===----------------------------------------------------------------------===//
@@ -361,6 +334,62 @@ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
361334
let hasVerifier = 1;
362335
}
363336

337+
//===----------------------------------------------------------------------===//
338+
// PtrAddOp
339+
//===----------------------------------------------------------------------===//
340+
341+
def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
342+
Pure, ViewLikeOpInterface,
343+
DeclareOpInterfaceMethods<InferTypeOpInterface>
344+
]> {
345+
let summary = "Pointer add operation";
346+
let description = [{
347+
The `ptr_add` operation adds an int-like offset to one or more pointers to produce one or more new pointers.
348+
349+
The operation supports both scalar and shaped types with value semantics:
350+
- When both base and offset are scalar: produces a single new pointer
351+
- When base is shaped and offset is scalar: adds the same offset to each
352+
pointer in the base
353+
- When base is scalar and offset is shaped: adds the single pointer to each
354+
offset in the shaped value
355+
- When both are shaped: performs element-wise addition (shapes must be
356+
compatible)
357+
358+
Example:
359+
360+
```mlir
361+
// Scalar base and offset
362+
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
363+
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
364+
365+
// Shaped base with scalar offset
366+
%ptrs_off = ptr.ptr_add %ptrs, %off : vector<4x!ptr.ptr<#ptr.generic_space>>, i32
367+
368+
// Scalar base with shaped offset
369+
%x_offs = ptr.ptr_add %x, %offs : !ptr.ptr<#ptr.generic_space>, vector<4xi32>
370+
371+
// Both base and offset are shaped
372+
%ptrs_offs = ptr.ptr_add %ptrs, %offs : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xi32>
373+
```
374+
}];
375+
let arguments = (ins
376+
Ptr_PtrLikeType:$base,
377+
Ptr_IntLikeType:$offset,
378+
DefaultValuedProp<EnumProp<Ptr_PtrAddFlags>, "PtrAddFlags::none">:$flags);
379+
let results = (outs Ptr_PtrLikeType:$result);
380+
let assemblyFormat = [{
381+
($flags^)? $base `,` $offset attr-dict `:` type($base) `,` type($offset)
382+
}];
383+
let hasFolder = 1;
384+
let extraClassDeclaration = [{
385+
/// `ViewLikeOp::getViewSource` method.
386+
Value getViewSource() { return getBase(); }
387+
388+
/// Returns the ptr type of the operation.
389+
ptr::PtrType getPtrType();
390+
}];
391+
}
392+
364393
//===----------------------------------------------------------------------===//
365394
// ScatterOp
366395
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Ptr/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ add_mlir_dialect_library(
3333
MLIRIR
3434
MLIRDataLayoutInterfaces
3535
MLIRMemorySlotInterfaces
36+
MLIRInferTypeOpInterface
3637
MLIRViewLikeInterface
3738
MLIRPtrMemorySpaceInterfaces
3839
)

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,46 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
346346
return nullptr;
347347
}
348348

349+
LogicalResult PtrAddOp::inferReturnTypes(
350+
MLIRContext *context, std::optional<Location> location, ValueRange operands,
351+
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
352+
SmallVectorImpl<Type> &inferredReturnTypes) {
353+
// Get the base pointer and offset types.
354+
Type baseType = operands[0].getType();
355+
Type offsetType = operands[1].getType();
356+
357+
// If neither are shaped types, result is same as base type.
358+
if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
359+
inferredReturnTypes.push_back(baseType);
360+
return success();
361+
}
362+
363+
// Handle cases with shaped types.
364+
if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
365+
// If both shaped, they must have the same shape.
366+
if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
367+
if (offTy.getShape() != baseTy.getShape()) {
368+
if (location)
369+
mlir::emitError(*location) << "shapes of base and offset must match";
370+
return failure();
371+
}
372+
// Make sure they are the same kind of shaped type.
373+
if (baseType.getTypeID() != offsetType.getTypeID()) {
374+
if (location)
375+
mlir::emitError(*location) << "the shaped containers type must match";
376+
return failure();
377+
}
378+
}
379+
inferredReturnTypes.push_back(baseType);
380+
return success();
381+
}
382+
383+
// Base is scalar, offset is shaped.
384+
auto offsetShapedType = cast<ShapedType>(offsetType);
385+
inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
386+
return success();
387+
}
388+
349389
//===----------------------------------------------------------------------===//
350390
// ToPtrOp
351391
//===----------------------------------------------------------------------===//

mlir/test/Conversion/PtrToLLVM/ptr-to-llvm.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(ptr, ptr, ptr, ptr)>
1717
// CHECK: }
1818
func.func @test_ptr_add(%arg0: !ptr.ptr<#ptr.generic_space>, %arg1: index) -> (!ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>) {
19-
%0 = ptr.ptr_add %arg0, %arg1 : <#ptr.generic_space>, index
20-
%1 = ptr.ptr_add nusw %arg0, %arg1 : <#ptr.generic_space>, index
21-
%2 = ptr.ptr_add nuw %arg0, %arg1 : <#ptr.generic_space>, index
22-
%3 = ptr.ptr_add inbounds %arg0, %arg1 : <#ptr.generic_space>, index
19+
%0 = ptr.ptr_add %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
20+
%1 = ptr.ptr_add nusw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
21+
%2 = ptr.ptr_add nuw %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
22+
%3 = ptr.ptr_add inbounds %arg0, %arg1 : !ptr.ptr<#ptr.generic_space>, index
2323
return %0, %1, %2, %3 : !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>, !ptr.ptr<#ptr.generic_space>
2424
}
2525

@@ -263,7 +263,7 @@ func.func @test_comprehensive_dynamic(%arg0: memref<?x?xf32, strided<[?, ?], off
263263
%0 = ptr.to_ptr %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space> -> <#ptr.generic_space>
264264
%1 = ptr.get_metadata %arg0 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
265265
%2 = ptr.type_offset f32 : index
266-
%3 = ptr.ptr_add inbounds %0, %2 : <#ptr.generic_space>, index
266+
%3 = ptr.ptr_add inbounds %0, %2 : !ptr.ptr<#ptr.generic_space>, index
267267
%4 = ptr.from_ptr %3 metadata %1 : <#ptr.generic_space> -> memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
268268
return %4 : memref<?x?xf32, strided<[?, ?], offset: ?>, #ptr.generic_space>
269269
}
@@ -313,6 +313,6 @@ func.func @test_memref_ptradd_indexing(%arg0: memref<10x?x30xf32, #ptr.generic_s
313313
%0 = ptr.to_ptr %arg0 : memref<10x?x30xf32, #ptr.generic_space> -> <#ptr.generic_space>
314314
%1 = ptr.type_offset f32 : index
315315
%2 = arith.muli %1, %arg1 : index
316-
%3 = ptr.ptr_add %0, %2 : <#ptr.generic_space>, index
316+
%3 = ptr.ptr_add %0, %2 : !ptr.ptr<#ptr.generic_space>, index
317317
return %3 : !ptr.ptr<#ptr.generic_space>
318318
}

mlir/test/Dialect/Ptr/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,19 @@ func.func @llvm_store(%arg0: !ptr.ptr<#llvm.address_space<1>>, %arg1: memref<f32
5454
ptr.store %arg1, %arg0 : memref<f32>, !ptr.ptr<#llvm.address_space<1>>
5555
return
5656
}
57+
58+
// -----
59+
60+
func.func @ptr_add_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
61+
// expected-error@+1 {{the shaped containers type must match}}
62+
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, vector<8xi64>
63+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
64+
}
65+
66+
// -----
67+
68+
func.func @ptr_add_shape_mismatch(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
69+
// expected-error@+1 {{shapes of base and offset must match}}
70+
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<4xi64>
71+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
72+
}

mlir/test/Dialect/Ptr/ops.mlir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ func.func @ptr_add_type_offset(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#
1111
return %res : !ptr.ptr<#ptr.generic_space>
1212
}
1313

14+
15+
1416
/// Check cast ops assembly.
1517
func.func @cast_ops(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
1618
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
@@ -126,3 +128,66 @@ func.func @llvm_masked_ops(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector
126128
ptr.masked_store %value, %ptr, %mask alignment = 4 : vector<4xf32>, !ptr.ptr<#llvm.address_space<3>>
127129
return %0 : vector<4xf32>
128130
}
131+
132+
/// Test ptr_add with shaped operands (vectors)
133+
func.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
134+
%res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
135+
%res0 = ptr.ptr_add none %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
136+
%res1 = ptr.ptr_add nusw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
137+
%res2 = ptr.ptr_add nuw %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
138+
%res3 = ptr.ptr_add inbounds %ptrs, %offsets : vector<4x!ptr.ptr<#ptr.generic_space>>, vector<4xindex>
139+
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
140+
}
141+
142+
/// Test ptr_add with shaped operands (tensors)
143+
func.func @ptr_add_tensor(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
144+
%res = ptr.ptr_add %ptrs, %offsets : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi64>
145+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
146+
}
147+
148+
/// Test ptr_add with 2D tensors
149+
func.func @ptr_add_tensor_2d(%ptrs: tensor<4x8x!ptr.ptr<#ptr.generic_space>>, %offsets: tensor<4x8xindex>) -> tensor<4x8x!ptr.ptr<#ptr.generic_space>> {
150+
%res = ptr.ptr_add %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
151+
%res1 = ptr.ptr_add nuw %ptrs, %offsets : tensor<4x8x!ptr.ptr<#ptr.generic_space>>, tensor<4x8xindex>
152+
return %res : tensor<4x8x!ptr.ptr<#ptr.generic_space>>
153+
}
154+
155+
/// Test ptr_add with scalar base and shaped offsets (vectors)
156+
func.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: vector<4xindex>) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
157+
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
158+
%res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
159+
%res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
160+
%res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
161+
%res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, vector<4xindex>
162+
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
163+
}
164+
165+
/// Test ptr_add with scalar base and shaped offsets (tensors)
166+
func.func @ptr_add_scalar_base_tensor_offsets(%ptr: !ptr.ptr<#ptr.generic_space>, %offsets: tensor<8xi64>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
167+
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
168+
%res0 = ptr.ptr_add none %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
169+
%res1 = ptr.ptr_add nusw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
170+
%res2 = ptr.ptr_add nuw %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
171+
%res3 = ptr.ptr_add inbounds %ptr, %offsets : !ptr.ptr<#ptr.generic_space>, tensor<8xi64>
172+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
173+
}
174+
175+
/// Test ptr_add with shaped base and scalar offset (vectors)
176+
func.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#ptr.generic_space>>, %offset: index) -> vector<4x!ptr.ptr<#ptr.generic_space>> {
177+
%res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
178+
%res0 = ptr.ptr_add none %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
179+
%res1 = ptr.ptr_add nusw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
180+
%res2 = ptr.ptr_add nuw %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
181+
%res3 = ptr.ptr_add inbounds %ptrs, %offset : vector<4x!ptr.ptr<#ptr.generic_space>>, index
182+
return %res : vector<4x!ptr.ptr<#ptr.generic_space>>
183+
}
184+
185+
/// Test ptr_add with shaped base and scalar offset (tensors)
186+
func.func @ptr_add_tensor_base_scalar_offset(%ptrs: tensor<8x!ptr.ptr<#ptr.generic_space>>, %offset: i64) -> tensor<8x!ptr.ptr<#ptr.generic_space>> {
187+
%res = ptr.ptr_add %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
188+
%res0 = ptr.ptr_add none %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
189+
%res1 = ptr.ptr_add nusw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
190+
%res2 = ptr.ptr_add nuw %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
191+
%res3 = ptr.ptr_add inbounds %ptrs, %offset : tensor<8x!ptr.ptr<#ptr.generic_space>>, i64
192+
return %res : tensor<8x!ptr.ptr<#ptr.generic_space>>
193+
}

mlir/test/Target/LLVMIR/ptr.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,3 +203,33 @@ llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>
203203
ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>>
204204
llvm.return
205205
}
206+
207+
// CHECK-LABEL: define <4 x ptr> @ptr_add_vector
208+
// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], <4 x i32> %[[OFFSETS:.*]]) {
209+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], <4 x i32> %[[OFFSETS]]
210+
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
211+
// CHECK-NEXT: }
212+
llvm.func @ptr_add_vector(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
213+
%res = ptr.ptr_add %ptrs, %offsets : vector<4x!ptr.ptr<#llvm.address_space<0>>>, vector<4xi32>
214+
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
215+
}
216+
217+
// CHECK-LABEL: define <4 x ptr> @ptr_add_scalar_base_vector_offsets
218+
// CHECK-SAME: (ptr %[[PTR:.*]], <4 x i32> %[[OFFSETS:.*]]) {
219+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, ptr %[[PTR]], <4 x i32> %[[OFFSETS]]
220+
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
221+
// CHECK-NEXT: }
222+
llvm.func @ptr_add_scalar_base_vector_offsets(%ptr: !ptr.ptr<#llvm.address_space<0>>, %offsets: vector<4xi32>) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
223+
%res = ptr.ptr_add %ptr, %offsets : !ptr.ptr<#llvm.address_space<0>>, vector<4xi32>
224+
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
225+
}
226+
227+
// CHECK-LABEL: define <4 x ptr> @ptr_add_vector_base_scalar_offset
228+
// CHECK-SAME: (<4 x ptr> %[[PTRS:.*]], i32 %[[OFFSET:.*]]) {
229+
// CHECK-NEXT: %[[RES:.*]] = getelementptr i8, <4 x ptr> %[[PTRS]], i32 %[[OFFSET]]
230+
// CHECK-NEXT: ret <4 x ptr> %[[RES]]
231+
// CHECK-NEXT: }
232+
llvm.func @ptr_add_vector_base_scalar_offset(%ptrs: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %offset: i32) -> vector<4x!ptr.ptr<#llvm.address_space<0>>> {
233+
%res = ptr.ptr_add %ptrs, %offset : vector<4x!ptr.ptr<#llvm.address_space<0>>>, i32
234+
llvm.return %res : vector<4x!ptr.ptr<#llvm.address_space<0>>>
235+
}

0 commit comments

Comments
 (0)