Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Optimizer/Builder/IntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ struct IntrinsicLibrary {
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genTranspose(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genThisGrid(mlir::Type, llvm::ArrayRef<mlir::Value>);
void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
Expand Down
85 changes: 85 additions & 0 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ static constexpr IntrinsicHandler handlers[]{
{{{"count", asAddr}, {"count_rate", asAddr}, {"count_max", asAddr}}},
/*isElemental=*/false},
{"tand", &I::genTand},
{"this_grid", &I::genThisGrid, {}, /*isElemental=*/false},
{"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
{"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
{"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
Expand Down Expand Up @@ -8109,6 +8110,90 @@ mlir::Value IntrinsicLibrary::genTand(mlir::Type resultType,
return getRuntimeCallGenerator("tan", ftype)(builder, loc, {arg});
}

// THIS_GRID
mlir::Value IntrinsicLibrary::genThisGrid(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 0);
auto recTy = mlir::cast<fir::RecordType>(resultType);
assert(recTy && "RecordType expepected");
mlir::Value res = builder.create<fir::AllocaOp>(loc, resultType);
mlir::Type i32Ty = builder.getI32Type();

mlir::Value threadIdX = builder.create<mlir::NVVM::ThreadIdXOp>(loc, i32Ty);
mlir::Value threadIdY = builder.create<mlir::NVVM::ThreadIdYOp>(loc, i32Ty);
mlir::Value threadIdZ = builder.create<mlir::NVVM::ThreadIdZOp>(loc, i32Ty);

mlir::Value blockIdX = builder.create<mlir::NVVM::BlockIdXOp>(loc, i32Ty);
mlir::Value blockIdY = builder.create<mlir::NVVM::BlockIdYOp>(loc, i32Ty);
mlir::Value blockIdZ = builder.create<mlir::NVVM::BlockIdZOp>(loc, i32Ty);

mlir::Value blockDimX = builder.create<mlir::NVVM::BlockDimXOp>(loc, i32Ty);
mlir::Value blockDimY = builder.create<mlir::NVVM::BlockDimYOp>(loc, i32Ty);
mlir::Value blockDimZ = builder.create<mlir::NVVM::BlockDimZOp>(loc, i32Ty);
mlir::Value gridDimX = builder.create<mlir::NVVM::GridDimXOp>(loc, i32Ty);
mlir::Value gridDimY = builder.create<mlir::NVVM::GridDimYOp>(loc, i32Ty);
mlir::Value gridDimZ = builder.create<mlir::NVVM::GridDimZOp>(loc, i32Ty);

// this_grid.size = ((blockDim.z * gridDim.z) * (blockDim.y * gridDim.y)) *
// (blockDim.x * gridDim.x);
mlir::Value resZ =
builder.create<mlir::arith::MulIOp>(loc, blockDimZ, gridDimZ);
mlir::Value resY =
builder.create<mlir::arith::MulIOp>(loc, blockDimY, gridDimY);
mlir::Value resX =
builder.create<mlir::arith::MulIOp>(loc, blockDimX, gridDimX);
mlir::Value resZY = builder.create<mlir::arith::MulIOp>(loc, resZ, resY);
mlir::Value size = builder.create<mlir::arith::MulIOp>(loc, resZY, resX);

// tmp = ((blockIdx.z * gridDim.y * gridDim.x) + (blockIdx.y * gridDim.x)) +
// blockIdx.x;
// this_group.rank = tmp * ((blockDim.x * blockDim.y) * blockDim.z) +
// ((threadIdx.z * blockDim.y) * blockDim.x) +
// (threadIdx.y * blockDim.x) + threadIdx.x + 1;
mlir::Value r1 = builder.create<mlir::arith::MulIOp>(loc, blockIdZ, gridDimY);
mlir::Value r2 = builder.create<mlir::arith::MulIOp>(loc, r1, gridDimX);
mlir::Value r3 = builder.create<mlir::arith::MulIOp>(loc, blockIdY, gridDimX);
mlir::Value r2r3 = builder.create<mlir::arith::AddIOp>(loc, r2, r3);
mlir::Value tmp = builder.create<mlir::arith::AddIOp>(loc, r2r3, blockIdX);

mlir::Value bXbY =
builder.create<mlir::arith::MulIOp>(loc, blockDimX, blockDimY);
mlir::Value bXbYbZ =
builder.create<mlir::arith::MulIOp>(loc, bXbY, blockDimZ);
mlir::Value tZbY =
builder.create<mlir::arith::MulIOp>(loc, threadIdZ, blockDimY);
mlir::Value tZbYbX =
builder.create<mlir::arith::MulIOp>(loc, tZbY, blockDimX);
mlir::Value tYbX =
builder.create<mlir::arith::MulIOp>(loc, threadIdY, blockDimX);
mlir::Value rank = builder.create<mlir::arith::MulIOp>(loc, tmp, bXbYbZ);
rank = builder.create<mlir::arith::AddIOp>(loc, rank, tZbYbX);
rank = builder.create<mlir::arith::AddIOp>(loc, rank, tYbX);
rank = builder.create<mlir::arith::AddIOp>(loc, rank, threadIdX);
mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1);
rank = builder.create<mlir::arith::AddIOp>(loc, rank, one);

auto sizeFieldName = recTy.getTypeList()[1].first;
mlir::Type sizeFieldTy = recTy.getTypeList()[1].second;
mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext());
mlir::Value sizeFieldIndex = builder.create<fir::FieldIndexOp>(
loc, fieldIndexType, sizeFieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
mlir::Value sizeCoord = builder.create<fir::CoordinateOp>(
loc, builder.getRefType(sizeFieldTy), res, sizeFieldIndex);
builder.create<fir::StoreOp>(loc, size, sizeCoord);

auto rankFieldName = recTy.getTypeList()[2].first;
mlir::Type rankFieldTy = recTy.getTypeList()[2].second;
mlir::Value rankFieldIndex = builder.create<fir::FieldIndexOp>(
loc, fieldIndexType, rankFieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
mlir::Value rankCoord = builder.create<fir::CoordinateOp>(
loc, builder.getRefType(rankFieldTy), res, rankFieldIndex);
builder.create<fir::StoreOp>(loc, rank, rankCoord);
return res;
}

// TRAILZ
mlir::Value IntrinsicLibrary::genTrailz(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
Expand Down
30 changes: 30 additions & 0 deletions flang/module/cooperative_groups.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
!===-- module/cooperative_groups.f90 ---------------------------------------===!
!
! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
! See https://llvm.org/LICENSE.txt for license information.
! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
!
!===------------------------------------------------------------------------===!

! CUDA Fortran cooperative groups

module cooperative_groups

use, intrinsic :: __fortran_builtins, only: c_devptr => __builtin_c_devptr

implicit none

type :: grid_group
type(c_devptr), private :: handle
integer(4) :: size
integer(4) :: rank
end type grid_group

interface
attributes(device) function this_grid()
import
type(grid_group) :: this_grid
end function
end interface

end module
52 changes: 52 additions & 0 deletions flang/test/Lower/CUDA/cuda-cooperative.cuf
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s

! Test CUDA Fortran procedures available in cooperative_groups module.

attributes(grid_global) subroutine g1()
use cooperative_groups
type(grid_group) :: gg
gg = this_grid()
end subroutine

! CHECK: %{{.*}} = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
! CHECK: %[[RES:.*]] = fir.alloca !fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>
! CHECK: %[[THREAD_ID_X:.*]] = nvvm.read.ptx.sreg.tid.x : i32
! CHECK: %[[THREAD_ID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
! CHECK: %[[THREAD_ID_Z:.*]] = nvvm.read.ptx.sreg.tid.z : i32
! CHECK: %[[BLOCK_ID_X:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32
! CHECK: %[[BLOCK_ID_Y:.*]] = nvvm.read.ptx.sreg.ctaid.y : i32
! CHECK: %[[BLOCK_ID_Z:.*]] = nvvm.read.ptx.sreg.ctaid.z : i32
! CHECK: %[[BLOCK_DIM_X:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
! CHECK: %[[BLOCK_DIM_Y:.*]] = nvvm.read.ptx.sreg.ntid.y : i32
! CHECK: %[[BLOCK_DIM_Z:.*]] = nvvm.read.ptx.sreg.ntid.z : i32
! CHECK: %[[GRID_DIM_X:.*]] = nvvm.read.ptx.sreg.nctaid.x : i32
! CHECK: %[[GRID_DIM_Y:.*]] = nvvm.read.ptx.sreg.nctaid.y : i32
! CHECK: %[[GRID_DIM_Z:.*]] = nvvm.read.ptx.sreg.nctaid.z : i32

! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_Z]], %[[GRID_DIM_Z]] : i32
! CHECK: %[[R2:.*]] = arith.muli %[[BLOCK_DIM_Y]], %[[GRID_DIM_Y]] : i32
! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[GRID_DIM_X]] : i32
! CHECK: %[[R4:.*]] = arith.muli %[[R1]], %[[R2]] : i32
! CHECK: %[[SIZE:.*]] = arith.muli %[[R4]], %[[R3]] : i32

! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[GRID_DIM_Y]] : i32
! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[GRID_DIM_X]] : i32
! CHECK: %[[R3:.*]] = arith.muli %[[BLOCK_ID_Y]], %[[GRID_DIM_X]] : i32
! CHECK: %[[R4:.*]] = arith.addi %[[R2]], %[[R3]] : i32
! CHECK: %[[TMP:.*]] = arith.addi %[[R4]], %[[BLOCK_ID_X]] : i32

! CHECK: %[[R1:.*]] = arith.muli %[[BLOCK_DIM_X]], %[[BLOCK_DIM_Y]] : i32
! CHECK: %[[R2:.*]] = arith.muli %[[R1]], %[[BLOCK_DIM_Z]] : i32
! CHECK: %[[R3:.*]] = arith.muli %[[THREAD_ID_Z]], %[[BLOCK_DIM_Y]] : i32
! CHECK: %[[R4:.*]] = arith.muli %[[R3]], %[[BLOCK_DIM_X]] : i32
! CHECK: %[[R5:.*]] = arith.muli %[[THREAD_ID_Y]], %[[BLOCK_DIM_X]] : i32
! CHECK: %[[RES0:.*]] = arith.muli %[[TMP]], %[[R2]] : i32
! CHECK: %[[RES1:.*]] = arith.addi %[[RES0]], %[[R4]] : i32
! CHECK: %[[RES2:.*]] = arith.addi %[[RES1]], %[[R5]] : i32
! CHECK: %[[RES3:.*]] = arith.addi %[[RES2]], %[[THREAD_ID_X]] : i32
! CHECK: %[[ONE:.*]] = arith.constant 1 : i32
! CHECK: %[[RANK:.*]] = arith.addi %[[RES3]], %[[ONE]] : i32
! CHECK: %[[COORD_SIZE:.*]] = fir.coordinate_of %[[RES]], size : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
! CHECK: fir.store %[[SIZE]] to %[[COORD_SIZE]] : !fir.ref<i32>
! CHECK: %[[COORD_RANK:.*]] = fir.coordinate_of %[[RES]], rank : (!fir.ref<!fir.type<_QMcooperative_groupsTgrid_group{_QMcooperative_groupsTgrid_group.handle:!fir.type<_QM__fortran_builtinsT__builtin_c_devptr{cptr:!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>}>,size:i32,rank:i32}>>) -> !fir.ref<i32>
! CHECK: fir.store %[[RANK]] to %[[COORD_RANK]] : !fir.ref<i32>
16 changes: 11 additions & 5 deletions flang/tools/f18/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(MODULES
"mma"
"__cuda_builtins"
"__cuda_device"
"cooperative_groups"
"cudadevice"
"ieee_arithmetic"
"ieee_exceptions"
Expand Down Expand Up @@ -60,12 +61,17 @@ if (NOT CMAKE_CROSSCOMPILING)
elseif(${filename} STREQUAL "__ppc_intrinsics" OR
${filename} STREQUAL "mma")
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod)
elseif(${filename} STREQUAL "__cuda_device")
elseif(${filename} STREQUAL "__cuda_device" OR
${filename} STREQUAL "cudadevice" OR
${filename} STREQUAL "cooperative_groups")
set(opts -fc1 -xcuda)
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
elseif(${filename} STREQUAL "cudadevice")
set(opts -fc1 -xcuda)
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
if(${filename} STREQUAL "__cuda_device")
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)
elseif(${filename} STREQUAL "cudadevice")
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_device.mod)
elseif(${filename} STREQUAL "cooperative_groups")
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/cudadevice.mod)
endif()
else()
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__fortran_builtins.mod)
if(${filename} STREQUAL "iso_fortran_env")
Expand Down