Skip to content

Commit bfe865b

Browse files
authored
[CIR][CIRGen][CUDA] Lower PTX synchronization primitives (#1903)
1 parent fe5977d commit bfe865b

File tree

4 files changed

+143
-29
lines changed

4 files changed

+143
-29
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinNVPTX.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "mlir/Dialect/Func/IR/FuncOps.h"
2121
#include "mlir/IR/Value.h"
22+
#include "mlir/IR/ValueRange.h"
2223
#include "clang/AST/GlobalDecl.h"
2324
#include "clang/Basic/Builtins.h"
2425
#include "clang/Basic/TargetBuiltins.h"
@@ -551,21 +552,67 @@ mlir::Value CIRGenFunction::emitNVPTXBuiltinExpr(unsigned builtinId,
551552
case NVPTX::BI__nvvm_getctarank_shared_cluster:
552553
llvm_unreachable("getctarank_shared_cluster NYI");
553554
case NVPTX::BI__nvvm_barrier_cluster_arrive:
554-
llvm_unreachable("barrier_cluster_arrive NYI");
555+
return builder
556+
.create<cir::LLVMIntrinsicCallOp>(
557+
getLoc(expr->getExprLoc()),
558+
builder.getStringAttr("nvvm.barrier.cluster.arrive"),
559+
builder.getVoidTy())
560+
.getResult();
555561
case NVPTX::BI__nvvm_barrier_cluster_arrive_relaxed:
556-
llvm_unreachable("barrier_cluster_arrive_relaxed NYI");
562+
return builder
563+
.create<cir::LLVMIntrinsicCallOp>(
564+
getLoc(expr->getExprLoc()),
565+
builder.getStringAttr("nvvm.barrier.cluster.arrive.relaxed"),
566+
builder.getVoidTy())
567+
.getResult();
557568
case NVPTX::BI__nvvm_barrier_cluster_wait:
558-
llvm_unreachable("barrier_cluster_wait NYI");
569+
return builder
570+
.create<cir::LLVMIntrinsicCallOp>(
571+
getLoc(expr->getExprLoc()),
572+
builder.getStringAttr("nvvm.barrier.cluster.wait"),
573+
builder.getVoidTy())
574+
.getResult();
559575
case NVPTX::BI__nvvm_fence_sc_cluster:
560-
llvm_unreachable("fence_sc_cluster NYI");
576+
return builder
577+
.create<cir::LLVMIntrinsicCallOp>(
578+
getLoc(expr->getExprLoc()),
579+
builder.getStringAttr("nvvm.fence.sc.cluster"), builder.getVoidTy(),
580+
mlir::ValueRange{})
581+
.getResult();
561582
case NVPTX::BI__nvvm_bar_sync:
562-
llvm_unreachable("bar_sync NYI");
583+
return builder
584+
.create<cir::LLVMIntrinsicCallOp>(
585+
getLoc(expr->getExprLoc()),
586+
builder.getStringAttr("nvvm.barrier.cta.sync.aligned.all"),
587+
builder.getVoidTy(),
588+
mlir::ValueRange{emitScalarExpr(expr->getArg(0))})
589+
.getResult();
563590
case NVPTX::BI__syncthreads:
564-
llvm_unreachable("syncthreads NYI");
591+
return builder
592+
.create<cir::LLVMIntrinsicCallOp>(
593+
getLoc(expr->getExprLoc()),
594+
builder.getStringAttr("nvvm.barrier.cta.sync.aligned.all"),
595+
builder.getVoidTy(),
596+
mlir::ValueRange{
597+
builder.getConstInt(getLoc(expr->getExprLoc()), SInt32Ty, 0)})
598+
.getResult();
565599
case NVPTX::BI__nvvm_barrier_sync:
566-
llvm_unreachable("barrier_sync NYI");
600+
return builder
601+
.create<cir::LLVMIntrinsicCallOp>(
602+
getLoc(expr->getExprLoc()),
603+
builder.getStringAttr("nvvm.barrier.cta.sync.all"),
604+
builder.getVoidTy(),
605+
mlir::ValueRange{emitScalarExpr(expr->getArg(0))})
606+
.getResult();
567607
case NVPTX::BI__nvvm_barrier_sync_cnt:
568-
llvm_unreachable("barrier_sync_cnt NYI");
608+
return builder
609+
.create<cir::LLVMIntrinsicCallOp>(
610+
getLoc(expr->getExprLoc()),
611+
builder.getStringAttr("nvvm.barrier.cta.sync.count"),
612+
builder.getVoidTy(),
613+
mlir::ValueRange{emitScalarExpr(expr->getArg(0)),
614+
emitScalarExpr(expr->getArg(1))})
615+
.getResult();
569616
default:
570617
return nullptr;
571618
}

clang/test/CIR/CodeGen/CUDA/builtin-functions.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@
1010
// RUN: %s -o %t.ll
1111
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
1212

13+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda \
14+
// RUN: -fcuda-is-device -emit-llvm -target-sdk-version=12.3 \
15+
// RUN: %s -o %t.ll
16+
// RUN: FileCheck --check-prefix=OGCHECK --input-file=%t.ll %s
17+
18+
__device__ void sync() {
19+
20+
// CIR: cir.llvm.intrinsic "nvvm.barrier.cta.sync.aligned.all" {{.*}} : (!s32i)
21+
// LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
22+
// OGCHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
23+
__nvvm_bar_sync(0);
24+
}
25+
1326
__device__ void builtins() {
1427
float f1, f2;
1528
double d1, d2;
@@ -59,7 +72,8 @@ __device__ void builtins() {
5972
// LLVM: call void @llvm.nvvm.membar.sys()
6073
__nvvm_membar_sys();
6174

62-
// TODO-CIR: cir.llvm.intrinsic "nvvm.barrier0"
63-
// TODO-LLVM: call void @llvm.nvvm.barrier0()
64-
// __syncthreads();
75+
// CIR: cir.llvm.intrinsic "nvvm.barrier.cta.sync.aligned.all"
76+
// LLVM: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
77+
// OGCHECK: call void @llvm.nvvm.barrier.cta.sync.aligned.all(i32 0)
78+
__syncthreads();
6579
}
Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,41 @@
11
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_70 \
22
// RUN: -fcuda-is-device -target-feature +ptx60 \
3-
// RUN: -emit-cir -o - -x cuda %s \
4-
// RUN: | FileCheck -check-prefix=CIR %s
3+
// RUN: -emit-cir -o %t.cir -x cuda %s
4+
// RUN: FileCheck -check-prefix=CIR --input-file=%t.cir %s
55
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
66
// RUN: -fcuda-is-device -target-feature +ptx65 \
7-
// RUN: -emit-cir -o - -x cuda %s \
8-
// RUN: | FileCheck -check-prefix=CIR %s
7+
// RUN: -emit-cir -o %t.cir -x cuda %s
8+
// RUN: FileCheck -check-prefix=CIR --input-file=%t.cir %s
99
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
1010
// RUN: -fcuda-is-device -target-feature +ptx70 \
11-
// RUN: -emit-cir -o - -x cuda %s \
12-
// RUN: | FileCheck -check-prefix=CIR %s
11+
// RUN: -emit-cir -o %t.cir -x cuda %s
12+
// RUN: FileCheck -check-prefix=CIR --input-file=%t.cir %s
1313

1414
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_70 \
1515
// RUN: -fcuda-is-device -target-feature +ptx60 \
16-
// RUN: -emit-llvm -o - -x cuda %s \
17-
// RUN: | FileCheck -check-prefix=LLVM %s
16+
// RUN: -emit-llvm -o %t.ll -x cuda %s
17+
// RUN: FileCheck -check-prefix=LLVM --input-file=%t.ll %s
1818
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
1919
// RUN: -fcuda-is-device -target-feature +ptx65 \
20-
// RUN: -emit-llvm -o - -x cuda %s \
21-
// RUN: | FileCheck -check-prefix=LLVM %s
20+
// RUN: -emit-llvm -o %t.ll -x cuda %s
21+
// RUN: FileCheck -check-prefix=LLVM --input-file=%t.ll %s
2222
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
2323
// RUN: -fcuda-is-device -target-feature +ptx70 \
24-
// RUN: -emit-llvm -o - -x cuda %s \
25-
// RUN: | FileCheck -check-prefix=LLVM %s
24+
// RUN: -emit-llvm -o %t.ll -x cuda %s
25+
// RUN: FileCheck -check-prefix=LLVM --input-file=%t.ll %s
2626

2727
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_70 \
2828
// RUN: -fcuda-is-device -target-feature +ptx60 \
29-
// RUN: -emit-llvm -o - -x cuda %s \
30-
// RUN: | FileCheck -check-prefix=OGCHECK %s
29+
// RUN: -emit-llvm -o %t_og.ll -x cuda %s
30+
// RUN: FileCheck -check-prefix=OGCHECK --input-file=%t_og.ll %s
3131
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 \
3232
// RUN: -fcuda-is-device -target-feature +ptx65 \
33-
// RUN: -emit-llvm -o - -x cuda %s \
34-
// RUN: | FileCheck -check-prefix=OGCHECK %s
33+
// RUN: -emit-llvm -o %t_og.ll -x cuda %s
34+
// RUN: FileCheck -check-prefix=OGCHECK --input-file=%t_og.ll %s
3535
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 \
3636
// RUN: -fcuda-is-device -target-feature +ptx70 \
37-
// RUN: -emit-llvm -o - -x cuda %s \
38-
// RUN: | FileCheck -check-prefix=OGCHECK %s
37+
// RUN: -emit-llvm -o %t_og.ll -x cuda %s
38+
// RUN: FileCheck -check-prefix=OGCHECK --input-file=%t_og.ll %s
3939

4040
#define __device__ __attribute__((device))
4141
#define __global__ __attribute__((global))
@@ -52,4 +52,14 @@ __device__ void nvvm_sync(unsigned mask, int i, float f, int a, int b,
5252
// OGCHECK: call void @llvm.nvvm.bar.warp.sync(i32
5353
__nvvm_bar_warp_sync(mask);
5454

55+
// CIR: cir.llvm.intrinsic "nvvm.barrier.cta.sync.all" {{.*}} : (!u32i)
56+
// LLVM: call void @llvm.nvvm.barrier.cta.sync.all(i32
57+
// OGCHECK: call void @llvm.nvvm.barrier.cta.sync.all(i32
58+
__nvvm_barrier_sync(mask);
59+
60+
// CIR: cir.llvm.intrinsic "nvvm.barrier.cta.sync.count" {{.*}} : (!u32i, !u32i)
61+
// LLVM: call void @llvm.nvvm.barrier.cta.sync.count(i32
62+
// OGCHECK: call void @llvm.nvvm.barrier.cta.sync.count(i32
63+
__nvvm_barrier_sync_cnt(mask, i);
64+
5565
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-feature +ptx80 \
2+
// RUN: -target-cpu sm_90 -fclangir -emit-cir -fcuda-is-device -target-sdk-version=12.3 \
3+
// RUN: %s -o %t.cir
4+
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
5+
6+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-feature +ptx80 \
7+
// RUN: -target-cpu sm_90 -fclangir -emit-llvm -fcuda-is-device -target-sdk-version=12.3 \
8+
// RUN: %s -o %t.ll
9+
// RUN: FileCheck --check-prefix=LLVM --input-file=%t.ll %s
10+
11+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-feature +ptx80 \
12+
// RUN: -target-cpu sm_90 -fclangir -emit-llvm -fcuda-is-device -target-sdk-version=12.3 \
13+
// RUN: %s -o %t.ll
14+
// RUN: FileCheck --check-prefix=OGCHECK --input-file=%t.ll %s
15+
16+
// CIR-LABEL: _Z6kernelPlPvj(
17+
// LLVM: define{{.*}} void @_Z6kernelPlPvj(
18+
// OGCHECK: define{{.*}} void @_Z6kernelPlPvj(
19+
__attribute__((global)) void kernel(long *out, void *ptr, unsigned u) {
20+
// CIR: cir.llvm.intrinsic "nvvm.barrier.cluster.arrive"
21+
// LLVM: call void @llvm.nvvm.barrier.cluster.arrive()
22+
// OGCHECK: call void @llvm.nvvm.barrier.cluster.arrive()
23+
__nvvm_barrier_cluster_arrive();
24+
25+
// CIR: cir.llvm.intrinsic "nvvm.barrier.cluster.arrive.relaxed"
26+
// LLVM: call void @llvm.nvvm.barrier.cluster.arrive.relaxed()
27+
// OGCHECK: call void @llvm.nvvm.barrier.cluster.arrive.relaxed()
28+
29+
__nvvm_barrier_cluster_arrive_relaxed();
30+
// CIR: cir.llvm.intrinsic "nvvm.barrier.cluster.wait"
31+
// LLVM: call void @llvm.nvvm.barrier.cluster.wait()
32+
// OGCHECK: call void @llvm.nvvm.barrier.cluster.wait()
33+
__nvvm_barrier_cluster_wait();
34+
35+
// CIR: cir.llvm.intrinsic "nvvm.fence.sc.cluster"
36+
// LLVM: call void @llvm.nvvm.fence.sc.cluster()
37+
// OGCHECK: call void @llvm.nvvm.fence.sc.cluster()
38+
__nvvm_fence_sc_cluster();
39+
40+
// CIR: cir.return
41+
// LLVM: ret void
42+
// OGCHECK: ret void
43+
}

0 commit comments

Comments
 (0)