Skip to content

Commit 0d6f12d

Browse files
committed
Rename the class/function from compare to less in operator kernels.
PR: USTC-KnowledgeComputingLab/qmb#53 Signed-off-by: Hao Zhang <[email protected]>
2 parents d9fc9c8 + 6d8b069 commit 0d6f12d

File tree

2 files changed

+40
-40
lines changed

2 files changed

+40
-40
lines changed

qmb/_hamiltonian_cpu.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ void apply_within_kernel(
108108
std::int64_t low = 0;
109109
std::int64_t high = result_batch_size - 1;
110110
std::int64_t mid = 0;
111-
auto compare = array_less<std::uint8_t, n_qubytes>();
111+
auto less = array_less<std::uint8_t, n_qubytes>();
112112
while (low <= high) {
113113
mid = (low + high) / 2;
114-
if (compare(current_configs, result_configs[mid])) {
114+
if (less(current_configs, result_configs[mid])) {
115115
high = mid - 1;
116-
} else if (compare(result_configs[mid], current_configs)) {
116+
} else if (less(result_configs[mid], current_configs)) {
117117
low = mid + 1;
118118
} else {
119119
success = true;
@@ -256,11 +256,11 @@ auto apply_within_interface(
256256
return result_psi;
257257
}
258258

259-
template<typename T, typename Compare = std::less<T>>
259+
template<typename T, typename Less = std::less<T>>
260260
void add_into_heap(T* heap, std::int64_t heap_size, const T& value) {
261-
auto compare = Compare();
261+
auto less = Less();
262262
std::int64_t index = 0;
263-
if (compare(value, heap[index])) {
263+
if (less(value, heap[index])) {
264264
} else {
265265
while (true) {
266266
// Calculate the indices of the left and right children
@@ -271,8 +271,8 @@ void add_into_heap(T* heap, std::int64_t heap_size, const T& value) {
271271
if (left_present) {
272272
if (right_present) {
273273
// Both left and right children are present
274-
if (compare(value, heap[left])) {
275-
if (compare(value, heap[right])) {
274+
if (less(value, heap[left])) {
275+
if (less(value, heap[right])) {
276276
// Both children are greater than the value, break
277277
break;
278278
} else {
@@ -281,12 +281,12 @@ void add_into_heap(T* heap, std::int64_t heap_size, const T& value) {
281281
index = right;
282282
}
283283
} else {
284-
if (compare(value, heap[right])) {
284+
if (less(value, heap[right])) {
285285
// The right child is greater than the value
286286
heap[index] = heap[left];
287287
index = left;
288288
} else {
289-
if (compare(heap[left], heap[right])) {
289+
if (less(heap[left], heap[right])) {
290290
heap[index] = heap[left];
291291
index = left;
292292
} else {
@@ -297,7 +297,7 @@ void add_into_heap(T* heap, std::int64_t heap_size, const T& value) {
297297
}
298298
} else {
299299
// Only the left child is present
300-
if (compare(value, heap[left])) {
300+
if (less(value, heap[left])) {
301301
break;
302302
} else {
303303
heap[index] = heap[left];
@@ -307,7 +307,7 @@ void add_into_heap(T* heap, std::int64_t heap_size, const T& value) {
307307
} else {
308308
if (right_present) {
309309
// Only the right child is present
310-
if (compare(value, heap[right])) {
310+
if (less(value, heap[right])) {
311311
break;
312312
} else {
313313
heap[index] = heap[right];
@@ -370,12 +370,12 @@ void find_relative_kernel(
370370
std::int64_t low = 0;
371371
std::int64_t high = exclude_size - 1;
372372
std::int64_t mid = 0;
373-
auto compare = array_less<std::uint8_t, n_qubytes>();
373+
auto less = array_less<std::uint8_t, n_qubytes>();
374374
while (low <= high) {
375375
mid = (low + high) / 2;
376-
if (compare(current_configs, exclude_configs[mid])) {
376+
if (less(current_configs, exclude_configs[mid])) {
377377
high = mid - 1;
378-
} else if (compare(exclude_configs[mid], current_configs)) {
378+
} else if (less(exclude_configs[mid], current_configs)) {
379379
low = mid + 1;
380380
} else {
381381
success = false;

qmb/_hamiltonian_cuda.cu

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ __device__ void apply_within_kernel(
113113
std::int64_t low = 0;
114114
std::int64_t high = result_batch_size - 1;
115115
std::int64_t mid = 0;
116-
auto compare = array_less<std::uint8_t, n_qubytes>();
116+
auto less = array_less<std::uint8_t, n_qubytes>();
117117
while (low <= high) {
118118
mid = (low + high) / 2;
119-
if (compare(current_configs, result_configs[mid])) {
119+
if (less(current_configs, result_configs[mid])) {
120120
high = mid - 1;
121-
} else if (compare(result_configs[mid], current_configs)) {
121+
} else if (less(result_configs[mid], current_configs)) {
122122
low = mid + 1;
123123
} else {
124124
success = true;
@@ -309,14 +309,14 @@ __device__ void mutex_unlock(int* mutex1, int* mutex2) {
309309
_mutex_unlock(mutex2);
310310
}
311311

312-
template<typename T, typename Compare = thrust::less<T>>
312+
template<typename T, typename Less = thrust::less<T>>
313313
__device__ void add_into_heap(T* heap, int* mutex, std::int64_t heap_size, const T& value) {
314-
auto compare = Compare();
314+
auto less = Less();
315315
std::int64_t index = 0;
316-
if (compare(value, heap[index])) {
316+
if (less(value, heap[index])) {
317317
} else {
318318
mutex_lock(&mutex[index]);
319-
if (compare(value, heap[index])) {
319+
if (less(value, heap[index])) {
320320
mutex_unlock(&mutex[index]);
321321
} else {
322322
while (true) {
@@ -329,14 +329,14 @@ __device__ void add_into_heap(T* heap, int* mutex, std::int64_t heap_size, const
329329
if (left_present) {
330330
if (right_present) {
331331
// Both left and right children are present
332-
if (compare(value, heap[left])) {
333-
if (compare(value, heap[right])) {
332+
if (less(value, heap[left])) {
333+
if (less(value, heap[right])) {
334334
// Both children are greater than the value, break
335335
break;
336336
} else {
337337
// The left child is greater than the value, treat it as if only the right child is present
338338
mutex_lock(&mutex[right]);
339-
if (compare(value, heap[right])) {
339+
if (less(value, heap[right])) {
340340
mutex_unlock(&mutex[right]);
341341
break;
342342
} else {
@@ -346,10 +346,10 @@ __device__ void add_into_heap(T* heap, int* mutex, std::int64_t heap_size, const
346346
}
347347
}
348348
} else {
349-
if (compare(value, heap[right])) {
349+
if (less(value, heap[right])) {
350350
// The right child is greater than the value, treat it as if only the left child is present
351351
mutex_lock(&mutex[left]);
352-
if (compare(value, heap[left])) {
352+
if (less(value, heap[left])) {
353353
mutex_unlock(&mutex[left]);
354354
break;
355355
} else {
@@ -359,8 +359,8 @@ __device__ void add_into_heap(T* heap, int* mutex, std::int64_t heap_size, const
359359
}
360360
} else {
361361
mutex_lock(&mutex[left], &mutex[right]);
362-
if (compare(heap[left], heap[right])) {
363-
if (compare(value, heap[left])) {
362+
if (less(heap[left], heap[right])) {
363+
if (less(value, heap[left])) {
364364
mutex_unlock(&mutex[left], &mutex[right]);
365365
break;
366366
} else {
@@ -369,7 +369,7 @@ __device__ void add_into_heap(T* heap, int* mutex, std::int64_t heap_size, const
369369
index = left;
370370
}
371371
} else {
372-
if (compare(value, heap[right])) {
372+
if (less(value, heap[right])) {
373373
mutex_unlock(&mutex[left], &mutex[right]);
374374
break;
375375
} else {
@@ -382,11 +382,11 @@ __device__ void add_into_heap(T* heap, int* mutex, std::int64_t heap_size, const
382382
}
383383
} else {
384384
// Only the left child is present
385-
if (compare(value, heap[left])) {
385+
if (less(value, heap[left])) {
386386
break;
387387
} else {
388388
mutex_lock(&mutex[left]);
389-
if (compare(value, heap[left])) {
389+
if (less(value, heap[left])) {
390390
mutex_unlock(&mutex[left]);
391391
break;
392392
} else {
@@ -399,11 +399,11 @@ __device__ void add_into_heap(T* heap, int* mutex, std::int64_t heap_size, const
399399
} else {
400400
if (right_present) {
401401
// Only the right child is present
402-
if (compare(value, heap[right])) {
402+
if (less(value, heap[right])) {
403403
break;
404404
} else {
405405
mutex_lock(&mutex[right]);
406-
if (compare(value, heap[right])) {
406+
if (less(value, heap[right])) {
407407
mutex_unlock(&mutex[right]);
408408
break;
409409
} else {
@@ -473,12 +473,12 @@ __device__ void find_relative_kernel(
473473
std::int64_t low = 0;
474474
std::int64_t high = exclude_size - 1;
475475
std::int64_t mid = 0;
476-
auto compare = array_less<std::uint8_t, n_qubytes>();
476+
auto less = array_less<std::uint8_t, n_qubytes>();
477477
while (low <= high) {
478478
mid = (low + high) / 2;
479-
if (compare(current_configs, exclude_configs[mid])) {
479+
if (less(current_configs, exclude_configs[mid])) {
480480
high = mid - 1;
481-
} else if (compare(exclude_configs[mid], current_configs)) {
481+
} else if (less(exclude_configs[mid], current_configs)) {
482482
low = mid + 1;
483483
} else {
484484
success = false;
@@ -704,12 +704,12 @@ __device__ void single_relative_kernel(
704704
std::int64_t low = 0;
705705
std::int64_t high = exclude_size - 1;
706706
std::int64_t mid = 0;
707-
auto compare = array_less<std::uint8_t, n_qubytes>();
707+
auto less = array_less<std::uint8_t, n_qubytes>();
708708
while (low <= high) {
709709
mid = (low + high) / 2;
710-
if (compare(current_configs, exclude_configs[mid])) {
710+
if (less(current_configs, exclude_configs[mid])) {
711711
high = mid - 1;
712-
} else if (compare(exclude_configs[mid], current_configs)) {
712+
} else if (less(exclude_configs[mid], current_configs)) {
713713
low = mid + 1;
714714
} else {
715715
success = false;

0 commit comments

Comments
 (0)