Skip to content

Commit 07ed101

Browse files
authored
[flang][cuda] Support logical(4) in syncthread_and|count|or functions (#164706)
1 parent 81a9d75 commit 07ed101

File tree

3 files changed

+55
-19
lines changed

3 files changed

+55
-19
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,18 @@ static constexpr IntrinsicHandler handlers[]{
989989
{"mask", asBox, handleDynamicOptional}}},
990990
/*isElemental=*/false},
991991
{"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false},
992-
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
993-
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
994-
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
992+
{"syncthreads_and_i4", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
993+
{"syncthreads_and_l4", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
994+
{"syncthreads_count_i4",
995+
&I::genSyncThreadsCount,
996+
{},
997+
/*isElemental=*/false},
998+
{"syncthreads_count_l4",
999+
&I::genSyncThreadsCount,
1000+
{},
1001+
/*isElemental=*/false},
1002+
{"syncthreads_or_i4", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
1003+
{"syncthreads_or_l4", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
9951004
{"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
9961005
{"system",
9971006
&I::genSystem,

flang/module/cudadevice.f90

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,32 @@ module cudadevice
2121
procedure :: syncthreads
2222
end interface
2323

24-
interface
25-
attributes(device) integer function syncthreads_and(value)
26-
integer, value :: value
24+
interface syncthreads_and
25+
attributes(device) integer function syncthreads_and_i4(value)
26+
integer(4), value :: value
2727
end function
28-
end interface
28+
attributes(device) integer function syncthreads_and_l4(value)
29+
logical(4), value :: value
30+
end function
31+
end interface syncthreads_and
2932

30-
interface
31-
attributes(device) integer function syncthreads_count(value)
32-
integer, value :: value
33+
interface syncthreads_count
34+
attributes(device) integer function syncthreads_count_i4(value)
35+
integer(4), value :: value
3336
end function
34-
end interface
37+
attributes(device) integer function syncthreads_count_l4(value)
38+
logical(4), value :: value
39+
end function
40+
end interface syncthreads_count
3541

36-
interface
37-
attributes(device) integer function syncthreads_or(value)
38-
integer, value :: value
42+
interface syncthreads_or
43+
attributes(device) integer function syncthreads_or_i4(value)
44+
integer(4), value :: value
3945
end function
40-
end interface
46+
attributes(device) integer function syncthreads_or_l4(value)
47+
logical(4), value :: value
48+
end function
49+
end interface syncthreads_or
4150

4251
interface
4352
attributes(device) subroutine syncwarp(mask)

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,23 @@ attributes(global) subroutine devsub()
1212
integer(8) :: al
1313
integer(8) :: time
1414
integer :: smalltime
15-
integer(4) :: res
15+
integer(4) :: res, offset
1616
integer(8) :: resl
1717

18+
integer :: tid
19+
tid = threadIdx%x
20+
1821
call syncthreads()
1922
call syncwarp(1)
2023
call threadfence()
2124
call threadfence_block()
2225
call threadfence_system()
2326
ret = syncthreads_and(1)
27+
res = syncthreads_and(tid > offset)
2428
ret = syncthreads_count(1)
29+
ret = syncthreads_count(tid > offset)
2530
ret = syncthreads_or(1)
31+
ret = syncthreads_or(tid > offset)
2632

2733
ai = atomicadd(ai, 1_4)
2834
al = atomicadd(al, 1_8)
@@ -100,9 +106,21 @@ end
100106
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
101107
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
102108
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
103-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1_i32_0) fastmath<contract> : (i32) -> i32
104-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1_i32_1) fastmath<contract> : (i32) -> i32
105-
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1_i32_2) fastmath<contract> : (i32) -> i32
109+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
110+
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
111+
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
112+
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
113+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CMP]])
114+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
115+
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
116+
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
117+
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
118+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CMP]]) fastmath<contract> : (i1) -> i32
119+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
120+
! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref<i32>
121+
! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref<i32>
122+
! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32
123+
! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CMP]]) fastmath<contract> : (i1) -> i32
106124
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
107125
! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64
108126
! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32

0 commit comments

Comments
 (0)