Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,3 +2294,9 @@ def _default_pmap_no_rank_reduction(new_val):
' ragged_dot_general_p.'
),
)

jax_collectives_common_channel_id = bool_flag(
name='jax_collectives_common_channel_id',
default=True,
help="Should collectives use a common channel ID? Temporary feature flag.",
)
10 changes: 8 additions & 2 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,12 @@ def check_jaxpr_constants(closed_jaxpr: core.ClosedJaxpr):
except Exception as exc:
warnings.warn(message + f" Exception raised while generating report: {exc}")

# TODO(phawkins): it is my firm belief that:
# a) channel IDs have only a vestigal function when applied to collectives, and
# b) their identity does not matter. The presence or absence of a channel
# changes whether XLA considers collectives to be inter-replica or
# inter-partition, but beyond that we believe they have little effect.
COLLECTIVE_CHANNEL_ID = 1

def lower_jaxpr_to_module(
module_name: str,
Expand Down Expand Up @@ -1274,8 +1280,8 @@ def lower_jaxpr_to_module(
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')

# HLO channels need to start at 1
channel_iter = itertools.count(1)
# HLO channels need to start at 1. We reserve 1 for collectives.
channel_iter = itertools.count(COLLECTIVE_CHANNEL_ID + 1)
# Create a keepalives list that will be mutated during the lowering.
keepalives: list[Any] = []
host_callbacks: list[Any] = []
Expand Down
40 changes: 23 additions & 17 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,13 @@ def _check_axis_names(axes, api_name):
f"Found an unbound axis name: {name}. To fix this, please call"
f" {api_name} under `jax.shard_map`.")

# TODO(phawkins): remove this function and flag if this doesn't break anyone.
def _get_channel(ctx):
if config.jax_collectives_common_channel_id.value:
return mlir.COLLECTIVE_CHANNEL_ID
else:
return ctx.module_context.new_channel_id()

def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
len_0 = len(axis_index_groups[0])
Expand Down Expand Up @@ -1017,10 +1024,9 @@ def _positional_reduce(aval, arg):

def all_reduce(aval, x):
if is_spmd:
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
_get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
Expand Down Expand Up @@ -1117,9 +1123,11 @@ def _pcollectives_lowering_common(ctx, *, axis_name, perm, op_name):
and axis_context.manual_axes
)
if is_manual:
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE))
channel_handle=hlo.ChannelHandle.get(
_get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE
)
)
else:
other_args = {}
return full_perm, other_args
Expand Down Expand Up @@ -1305,11 +1313,11 @@ def source_to_front(group):
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
# We want to emit the collective-broadcast with global device IDs and a unique
# We want to emit the collective-broadcast with global device IDs and a
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
channel_handle = hlo.ChannelHandle.get(_get_channel(ctx),
mlir.DEVICE_TO_DEVICE_TYPE)
other_args = dict(channel_handle=channel_handle)
else:
other_args = {}
Expand Down Expand Up @@ -1358,11 +1366,11 @@ def _all_to_all_lowering(
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# We want to emit the all-gather with global device IDs and a
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
channel_handle = hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)
channel_handle = hlo.ChannelHandle.get(_get_channel(ctx),
mlir.DEVICE_TO_DEVICE_TYPE)
other_args = dict(channel_handle=channel_handle)
else:
other_args = {}
Expand Down Expand Up @@ -1519,7 +1527,7 @@ def _ragged_all_to_all_lowering(
ctx.module_context.axis_context, (SPMDAxisContext, ShardingContext))
if is_spmd:
ragged_all_to_all_attrs['channel_id'] = ir.IntegerAttr.get(
ir.IntegerType.get_signless(64), ctx.module_context.new_channel()
ir.IntegerType.get_signless(64), _get_channel(ctx)
)

return hlo.CustomCallOp(
Expand Down Expand Up @@ -1746,13 +1754,12 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
axis_index_groups)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# We want to emit the all-gather with global device IDs and a
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
_get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
Expand Down Expand Up @@ -1969,13 +1976,12 @@ def _reduce_scatter_lowering(
(SPMDAxisContext, ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# We want to emit the all-gather with global device IDs and a
# channel ID, as otherwise it interprets the devices as replicas instead
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
_get_channel(ctx), mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
Expand Down
Loading