Skip to content

Commit 34e9f3d

Browse files
authored
[mlir][ptr] Add gather, masked_load, masked_store, and scatter ops (#156368)
This patch adds the `gather`, `masked_load`, `masked_store`, and `scatter` operations to the `ptr` dialect. It also implements translation from these operations to LLVM intrinsics: - ptr.gather -> llvm.masked.gather - ptr.masked_load -> llvm.masked.load - ptr.masked_store -> llvm.masked.store - ptr.scatter -> llvm.masked.scatter Example: ```mlir llvm.func @mixed_masked_ops_address_spaces(%ptr: !ptr.ptr<#llvm.address_space<3>>, %ptrs: vector<4x!ptr.ptr<#llvm.address_space<3>>>, %mask: vector<4xi1>, %value: vector<4xf64>, %passthrough: vector<4xf64>) { %0 = ptr.gather %ptrs, %mask, %passthrough alignment = 8 : vector<4x!ptr.ptr<#llvm.address_space<3>>> -> vector<4xf64> ptr.scatter %value, %ptrs, %mask alignment = 8 : vector<4xf64>, vector<4x!ptr.ptr<#llvm.address_space<3>>> %1 = ptr.masked_load %ptr, %mask, %passthrough alignment = 8 : !ptr.ptr<#llvm.address_space<3>> -> vector<4xf64> ptr.masked_store %value, %ptr, %mask alignment = 8 : vector<4xf64>, !ptr.ptr<#llvm.address_space<3>> llvm.return } ``` Translates to: ```llvm define void @mixed_masked_ops_address_spaces(ptr addrspace(3) %0, <4 x ptr addrspace(3)> %1, <4 x i1> %2, <4 x double> %3, <4 x double> %4) { %6 = call <4 x double> @llvm.masked.gather.v4f64.v4p3(<4 x ptr addrspace(3)> %1, i32 8, <4 x i1> %2, <4 x double> %4) call void @llvm.masked.scatter.v4f64.v4p3(<4 x double> %3, <4 x ptr addrspace(3)> %1, i32 8, <4 x i1> %2) %7 = call <4 x double> @llvm.masked.load.v4f64.p3(ptr addrspace(3) %0, i32 8, <4 x i1> %2, <4 x double> %4) call void @llvm.masked.store.v4f64.p3(<4 x double> %3, ptr addrspace(3) %0, i32 8, <4 x i1> %2) ret void } ```
1 parent e6c63d9 commit 34e9f3d

File tree

6 files changed

+668
-15
lines changed

6 files changed

+668
-15
lines changed

mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td

Lines changed: 228 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,38 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/Interfaces/ViewLikeInterface.td"
1818
include "mlir/IR/OpAsmInterface.td"
1919

20+
//===----------------------------------------------------------------------===//
21+
// Common props
22+
//===----------------------------------------------------------------------===//
23+
24+
def AlignmentProp : OptionalProp<I64Prop>;
25+
26+
//===----------------------------------------------------------------------===//
27+
// Common types
28+
//===----------------------------------------------------------------------===//
29+
30+
// A shaped value type with value semantics and rank.
31+
class Ptr_ShapedValueType<list<Type> allowedTypes, list<Pred> preds = []> :
32+
ShapedContainerType<allowedTypes,
33+
/*containerPred=*/And<[HasValueSemanticsPred] # preds>,
34+
/*descr=*/[{A shaped type with value semantics and rank.}],
35+
/*cppType=*/"::mlir::ShapedType">;
36+
37+
// A shaped pointer type with value semantics and rank.
38+
class Ptr_ShapedPtrType : Ptr_ShapedValueType<[Ptr_PtrType], [HasRankPred]>;
39+
40+
// A shaped value type of rank 1 of any element type.
41+
def Ptr_Any1DType :
42+
Ptr_ShapedValueType<[AnyType], [HasAnyRankOfPred<[1]>]>;
43+
44+
// A shaped value type of rank 1 of `i1` element type.
45+
def Ptr_Mask1DType :
46+
Ptr_ShapedValueType<[I1], [HasAnyRankOfPred<[1]>]>;
47+
48+
// A shaped value type of rank 1 of `i1` element type.
49+
def Ptr_Ptr1DType :
50+
Ptr_ShapedValueType<[Ptr_PtrType], [HasAnyRankOfPred<[1]>]>;
51+
2052
//===----------------------------------------------------------------------===//
2153
// FromPtrOp
2254
//===----------------------------------------------------------------------===//
@@ -56,6 +88,58 @@ def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
5688
let hasVerifier = 1;
5789
}
5890

91+
//===----------------------------------------------------------------------===//
92+
// GatherOp
93+
//===----------------------------------------------------------------------===//
94+
95+
def Ptr_GatherOp : Pointer_Op<"gather", [
96+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
97+
TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
98+
::llvm::cast<ShapedType>($_self).clone(
99+
IntegerType::get($_self.getContext(), 1))
100+
}]>,
101+
AllTypesMatch<["result", "passthrough"]>,
102+
// Check the shapes are compatible and both use the same shaped container
103+
// type.
104+
AllShapesMatch<["result", "ptrs"]>, AllTypeIDsMatch<["result", "ptrs"]>
105+
]> {
106+
let summary = "Gather operation";
107+
let description = [{
108+
The `gather` operation performs conditional loads from multiple memory
109+
locations specified by `ptrs` based on a mask `mask`. Elements of the
110+
result corresponding to masked-off lanes are taken from the passthrough
111+
operand.
112+
113+
The mask operand is a shaped type of `i1` elements that must have the same
114+
shape as the result type.
115+
116+
Examples:
117+
```mlir
118+
// Gather values from multiple memory locations
119+
%result = ptr.gather %ptrs, %mask, %passthrough :
120+
vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
121+
122+
// Gather with alignment
123+
%result = ptr.gather %ptrs, %mask, %passthrough alignment = 8 :
124+
vector<4x!ptr.ptr<#ptr.generic_space>> -> vector<4xf32>
125+
```
126+
}];
127+
let arguments = (ins Ptr_Ptr1DType:$ptrs,
128+
Ptr_Mask1DType:$mask,
129+
Ptr_Any1DType:$passthrough,
130+
AlignmentProp:$alignment);
131+
let results = (outs Ptr_Any1DType:$result);
132+
let assemblyFormat = [{
133+
$ptrs `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
134+
attr-dict `:` type($ptrs) `->` type($result)
135+
}];
136+
let builders = [
137+
OpBuilder<(ins "Type":$resultType, "Value":$ptrs, "Value":$mask,
138+
"Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
139+
];
140+
let hasVerifier = 1;
141+
}
142+
59143
//===----------------------------------------------------------------------===//
60144
// GetMetadataOp
61145
//===----------------------------------------------------------------------===//
@@ -122,8 +206,6 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
122206
// LoadOp
123207
//===----------------------------------------------------------------------===//
124208

125-
def AlignmentProp : OptionalProp<I64Prop>;
126-
127209
def Ptr_LoadOp : Pointer_Op<"load", [
128210
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
129211
]> {
@@ -184,6 +266,150 @@ def Ptr_LoadOp : Pointer_Op<"load", [
184266
let hasVerifier = 1;
185267
}
186268

269+
//===----------------------------------------------------------------------===//
270+
// MaskedLoadOp
271+
//===----------------------------------------------------------------------===//
272+
273+
def Ptr_MaskedLoadOp : Pointer_Op<"masked_load", [
274+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
275+
TypesMatchWith<"result and mask must be compatible", "result", "mask", [{
276+
::llvm::cast<ShapedType>($_self).clone(
277+
IntegerType::get($_self.getContext(), 1))
278+
}]>,
279+
AllTypesMatch<["result", "passthrough"]>
280+
]> {
281+
let summary = "Masked load operation";
282+
let description = [{
283+
The `masked_load` operation performs a conditional load from memory based
284+
on a mask. Elements of the result corresponding to masked-off lanes are
285+
taken from the passthrough operand.
286+
287+
The mask operand is a shaped type of `i1` elements that must have the same
288+
shape as the result type.
289+
290+
Examples:
291+
```mlir
292+
// Masked load with passthrough on vectors
293+
%result = ptr.masked_load %ptr, %mask, %passthrough :
294+
!ptr.ptr<#ptr.generic_space> -> vector<4xf32>
295+
296+
// Masked load with passthrough on tensors
297+
%result = ptr.masked_load %ptr, %mask, %passthrough :
298+
!ptr.ptr<#ptr.generic_space> -> tensor<4xf32>
299+
```
300+
}];
301+
let arguments = (ins Ptr_PtrType:$ptr,
302+
Ptr_Mask1DType:$mask,
303+
Ptr_Any1DType:$passthrough,
304+
AlignmentProp:$alignment);
305+
let results = (outs Ptr_Any1DType:$result);
306+
let assemblyFormat = [{
307+
$ptr `,` $mask `,` $passthrough (`alignment` `=` $alignment^)?
308+
attr-dict `:` qualified(type($ptr)) `->` type($result)
309+
}];
310+
let builders = [
311+
OpBuilder<(ins "Type":$resultType, "Value":$ptr, "Value":$mask,
312+
"Value":$passthrough, CArg<"unsigned", "0">:$alignment)>
313+
];
314+
let hasVerifier = 1;
315+
}
316+
317+
//===----------------------------------------------------------------------===//
318+
// MaskedStoreOp
319+
//===----------------------------------------------------------------------===//
320+
321+
def Ptr_MaskedStoreOp : Pointer_Op<"masked_store", [
322+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
323+
TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
324+
::llvm::cast<ShapedType>($_self).clone(
325+
IntegerType::get($_self.getContext(), 1))
326+
}]>
327+
]> {
328+
let summary = "Masked store operation";
329+
let description = [{
330+
The `masked_store` operation performs a conditional store to memory based
331+
on a mask. Only elements corresponding to set bits in the mask are written
332+
to memory.
333+
334+
The mask operand is a shaped type of `i1` elements that must have the same
335+
shape as the value being stored.
336+
337+
Examples:
338+
```mlir
339+
// Masked store
340+
ptr.masked_store %value, %ptr, %mask :
341+
vector<4xf32>, !ptr.ptr<#ptr.generic_space>
342+
343+
// Masked store with alignment
344+
ptr.masked_store %value, %ptr, %mask alignment = 8 :
345+
vector<4xf32>, !ptr.ptr<#ptr.generic_space>
346+
```
347+
}];
348+
349+
let arguments = (ins Ptr_Any1DType:$value,
350+
Ptr_PtrType:$ptr,
351+
Ptr_Mask1DType:$mask,
352+
AlignmentProp:$alignment);
353+
let assemblyFormat = [{
354+
$value `,` $ptr `,` $mask (`alignment` `=` $alignment^)? attr-dict `:`
355+
type($value) `,` qualified(type($ptr))
356+
}];
357+
let builders = [
358+
OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask,
359+
CArg<"unsigned", "0">:$alignment)>
360+
];
361+
let hasVerifier = 1;
362+
}
363+
364+
//===----------------------------------------------------------------------===//
365+
// ScatterOp
366+
//===----------------------------------------------------------------------===//
367+
368+
def Ptr_ScatterOp : Pointer_Op<"scatter", [
369+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
370+
TypesMatchWith<"value and mask must be compatible", "value", "mask", [{
371+
::llvm::cast<ShapedType>($_self).clone(
372+
IntegerType::get($_self.getContext(), 1))
373+
}]>,
374+
// Check the shapes are compatible and both use the same shaped container
375+
// type.
376+
AllShapesMatch<["value", "ptrs"]>, AllTypeIDsMatch<["value", "ptrs"]>
377+
]> {
378+
let summary = "Scatter operation";
379+
let description = [{
380+
The `scatter` operation performs a conditional store of a value `value` to
381+
multiple memory locations specified by `ptrs` based on a mask `mask`.
382+
383+
Only elements corresponding to set bits in the mask are written to memory.
384+
The mask operand is a shaped type of `i1` elements that must have the same
385+
shape as the value being stored.
386+
387+
Examples:
388+
```mlir
389+
// Scatter values to multiple memory locations
390+
ptr.scatter %value, %ptrs, %mask :
391+
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
392+
393+
// Scatter with alignment
394+
ptr.scatter %value, %ptrs, %mask alignment = 8 :
395+
vector<4xf32>, vector<4x!ptr.ptr<#ptr.generic_space>>
396+
```
397+
}];
398+
let arguments = (ins Ptr_Any1DType:$value,
399+
Ptr_Ptr1DType:$ptrs,
400+
Ptr_Mask1DType:$mask,
401+
AlignmentProp:$alignment);
402+
let assemblyFormat = [{
403+
$value `,` $ptrs `,` $mask (`alignment` `=` $alignment^)?
404+
attr-dict `:` type($value) `,` type($ptrs)
405+
}];
406+
let builders = [
407+
OpBuilder<(ins "Value":$value, "Value":$ptrs, "Value":$mask,
408+
CArg<"unsigned", "0">:$alignment)>
409+
];
410+
let hasVerifier = 1;
411+
}
412+
187413
//===----------------------------------------------------------------------===//
188414
// StoreOp
189415
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpBase.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,10 @@ class AllShapesMatch<list<string> names> :
556556
class AllTypesMatch<list<string> names> :
557557
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
558558

559+
// Checks that all type IDs match.
560+
class AllTypeIDsMatch<list<string> names> :
561+
AllMatchSameOperatorTrait<names, "$_self.getType().getTypeID()", "type IDs">;
562+
559563
// A type constraint that verifies that a shaped type matches the size and
560564
// element type of a container with element types. More specifically, it denotes
561565
// shapedArg.getType().getNumElements() == elementsArg.size() &&

0 commit comments

Comments
 (0)