Skip to content

Commit 4b08356

Browse files
committed
Add replicate and centralize tiling distributions
1 parent effce44 commit 4b08356

File tree

1 file changed

+71
-2
lines changed

1 file changed

+71
-2
lines changed

include/kernel_float/tiling.h

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ struct block_size {
4141
};
4242

4343
template<size_t... Ns>
44-
struct virtual_block_size {
44+
struct unravel_block_size {
4545
static constexpr size_t rank = sizeof...(Ns);
4646

4747
KERNEL_FLOAT_INLINE
48-
virtual_block_size(dim3 thread_index) {
48+
unravel_block_size(dim3 thread_index) {
4949
thread_index_ = thread_index.x;
5050
}
5151

@@ -160,6 +160,75 @@ struct block_cyclic {
160160
template<size_t N, size_t K>
161161
using type = cyclic_impl<M, N, K>;
162162
};
163+
164+
template<size_t N>
165+
struct replicate_impl {
166+
static constexpr bool is_exhaustive = true;
167+
static constexpr size_t items_per_thread = N;
168+
169+
KERNEL_FLOAT_INLINE
170+
static constexpr bool local_is_present(size_t thread_index, size_t local_index) {
171+
return true;
172+
}
173+
174+
KERNEL_FLOAT_INLINE
175+
static constexpr size_t local_to_global(size_t thread_index, size_t local_index) {
176+
return local_index;
177+
}
178+
179+
KERNEL_FLOAT_INLINE
180+
static constexpr size_t global_to_local(size_t global_index) {
181+
return global_index;
182+
}
183+
184+
KERNEL_FLOAT_INLINE
185+
static constexpr size_t global_to_owner(size_t global_index) {
186+
return 0;
187+
}
188+
};
189+
190+
struct replicate {
191+
template<size_t N, size_t K>
192+
using type = replicate_impl<N>;
193+
};
194+
195+
template<size_t N, size_t Root, size_t K>
196+
struct centralize_impl {
197+
static_assert(Root < K, "index of root thread cannot exceed thread block size");
198+
static constexpr bool is_exhaustive = K == 1;
199+
static constexpr size_t items_per_thread = N;
200+
201+
KERNEL_FLOAT_INLINE
202+
static constexpr bool local_is_present(size_t thread_index, size_t local_index) {
203+
return K == 1 || thread_index == Root;
204+
}
205+
206+
KERNEL_FLOAT_INLINE
207+
static constexpr size_t local_to_global(size_t thread_index, size_t local_index) {
208+
return local_index;
209+
}
210+
211+
KERNEL_FLOAT_INLINE
212+
static constexpr size_t global_to_local(size_t global_index) {
213+
return global_index;
214+
}
215+
216+
KERNEL_FLOAT_INLINE
217+
static constexpr size_t global_to_owner(size_t global_index) {
218+
return Root;
219+
}
220+
};
221+
222+
struct at_thread0 {
223+
template<size_t N, size_t K>
224+
using type = centralize_impl<N, 0, K>;
225+
};
226+
227+
template<size_t I>
228+
struct at_thread {
229+
template<size_t N, size_t K>
230+
using type = centralize_impl<N, I, K>;
231+
};
163232
} // namespace dist
164233

165234
template<typename... Ds>

0 commit comments

Comments
 (0)