Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1082,13 +1082,28 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
case UR_DEVICE_INFO_COMPOSITE_DEVICE:
case UR_DEVICE_INFO_MAX_READ_WRITE_IMAGE_ARGS:
case UR_DEVICE_INFO_GPU_EU_COUNT:
case UR_DEVICE_INFO_GPU_EU_SIMD_WIDTH:
case UR_DEVICE_INFO_GPU_EU_SLICES:
case UR_DEVICE_INFO_GPU_SUBSLICES_PER_SLICE:
case UR_DEVICE_INFO_GPU_EU_COUNT_PER_SUBSLICE:
case UR_DEVICE_INFO_GPU_HW_THREADS_PER_EU:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;

case UR_DEVICE_INFO_GPU_EU_SIMD_WIDTH: {
// Nvidia's GPU SIMD units are warp-size wide.
return ReturnValue(hDevice->getWarpSize());
}
case UR_DEVICE_INFO_GPU_HW_THREADS_PER_EU: {
int MaxHwThreads{0};
UR_CHECK_ERROR(cuDeviceGetAttribute(
reinterpret_cast<int *>(&MaxHwThreads),
CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR, hDevice->get()));
detail::ur::assertion(MaxHwThreads > 0);
// calculate the maximum number of resident warps per SM.
const uint32_t WarpSize = hDevice->getWarpSize();
detail::ur::assertion(WarpSize > 0);
uint32_t ResidentWarpCount = static_cast<uint32_t>(MaxHwThreads) / WarpSize;
return ReturnValue(ResidentWarpCount);
}

case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP:
case UR_DEVICE_INFO_COMMAND_BUFFER_EVENT_SUPPORT_EXP:
return ReturnValue(true);
Expand Down
7 changes: 7 additions & 0 deletions source/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct ur_device_handle_t_ {
int MaxChosenLocalMem{0};
bool MaxLocalMemSizeChosen{false};
uint32_t NumComputeUnits{0};
uint32_t WarpSize{0};

public:
ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
Expand All @@ -59,6 +60,10 @@ struct ur_device_handle_t_ {
reinterpret_cast<int *>(&NumComputeUnits),
CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, cuDevice));

UR_CHECK_ERROR(cuDeviceGetAttribute(reinterpret_cast<int *>(&WarpSize),
CU_DEVICE_ATTRIBUTE_WARP_SIZE,
cuDevice));

// Set local mem max size if env var is present
static const char *LocalMemSizePtrUR =
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE");
Expand Down Expand Up @@ -114,6 +119,8 @@ struct ur_device_handle_t_ {
bool maxLocalMemSizeChosen() { return MaxLocalMemSizeChosen; };

uint32_t getNumComputeUnits() const noexcept { return NumComputeUnits; };

uint32_t getWarpSize() const noexcept { return WarpSize; };
};

int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);
Loading