Skip to content

Commit ffaba75

Browse files
authored
[MLIR][ROCDL] Add permlane16.swap and permanlane32.swap (#153804)
add rocdl.permlane16.swap and rocdl.permanlane32.swap
1 parent 38eb14f commit ffaba75

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,53 @@ def ROCDL_PermlaneX16Op : ROCDL_IntrOp<"permlanex16", [], [0],
780780
}];
781781
}
782782

783+
class ROCDL_ConcretePair<Type elem0, Type elem1> :
784+
Type<And<[
785+
LLVM_AnyStruct.predicate,
786+
SubstLeaves<
787+
"$_self",
788+
"::llvm::cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[0]",
789+
elem0.predicate>,
790+
SubstLeaves<
791+
"$_self",
792+
"::llvm::cast<::mlir::LLVM::LLVMStructType>($_self).getBody()[1]",
793+
elem1.predicate>
794+
]>,
795+
"LLVM dialect-compatible struct of " # elem0.summary # "and" # elem1.summary,
796+
"::mlir::LLVM::LLVMStructType">,
797+
BuildableType<"::mlir::LLVM::LLVMStructType::getLiteral($_builder.getContext(), "
798+
"{" # elem0.builderCall # ", " # elem1.builderCall # "})">;
799+
800+
// Permlane16 swap intrinsic operation
801+
def ROCDL_Permlane16SwapOp : ROCDL_IntrOp<"permlane16.swap", [], [],
802+
[], 1, 0, 0, 0,
803+
[2, 3], ["fi", "boundControl"]>,
804+
Arguments<(ins I32:$old, I32:$src, I1Attr:$fi, I1Attr:$boundControl)> {
805+
let results = (outs ROCDL_ConcretePair<I32, I32>:$res);
806+
let assemblyFormat = [{
807+
attr-dict $old `,` $src `,` $fi `,` $boundControl `:` `(` type($old) `,` type($src) `)` `->` type($res)
808+
}];
809+
let description = [{
810+
Performs a `permlane16.swap` operation with the given operands, applying the
811+
permutation specified by $fi to the provided inputs.
812+
}];
813+
}
814+
815+
// Permlane32 swap intrinsic operation
816+
def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
817+
[], 1, 0, 0, 0,
818+
[2, 3], ["fi", "boundControl"]>,
819+
Arguments<(ins I32:$old, I32:$src, I1Attr:$fi, I1Attr:$boundControl)> {
820+
let results = (outs ROCDL_ConcretePair<I32, I32>:$res);
821+
let assemblyFormat = [{
822+
attr-dict $old `,` $src `,` $fi `,` $boundControl `:` `(` type($old) `,` type($src) `)` `->` type($res)
823+
}];
824+
let description = [{
825+
Performs a `permlane32.swap` operation with the given operands, applying the
826+
permutation specified by $fi to the provided inputs.
827+
}];
828+
}
829+
783830
class ROCDL_ConcreteVector<Type elem, int length> :
784831
FixedVectorOfLengthAndType<[length], [elem]>,
785832
BuildableType<

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,22 @@ llvm.func @rocdl.permlanex16(%src : f32) -> f32 {
10091009

10101010
// -----
10111011

1012+
llvm.func @rocdl.permlane16.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
1013+
// CHECK-LABEL: rocdl.permlane16.swap
1014+
// CHECK: rocdl.permlane16.swap %{{.*}} %{{.*}}
1015+
%res = rocdl.permlane16.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
1016+
llvm.return %res : !llvm.struct<(i32, i32)>
1017+
}
1018+
1019+
llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
1020+
// CHECK-LABEL: rocdl.permlane32.swap
1021+
// CHECK: rocdl.permlane32.swap %{{.*}} %{{.*}}
1022+
%res = rocdl.permlane32.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
1023+
llvm.return %res : !llvm.struct<(i32, i32)>
1024+
}
1025+
1026+
// -----
1027+
10121028
// expected-error@below {{attribute attached to unexpected op}}
10131029
func.func private @expected_llvm_func() attributes { rocdl.kernel }
10141030

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,20 @@ llvm.func @rocdl.permlanex16(%src0 : f32, %src1 : i32, %src2 : vector<2 x f32>,
941941
llvm.return %ret0 : f32
942942
}
943943

944+
llvm.func @rocdl.permlane16.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
945+
// CHECK-LABEL: rocdl.permlane16.swap
946+
// CHECK: call { i32, i32 } @llvm.amdgcn.permlane16.swap(i32 %{{.*}}, i32 %{{.*}}, i1 false, i1 true)
947+
%ret = rocdl.permlane16.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
948+
llvm.return %ret : !llvm.struct<(i32, i32)>
949+
}
950+
951+
llvm.func @rocdl.permlane32.swap(%src : i32) -> !llvm.struct<(i32, i32)> {
952+
// CHECK-LABEL: rocdl.permlane32.swap
953+
// CHECK: call { i32, i32 } @llvm.amdgcn.permlane32.swap(i32 %{{.*}}, i32 %{{.*}}, i1 false, i1 true)
954+
%ret = rocdl.permlane32.swap %src, %src, 0, -1 : (i32, i32) -> !llvm.struct<(i32, i32)>
955+
llvm.return %ret : !llvm.struct<(i32, i32)>
956+
}
957+
944958
llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
945959
// CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
946960
%r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>

0 commit comments

Comments
 (0)