Skip to content

Commit fd6f06c

Browse files
committed
[EK-VT] Adding a workgroup class to VecUtils
This diff adds a new class called `WorkgroupSize` to the `VecUtils` header file. The `WorkgroupSize` class takes three `uint32_t` values as parameters and stores them in a single `uint32_t` variable using bitwise operations. This class is used in the Vulkan backend to specify the size of a workgroup for a given operation. Differential Revision: [D70021019](https://our.internmc.facebook.com/intern/diff/D70021019/) ghstack-source-id: 267790990 Pull Request resolved: #8632
1 parent f87940d commit fd6f06c

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

backends/vulkan/runtime/utils/VecUtils.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,5 +479,49 @@ inline int64_t multiply_integers(Iter begin, Iter end) {
479479
begin, end, static_cast<int64_t>(1), std::multiplies<>());
480480
}
481481

482+
class WorkgroupSize final {
483+
uint32_t val;
484+
485+
public:
486+
explicit WorkgroupSize() : val(0) {}
487+
explicit WorkgroupSize(const uint32_t x, const uint32_t y, const uint32_t z) {
488+
// shift numbers by multiple of 11 bits, since each local workgroup axis can
489+
// be 1024 at most and which is 0x400. only z axis can't store 1024, because
490+
// it would overflow uint32_t storage.
491+
if (z == 1024) {
492+
throw std::runtime_error(
493+
"Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage");
494+
}
495+
val = x | (y << 11) | (z << 22);
496+
}
497+
498+
explicit WorkgroupSize(const uvec3& vec) {
499+
// shift numbers by multiple of 11 bits, since each local workgroup axis can
500+
// be 1024 at most and which is 0x400. only z axis can't store 1024, because
501+
// it would overflow uint32_t storage.
502+
if (z == 1024) {
503+
throw std::runtime_error(
504+
"Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage");
505+
}
506+
val = vec[0u] | (vec[1u] << 11) | (vec[2u] << 22);
507+
}
508+
509+
explicit inline operator uvec3() const {
510+
return {
511+
val & 0x7ffu,
512+
(val >> 11) & 0x7ffu,
513+
(val >> 22),
514+
};
515+
}
516+
517+
explicit inline operator uint32_t() const {
518+
return val;
519+
}
520+
521+
inline constexpr uint32_t operator[](const int idx) const {
522+
return (val >> (11 * idx)) & 0x7ffu;
523+
}
524+
};
525+
482526
} // namespace utils
483527
} // namespace vkcompute

0 commit comments

Comments
 (0)