Skip to content

Commit d32fc6a

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int64_t indices and offsets in TBE inference [10/N] (#3263)
Summary: Pull Request resolved: #3263 X-link: facebookresearch/FBGEMM#364 - Add int64_t support for `pruned_hashmap_insert_{{ wdesc }}_cpu` to prevent runtime errors in tests Reviewed By: spcyppt Differential Revision: D64705072 fbshipit-source-id: cccc7ea306316e15058f7a31bb481044a4011b00
1 parent 2cf3606 commit d32fc6a

File tree

2 files changed

+60
-56
lines changed

2 files changed

+60
-56
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -70,62 +70,66 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
7070
const int32_t B = (offsets.size(0) - 1) / T;
7171
TORCH_CHECK(B > 0);
7272

73-
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
74-
using uidx_t =
75-
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;
76-
77-
const auto* indices_acc = indices.data_ptr<index_t>();
78-
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
79-
const auto* offsets_acc = offsets.data_ptr<index_t>();
80-
81-
auto hash_table_acc = hash_table.accessor<int32_t, 2>();
82-
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();
83-
84-
for (const auto t : c10::irange(T)) {
85-
const auto table_start = hash_table_offsets_acc[t];
86-
const auto table_end = hash_table_offsets_acc[t + 1];
87-
if (table_start == table_end) {
88-
continue;
89-
}
90-
const auto capacity = table_end - table_start;
91-
92-
for (const auto b : c10::irange(B)) {
93-
const auto indices_start = offsets_acc[t * B + b];
94-
const auto indices_end = offsets_acc[t * B + b + 1];
95-
const auto L = indices_end - indices_start;
96-
97-
for (const auto l : c10::irange(L)) {
98-
const auto idx = indices_acc[indices_start + l];
99-
const auto dense_idx = dense_indices_acc[indices_start + l];
100-
if (dense_idx == -1) {
101-
// -1 means this row has been pruned, do not insert it.
102-
continue;
103-
}
73+
AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu_0", [&] {
74+
using hash_t = index_t;
10475

105-
auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
106-
while (true) {
107-
const auto ht_idx = table_start + static_cast<int64_t>(slot);
108-
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];
109-
110-
// Empty slot
111-
if (slot_sparse_idx == -1) {
112-
hash_table_acc[ht_idx][0] = idx;
113-
hash_table_acc[ht_idx][1] = dense_idx;
114-
break;
76+
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu_1", [&] {
77+
using uidx_t =
78+
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;
79+
80+
const auto* indices_acc = indices.data_ptr<index_t>();
81+
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
82+
const auto* offsets_acc = offsets.data_ptr<index_t>();
83+
84+
auto hash_table_acc = hash_table.accessor<hash_t, 2>();
85+
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();
86+
87+
for (const auto t : c10::irange(T)) {
88+
const auto table_start = hash_table_offsets_acc[t];
89+
const auto table_end = hash_table_offsets_acc[t + 1];
90+
if (table_start == table_end) {
91+
continue;
92+
}
93+
const auto capacity = table_end - table_start;
94+
95+
for (const auto b : c10::irange(B)) {
96+
const auto indices_start = offsets_acc[t * B + b];
97+
const auto indices_end = offsets_acc[t * B + b + 1];
98+
const auto L = indices_end - indices_start;
99+
100+
for (const auto l : c10::irange(L)) {
101+
const auto idx = indices_acc[indices_start + l];
102+
const auto dense_idx = dense_indices_acc[indices_start + l];
103+
if (dense_idx == -1) {
104+
// -1 means this row has been pruned, do not insert it.
105+
continue;
115106
}
116-
117-
// Already exists (shouldn't happen in practice)
118-
if (slot_sparse_idx == idx) {
119-
hash_table_acc[ht_idx][1] = dense_idx;
120-
break;
107+
108+
auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
109+
while (true) {
110+
const auto ht_idx = table_start + static_cast<int64_t>(slot);
111+
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];
112+
113+
// Empty slot
114+
if (slot_sparse_idx == -1) {
115+
hash_table_acc[ht_idx][0] = static_cast<hash_t>(idx);
116+
hash_table_acc[ht_idx][1] = static_cast<hash_t>(dense_idx);
117+
break;
118+
}
119+
120+
// Already exists (shouldn't happen in practice)
121+
if (slot_sparse_idx == idx) {
122+
hash_table_acc[ht_idx][1] = static_cast<hash_t>(dense_idx);
123+
break;
124+
}
125+
126+
// Linear probe
127+
slot = (slot + 1) % capacity;
121128
}
122-
123-
// Linear probe
124-
slot = (slot + 1) % capacity;
125129
}
126130
}
127131
}
128-
}
132+
});
129133
});
130134

131135
return;
@@ -519,14 +523,14 @@ Tensor pruned_array_lookup_cpu(
519523
auto dense_indices = empty_like(indices);
520524

521525
AT_DISPATCH_INDEX_TYPES(index_remappings.scalar_type(), "pruned_array_lookup_cpu_0", [&] {
522-
using hash_t = index_t;
526+
using remap_t = index_t;
523527

524528
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu_1", [&] {
525529
const auto* indices_acc = indices.data_ptr<index_t>();
526530
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
527531
const auto* offsets_acc = offsets.data_ptr<index_t>();
528532

529-
const auto index_remappings_acc = index_remappings.data_ptr<hash_t>();
533+
const auto index_remappings_acc = index_remappings.data_ptr<remap_t>();
530534
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();
531535

532536
at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {

fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,14 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
100100
}
101101
}
102102

103-
template <typename index_t, typename hash_t>
103+
template <typename index_t, typename remap_t>
104104
__global__
105105
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
106106
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
107107
indices,
108108
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
109109
offsets,
110-
const pta::PackedTensorAccessor32<hash_t, 1, at::RestrictPtrTraits>
110+
const pta::PackedTensorAccessor32<remap_t, 1, at::RestrictPtrTraits>
111111
index_remappings,
112112
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
113113
index_remappings_offsets,
@@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda(
231231
232232
AT_DISPATCH_INDEX_TYPES(
233233
index_remappings.scalar_type(), "pruned_array_lookup_cuda_0", [&] {
234-
using hash_t = index_t;
234+
using remap_t = index_t;
235235
236236
AT_DISPATCH_INDEX_TYPES(
237237
indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] {
@@ -249,7 +249,7 @@ Tensor pruned_array_lookup_cuda(
249249
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
250250
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
251251
MAKE_PTA_WITH_NAME(
252-
func_name, index_remappings, hash_t, 1, 32),
252+
func_name, index_remappings, remap_t, 1, 32),
253253
MAKE_PTA_WITH_NAME(
254254
func_name, index_remappings_offsets, int64_t, 1, 32),
255255
B,

0 commit comments

Comments
 (0)