Skip to content

Conversation

@schwarzschild-radius
Copy link
Contributor

This commit adds support for the following PTX predefined special registers

  • warpid
  • nwarpid
  • smid
  • nsmid
  • gridid
  • lanemask.*
  • globaltimer
  • envreg* And added lit tests under nvvmir.mlir

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Pradeep Kumar (schwarzschild-radius)

Changes

This commit adds support for the following PTX predefined special registers

  • warpid
  • nwarpid
  • smid
  • nsmid
  • gridid
  • lanemask.*
  • globaltimer
  • envreg* And added lit tests under nvvmir.mlir

Full diff: https://github.com/llvm/llvm-project/pull/112343.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+21-1)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+88-4)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 152715f281088e..e67f5fc8f9347b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -139,9 +139,22 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []>
 }
 
 //===----------------------------------------------------------------------===//
-// Lane index and range
+// Lane, Warp, SM, Grid index and range
 def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
 def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
+def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
+def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
+def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
+def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
+def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
+
+//===----------------------------------------------------------------------===//
+// Lane Mask Comparison Ops
+def NVVM_LaneMaskEqOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.eq">;
+def NVVM_LaneMaskLeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.le">;
+def NVVM_LaneMaskLtOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.lt">;
+def NVVM_LaneMaskGeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.ge">;
+def NVVM_LaneMaskGtOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.lanemask.gt">;
 
 //===----------------------------------------------------------------------===//
 // Thread index and range
@@ -189,6 +202,13 @@ def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nct
 // Clock registers
 def NVVM_ClockOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock">;
 def NVVM_Clock64Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.clock64">;
+def NVVM_GlobalTimerOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.globaltimer">;
+
+//===----------------------------------------------------------------------===//
+// envreg registers
+foreach index = !range(0, 32) in {
+  def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
+}
 
 //===----------------------------------------------------------------------===//
 // NVVM approximate op definitions
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 7fd082a5eb3c75..0471e5faf84578 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -62,10 +62,94 @@ llvm.func @nvvm_special_regs() -> i32 {
   %29 = nvvm.read.ptx.sreg.clock : i32
   // CHECK: call i64 @llvm.nvvm.read.ptx.sreg.clock64
   %30 = nvvm.read.ptx.sreg.clock64 : i64
-
-  // CHECK: %31 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
-  %31 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
-
+  // CHECK: call i64 @llvm.nvvm.read.ptx.sreg.globaltimer
+  %31 = nvvm.read.ptx.sreg.globaltimer : i64
+  // CHECK: %32 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %32 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.warpid
+  %33 = nvvm.read.ptx.sreg.warpid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nwarpid
+  %34 = nvvm.read.ptx.sreg.nwarpid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.smid
+  %35 = nvvm.read.ptx.sreg.smid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nsmid
+  %36 = nvvm.read.ptx.sreg.nsmid : i32
+  // CHECK: call i32 @llvm.nvvm.read.ptx.sreg.gridid
+  %37 = nvvm.read.ptx.sreg.gridid : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg0
+  %38 = nvvm.read.ptx.sreg.envreg0 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg1
+  %39 = nvvm.read.ptx.sreg.envreg1 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg2
+  %40 = nvvm.read.ptx.sreg.envreg2 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg3
+  %41 = nvvm.read.ptx.sreg.envreg3 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg4
+  %42 = nvvm.read.ptx.sreg.envreg4 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg5
+  %43 = nvvm.read.ptx.sreg.envreg5 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg6
+  %44 = nvvm.read.ptx.sreg.envreg6 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg7
+  %45 = nvvm.read.ptx.sreg.envreg7 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg8
+  %46 = nvvm.read.ptx.sreg.envreg8 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg9
+  %47 = nvvm.read.ptx.sreg.envreg9 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg10
+  %48 = nvvm.read.ptx.sreg.envreg10 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg11
+  %49 = nvvm.read.ptx.sreg.envreg11 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg12
+  %50 = nvvm.read.ptx.sreg.envreg12 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg13
+  %51 = nvvm.read.ptx.sreg.envreg13 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg14
+  %52 = nvvm.read.ptx.sreg.envreg14 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg15
+  %53 = nvvm.read.ptx.sreg.envreg15 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg16
+  %54 = nvvm.read.ptx.sreg.envreg16 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg17
+  %55 = nvvm.read.ptx.sreg.envreg17 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg18
+  %56 = nvvm.read.ptx.sreg.envreg18 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg19
+  %57 = nvvm.read.ptx.sreg.envreg19 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg20
+  %58 = nvvm.read.ptx.sreg.envreg20 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg21
+  %59 = nvvm.read.ptx.sreg.envreg21 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg22
+  %60 = nvvm.read.ptx.sreg.envreg22 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg23
+  %61 = nvvm.read.ptx.sreg.envreg23 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg24
+  %62 = nvvm.read.ptx.sreg.envreg24 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg25
+  %63 = nvvm.read.ptx.sreg.envreg25 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg26
+  %64 = nvvm.read.ptx.sreg.envreg26 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg27
+  %65 = nvvm.read.ptx.sreg.envreg27 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg28
+  %66 = nvvm.read.ptx.sreg.envreg28 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg29
+  %67 = nvvm.read.ptx.sreg.envreg29 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg30
+  %68 = nvvm.read.ptx.sreg.envreg30 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.envreg31
+  %69 = nvvm.read.ptx.sreg.envreg31 : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.eq
+  %70 = nvvm.read.ptx.sreg.lanemask.eq : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.le
+  %71 = nvvm.read.ptx.sreg.lanemask.le : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.lt
+  %72 = nvvm.read.ptx.sreg.lanemask.lt : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.ge
+  %73 = nvvm.read.ptx.sreg.lanemask.ge : i32
+  //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt
+  %74 = nvvm.read.ptx.sreg.lanemask.gt : i32
   llvm.return %1 : i32
 }
 

@schwarzschild-radius
Copy link
Contributor Author

@grypp @durga4github had a question on whether the lanemask.* Ops needs to be a SpecialRegisterOp vs a SpecialRangeableRegisterOp

This commit adds support for the following PTX predefined special
registers
* warpid
* nwarpid
* smid
* nsmid
* gridid
* lanemask.*
* globaltimer
* envreg*
And added lit tests under nvvmir.mlir
@schwarzschild-radius schwarzschild-radius force-pushed the ptx_special_register_support branch from bcab57a to fecd365 Compare October 17, 2024 05:33
@schwarzschild-radius schwarzschild-radius merged commit 9b713f5 into llvm:main Oct 17, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants