@@ -338,20 +338,19 @@ def cumprod(input, axis=0, reverse=False):
338338def _compare_and_swap (x , flip , i : core .constexpr , n_dims : core .constexpr ):
339339 n_outer : core .constexpr = x .numel >> n_dims
340340 shape : core .constexpr = [n_outer * 2 ** i , 2 , 2 ** (n_dims - i - 1 )]
341- y = core .reshape (x , shape )
342- # slice left/right with 'stride' 2**(n_dims - i - 1)
343- left , right = core .split (core .permute (y , (0 , 2 , 1 )))
344- left = core .reshape (core .broadcast_to (left [:, None , :], shape ), x .shape )
345- right = core .reshape (core .broadcast_to (right [:, None , :], shape ), x .shape )
346- left = left .to (y .dtype )
347- right = right .to (y .dtype )
348- # actual compare-and-swap
341+
342+ # flip along middle dimension (the bitwise XORs will be optimised away):
349343 idtype = core .get_int_dtype (bitwidth = x .dtype .primitive_bitwidth , signed = True )
350- ileft = left .to (idtype , bitcast = True )
351- iright = right .to (idtype , bitcast = True )
352- ix = x .to (idtype , bitcast = True )
353- ret = ix ^ core .where ((left > right ) != flip , ileft ^ iright , zeros_like (ix ))
354- return ret .to (x .dtype , bitcast = True )
344+ ix = core .reshape (x , shape ).to (idtype , bitcast = True )
345+ iy = ix ^ xor_sum (ix , 1 , True )
346+ y = core .reshape (iy .to (x .dtype , bitcast = True ), x .shape )
347+
348+ # determines whether we are in the right (rather than left) position along the axis:
349+ is_right = core .reshape (core .broadcast_to (core .arange (0 , 2 )[None , :, None ], shape ), x .shape )
350+
351+ # conditional swap:
352+ ret = core .where ((x > y ) != (flip ^ is_right ), y , x )
353+ return ret
355354
356355
357356@jit
@@ -451,14 +450,8 @@ def flip(x, dim=None):
451450
452451 idtype = core .get_int_dtype (bitwidth = x .dtype .primitive_bitwidth , signed = True )
453452 y = core .reshape (x .to (idtype , bitcast = True ), [2 ] * steps )
454- y = core .expand_dims (y , start )
455- flip = (core .arange (0 , 2 )[:, None ] == 1 - core .arange (0 , 2 ))
456453 for i in core .static_range (start , steps ):
457- flip2 = flip
458- for j in core .static_range (0 , steps + 1 ):
459- if j != i and j != i + 1 :
460- flip2 = core .expand_dims (flip2 , j )
461- y = sum (y * flip2 , i + 1 , keep_dims = True , dtype = y .dtype )
454+ y = y ^ xor_sum (y , i , True )
462455 x = core .reshape (y , x .shape ).to (x .dtype , bitcast = True )
463456 return x
464457
0 commit comments