From c95478447184e5cb7ee31432452d85f0d715a226 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 6 Mar 2025 17:30:48 -0800 Subject: [PATCH 1/2] [flang][cuda] Add more interfaces for __ldca, __ldcs, __ldlu and __ldcv --- flang/module/cudadevice.f90 | 100 +++++++++++++++++++++ flang/test/Lower/CUDA/cuda-device-proc.cuf | 18 ++++ 2 files changed, 118 insertions(+) diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90 index 4491e9f653270..baaa112f5d8c2 100644 --- a/flang/module/cudadevice.f90 +++ b/flang/module/cudadevice.f90 @@ -1118,6 +1118,31 @@ attributes(device) integer function match_any_syncjd(mask, val) !dir$ ignore_tkr (d) x complex(8), intent(in) :: x end function + attributes(device) pure function __ldca_i4x4(x) result(y) + !dir$ ignore_tkr (d) x + integer(4), dimension(4), intent(in) :: x + integer(4), dimension(4) :: y + end function + attributes(device) pure function __ldca_i8x2(x) result(y) + !dir$ ignore_tkr (d) x + integer(8), dimension(2), intent(in) :: x + integer(8), dimension(2) :: y + end function + attributes(device) pure function __ldca_r2x2(x) result(y) + !dir$ ignore_tkr (d) x + real(2), dimension(2), intent(in) :: x + real(2), dimension(2) :: y + end function + attributes(device) pure function __ldca_r4x4(x) result(y) + !dir$ ignore_tkr (d) x + real(4), dimension(4), intent(in) :: x + real(4), dimension(4) :: y + end function + attributes(device) pure function __ldca_r8x2(x) result(y) + !dir$ ignore_tkr (d) x + real(8), dimension(2), intent(in) :: x + real(8), dimension(2) :: y + end function end interface ! LDCS @@ -1158,6 +1183,31 @@ attributes(device) integer function match_any_syncjd(mask, val) !dir$ ignore_tkr (d) x complex(8), intent(in) :: x end function + attributes(device) pure function __ldcs_i4x4(x) result(y) + !dir$ ignore_tkr (d) x + integer(4), dimension(4), intent(in) :: x + integer(4), dimension(4) :: y + end function + attributes(device) pure function __ldcs_i8x2(x) result(y) + !dir$ ignore_tkr (d) x + integer(8), dimension(2), intent(in) :: x + integer(8), dimension(2) :: y + end function + attributes(device) pure function __ldcs_r2x2(x) result(y) + !dir$ ignore_tkr (d) x + real(2), dimension(2), intent(in) :: x + real(2), dimension(2) :: y + end function + attributes(device) pure function __ldcs_r4x4(x) result(y) + !dir$ ignore_tkr (d) x + real(4), dimension(4), intent(in) :: x + real(4), dimension(4) :: y + end function + attributes(device) pure function __ldcs_r8x2(x) result(y) + !dir$ ignore_tkr (d) x + real(8), dimension(2), intent(in) :: x + real(8), dimension(2) :: y + end function end interface ! LDLU @@ -1198,6 +1248,31 @@ attributes(device) integer function match_any_syncjd(mask, val) !dir$ ignore_tkr (d) x complex(8), intent(in) :: x end function + attributes(device) pure function __ldlu_i4x4(x) result(y) + !dir$ ignore_tkr (d) x + integer(4), dimension(4), intent(in) :: x + integer(4), dimension(4) :: y + end function + attributes(device) pure function __ldlu_i8x2(x) result(y) + !dir$ ignore_tkr (d) x + integer(8), dimension(2), intent(in) :: x + integer(8), dimension(2) :: y + end function + attributes(device) pure function __ldlu_r2x2(x) result(y) + !dir$ ignore_tkr (d) x + real(2), dimension(2), intent(in) :: x + real(2), dimension(2) :: y + end function + attributes(device) pure function __ldlu_r4x4(x) result(y) + !dir$ ignore_tkr (d) x + real(4), dimension(4), intent(in) :: x + real(4), dimension(4) :: y + end function + attributes(device) pure function __ldlu_r8x2(x) result(y) + !dir$ ignore_tkr (d) x + real(8), dimension(2), intent(in) :: x + real(8), dimension(2) :: y + end function end interface ! LDCV @@ -1238,6 +1313,31 @@ attributes(device) integer function match_any_syncjd(mask, val) !dir$ ignore_tkr (d) x complex(8), intent(in) :: x end function + attributes(device) pure function __ldcv_i4x4(x) result(y) + !dir$ ignore_tkr (d) x + integer(4), dimension(4), intent(in) :: x + integer(4), dimension(4) :: y + end function + attributes(device) pure function __ldcv_i8x2(x) result(y) + !dir$ ignore_tkr (d) x + integer(8), dimension(2), intent(in) :: x + integer(8), dimension(2) :: y + end function + attributes(device) pure function __ldcv_r2x2(x) result(y) + !dir$ ignore_tkr (d) x + real(2), dimension(2), intent(in) :: x + real(2), dimension(2) :: y + end function + attributes(device) pure function __ldcv_r4x4(x) result(y) + !dir$ ignore_tkr (d) x + real(4), dimension(4), intent(in) :: x + real(4), dimension(4) :: y + end function + attributes(device) pure function __ldcv_r8x2(x) result(y) + !dir$ ignore_tkr (d) x + real(8), dimension(2), intent(in) :: x + real(8), dimension(2) :: y + end function end interface ! STWB diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index c651d34c55093..feff9c31f12f2 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -198,3 +198,21 @@ end subroutine ! CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %{{.*}}#1 : !fir.ref to !llvm.ptr ! CHECK: %[[ATOMIC:.*]] = llvm.cmpxchg %[[CAST]], %[[BCAST1]], %[[BCAST2]] acq_rel monotonic : !llvm.ptr, i64 ! CHECK: %[[RES:.*]] = llvm.extractvalue %[[ATOMIC]][1] : !llvm.struct<(i64, i1)> + + +attributes(global) subroutine __ldXX(b) + integer, device :: b(*) + integer, device :: x(4) + + x(1:4) = __ldca(b(i:j)) + x = __ldcg(b(i:j)) + x = __ldcs(b(i:j)) + x(1:4) = __ldlu(b(i:j)) + x(1:4) = __ldcv(b(i:j)) +end + +! CHECK-LABEL: func.func @_QP__ldxx +! CHECK: __ldca_i4x4 +! CHECK: __ldcg_i4x4 +! CHECK: __ldcs_i4x4 +! CHECK: __ldlu_i4x4 From 2dc746b395244cb540a11685196f40c9d7570db9 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 6 Mar 2025 17:54:16 -0800 Subject: [PATCH 2/2] Add more tests --- flang/test/Lower/CUDA/cuda-device-proc.cuf | 70 ++++++++++++++++++++-- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index feff9c31f12f2..5f39f78f8ecae 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -199,11 +199,9 @@ end subroutine ! CHECK: %[[ATOMIC:.*]] = llvm.cmpxchg %[[CAST]], %[[BCAST1]], %[[BCAST2]] acq_rel monotonic : !llvm.ptr, i64 ! CHECK: %[[RES:.*]] = llvm.extractvalue %[[ATOMIC]][1] : !llvm.struct<(i64, i1)> - -attributes(global) subroutine __ldXX(b) +attributes(global) subroutine __ldXXi4(b) integer, device :: b(*) integer, device :: x(4) - x(1:4) = __ldca(b(i:j)) x = __ldcg(b(i:j)) x = __ldcs(b(i:j)) @@ -211,8 +209,72 @@ attributes(global) subroutine __ldXX(b) x(1:4) = __ldcv(b(i:j)) end -! CHECK-LABEL: func.func @_QP__ldxx +! CHECK-LABEL: func.func @_QP__ldxxi4 ! CHECK: __ldca_i4x4 ! CHECK: __ldcg_i4x4 ! CHECK: __ldcs_i4x4 ! CHECK: __ldlu_i4x4 + +attributes(global) subroutine __ldXXi8(b) + integer(8), device :: b(*) + integer(8), device :: x(2) + x(1:2) = __ldca(b(i:j)) + x = __ldcg(b(i:j)) + x = __ldcs(b(i:j)) + x(1:2) = __ldlu(b(i:j)) + x(1:2) = __ldcv(b(i:j)) +end + +! CHECK-LABEL: func.func @_QP__ldxxi8 +! CHECK: __ldca_i8x2 +! CHECK: __ldcg_i8x2 +! CHECK: __ldcs_i8x2 +! CHECK: __ldlu_i8x2 + +attributes(global) subroutine __ldXXr4(b) + real, device :: b(*) + real, device :: x(4) + x(1:4) = __ldca(b(i:j)) + x = __ldcg(b(i:j)) + x = __ldcs(b(i:j)) + x(1:4) = __ldlu(b(i:j)) + x(1:4) = __ldcv(b(i:j)) +end + +! CHECK-LABEL: func.func @_QP__ldxxr4 +! CHECK: __ldca_r4x4 +! CHECK: __ldcg_r4x4 +! CHECK: __ldcs_r4x4 +! CHECK: __ldlu_r4x4 + +attributes(global) subroutine __ldXXr2(b) + real(2), device :: b(*) + real(2), device :: x(2) + x(1:2) = __ldca(b(i:j)) + x = __ldcg(b(i:j)) + x = __ldcs(b(i:j)) + x(1:2) = __ldlu(b(i:j)) + x(1:2) = __ldcv(b(i:j)) +end + +! CHECK-LABEL: func.func @_QP__ldxxr2 +! CHECK: __ldca_r2x2 +! CHECK: __ldcg_r2x2 +! CHECK: __ldcs_r2x2 +! CHECK: __ldlu_r2x2 + +attributes(global) subroutine __ldXXr8(b) + real(8), device :: b(*) + real(8), device :: x(2) + x(1:2) = __ldca(b(i:j)) + x = __ldcg(b(i:j)) + x = __ldcs(b(i:j)) + x(1:2) = __ldlu(b(i:j)) + x(1:2) = __ldcv(b(i:j)) +end + +! CHECK-LABEL: func.func @_QP__ldxxr8 +! CHECK: __ldca_r8x2 +! CHECK: __ldcg_r8x2 +! CHECK: __ldcs_r8x2 +! CHECK: __ldlu_r8x2