Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 48 additions & 42 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,20 @@ class NVVM_IntrOp<string mnem, list<Trait> traits = [],
// NVVM special register op definitions
//===----------------------------------------------------------------------===//

class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
class NVVM_PureSpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_IntrOp<mnemonic, !listconcat(traits, [Pure]), 1> {
let arguments = (ins);
let assemblyFormat = "attr-dict `:` type($res)";
}

class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_SpecialRegisterOp<mnemonic,
class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_IntrOp<mnemonic, traits, 1> {
let arguments = (ins);
let assemblyFormat = "attr-dict `:` type($res)";
}

class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_PureSpecialRegisterOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
Expand Down Expand Up @@ -189,63 +195,63 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>

//===----------------------------------------------------------------------===//
// Lane, Warp, SM, Grid index and range
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
def NVVM_LaneIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
def NVVM_WarpSizeOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
def NVVM_WarpIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
def NVVM_WarpDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
def NVVM_SmIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
def NVVM_SmDimOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
def NVVM_GridIdOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;

//===----------------------------------------------------------------------===//
// Lane Mask Comparison Ops
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
def NVVM_LaneMaskGtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;
def NVVM_LaneMaskEqOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
def NVVM_LaneMaskLeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
def NVVM_LaneMaskLtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
def NVVM_LaneMaskGeOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
def NVVM_LaneMaskGtOp : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.lanemask.gt">;

//===----------------------------------------------------------------------===//
// Thread index and range
def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
def NVVM_ThreadIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
def NVVM_ThreadIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
def NVVM_ThreadIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
def NVVM_BlockDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
def NVVM_BlockDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
def NVVM_BlockDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;

//===----------------------------------------------------------------------===//
// Block index and range
def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
def NVVM_BlockIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
def NVVM_BlockIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
def NVVM_BlockIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
def NVVM_GridDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
def NVVM_GridDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
def NVVM_GridDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;

//===----------------------------------------------------------------------===//
// CTA Cluster index and range
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
def NVVM_ClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
def NVVM_ClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
def NVVM_ClusterDimXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
def NVVM_ClusterDimYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
def NVVM_ClusterDimZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;


//===----------------------------------------------------------------------===//
// CTA index and range within Cluster
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
def NVVM_BlockInClusterIdXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksXOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksYOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksZOp : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;

//===----------------------------------------------------------------------===//
// CTA index and across Cluster dimensions
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
def NVVM_ClusterId : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDim : NVVM_PureSpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;

//===----------------------------------------------------------------------===//
// Clock registers
Expand All @@ -256,7 +262,7 @@ def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
//===----------------------------------------------------------------------===//
// envreg registers
foreach index = !range(0, 32) in {
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
def NVVM_EnvReg # index # Op : NVVM_PureSpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
}

//===----------------------------------------------------------------------===//
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Dialect/LLVMIR/cse-nvvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: mlir-opt %s -cse -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @nvvm_special_regs_clock
llvm.func @nvvm_special_regs_clock() -> !llvm.struct<(i32, i32)> {
%0 = llvm.mlir.zero: !llvm.struct<(i32, i32)>
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
%1 = nvvm.read.ptx.sreg.clock : i32
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock
%2 = nvvm.read.ptx.sreg.clock : i32
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i32, i32)>
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i32, i32)>
llvm.return %5: !llvm.struct<(i32, i32)>
}

// CHECK-LABEL: @nvvm_special_regs_clock64
llvm.func @nvvm_special_regs_clock64() -> !llvm.struct<(i64, i64)> {
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
%1 = nvvm.read.ptx.sreg.clock64 : i64
// CHECK: {{.*}} = nvvm.read.ptx.sreg.clock64
%2 = nvvm.read.ptx.sreg.clock64 : i64
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
llvm.return %5: !llvm.struct<(i64, i64)>
}

// CHECK-LABEL: @nvvm_special_regs_globaltimer
llvm.func @nvvm_special_regs_globaltimer() -> !llvm.struct<(i64, i64)> {
%0 = llvm.mlir.zero: !llvm.struct<(i64, i64)>
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
%1 = nvvm.read.ptx.sreg.globaltimer : i64
// CHECK: {{.*}} = nvvm.read.ptx.sreg.globaltimer
%2 = nvvm.read.ptx.sreg.globaltimer : i64
%4 = llvm.insertvalue %1, %0[0]: !llvm.struct<(i64, i64)>
%5 = llvm.insertvalue %2, %4[1]: !llvm.struct<(i64, i64)>
llvm.return %5: !llvm.struct<(i64, i64)>
}