Skip to content

Commit 683bfb5

Browse files
committed
opt group_index_select_or_add_2d_kernel
1 parent 588d269 commit 683bfb5

File tree

2 files changed

+209
-53
lines changed

2 files changed

+209
-53
lines changed

fbgemm_gpu/src/sparse_ops/sparse_group_index.cu

Lines changed: 208 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@ using Tensor = at::Tensor;
1212

1313
namespace fbgemm_gpu {
1414

15-
// The wave size is forced to be 32 on ROCm devices in favor
16-
// of granularity losses reduction.
17-
constexpr int EMULATED_WARP_SIZE = 32;
1815
// TODO: Update UNROLL_FACTOR
1916
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
2017
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
21-
GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE;
18+
GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
2219

2320
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2421
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -46,59 +43,218 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4643
const int64_t num_work_rows, // number of rows to work on per member
4744
const int64_t group_size) {
4845
const auto total_num_warps = warp_offsets_group[group_size];
49-
int32_t num_cols = 0;
50-
int32_t warps_per_row = 0;
51-
52-
if constexpr (!USE_VAR_COLS) {
53-
num_cols = num_cols_group[0];
54-
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
55-
}
46+
// USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
47+
if (USE_INDEX_SELECT) {
48+
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
49+
warp_id < total_num_warps;
50+
warp_id += gridDim.x * blockDim.y) {
51+
int32_t member_id, member_warp_id, num_cols, warps_per_row;
52+
if (USE_VAR_COLS) {
53+
__shared__ int member_ids[kMaxThreads / kWarpSize];
54+
if (threadIdx.x == 0) {
55+
binary_search_range(
56+
&member_ids[threadIdx.y],
57+
warp_offsets_group + 1,
58+
warp_id,
59+
group_size);
60+
}
61+
syncwarp();
62+
member_id = member_ids[threadIdx.y];
63+
num_cols = num_cols_group[member_id];
64+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
65+
member_warp_id = warp_id - warp_offsets_group[member_id];
66+
} else {
67+
// All columns are the same
68+
num_cols = num_cols_group[0];
69+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
70+
member_id = warp_id / (warps_per_row * num_work_rows);
71+
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
72+
}
73+
const auto row = member_warp_id / warps_per_row;
74+
const auto col_offset =
75+
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
76+
(threadIdx.x * UNROLL_FACTOR);
77+
scalar_t* input =
78+
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
79+
scalar_t* output =
80+
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;
5681

57-
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
58-
warp_id < total_num_warps;
59-
warp_id += gridDim.x * blockDim.y) {
60-
int32_t member_id = 0;
61-
int32_t member_warp_id = 0;
62-
if constexpr (USE_VAR_COLS) {
63-
__shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
64-
if (threadIdx.x == 0) {
65-
binary_search_range(
66-
&member_ids[threadIdx.y],
67-
warp_offsets_group + 1,
68-
warp_id,
69-
group_size);
82+
index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
83+
const index_t idx = indices[row];
84+
#pragma unroll
85+
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
86+
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
7087
}
71-
syncwarp();
72-
member_id = member_ids[threadIdx.y];
73-
num_cols = num_cols_group[member_id];
74-
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
75-
member_warp_id = warp_id - warp_offsets_group[member_id];
76-
} else {
77-
// All columns are the same
78-
member_id = warp_id / (warps_per_row * num_work_rows);
79-
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
8088
}
81-
const auto row = member_warp_id / warps_per_row;
82-
const auto col_offset =
83-
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
84-
(threadIdx.x * UNROLL_FACTOR);
85-
scalar_t* input =
86-
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
87-
scalar_t* output =
88-
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;
89-
90-
index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
91-
const index_t idx = indices[row];
89+
} else {
90+
// Cache a handful of scatter destinations per warp so we can merge
91+
// consecutive updates that hit the same index before touching global memory.
92+
constexpr int kCacheSlots = 2;
93+
index_t cached_idx[kCacheSlots];
94+
scalar_t cached_vals[kCacheSlots][UNROLL_FACTOR];
95+
bool cached_valid[kCacheSlots];
9296
#pragma unroll
93-
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
94-
// Compile time conditional
95-
if constexpr (USE_INDEX_SELECT) {
96-
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
97+
for (int slot = 0; slot < kCacheSlots; ++slot) {
98+
cached_valid[slot] = false;
99+
}
100+
int32_t active_member_id = -1;
101+
int32_t active_num_cols = 0;
102+
int32_t active_col_offset = -1;
103+
scalar_t* active_input_base = nullptr;
104+
scalar_t* active_output_base = nullptr;
105+
index_t* active_indices = nullptr;
106+
107+
auto flush_cache = [&](scalar_t* out_base,
108+
int32_t num_cols,
109+
int32_t col_offset) {
110+
if (!out_base) {
111+
return;
112+
}
113+
#pragma unroll
114+
for (int slot = 0; slot < kCacheSlots; ++slot) {
115+
if (!cached_valid[slot]) {
116+
continue;
117+
}
118+
const int64_t row_offset =
119+
static_cast<int64_t>(cached_idx[slot]) * num_cols;
120+
#pragma unroll
121+
for (int j = 0; j < UNROLL_FACTOR; ++j) {
122+
const int32_t col = col_offset + j;
123+
if (col >= num_cols) {
124+
break;
125+
}
126+
gpuAtomicAddNoReturn(
127+
out_base + row_offset + col, cached_vals[slot][j]);
128+
}
129+
cached_valid[slot] = false;
130+
}
131+
};
132+
133+
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
134+
warp_id < total_num_warps;
135+
warp_id += gridDim.x * blockDim.y) {
136+
int32_t member_id, member_warp_id, num_cols, warps_per_row;
137+
if (USE_VAR_COLS) {
138+
__shared__ int member_ids[kMaxThreads / kWarpSize];
139+
if (threadIdx.x == 0) {
140+
binary_search_range(
141+
&member_ids[threadIdx.y],
142+
warp_offsets_group + 1,
143+
warp_id,
144+
group_size);
145+
}
146+
syncwarp();
147+
member_id = member_ids[threadIdx.y];
148+
num_cols = num_cols_group[member_id];
149+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
150+
member_warp_id = warp_id - warp_offsets_group[member_id];
97151
} else {
98-
gpuAtomicAddNoReturn(
99-
&output[idx * num_cols + i], input[row * num_cols + i]);
152+
// All columns are the same
153+
num_cols = num_cols_group[0];
154+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
155+
member_id = warp_id / (warps_per_row * num_work_rows);
156+
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
157+
}
158+
const int64_t row = member_warp_id / warps_per_row;
159+
const int32_t col_offset =
160+
static_cast<int32_t>(((member_warp_id % warps_per_row)
161+
<< LOG_COLS_PER_WARP) +
162+
(threadIdx.x * UNROLL_FACTOR));
163+
164+
const bool member_changed = member_id != active_member_id;
165+
const bool num_cols_changed =
166+
member_changed ? false : (num_cols != active_num_cols);
167+
const bool col_changed =
168+
member_changed ? false : (col_offset != active_col_offset);
169+
if (member_changed || num_cols_changed || col_changed) {
170+
flush_cache(active_output_base, active_num_cols, active_col_offset);
171+
active_member_id = member_id;
172+
active_num_cols = num_cols;
173+
active_col_offset = col_offset;
174+
active_input_base =
175+
reinterpret_cast<scalar_t*>(input_ptrs[member_id]);
176+
active_output_base =
177+
reinterpret_cast<scalar_t*>(output_ptrs[member_id]);
178+
active_indices =
179+
reinterpret_cast<index_t*>(indices_ptrs[member_id]);
180+
}
181+
182+
if (col_offset >= active_num_cols) {
183+
continue;
184+
}
185+
186+
const index_t idx = active_indices[row];
187+
188+
scalar_t local_vals[UNROLL_FACTOR];
189+
#pragma unroll
190+
for (int j = 0; j < UNROLL_FACTOR; ++j) {
191+
local_vals[j] = static_cast<scalar_t>(0);
192+
}
193+
const int64_t input_offset =
194+
static_cast<int64_t>(row) * active_num_cols + active_col_offset;
195+
#pragma unroll
196+
for (int j = 0; j < UNROLL_FACTOR; ++j) {
197+
const int32_t col = active_col_offset + j;
198+
if (col >= active_num_cols) {
199+
break;
200+
}
201+
local_vals[j] = active_input_base[input_offset + j];
202+
}
203+
204+
bool appended = false;
205+
#pragma unroll
206+
for (int slot = 0; slot < kCacheSlots; ++slot) {
207+
if (cached_valid[slot] && cached_idx[slot] == idx) {
208+
#pragma unroll
209+
for (int j = 0; j < UNROLL_FACTOR; ++j) {
210+
const int32_t col = active_col_offset + j;
211+
if (col >= active_num_cols) {
212+
break;
213+
}
214+
cached_vals[slot][j] += local_vals[j];
215+
}
216+
appended = true;
217+
break;
218+
}
219+
}
220+
221+
if (!appended) {
222+
int slot_to_use = -1;
223+
#pragma unroll
224+
for (int slot = 0; slot < kCacheSlots; ++slot) {
225+
if (!cached_valid[slot]) {
226+
slot_to_use = slot;
227+
break;
228+
}
229+
}
230+
if (slot_to_use == -1) {
231+
slot_to_use = 0;
232+
const int64_t row_offset =
233+
static_cast<int64_t>(cached_idx[slot_to_use]) *
234+
active_num_cols;
235+
#pragma unroll
236+
for (int j = 0; j < UNROLL_FACTOR; ++j) {
237+
const int32_t col = active_col_offset + j;
238+
if (col >= active_num_cols) {
239+
break;
240+
}
241+
gpuAtomicAddNoReturn(
242+
active_output_base + row_offset + col,
243+
cached_vals[slot_to_use][j]);
244+
}
245+
cached_valid[slot_to_use] = false;
246+
}
247+
248+
cached_idx[slot_to_use] = idx;
249+
#pragma unroll
250+
for (int j = 0; j < UNROLL_FACTOR; ++j) {
251+
cached_vals[slot_to_use][j] = local_vals[j];
252+
}
253+
cached_valid[slot_to_use] = true;
100254
}
101255
}
256+
257+
flush_cache(active_output_base, active_num_cols, active_col_offset);
102258
}
103259
}
104260

@@ -123,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
123279
at::cuda::OptionalCUDAGuard device_guard(device);
124280

125281
// Partition work based on num_work_rows
126-
uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE;
282+
uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
127283
uint32_t max_grid_size =
128284
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
129285
uint32_t grid_size = std::min(
130286
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
131287
max_grid_size);
132-
dim3 block_size(EMULATED_WARP_SIZE, num_warps_per_threadblock, 1);
288+
dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
133289

134290
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
135291
FBGEMM_LAUNCH_KERNEL( \

fbgemm_gpu/test/sparse/index_select_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/usr/bin/env python3
1+
#!/usr/bin/env python3
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
44
#

0 commit comments

Comments
 (0)