@@ -2633,24 +2633,24 @@ def _integer_pow(x, *, y):
26332633def _integer_pow_lowering (ctx , x , * , y ):
26342634 # These cases are subsumed by the general case, but it's faster to emit these
26352635 # common cases directly.
2636- if y == 2 :
2636+ if y == 1 :
2637+ out = x
2638+ elif y == 2 :
26372639 out = hlo .multiply (x , x )
26382640 elif y == 3 :
26392641 out = hlo .multiply (hlo .multiply (x , x ), x )
2642+ elif y == - 1 :
2643+ out = hlo .divide (mlir .full_like_aval (ctx , 1 , ctx .avals_in [0 ]), x )
26402644 else :
26412645 lowering = mlir .lower_fun (_integer_pow , multiple_results = False )
2642- # TODO(b/217551391): emitting an out-of-line call leads to a large
2643- # expansion when the MLIR is lowered to HLO, because the HLO lowering
2644- # clones the callee. Consider unconditionally caching when the MLIR->HLO
2645- # lowering doesn't expand the program.
2646- lowering = mlir .cache_lowering (lowering )
2647- out = lowering (ctx , x , y = y )
2646+ if builtins .abs (y ) >= 3 :
2647+ lowering = mlir .cache_lowering (lowering )
2648+ out , = lowering (ctx , x , y = y )
26482649 if config .sharding_in_types .value :
26492650 aval_out , = ctx .avals_out
26502651 proto = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
2651- out = out [0 ] if isinstance (out , list ) else out
26522652 return [mlir .wrap_with_sharding_op (ctx , out , aval_out , proto )]
2653- return out if isinstance ( out , list ) else [out ]
2653+ return [out ]
26542654
26552655mlir .register_lowering (integer_pow_p , _integer_pow_lowering )
26562656
0 commit comments