@@ -70,62 +70,66 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
7070 const int32_t B = (offsets.size (0 ) - 1 ) / T;
7171 TORCH_CHECK (B > 0 );
7272
73- AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " pruned_hashmap_insert_{{ wdesc }}_cpu" , [&] {
74- using uidx_t =
75- std::conditional_t <std::is_same_v<index_t , int64_t >, uint64_t , uint32_t >;
76-
77- const auto * indices_acc = indices.data_ptr <index_t >();
78- const auto * dense_indices_acc = dense_indices.data_ptr <index_t >();
79- const auto * offsets_acc = offsets.data_ptr <index_t >();
80-
81- auto hash_table_acc = hash_table.accessor <int32_t , 2 >();
82- const auto hash_table_offsets_acc = hash_table_offsets.accessor <int64_t , 1 >();
83-
84- for (const auto t : c10::irange (T)) {
85- const auto table_start = hash_table_offsets_acc[t];
86- const auto table_end = hash_table_offsets_acc[t + 1 ];
87- if (table_start == table_end) {
88- continue ;
89- }
90- const auto capacity = table_end - table_start;
91-
92- for (const auto b : c10::irange (B)) {
93- const auto indices_start = offsets_acc[t * B + b];
94- const auto indices_end = offsets_acc[t * B + b + 1 ];
95- const auto L = indices_end - indices_start;
96-
97- for (const auto l : c10::irange (L)) {
98- const auto idx = indices_acc[indices_start + l];
99- const auto dense_idx = dense_indices_acc[indices_start + l];
100- if (dense_idx == -1 ) {
101- // -1 means this row has been pruned, do not insert it.
102- continue ;
103- }
73+ AT_DISPATCH_INDEX_TYPES (hash_table.scalar_type (), " pruned_hashmap_insert_{{ wdesc }}_cpu_0" , [&] {
74+ using hash_t = index_t ;
10475
105- auto slot = pruned_hash_function (static_cast <uidx_t >(idx)) % capacity;
106- while (true ) {
107- const auto ht_idx = table_start + static_cast <int64_t >(slot);
108- const auto slot_sparse_idx = hash_table_acc[ht_idx][0 ];
109-
110- // Empty slot
111- if (slot_sparse_idx == -1 ) {
112- hash_table_acc[ht_idx][0 ] = idx;
113- hash_table_acc[ht_idx][1 ] = dense_idx;
114- break ;
76+ AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " pruned_hashmap_insert_{{ wdesc }}_cpu_1" , [&] {
77+ using uidx_t =
78+ std::conditional_t <std::is_same_v<index_t , int64_t >, uint64_t , uint32_t >;
79+
80+ const auto * indices_acc = indices.data_ptr <index_t >();
81+ const auto * dense_indices_acc = dense_indices.data_ptr <index_t >();
82+ const auto * offsets_acc = offsets.data_ptr <index_t >();
83+
84+ auto hash_table_acc = hash_table.accessor <hash_t , 2 >();
85+ const auto hash_table_offsets_acc = hash_table_offsets.accessor <int64_t , 1 >();
86+
87+ for (const auto t : c10::irange (T)) {
88+ const auto table_start = hash_table_offsets_acc[t];
89+ const auto table_end = hash_table_offsets_acc[t + 1 ];
90+ if (table_start == table_end) {
91+ continue ;
92+ }
93+ const auto capacity = table_end - table_start;
94+
95+ for (const auto b : c10::irange (B)) {
96+ const auto indices_start = offsets_acc[t * B + b];
97+ const auto indices_end = offsets_acc[t * B + b + 1 ];
98+ const auto L = indices_end - indices_start;
99+
100+ for (const auto l : c10::irange (L)) {
101+ const auto idx = indices_acc[indices_start + l];
102+ const auto dense_idx = dense_indices_acc[indices_start + l];
103+ if (dense_idx == -1 ) {
104+ // -1 means this row has been pruned, do not insert it.
105+ continue ;
115106 }
116-
117- // Already exists (shouldn't happen in practice)
118- if (slot_sparse_idx == idx) {
119- hash_table_acc[ht_idx][1 ] = dense_idx;
120- break ;
107+
108+ auto slot = pruned_hash_function (static_cast <uidx_t >(idx)) % capacity;
109+ while (true ) {
110+ const auto ht_idx = table_start + static_cast <int64_t >(slot);
111+ const auto slot_sparse_idx = hash_table_acc[ht_idx][0 ];
112+
113+ // Empty slot
114+ if (slot_sparse_idx == -1 ) {
115+ hash_table_acc[ht_idx][0 ] = static_cast <hash_t >(idx);
116+ hash_table_acc[ht_idx][1 ] = static_cast <hash_t >(dense_idx);
117+ break ;
118+ }
119+
120+ // Already exists (shouldn't happen in practice)
121+ if (slot_sparse_idx == idx) {
122+ hash_table_acc[ht_idx][1 ] = static_cast <hash_t >(dense_idx);
123+ break ;
124+ }
125+
126+ // Linear probe
127+ slot = (slot + 1 ) % capacity;
121128 }
122-
123- // Linear probe
124- slot = (slot + 1 ) % capacity;
125129 }
126130 }
127131 }
128- }
132+ });
129133 });
130134
131135 return ;
@@ -519,14 +523,14 @@ Tensor pruned_array_lookup_cpu(
519523 auto dense_indices = empty_like (indices);
520524
521525 AT_DISPATCH_INDEX_TYPES (index_remappings.scalar_type (), " pruned_array_lookup_cpu_0" , [&] {
522- using hash_t = index_t ;
526+ using remap_t = index_t ;
523527
524528 AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " pruned_array_lookup_cpu_1" , [&] {
525529 const auto * indices_acc = indices.data_ptr <index_t >();
526530 auto * dense_indices_acc = dense_indices.data_ptr <index_t >();
527531 const auto * offsets_acc = offsets.data_ptr <index_t >();
528532
529- const auto index_remappings_acc = index_remappings.data_ptr <hash_t >();
533+ const auto index_remappings_acc = index_remappings.data_ptr <remap_t >();
530534 const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr <int64_t >();
531535
532536 at::parallel_for (0 , T, 1 , [&](int64_t begin, int64_t end) {
0 commit comments