@@ -1591,20 +1591,33 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
15911591 axis_name , = axis_name
15921592 if axis_name not in axis_env .names :
15931593 raise NameError (f"unbound axis name: { axis_name } " )
1594+ axis_context = ctx .module_context .axis_context
15941595 axis_pos = list (axis_env .names ).index (axis_name )
1596+
1597+ # For partial auto, lower using iota.
1598+ if (isinstance (axis_context , SPMDAxisContext ) and
1599+ axis_context .manual_axes and
1600+ axis_context .manual_axes != frozenset (axis_context .mesh .axis_names )):
1601+ x = hlo .iota (ir .RankedTensorType .get (
1602+ [axis_env .sizes [axis_pos ]], ir .IntegerType .get_signless (32 )), mlir .i64_attr (0 ))
1603+ sharding_proto = (
1604+ NamedSharding (axis_context .mesh , P (axis_name ))
1605+ ._to_xla_hlo_sharding (1 ).to_proto ())
1606+ aval_in = ShapedArray ((axis_env .sizes [axis_pos ],), np .int32 )
1607+ aval_out = ShapedArray ((1 ,), np .int32 )
1608+ sx = mlir .wrap_with_sharding_op (ctx , x , aval_in , sharding_proto )
1609+ proto = pxla .manual_proto (aval_in , axis_context .manual_axes , axis_context .mesh )
1610+ x = mlir .wrap_with_full_to_shard_op (ctx , sx , aval_out , proto )
1611+ return hlo .reshape (ir .RankedTensorType .get ([], ir .IntegerType .get_signless (32 )), x )
1612+
15951613 nreplicas = axis_env .nreps // math .prod (axis_env .sizes )
15961614 div = mlir .ir_constant (
15971615 np .array (
15981616 nreplicas * math .prod (axis_env .sizes [axis_pos + 1 :]), dtype = np .uint32
15991617 )
16001618 )
16011619 mod = mlir .ir_constant (np .array (axis_env .sizes [axis_pos ], dtype = np .uint32 ))
1602- axis_context = ctx .module_context .axis_context
1603- is_spmd = isinstance (
1604- axis_context ,
1605- (SPMDAxisContext , ShardingContext ),
1606- )
1607- if is_spmd :
1620+ if isinstance (axis_context , (ShardingContext , SPMDAxisContext )):
16081621 device_id = hlo .partition_id ()
16091622 else :
16101623 device_id = hlo .replica_id ()
0 commit comments