Skip to content

Commit 45c1841

Browse files
Use kernel device-specific descriptor to determine max-wg-size for this kernel
This resolves ``` RuntimeError: Exceeded the number of registers available on the hardware. The number registers per work-group cannot exceed 65536 for this kernel on this device. The kernel uses 108 registers per work-item for a total of 1024 work-items per work-group. -54 (PI_ERROR_INVALID_WORK_GROUP_SIZE) ``` when running example: ```python import dpctl.tensor as dpt m1 = dpt.ones((1000, 1000), dtype="i4", device="cuda") m2 = dpt.ones((1000, 1003), dtype="i4", device="cuda") r = dpt.matmul(m1[:, :900], m2[:900, :]) ```
1 parent 6efb2c9 commit 45c1841

File tree

1 file changed

+5
-2
lines changed
  • dpctl/tensor/libtensor/include/kernels/linalg_functions

1 file changed

+5
-2
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,10 +1365,13 @@ sycl::event _gemm_batch_nm_impl(sycl::queue &exec_q,
13651365
const std::uint32_t max_sg_size = krn.template get_info<
13661366
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
13671367

1368+
const size_t k_wg_sz = krn.template get_info<
1369+
sycl::info::kernel_device_specific::work_group_size>(dev);
1370+
13681371
// Limit work-group size
13691372
constexpr size_t wg_sz_limit(2048);
1370-
const size_t max_wg_sz = std::min<size_t>(
1371-
dev.get_info<sycl::info::device::max_work_group_size>(), wg_sz_limit);
1373+
const size_t max_wg_sz = std::min(wg_sz_limit, k_wg_sz);
1374+
13721375
const std::uint32_t max_subgroups_per_wg =
13731376
static_cast<std::uint32_t>(max_wg_sz / max_sg_size);
13741377

0 commit comments

Comments
 (0)