Skip to content

Commit fbd7c72

Browse files
clementvalkcloudy0717
authored andcommitted
[flang][cuda] Do not use address cast for src and dst in TMA bulk load (llvm#170564)
1 parent 2bc6e1a commit fbd7c72

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1508,10 +1508,8 @@ static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
15081508
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
15091509
barrier = builder.createConvert(loc, llvmPtrTy, barrier);
15101510
setAlignment(dst, kTMAAlignment);
1511-
dst = convertPtrToNVVMSpace(builder, loc, dst,
1512-
mlir::NVVM::NVVMMemorySpace::Shared);
1513-
src = convertPtrToNVVMSpace(builder, loc, src,
1514-
mlir::NVVM::NVVMMemorySpace::Shared);
1511+
dst = builder.createConvert(loc, llvmPtrTy, dst);
1512+
src = builder.createConvert(loc, llvmPtrTy, src);
15151513
mlir::NVVM::InlinePtxOp::create(
15161514
builder, loc, mlir::TypeRange{}, {dst, src, size, barrier}, {},
15171515
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], "

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ end subroutine
543543
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
544544
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
545545
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
546-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
546+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
547547
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
548548

549549
attributes(global) subroutine test_tma_bulk_load_c8(a, n)
@@ -563,7 +563,7 @@ end subroutine
563563
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 16 : i32
564564
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
565565
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
566-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
566+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
567567
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
568568

569569
attributes(global) subroutine test_tma_bulk_load_i4(a, n)
@@ -583,7 +583,7 @@ end subroutine
583583
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
584584
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
585585
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
586-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
586+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
587587
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
588588

589589
attributes(global) subroutine test_tma_bulk_load_i8(a, n)
@@ -603,7 +603,7 @@ end subroutine
603603
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
604604
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
605605
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
606-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
606+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
607607
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
608608

609609
attributes(global) subroutine test_tma_bulk_load_r2(a, n)
@@ -623,7 +623,7 @@ end subroutine
623623
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 2 : i32
624624
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
625625
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
626-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
626+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
627627
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
628628

629629
attributes(global) subroutine test_tma_bulk_load_r4(a, n)
@@ -643,7 +643,7 @@ end subroutine
643643
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 4 : i32
644644
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
645645
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
646-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
646+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
647647
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
648648

649649
attributes(global) subroutine test_tma_bulk_load_r8(a, n)
@@ -663,7 +663,7 @@ end subroutine
663663
! CHECK: %[[ELEM_SIZE:.*]] = arith.constant 8 : i32
664664
! CHECK: %[[SIZE:.*]] = arith.muli %[[COUNT]], %[[ELEM_SIZE]] : i32
665665
! CHECK: %[[BARRIER_PTR:.*]] = fir.convert %[[BARRIER]]#0 : (!fir.ref<i64>) -> !llvm.ptr
666-
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr<3>, !llvm.ptr<3>, i32, !llvm.ptr)
666+
! CHECK: nvvm.inline_ptx "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ro(%{{.*}}, %{{.*}}, %[[SIZE]], %[[BARRIER_PTR]] : !llvm.ptr, !llvm.ptr, i32, !llvm.ptr)
667667
! CHECK: nvvm.inline_ptx "mbarrier.expect_tx.relaxed.cta.shared::cta.b64 [%0], %1;" ro(%[[BARRIER_PTR]], %[[SIZE]] : !llvm.ptr, i32)
668668

669669
attributes(global) subroutine test_tma_bulk_store_c4(c, n)

0 commit comments

Comments
 (0)