Skip to content

Commit 24eea91

Browse files
clementvalaokblast
authored andcommitted
[flang][cuda] Convert src and dst to llvm.ptr in tma_bulk_load (llvm#165618)
1 parent 5734253 commit 24eea91

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9362,6 +9362,8 @@ static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
93629362
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
93639363
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
93649364
barrier = builder.createConvert(loc, llvmPtrTy, barrier);
9365+
dst = builder.createConvert(loc, llvmPtrTy, dst);
9366+
src = builder.createConvert(loc, llvmPtrTy, src);
93659367
mlir::NVVM::InlinePtxOp::create(
93669368
builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
93679369
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ end subroutine
544544
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
545545
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
546546
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
547-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f32>>>, !fir.ref<complex<f32>>, i32, !llvm.ptr)
547+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
548548
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
549549

550550
attributes(global) subroutine test_tma_bulk_load_c8(a, n)
@@ -563,7 +563,7 @@ end subroutine
563563
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
564564
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
565565
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
566-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xcomplex<f64>>>, !fir.ref<complex<f64>>, i32, !llvm.ptr)
566+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
567567
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
568568

569569
attributes(global) subroutine test_tma_bulk_load_i4(a, n)
@@ -582,7 +582,7 @@ end subroutine
582582
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
583583
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
584584
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
585-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>, i32, !llvm.ptr)
585+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
586586
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
587587

588588
attributes(global) subroutine test_tma_bulk_load_i8(a, n)
@@ -601,7 +601,7 @@ end subroutine
601601
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
602602
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
603603
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
604-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xi64>>, !fir.ref<i64>, i32, !llvm.ptr)
604+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
605605
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
606606

607607
attributes(global) subroutine test_tma_bulk_load_r2(a, n)
@@ -620,7 +620,7 @@ end subroutine
620620
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
621621
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
622622
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
623-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf16>>, !fir.ref<f16>, i32, !llvm.ptr)
623+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
624624
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
625625

626626
attributes(global) subroutine test_tma_bulk_load_r4(a, n)
@@ -639,7 +639,7 @@ end subroutine
639639
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
640640
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
641641
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
642-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf32>>, !fir.ref<f32>, i32, !llvm.ptr)
642+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
643643
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
644644

645645
attributes(global) subroutine test_tma_bulk_load_r8(a, n)
@@ -658,7 +658,7 @@ end subroutine
658658
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
659659
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
660660
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
661-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !fir.ref<!fir.array<1024xf64>>, !fir.ref<f64>, i32, !llvm.ptr)
661+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
662662
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
663663

664664
attributes(global) subroutine test_tma_bulk_store_c4(c, n)

0 commit comments

Comments
 (0)