Skip to content

Commit abcc7fd

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Initial commit to add varying_manual_axes: frozenset[AxisName] to ShapedArray. Also add jax_varying_axes_in_types config to hide this option under while we develop it.
PiperOrigin-RevId: 736141670
1 parent 8b7cfcb commit abcc7fd

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

jax/_src/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def trace_context():
235235
threefry_partitionable.value,
236236
threefry_gpu_kernel_lowering.value,
237237
use_direct_linearize.value,
238+
varying_axes_in_types.value,
238239
softmax_custom_jvp.value,
239240
disable_jit.value,
240241
debug_key_reuse.value,
@@ -1084,6 +1085,14 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
10841085
help=('Use direct linearization instead JVP followed by partial eval'),
10851086
include_in_jit_key=True)
10861087

1088+
varying_axes_in_types = bool_state(
1089+
name='jax_varying_axes_in_types',
1090+
default=False,
1091+
help=('Adds varying manual axes to ShapedArray to track which mesh axes the'
1092+
' array is varying over. This will help to remove the efficient'
1093+
' transpose rewrite machinery in shard_map'),
1094+
include_in_jit_key=True)
1095+
10871096
data_dependent_tracing_fallback = bool_state(
10881097
name='jax_data_dependent_tracing_fallback',
10891098
default=False,

jax/_src/core.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,14 +1893,17 @@ def get_sharding(sharding, shape):
18931893

18941894

18951895
class ShapedArray(UnshapedArray):
1896-
__slots__ = ['shape', 'sharding'] # inherits slots from parent
1896+
__slots__ = ['shape', 'sharding', 'varying_manual_axes'] # inherits slots from parent
18971897
array_abstraction_level = 2
18981898

1899-
def __init__(self, shape, dtype, weak_type=False, *, sharding=None):
1899+
def __init__(self, shape, dtype, weak_type=False, *, sharding=None,
1900+
varying_manual_axes: frozenset[AxisName] = frozenset()):
19001901
self.shape = canonicalize_shape(shape)
19011902
self.dtype = _dtype_object(dtype)
19021903
self.weak_type = weak_type
19031904
self.sharding = get_sharding(sharding, self.shape)
1905+
if config.varying_axes_in_types.value:
1906+
self.varying_manual_axes = varying_manual_axes
19041907

19051908
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
19061909
if shape is None:
@@ -1911,6 +1914,9 @@ def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
19111914
weak_type = self.weak_type
19121915
if 'sharding' not in kwargs:
19131916
kwargs['sharding'] = self.sharding
1917+
if 'varying_manual_axes' not in kwargs:
1918+
kwargs['varying_manual_axes'] = getattr(self, 'varying_manual_axes',
1919+
frozenset())
19141920
return ShapedArray(shape, dtype, weak_type, **kwargs)
19151921

19161922
ndim = property(lambda self: len(self.shape))
@@ -1927,17 +1933,22 @@ def __eq__(self, other):
19271933
return (type(self) is type(other)
19281934
and self.dtype == other.dtype and self.shape == other.shape
19291935
and self.weak_type == other.weak_type
1930-
and self.sharding == other.sharding)
1936+
and self.sharding == other.sharding
1937+
and (getattr(self, 'varying_manual_axes', frozenset()) ==
1938+
getattr(other, 'varying_manual_axes', frozenset())))
19311939

19321940
def __hash__(self):
19331941
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
19341942
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
19351943
# the unique character code via hash(self.dtype.char)
1936-
return hash((self.shape, self.dtype, self.weak_type, self.sharding))
1944+
return hash((self.shape, self.dtype, self.weak_type, self.sharding,
1945+
getattr(self, 'varying_manual_axes', frozenset())))
19371946

19381947
def to_tangent_aval(self):
1939-
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
1940-
self.weak_type, sharding=self.sharding)
1948+
return ShapedArray(
1949+
self.shape, primal_dtype_to_tangent_dtype(self.dtype),
1950+
self.weak_type, sharding=self.sharding,
1951+
varying_manual_axes=getattr(self, 'varying_manual_axes', frozenset()))
19411952

19421953
def str_short(self, short_dtypes=False, mesh_axis_types=False):
19431954
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else

jax/_src/pallas/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def to_block_mapping(
343343
if self.block_shape is None:
344344
block_shape = array_aval.shape
345345
else:
346-
block_shape = self.block_shape
346+
block_shape = self.block_shape # type: ignore
347347
if len(array_aval.shape) != len(block_shape):
348348
raise ValueError(
349349
f"Block shape for {origin} (= {block_shape}) "

0 commit comments

Comments
 (0)