|
7 | 7 | # constexpr utilities
|
8 | 8 |
|
9 | 9 |
|
10 |
| -def _log2(i: core.constexpr): |
| 10 | +@core.constexpr_function |
| 11 | +def _log2(i): |
11 | 12 | log2 = 0
|
12 |
| - n = core.constexpr(i).value |
| 13 | + n = i |
13 | 14 | while n > 1:
|
14 | 15 | n >>= 1
|
15 | 16 | log2 += 1
|
16 |
| - return core.constexpr(log2) |
| 17 | + return log2 |
17 | 18 |
|
18 | 19 |
|
19 |
| -def _is_power_of_two(i: core.constexpr): |
20 |
| - n = i.value |
21 |
| - return core.constexpr((n & (n - 1)) == 0 and n != 0) |
| 20 | +@core.constexpr_function |
| 21 | +def _is_power_of_two(i): |
| 22 | + return (i & (i - 1)) == 0 and i != 0 |
22 | 23 |
|
23 | 24 |
|
24 | 25 | # -----------------------
|
@@ -263,8 +264,8 @@ def _sum_combine(a, b):
|
263 | 264 | # sum
|
264 | 265 |
|
265 | 266 |
|
266 |
| -def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr): |
267 |
| - dtype = core._unwrap_if_constexpr(dtype) |
| 267 | +@core.constexpr_function |
| 268 | +def _pick_sum_dtype(in_dtype, dtype): |
268 | 269 | if dtype is not None:
|
269 | 270 | return dtype
|
270 | 271 |
|
@@ -476,14 +477,13 @@ def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = co
|
476 | 477 | return _bitonic_merge(x, n_dims, descending, n_dims)
|
477 | 478 |
|
478 | 479 |
|
| 480 | +@core.constexpr_function |
479 | 481 | def _get_flip_dim(dim, shape):
|
480 |
| - dim = core._unwrap_if_constexpr(dim) |
481 |
| - shape = core._unwrap_if_constexpr(shape) |
482 | 482 | if dim is None:
|
483 | 483 | dim = len(shape) - 1
|
484 | 484 | if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
|
485 | 485 | dim += len(shape)
|
486 |
| - return core.constexpr(dim) |
| 486 | + return dim |
487 | 487 |
|
488 | 488 |
|
489 | 489 | @core._tensor_member_fn
|
|
0 commit comments