@@ -17,6 +17,38 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1717include "mlir/Interfaces/ViewLikeInterface.td"
1818include "mlir/IR/OpAsmInterface.td"
1919
20+ //===----------------------------------------------------------------------===//
21+ // Common props
22+ //===----------------------------------------------------------------------===//
23+
24+ def AlignmentProp : OptionalProp<I64Prop>;
25+
26+ //===----------------------------------------------------------------------===//
27+ // Common types
28+ //===----------------------------------------------------------------------===//
29+
30+ // A shaped value type with value semantics and rank.
31+ class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
32+ ShapedContainerType<allowedTypes,
33+ /*containerPred=*/And<[HasValueSemanticsPred] # preds>,
34+ /*descr=*/[{A shaped type with value semantics and rank.}],
35+ /*cppType=*/"::mlir::ShapedType">;
36+
37+ // A shaped pointer type with value semantics and rank.
38+ class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
39+
40+ // A shaped value type of rank 1 of any element type.
41+ def Ptr_Any1DType :
42+ Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
43+
44+ // A shaped value type of rank 1 of `i1` element type.
45+ def Ptr_Mask1DType :
46+ Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>;
47+
48+ // A shaped value type of rank 1 of `i1` element type.
49+ def Ptr_Ptr1DType :
50+ Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
51+
2052//===----------------------------------------------------------------------===//
2153// FromPtrOp
2254//===----------------------------------------------------------------------===//
@@ -56,6 +88,58 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
5688 let hasVerifier = 1;
5789}
5890
91+ //===----------------------------------------------------------------------===//
92+ // GatherOp
93+ //===----------------------------------------------------------------------===//
94+
95+ def Ptr_GatherOp : Pointer_Op<"gather", [
96+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
97+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
98+ ::llvm::cast<ShapedType>($_self).clone(
99+ IntegerType::get($_self.getContext(), 1))
100+ }]>,
101+ AllTypesMatch<["result", "passthrough"]>,
102+ // Check the shapes are compatible and both use the same shaped container
103+ // type.
104+ AllShapesMatch<["result", "ptrs"]>, AllTypeIDsMatch<["result", "ptrs"]>
105+ ]> {
106+ let summary = "Gather operation";
107+ let description = [{
108+ The `gather` operation performs conditional loads from multiple memory
109+ locations specified by `ptrs` based on a mask `mask`. Elements of the
110+ result corresponding to masked-off lanes are taken from the passthrough
111+ operand.
112+
113+ The mask operand is a shaped type of `i1` elements that must have the same
114+ shape as the result type.
115+
116+ Examples:
117+ ```mlir
118+ // Gather values from multiple memory locations
119+ %result = ptr.gather %ptrs, %mask, %passthrough :
120+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
121+
122+ // Gather with alignment
123+ %result = ptr.gather %ptrs, %mask, %passthrough alignment = 8 :
124+ vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
125+ ```
126+ }];
127+ let arguments = (ins Ptr_Ptr1DType:$ptrs,
128+ Ptr_Mask1DType:$mask,
129+ Ptr_Any1DType:$passthrough,
130+ AlignmentProp:$alignment);
131+ let results = (outs Ptr_Any1DType:$result);
132+ let assemblyFormat = [{
133+ $ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
134+ attr-dict `:` type($ptrs) `->` type($result)
135+ }];
136+ let builders = [
137+ OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
138+ "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
139+ ];
140+ let hasVerifier = 1;
141+ }
142+
59143//===----------------------------------------------------------------------===//
60144// GetMetadataOp
61145//===----------------------------------------------------------------------===//
@@ -122,8 +206,6 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
122206// LoadOp
123207//===----------------------------------------------------------------------===//
124208
125- def AlignmentProp : OptionalProp<I64Prop>;
126-
127209def Ptr_LoadOp : Pointer_Op<"load", [
128210 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
129211 ]> {
@@ -184,6 +266,150 @@ def Ptr_LoadOp : Pointer_Op<"load", [
184266 let hasVerifier = 1;
185267}
186268
269+ //===----------------------------------------------------------------------===//
270+ // MaskedLoadOp
271+ //===----------------------------------------------------------------------===//
272+
273+ def Ptr_MaskedLoadOp : Pointer_Op<"masked_load", [
274+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
275+ TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
276+ ::llvm::cast<ShapedType>($_self).clone(
277+ IntegerType::get($_self.getContext(), 1))
278+ }]>,
279+ AllTypesMatch<["result", "passthrough"]>
280+ ]> {
281+ let summary = "Masked load operation";
282+ let description = [{
283+ The `masked_load` operation performs a conditional load from memory based
284+ on a mask. Elements of the result corresponding to masked-off lanes are
285+ taken from the passthrough operand.
286+
287+ The mask operand is a shaped type of `i1` elements that must have the same
288+ shape as the result type.
289+
290+ Examples:
291+ ```mlir
292+ // Masked load with passthrough on vectors
293+ %result = ptr.masked_load %ptr, %mask, %passthrough :
294+ !ptr.ptr<#ptr.generic_space> -> vector<4xf32>
295+
296+ // Masked load with passthrough on tensors
297+ %result = ptr.masked_load %ptr, %mask, %passthrough :
298+ !ptr.ptr<#ptr.generic_space> -> tensor<4xf32>
299+ ```
300+ }];
301+ let arguments = (ins Ptr_PtrType:$ptr,
302+ Ptr_Mask1DType:$mask,
303+ Ptr_Any1DType:$passthrough,
304+ AlignmentProp:$alignment);
305+ let results = (outs Ptr_Any1DType:$result);
306+ let assemblyFormat = [{
307+ $ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
308+ attr-dict `:` qualified(type($ptr)) `->` type($result)
309+ }];
310+ let builders = [
311+ OpBuilder<(ins "Type":$resultType, "Value":$ptr, "Value":$mask,
312+ "Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
313+ ];
314+ let hasVerifier = 1;
315+ }
316+
317+ //===----------------------------------------------------------------------===//
318+ // MaskedStoreOp
319+ //===----------------------------------------------------------------------===//
320+
321+ def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
322+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
323+ TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
324+ ::llvm::cast<ShapedType>($_self).clone(
325+ IntegerType::get($_self.getContext(), 1))
326+ }]>
327+ ]> {
328+ let summary = "Masked store operation";
329+ let description = [{
330+ The `masked_store` operation performs a conditional store to memory based
331+ on a mask. Only elements corresponding to set bits in the mask are written
332+ to memory.
333+
334+ The mask operand is a shaped type of `i1` elements that must have the same
335+ shape as the value being stored.
336+
337+ Examples:
338+ ```mlir
339+ // Masked store
340+ ptr.masked_store %value, %ptr, %mask :
341+ vector<4xf32>, !ptr.ptr<#ptr.generic_space>
342+
343+ // Masked store with alignment
344+ ptr.masked_store %value, %ptr, %mask alignment = 8 :
345+ vector<4xf32>, !ptr.ptr<#ptr.generic_space>
346+ ```
347+ }];
348+
349+ let arguments = (ins Ptr_Any1DType:$value,
350+ Ptr_PtrType:$ptr,
351+ Ptr_Mask1DType:$mask,
352+ AlignmentProp:$alignment);
353+ let assemblyFormat = [{
354+ $value `,` $ptr `,` $mask (`alignment` `=` $alignment^)? attr-dict `:`
355+ type($value) `,` qualified(type($ptr))
356+ }];
357+ let builders = [
358+ OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
359+ CArg<"unsigned", "0">:$alignment)>
360+ ];
361+ let hasVerifier = 1;
362+ }
363+
364+ //===----------------------------------------------------------------------===//
365+ // ScatterOp
366+ //===----------------------------------------------------------------------===//
367+
368+ def Ptr_ScatterOp : Pointer_Op<"scatter", [
369+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
370+ TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
371+ ::llvm::cast<ShapedType>($_self).clone(
372+ IntegerType::get($_self.getContext(), 1))
373+ }]>,
374+ // Check the shapes are compatible and both use the same shaped container
375+ // type.
376+ AllShapesMatch<["value", "ptrs"]>, AllTypeIDsMatch<["value", "ptrs"]>
377+ ]> {
378+ let summary = "Scatter operation";
379+ let description = [{
380+ The `scatter` operation performs a conditional store of a value `value` to
381+ multiple memory locations specified by `ptrs` based on a mask `mask`.
382+
383+ Only elements corresponding to set bits in the mask are written to memory.
384+ The mask operand is a shaped type of `i1` elements that must have the same
385+ shape as the value being stored.
386+
387+ Examples:
388+ ```mlir
389+ // Scatter values to multiple memory locations
390+ ptr.scatter %value, %ptrs, %mask :
391+ vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
392+
393+ // Scatter with alignment
394+ ptr.scatter %value, %ptrs, %mask alignment = 8 :
395+ vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
396+ ```
397+ }];
398+ let arguments = (ins Ptr_Any1DType:$value,
399+ Ptr_Ptr1DType:$ptrs,
400+ Ptr_Mask1DType:$mask,
401+ AlignmentProp:$alignment);
402+ let assemblyFormat = [{
403+ $value `,` $ptrs `,` $mask (`alignment` `=` $alignment^)?
404+ attr-dict `:` type($value) `,` type($ptrs)
405+ }];
406+ let builders = [
407+ OpBuilder<(ins "Value":$value, "Value":$ptrs, "Value":$mask,
408+ CArg<"unsigned", "0">:$alignment)>
409+ ];
410+ let hasVerifier = 1;
411+ }
412+
187413//===----------------------------------------------------------------------===//
188414// StoreOp
189415//===----------------------------------------------------------------------===//
0 commit comments