@@ -2631,24 +2631,24 @@ def _integer_pow(x, *, y):
26312631def _integer_pow_lowering (ctx , x , * , y ):
26322632 # These cases are subsumed by the general case, but it's faster to emit these
26332633 # common cases directly.
2634- if y == 2 :
2634+ if y == 1 :
2635+ out = x
2636+ elif y == 2 :
26352637 out = hlo .multiply (x , x )
26362638 elif y == 3 :
26372639 out = hlo .multiply (hlo .multiply (x , x ), x )
2640+ elif y == - 1 :
2641+ out = hlo .divide (mlir .full_like_aval (ctx , 1 , ctx .avals_in [0 ]), x )
26382642 else :
26392643 lowering = mlir .lower_fun (_integer_pow , multiple_results = False )
2640- # TODO(b/217551391): emitting an out-of-line call leads to a large
2641- # expansion when the MLIR is lowered to HLO, because the HLO lowering
2642- # clones the callee. Consider unconditionally caching when the MLIR->HLO
2643- # lowering doesn't expand the program.
2644- lowering = mlir .cache_lowering (lowering )
2645- out = lowering (ctx , x , y = y )
2644+ if builtins .abs (y ) >= 3 :
2645+ lowering = mlir .cache_lowering (lowering )
2646+ out , = lowering (ctx , x , y = y )
26462647 if config .sharding_in_types .value :
26472648 aval_out , = ctx .avals_out
26482649 proto = aval_out .sharding ._to_xla_hlo_sharding (aval_out .ndim ).to_proto ()
2649- out = out [0 ] if isinstance (out , list ) else out
26502650 return [mlir .wrap_with_sharding_op (ctx , out , aval_out , proto )]
2651- return out if isinstance ( out , list ) else [out ]
2651+ return [out ]
26522652
26532653mlir .register_lowering (integer_pow_p , _integer_pow_lowering )
26542654
0 commit comments