Skip to content

Commit e3faf85

Browse files
committed
[export] Cleaned up types of [in|out]_shardings
Previously we declared Exported.in_shardings to be a sequence of `core.AbstractValue`, but in reality we only support `core.ShapedArray`. We change the type declaration and this allowed us to clean up some `# type: ignore"
1 parent 11370b7 commit e3faf85

File tree

5 files changed

+33
-37
lines changed

5 files changed

+33
-37
lines changed

jax/_src/export/_export.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ class Exported:
286286
"""
287287
fun_name: str
288288
in_tree: tree_util.PyTreeDef
289-
in_avals: tuple[core.AbstractValue, ...]
289+
in_avals: tuple[core.ShapedArray, ...]
290290
out_tree: tree_util.PyTreeDef
291-
out_avals: tuple[core.AbstractValue, ...]
291+
out_avals: tuple[core.ShapedArray, ...]
292292

293293
in_shardings_hlo: tuple[HloSharding | None, ...]
294294
out_shardings_hlo: tuple[HloSharding | None, ...]
@@ -1257,8 +1257,8 @@ def _call_exported_abstract_eval(
12571257
assert len(in_avals) == len(exported.in_avals) # since the pytrees have the same structure
12581258
# Check that the expected shapes match the actual ones
12591259
for arg_idx, (exp_aval, actual_aval) in enumerate(zip(exported.in_avals, in_avals)):
1260-
exp_aval: core.ShapedArray = exp_aval # type: ignore
1261-
actual_aval: core.ShapedArray = actual_aval # type: ignore
1260+
if not isinstance(actual_aval, core.ShapedArray):
1261+
raise ValueError(f"Expected ShapedArray but got: {actual_aval}")
12621262
def pp_arg_dim(dim_idx: int | None) -> str:
12631263
return shape_poly.pretty_print_dimension_descriptor(exported.in_tree,
12641264
arg_idx, dim_idx)
@@ -1302,10 +1302,10 @@ def pp_arg_dim(dim_idx: int | None) -> str:
13021302
exported_dim_values = [synthetic_eval.evaluate(solution[var])
13031303
for var in exported_dim_vars]
13041304
out_avals = tuple(
1305-
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars, # type: ignore
1305+
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
13061306
*exported_dim_values),
1307-
dtype=out_aval.dtype, weak_type=out_aval.weak_type, # type: ignore
1308-
named_shape=out_aval.named_shape) # type: ignore
1307+
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
1308+
named_shape=out_aval.named_shape)
13091309
for out_aval in exported.out_avals)
13101310
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
13111311

jax/_src/export/serialization.fbs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ table PyTreeDef {
3838

3939
enum AbstractValueKind: byte {
4040
shapedArray = 0,
41-
abstractToken = 1,
41+
abstractToken = 1, // unused
4242
}
4343

4444
enum DType: byte {

jax/_src/export/serialization.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -331,28 +331,24 @@ def _deserialize_pytreedef_to_pytree(p: ser_flatbuf.PyTreeDef):
331331

332332

333333
def _serialize_aval(
334-
builder: flatbuffers.Builder, aval: core.AbstractValue
334+
builder: flatbuffers.Builder, aval: core.ShapedArray
335335
) -> int:
336-
aval_type = type(aval)
337-
if aval_type is core.ShapedArray:
338-
aval_kind = ser_flatbuf.AbstractValueKind.shapedArray
339-
shape_offsets = [builder.CreateString(str(d)) for d in aval.shape]
340-
ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape))
341-
for d in reversed(shape_offsets):
342-
builder.PrependUOffsetTRelative(d)
343-
shape_vector_offset = builder.EndVector()
344-
345-
ser_flatbuf.AbstractValueStart(builder)
346-
ser_flatbuf.AbstractValueAddKind(builder, aval_kind)
347-
ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset)
348-
ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype])
349-
return ser_flatbuf.AbstractValueEnd(builder)
350-
else:
351-
raise NotImplementedError(f"serializing AbstractValue: {aval}")
336+
aval_kind = ser_flatbuf.AbstractValueKind.shapedArray
337+
shape_offsets = [builder.CreateString(str(d)) for d in aval.shape]
338+
ser_flatbuf.AbstractValueStartShapeVector(builder, len(aval.shape))
339+
for d in reversed(shape_offsets):
340+
builder.PrependUOffsetTRelative(d)
341+
shape_vector_offset = builder.EndVector()
342+
343+
ser_flatbuf.AbstractValueStart(builder)
344+
ser_flatbuf.AbstractValueAddKind(builder, aval_kind)
345+
ser_flatbuf.AbstractValueAddShape(builder, shape_vector_offset)
346+
ser_flatbuf.AbstractValueAddDtype(builder, _dtype_to_dtype_kind[aval.dtype])
347+
return ser_flatbuf.AbstractValueEnd(builder)
352348

353349

354350
def _deserialize_aval(aval: ser_flatbuf.AbstractValue,
355-
scope) -> core.AbstractValue:
351+
scope) -> core.ShapedArray:
356352
aval_kind = aval.Kind()
357353
if aval_kind == ser_flatbuf.AbstractValueKind.shapedArray:
358354
dtype = _dtype_kind_to_dtype[aval.Dtype()]

jax/_src/export/shape_poly.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,10 +1717,10 @@ def _dimension_size_lowering_rule(ctx, arg, *, dimension):
17171717
mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule)
17181718

17191719

1720-
def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
1720+
def all_dim_vars(args_avals: Sequence[core.ShapedArray]) -> Sequence[str]:
17211721
dim_vars: set[str] = set()
17221722
for a in args_avals:
1723-
for d in a.shape: # type: ignore[attribute-error,unused-ignore]
1723+
for d in a.shape:
17241724
if is_symbolic_dim(d):
17251725
dim_vars = dim_vars.union(d._get_vars())
17261726
return sorted(dim_vars)
@@ -1911,7 +1911,7 @@ def pretty_print_dimension_descriptor(
19111911

19121912
@util.cache()
19131913
def solve_dim_vars(
1914-
args_avals: Sequence[core.AbstractValue],
1914+
args_avals: Sequence[core.ShapedArray],
19151915
args_kwargs_tree: tree_util.PyTreeDef,
19161916
) -> tuple[DimVarEnv, ShapeConstraints, Sequence[tuple[str, int, int]]]:
19171917
"""Solves dimension variables in a called function's avals in terms of actual argument shapes.
@@ -1956,12 +1956,12 @@ def solve_dim_vars(
19561956
# tuples with argument name and its polymorphic shape ('args[0]', '(a, a + b'))
19571957
polymorphic_shape_specs: list[tuple[str, str]] = []
19581958
for arg_idx, aval in enumerate(args_avals):
1959-
if all(not is_symbolic_dim(d) for d in aval.shape): # type: ignore[attribute-error,unused-ignore]
1959+
if all(not is_symbolic_dim(d) for d in aval.shape):
19601960
continue
19611961
polymorphic_shape_specs.append(
19621962
(pretty_print_dimension_descriptor(args_kwargs_tree, arg_idx, None),
1963-
str(aval.shape))) # type: ignore[attribute-error,unused-ignore]
1964-
for dim_idx, aval_d in enumerate(aval.shape): # type: ignore[attribute-error,unused-ignore]
1963+
str(aval.shape)))
1964+
for dim_idx, aval_d in enumerate(aval.shape):
19651965
if is_symbolic_dim(aval_d):
19661966
synth_dim_var = pretty_print_dimension_descriptor(args_kwargs_tree,
19671967
arg_idx, dim_idx)
@@ -1976,7 +1976,7 @@ def solve_dim_vars(
19761976

19771977

19781978
def compute_dim_vars_from_arg_shapes(
1979-
args_avals: Sequence[core.AbstractValue],
1979+
args_avals: Sequence[core.ShapedArray],
19801980
*actual_args: jax.Array,
19811981
args_kwargs_tree: tree_util.PyTreeDef) -> Sequence[jax.Array]:
19821982
"""Computes values of dimension variables to unify args_avals with actual arguments.

tests/export_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,10 +1489,10 @@ def f_outer(x):
14891489
self.assertAllClose(2. * 2. * x + 10. + 4. * 2. * x, res)
14901490

14911491
@jtu.parameterized_filterable(
1492-
kwargs=[
1493-
dict(v=v)
1494-
for v in range(export.minimum_supported_serialization_version,
1495-
export.maximum_supported_serialization_version + 1)])
1492+
kwargs=[
1493+
dict(v=v)
1494+
for v in range(export.minimum_supported_serialization_version,
1495+
export.maximum_supported_serialization_version + 1)])
14961496
def test_ordered_effects_poly(self, *, v: int):
14971497
with config.jax_serialization_version(v):
14981498
logging.info(

0 commit comments

Comments
 (0)