diff --git a/jax/_src/config.py b/jax/_src/config.py index 93b211d53c5d..63129697ea10 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -2294,3 +2294,9 @@ def _default_pmap_no_rank_reduction(new_val): ' ragged_dot_general_p.' ), ) + +jax_collectives_common_channel_id = bool_flag( + name='jax_collectives_common_channel_id', + default=True, + help="Should collectives use a common channel ID? Temporary feature flag.", +) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f3dd5b503ed2..90f879372e01 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1182,6 +1182,12 @@ def check_jaxpr_constants(closed_jaxpr: core.ClosedJaxpr): except Exception as exc: warnings.warn(message + f" Exception raised while generating report: {exc}") +# TODO(phawkins): it is my firm belief that: +# a) channel IDs have only a vestigal function when applied to collectives, and +# b) their identity does not matter. The presence or absence of a channel +# changes whether XLA considers collectives to be inter-replica or +# inter-partition, but beyond that we believe they have little effect. +COLLECTIVE_CHANNEL_ID = 1 def lower_jaxpr_to_module( module_name: str, @@ -1274,8 +1280,8 @@ def lower_jaxpr_to_module( if unlowerable_effects: raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') - # HLO channels need to start at 1 - channel_iter = itertools.count(1) + # HLO channels need to start at 1. We reserve 1 for collectives. + channel_iter = itertools.count(COLLECTIVE_CHANNEL_ID + 1) # Create a keepalives list that will be mutated during the lowering. keepalives: list[Any] = [] host_callbacks: list[Any] = [] diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 357a93f0d2f8..2ea3deb6c3e9 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -987,6 +987,13 @@ def _check_axis_names(axes, api_name): f"Found an unbound axis name: {name}. To fix this, please call" f" {api_name} under `jax.shard_map`.") +# TODO(phawkins): remove this function and flag if this doesn't break anyone. +def _get_channel(ctx): + if config.jax_collectives_common_channel_id.value: + return mlir.COLLECTIVE_CHANNEL_ID + else: + return ctx.module_context.new_channel_id() + def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): len_0 = len(axis_index_groups[0]) @@ -1017,10 +1024,9 @@ def _positional_reduce(aval, arg): def all_reduce(aval, x): if is_spmd: - channel = ctx.module_context.new_channel() other_args = dict( channel_handle=hlo.ChannelHandle.get( - channel, mlir.DEVICE_TO_DEVICE_TYPE), + _get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} @@ -1117,9 +1123,11 @@ def _pcollectives_lowering_common(ctx, *, axis_name, perm, op_name): and axis_context.manual_axes ) if is_manual: - channel = ctx.module_context.new_channel() other_args = dict( - channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)) + channel_handle=hlo.ChannelHandle.get( + _get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE + ) + ) else: other_args = {} return full_perm, other_args @@ -1305,11 +1313,11 @@ def source_to_front(group): (SPMDAxisContext, ShardingContext), ) if is_spmd: - # We want to emit the collective-broadcast with global device IDs and a unique + # We want to emit the collective-broadcast with global device IDs and a # channel ID, as otherwise it interprets the devices as replicas instead # of partitions - and XLA is configured with only a single replica. - channel = ctx.module_context.new_channel() - channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE) + channel_handle = hlo.ChannelHandle.get(_get_channel(ctx), + mlir.DEVICE_TO_DEVICE_TYPE) other_args = dict(channel_handle=channel_handle) else: other_args = {} @@ -1358,11 +1366,11 @@ def _all_to_all_lowering( (SPMDAxisContext, ShardingContext), ) if is_spmd: - # We want to emit the all-gather with global device IDs and a unique + # We want to emit the all-gather with global device IDs and a # channel ID, as otherwise it interprets the devices as replicas instead # of partitions - and XLA is configured with only a single replica. - channel = ctx.module_context.new_channel() - channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE) + channel_handle = hlo.ChannelHandle.get(_get_channel(ctx), + mlir.DEVICE_TO_DEVICE_TYPE) other_args = dict(channel_handle=channel_handle) else: other_args = {} @@ -1519,7 +1527,7 @@ def _ragged_all_to_all_lowering( ctx.module_context.axis_context, (SPMDAxisContext, ShardingContext)) if is_spmd: ragged_all_to_all_attrs['channel_id'] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(64), ctx.module_context.new_channel() + ir.IntegerType.get_signless(64), _get_channel(ctx) ) return hlo.CustomCallOp( @@ -1746,13 +1754,12 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, axis_index_groups) if is_spmd: - # We want to emit the all-gather with global device IDs and a unique + # We want to emit the all-gather with global device IDs and a # channel ID, as otherwise it interprets the devices as replicas instead # of partitions - and XLA is configured with only a single replica. - channel = ctx.module_context.new_channel() other_args = dict( channel_handle=hlo.ChannelHandle.get( - channel, mlir.DEVICE_TO_DEVICE_TYPE), + _get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} @@ -1969,13 +1976,12 @@ def _reduce_scatter_lowering( (SPMDAxisContext, ShardingContext), ) if is_spmd: - # We want to emit the all-gather with global device IDs and a unique + # We want to emit the all-gather with global device IDs and a # channel ID, as otherwise it interprets the devices as replicas instead # of partitions - and XLA is configured with only a single replica. - channel = ctx.module_context.new_channel() other_args = dict( channel_handle=hlo.ChannelHandle.get( - channel, mlir.DEVICE_TO_DEVICE_TYPE), + _get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {}