@@ -12,13 +12,10 @@ using Tensor = at::Tensor;
1212
1313namespace fbgemm_gpu {
1414
15- // The wave size is forced to be 32 on ROCm devices in favor
16- // of granularity losses reduction.
17- constexpr int EMULATED_WARP_SIZE = 32 ;
1815// TODO: Update UNROLL_FACTOR
1916constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1 ;
2017constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
21- GROUP_INDEX_SELECT_UNROLL_FACTOR * EMULATED_WARP_SIZE ;
18+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize ;
2219
2320// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
2421constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
@@ -46,59 +43,218 @@ __launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
4643 const int64_t num_work_rows, // number of rows to work on per member
4744 const int64_t group_size) {
4845 const auto total_num_warps = warp_offsets_group[group_size];
49- int32_t num_cols = 0 ;
50- int32_t warps_per_row = 0 ;
51-
52- if constexpr (!USE_VAR_COLS) {
53- num_cols = num_cols_group[0 ];
54- warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
55- }
46+ // USE_INDEX_SELECT is a template argument; the compiler prunes the unused branch.
47+ if (USE_INDEX_SELECT) {
48+ for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
49+ warp_id < total_num_warps;
50+ warp_id += gridDim .x * blockDim .y ) {
51+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
52+ if (USE_VAR_COLS) {
53+ __shared__ int member_ids[kMaxThreads / kWarpSize ];
54+ if (threadIdx .x == 0 ) {
55+ binary_search_range (
56+ &member_ids[threadIdx .y ],
57+ warp_offsets_group + 1 ,
58+ warp_id,
59+ group_size);
60+ }
61+ syncwarp ();
62+ member_id = member_ids[threadIdx .y ];
63+ num_cols = num_cols_group[member_id];
64+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
65+ member_warp_id = warp_id - warp_offsets_group[member_id];
66+ } else {
67+ // All columns are the same
68+ num_cols = num_cols_group[0 ];
69+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
70+ member_id = warp_id / (warps_per_row * num_work_rows);
71+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
72+ }
73+ const auto row = member_warp_id / warps_per_row;
74+ const auto col_offset =
75+ ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
76+ (threadIdx .x * UNROLL_FACTOR);
77+ scalar_t * input =
78+ reinterpret_cast <scalar_t *>(input_ptrs[member_id]) + col_offset;
79+ scalar_t * output =
80+ reinterpret_cast <scalar_t *>(output_ptrs[member_id]) + col_offset;
5681
57- for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
58- warp_id < total_num_warps;
59- warp_id += gridDim .x * blockDim .y ) {
60- int32_t member_id = 0 ;
61- int32_t member_warp_id = 0 ;
62- if constexpr (USE_VAR_COLS) {
63- __shared__ int member_ids[kMaxThreads / EMULATED_WARP_SIZE];
64- if (threadIdx .x == 0 ) {
65- binary_search_range (
66- &member_ids[threadIdx .y ],
67- warp_offsets_group + 1 ,
68- warp_id,
69- group_size);
82+ index_t * indices = reinterpret_cast <index_t *>(indices_ptrs[member_id]);
83+ const index_t idx = indices[row];
84+ #pragma unroll
85+ for (int i = 0 ; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
86+ output[row * num_cols + i] = LDG (&input[idx * num_cols + i]);
7087 }
71- syncwarp ();
72- member_id = member_ids[threadIdx .y ];
73- num_cols = num_cols_group[member_id];
74- warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
75- member_warp_id = warp_id - warp_offsets_group[member_id];
76- } else {
77- // All columns are the same
78- member_id = warp_id / (warps_per_row * num_work_rows);
79- member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
8088 }
81- const auto row = member_warp_id / warps_per_row;
82- const auto col_offset =
83- ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
84- (threadIdx .x * UNROLL_FACTOR);
85- scalar_t * input =
86- reinterpret_cast <scalar_t *>(input_ptrs[member_id]) + col_offset;
87- scalar_t * output =
88- reinterpret_cast <scalar_t *>(output_ptrs[member_id]) + col_offset;
89-
90- index_t * indices = reinterpret_cast <index_t *>(indices_ptrs[member_id]);
91- const index_t idx = indices[row];
89+ } else {
90+ // Cache a handful of scatter destinations per warp so we can merge
91+ // consecutive updates that hit the same index before touching global memory.
92+ constexpr int kCacheSlots = 2 ;
93+ index_t cached_idx[kCacheSlots ];
94+ scalar_t cached_vals[kCacheSlots ][UNROLL_FACTOR];
95+ bool cached_valid[kCacheSlots ];
9296#pragma unroll
93- for (int i = 0 ; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
94- // Compile time conditional
95- if constexpr (USE_INDEX_SELECT) {
96- output[row * num_cols + i] = LDG (&input[idx * num_cols + i]);
97+ for (int slot = 0 ; slot < kCacheSlots ; ++slot) {
98+ cached_valid[slot] = false ;
99+ }
100+ int32_t active_member_id = -1 ;
101+ int32_t active_num_cols = 0 ;
102+ int32_t active_col_offset = -1 ;
103+ scalar_t * active_input_base = nullptr ;
104+ scalar_t * active_output_base = nullptr ;
105+ index_t * active_indices = nullptr ;
106+
107+ auto flush_cache = [&](scalar_t * out_base,
108+ int32_t num_cols,
109+ int32_t col_offset) {
110+ if (!out_base) {
111+ return ;
112+ }
113+ #pragma unroll
114+ for (int slot = 0 ; slot < kCacheSlots ; ++slot) {
115+ if (!cached_valid[slot]) {
116+ continue ;
117+ }
118+ const int64_t row_offset =
119+ static_cast <int64_t >(cached_idx[slot]) * num_cols;
120+ #pragma unroll
121+ for (int j = 0 ; j < UNROLL_FACTOR; ++j) {
122+ const int32_t col = col_offset + j;
123+ if (col >= num_cols) {
124+ break ;
125+ }
126+ gpuAtomicAddNoReturn (
127+ out_base + row_offset + col, cached_vals[slot][j]);
128+ }
129+ cached_valid[slot] = false ;
130+ }
131+ };
132+
133+ for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
134+ warp_id < total_num_warps;
135+ warp_id += gridDim .x * blockDim .y ) {
136+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
137+ if (USE_VAR_COLS) {
138+ __shared__ int member_ids[kMaxThreads / kWarpSize ];
139+ if (threadIdx .x == 0 ) {
140+ binary_search_range (
141+ &member_ids[threadIdx .y ],
142+ warp_offsets_group + 1 ,
143+ warp_id,
144+ group_size);
145+ }
146+ syncwarp ();
147+ member_id = member_ids[threadIdx .y ];
148+ num_cols = num_cols_group[member_id];
149+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
150+ member_warp_id = warp_id - warp_offsets_group[member_id];
97151 } else {
98- gpuAtomicAddNoReturn (
99- &output[idx * num_cols + i], input[row * num_cols + i]);
152+ // All columns are the same
153+ num_cols = num_cols_group[0 ];
154+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
155+ member_id = warp_id / (warps_per_row * num_work_rows);
156+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
157+ }
158+ const int64_t row = member_warp_id / warps_per_row;
159+ const int32_t col_offset =
160+ static_cast <int32_t >(((member_warp_id % warps_per_row)
161+ << LOG_COLS_PER_WARP) +
162+ (threadIdx .x * UNROLL_FACTOR));
163+
164+ const bool member_changed = member_id != active_member_id;
165+ const bool num_cols_changed =
166+ member_changed ? false : (num_cols != active_num_cols);
167+ const bool col_changed =
168+ member_changed ? false : (col_offset != active_col_offset);
169+ if (member_changed || num_cols_changed || col_changed) {
170+ flush_cache (active_output_base, active_num_cols, active_col_offset);
171+ active_member_id = member_id;
172+ active_num_cols = num_cols;
173+ active_col_offset = col_offset;
174+ active_input_base =
175+ reinterpret_cast <scalar_t *>(input_ptrs[member_id]);
176+ active_output_base =
177+ reinterpret_cast <scalar_t *>(output_ptrs[member_id]);
178+ active_indices =
179+ reinterpret_cast <index_t *>(indices_ptrs[member_id]);
180+ }
181+
182+ if (col_offset >= active_num_cols) {
183+ continue ;
184+ }
185+
186+ const index_t idx = active_indices[row];
187+
188+ scalar_t local_vals[UNROLL_FACTOR];
189+ #pragma unroll
190+ for (int j = 0 ; j < UNROLL_FACTOR; ++j) {
191+ local_vals[j] = static_cast <scalar_t >(0 );
192+ }
193+ const int64_t input_offset =
194+ static_cast <int64_t >(row) * active_num_cols + active_col_offset;
195+ #pragma unroll
196+ for (int j = 0 ; j < UNROLL_FACTOR; ++j) {
197+ const int32_t col = active_col_offset + j;
198+ if (col >= active_num_cols) {
199+ break ;
200+ }
201+ local_vals[j] = active_input_base[input_offset + j];
202+ }
203+
204+ bool appended = false ;
205+ #pragma unroll
206+ for (int slot = 0 ; slot < kCacheSlots ; ++slot) {
207+ if (cached_valid[slot] && cached_idx[slot] == idx) {
208+ #pragma unroll
209+ for (int j = 0 ; j < UNROLL_FACTOR; ++j) {
210+ const int32_t col = active_col_offset + j;
211+ if (col >= active_num_cols) {
212+ break ;
213+ }
214+ cached_vals[slot][j] += local_vals[j];
215+ }
216+ appended = true ;
217+ break ;
218+ }
219+ }
220+
221+ if (!appended) {
222+ int slot_to_use = -1 ;
223+ #pragma unroll
224+ for (int slot = 0 ; slot < kCacheSlots ; ++slot) {
225+ if (!cached_valid[slot]) {
226+ slot_to_use = slot;
227+ break ;
228+ }
229+ }
230+ if (slot_to_use == -1 ) {
231+ slot_to_use = 0 ;
232+ const int64_t row_offset =
233+ static_cast <int64_t >(cached_idx[slot_to_use]) *
234+ active_num_cols;
235+ #pragma unroll
236+ for (int j = 0 ; j < UNROLL_FACTOR; ++j) {
237+ const int32_t col = active_col_offset + j;
238+ if (col >= active_num_cols) {
239+ break ;
240+ }
241+ gpuAtomicAddNoReturn (
242+ active_output_base + row_offset + col,
243+ cached_vals[slot_to_use][j]);
244+ }
245+ cached_valid[slot_to_use] = false ;
246+ }
247+
248+ cached_idx[slot_to_use] = idx;
249+ #pragma unroll
250+ for (int j = 0 ; j < UNROLL_FACTOR; ++j) {
251+ cached_vals[slot_to_use][j] = local_vals[j];
252+ }
253+ cached_valid[slot_to_use] = true ;
100254 }
101255 }
256+
257+ flush_cache (active_output_base, active_num_cols, active_col_offset);
102258 }
103259}
104260
@@ -123,13 +279,13 @@ DLL_PUBLIC void group_index_select_or_add_cuda(
123279 at::cuda::OptionalCUDAGuard device_guard (device);
124280
125281 // Partition work based on num_work_rows
126- uint32_t num_warps_per_threadblock = kMaxThreads / EMULATED_WARP_SIZE ;
282+ uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize ;
127283 uint32_t max_grid_size =
128284 at::cuda::getCurrentDeviceProperties ()->multiProcessorCount * 8 ;
129285 uint32_t grid_size = std::min (
130286 cuda_calc_xblock_count (total_num_warps, num_warps_per_threadblock),
131287 max_grid_size);
132- dim3 block_size (EMULATED_WARP_SIZE , num_warps_per_threadblock, 1 );
288+ dim3 block_size (kWarpSize , num_warps_per_threadblock, 1 );
133289
134290#define INVOKE_GROUP_INDEX_SELECT_OR_ADD (USE_INDEX_SELECT, USE_VAR_COLS ) \
135291 FBGEMM_LAUNCH_KERNEL ( \
0 commit comments