Skip to content

Commit 2e67eef

Browse files
committed
[flang][cuda] Set alignment for tma bulk store
1 parent 2fb2d7e commit 2e67eef

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ static const char __ldlu_r2x2[] = "__ldlu_r2x2_";
5353
static const char __ldlu_r4x4[] = "__ldlu_r4x4_";
5454
static const char __ldlu_r8x2[] = "__ldlu_r8x2_";
5555

56+
static constexpr unsigned kTMAAlignment = 16;
57+
5658
// CUDA specific intrinsic handlers.
5759
static constexpr IntrinsicHandler cudaHandlers[]{
5860
{"__ldca_i4x4",
@@ -1505,7 +1507,7 @@ static void genTMABulkLoad(fir::FirOpBuilder &builder, mlir::Location loc,
15051507
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, nelem, eleSize);
15061508
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
15071509
barrier = builder.createConvert(loc, llvmPtrTy, barrier);
1508-
setAlignment(dst, 16);
1510+
setAlignment(dst, kTMAAlignment);
15091511
dst = convertPtrToNVVMSpace(builder, loc, dst,
15101512
mlir::NVVM::NVVMMemorySpace::Shared);
15111513
src = convertPtrToNVVMSpace(builder, loc, src,
@@ -1611,6 +1613,7 @@ static void genTMABulkStore(fir::FirOpBuilder &builder, mlir::Location loc,
16111613
mlir::Value src, mlir::Value dst, mlir::Value count,
16121614
mlir::Value eleSize) {
16131615
mlir::Value size = mlir::arith::MulIOp::create(builder, loc, eleSize, count);
1616+
setAlignment(src, kTMAAlignment);
16141617
src = convertPtrToNVVMSpace(builder, loc, src,
16151618
mlir::NVVM::NVVMMemorySpace::Shared);
16161619
dst = convertPtrToNVVMSpace(builder, loc, dst,

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ attributes(global) subroutine test_tma_bulk_store_c4(c, n)
675675
end subroutine
676676

677677
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c4
678+
! CHECK: cuf.shared_memory !fir.array<1024xcomplex<f32>> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_c4Etmpa"} -> !fir.ref<!fir.array<1024xcomplex<f32>>>
678679
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
679680
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
680681
! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -688,6 +689,7 @@ attributes(global) subroutine test_tma_bulk_store_c8(c, n)
688689
end subroutine
689690

690691
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_c8
692+
! CHECK: cuf.shared_memory !fir.array<1024xcomplex<f64>> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_c8Etmpa"} -> !fir.ref<!fir.array<1024xcomplex<f64>>>
691693
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
692694
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
693695
! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -701,6 +703,7 @@ attributes(global) subroutine test_tma_bulk_store_i4(c, n)
701703
end subroutine
702704

703705
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i4
706+
! CHECK: cuf.shared_memory !fir.array<1024xi32> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_i4Etmpa"} -> !fir.ref<!fir.array<1024xi32>>
704707
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
705708
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
706709
! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -714,6 +717,7 @@ attributes(global) subroutine test_tma_bulk_store_i8(c, n)
714717
end subroutine
715718

716719
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_i8
720+
! CHECK: cuf.shared_memory !fir.array<1024xi64> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_i8Etmpa"} -> !fir.ref<!fir.array<1024xi64>>
717721
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
718722
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
719723
! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -728,6 +732,7 @@ attributes(global) subroutine test_tma_bulk_store_r2(c, n)
728732
end subroutine
729733

730734
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r2
735+
! CHECK: cuf.shared_memory !fir.array<1024xf16> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_r2Etmpa"} -> !fir.ref<!fir.array<1024xf16>>
731736
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
732737
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
733738
! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -741,6 +746,7 @@ attributes(global) subroutine test_tma_bulk_store_r4(c, n)
741746
end subroutine
742747

743748
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r4
749+
! CHECK: cuf.shared_memory !fir.array<1024xf32> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_r4Etmpa"} -> !fir.ref<!fir.array<1024xf32>>
744750
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
745751
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
746752
! CHECK: nvvm.cp.async.bulk.wait_group 0
@@ -754,6 +760,7 @@ attributes(global) subroutine test_tma_bulk_store_r8(c, n)
754760
end subroutine
755761

756762
! CHECK-LABEL: func.func @_QPtest_tma_bulk_store_r8
763+
! CHECK: cuf.shared_memory !fir.array<1024xf64> align 16 {bindc_name = "tmpa", uniq_name = "_QFtest_tma_bulk_store_r8Etmpa"} -> !fir.ref<!fir.array<1024xf64>>
757764
! CHECK: nvvm.cp.async.bulk.global.shared.cta %{{.*}}, %{{.*}}, %{{.*}} : <1>, <3>
758765
! CHECK: nvvm.inline_ptx "cp.async.bulk.commit_group;"
759766
! CHECK: nvvm.cp.async.bulk.wait_group 0

0 commit comments

Comments
 (0)