|
| 1 | +import copy |
| 2 | +import inspect |
| 3 | +import itertools |
| 4 | +import string |
| 5 | +import warnings |
| 6 | + |
| 7 | +from keras.src import layers |
| 8 | +from keras.src import tree |
| 9 | +from keras.src.backend.common.stateless_scope import StatelessScope |
| 10 | +from keras.src.utils.module_utils import tensorflow as tf |
| 11 | + |
| 12 | + |
| 13 | +class JaxExportArchive: |
| 14 | + def __init__(self): |
| 15 | + self._backend_variables = [] |
| 16 | + self._backend_trainable_variables = [] |
| 17 | + self._backend_non_trainable_variables = [] |
| 18 | + |
| 19 | + def track(self, resource): |
| 20 | + if not isinstance(resource, layers.Layer): |
| 21 | + raise ValueError( |
| 22 | + "Invalid resource type. Expected an instance of a " |
| 23 | + "JAX-based Keras `Layer` or `Model`. " |
| 24 | + f"Received instead an object of type '{type(resource)}'. " |
| 25 | + f"Object received: {resource}" |
| 26 | + ) |
| 27 | + |
| 28 | + if isinstance(resource, layers.Layer): |
| 29 | + # Variables in the lists below are actually part of the trackables |
| 30 | + # that get saved, because the lists are created in __init__. |
| 31 | + trainable_variables = resource.trainable_variables |
| 32 | + non_trainable_variables = resource.non_trainable_variables |
| 33 | + |
| 34 | + self._tf_trackable.trainable_variables += tree.map_structure( |
| 35 | + self._convert_to_tf_variable, trainable_variables |
| 36 | + ) |
| 37 | + self._tf_trackable.non_trainable_variables += tree.map_structure( |
| 38 | + self._convert_to_tf_variable, non_trainable_variables |
| 39 | + ) |
| 40 | + self._tf_trackable.variables = ( |
| 41 | + self._tf_trackable.trainable_variables |
| 42 | + + self._tf_trackable.non_trainable_variables |
| 43 | + ) |
| 44 | + |
| 45 | + self._backend_trainable_variables += trainable_variables |
| 46 | + self._backend_non_trainable_variables += non_trainable_variables |
| 47 | + self._backend_variables = ( |
| 48 | + self._backend_trainable_variables |
| 49 | + + self._backend_non_trainable_variables |
| 50 | + ) |
| 51 | + |
| 52 | + def add_endpoint(self, name, fn, input_signature=None, **kwargs): |
| 53 | + jax2tf_kwargs = kwargs.pop("jax2tf_kwargs", None) |
| 54 | + # Use `copy.copy()` to avoid modification issues. |
| 55 | + jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {} |
| 56 | + is_static = bool(kwargs.pop("is_static", False)) |
| 57 | + |
| 58 | + # Configure `jax2tf_kwargs` |
| 59 | + if "native_serialization" not in jax2tf_kwargs: |
| 60 | + jax2tf_kwargs["native_serialization"] = ( |
| 61 | + self._check_device_compatible() |
| 62 | + ) |
| 63 | + if "polymorphic_shapes" not in jax2tf_kwargs: |
| 64 | + jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape( |
| 65 | + input_signature |
| 66 | + ) |
| 67 | + |
| 68 | + # Note: we truncate the number of parameters to what is specified by |
| 69 | + # `input_signature`. |
| 70 | + fn_signature = inspect.signature(fn) |
| 71 | + fn_parameters = list(fn_signature.parameters.values()) |
| 72 | + |
| 73 | + if is_static: |
| 74 | + from jax.experimental import jax2tf |
| 75 | + |
| 76 | + jax_fn = jax2tf.convert(fn, **jax2tf_kwargs) |
| 77 | + jax_fn.__signature__ = inspect.Signature( |
| 78 | + parameters=fn_parameters[0 : len(input_signature)], |
| 79 | + return_annotation=fn_signature.return_annotation, |
| 80 | + ) |
| 81 | + |
| 82 | + decorated_fn = tf.function( |
| 83 | + jax_fn, |
| 84 | + input_signature=input_signature, |
| 85 | + autograph=False, |
| 86 | + ) |
| 87 | + else: |
| 88 | + # 1. Create a stateless wrapper for `fn` |
| 89 | + # 2. jax2tf the stateless wrapper |
| 90 | + # 3. Create a stateful function that binds the variables with |
| 91 | + # the jax2tf converted stateless wrapper |
| 92 | + # 4. Make the signature of the stateful function the same as the |
| 93 | + # original function |
| 94 | + # 5. Wrap in a `tf.function` |
| 95 | + def stateless_fn(variables, *args, **kwargs): |
| 96 | + state_mapping = zip(self._backend_variables, variables) |
| 97 | + with StatelessScope(state_mapping=state_mapping) as scope: |
| 98 | + output = fn(*args, **kwargs) |
| 99 | + |
| 100 | + # Gather updated non-trainable variables |
| 101 | + non_trainable_variables = [] |
| 102 | + for var in self._backend_non_trainable_variables: |
| 103 | + new_value = scope.get_current_value(var) |
| 104 | + non_trainable_variables.append(new_value) |
| 105 | + return output, non_trainable_variables |
| 106 | + |
| 107 | + jax2tf_stateless_fn = self._convert_jax2tf_function( |
| 108 | + stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs |
| 109 | + ) |
| 110 | + |
| 111 | + def stateful_fn(*args, **kwargs): |
| 112 | + output, non_trainable_variables = jax2tf_stateless_fn( |
| 113 | + # Change the trackable `ListWrapper` to a plain `list` |
| 114 | + list(self._tf_trackable.variables), |
| 115 | + *args, |
| 116 | + **kwargs, |
| 117 | + ) |
| 118 | + for var, new_value in zip( |
| 119 | + self._tf_trackable.non_trainable_variables, |
| 120 | + non_trainable_variables, |
| 121 | + ): |
| 122 | + var.assign(new_value) |
| 123 | + return output |
| 124 | + |
| 125 | + stateful_fn.__signature__ = inspect.Signature( |
| 126 | + parameters=fn_parameters[0 : len(input_signature)], |
| 127 | + return_annotation=fn_signature.return_annotation, |
| 128 | + ) |
| 129 | + |
| 130 | + decorated_fn = tf.function( |
| 131 | + stateful_fn, |
| 132 | + input_signature=input_signature, |
| 133 | + autograph=False, |
| 134 | + ) |
| 135 | + return decorated_fn |
| 136 | + |
| 137 | + def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None): |
| 138 | + from jax.experimental import jax2tf |
| 139 | + |
| 140 | + variables_shapes = self._to_polymorphic_shape( |
| 141 | + self._backend_variables, allow_none=False |
| 142 | + ) |
| 143 | + input_shapes = list(jax2tf_kwargs["polymorphic_shapes"]) |
| 144 | + jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes |
| 145 | + return jax2tf.convert(fn, **jax2tf_kwargs) |
| 146 | + |
| 147 | + def _to_polymorphic_shape(self, struct, allow_none=True): |
| 148 | + if allow_none: |
| 149 | + # Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz |
| 150 | + # for unknown non-batch dims. Defined here to be scope per endpoint. |
| 151 | + dim_names = itertools.chain( |
| 152 | + string.ascii_lowercase, |
| 153 | + itertools.starmap( |
| 154 | + lambda a, b: a + b, |
| 155 | + itertools.product(string.ascii_lowercase, repeat=2), |
| 156 | + ), |
| 157 | + ) |
| 158 | + |
| 159 | + def convert_shape(x): |
| 160 | + poly_shape = [] |
| 161 | + for index, dim in enumerate(list(x.shape)): |
| 162 | + if dim is not None: |
| 163 | + poly_shape.append(str(dim)) |
| 164 | + elif not allow_none: |
| 165 | + raise ValueError( |
| 166 | + f"Illegal None dimension in {x} with shape {x.shape}" |
| 167 | + ) |
| 168 | + elif index == 0: |
| 169 | + poly_shape.append("batch") |
| 170 | + else: |
| 171 | + poly_shape.append(next(dim_names)) |
| 172 | + return "(" + ", ".join(poly_shape) + ")" |
| 173 | + |
| 174 | + return tree.map_structure(convert_shape, struct) |
| 175 | + |
| 176 | + def _check_device_compatible(self): |
| 177 | + from jax import default_backend as jax_device |
| 178 | + |
| 179 | + if ( |
| 180 | + jax_device() == "gpu" |
| 181 | + and len(tf.config.list_physical_devices("GPU")) == 0 |
| 182 | + ): |
| 183 | + warnings.warn( |
| 184 | + "JAX backend is using GPU for export, but installed " |
| 185 | + "TF package cannot access GPU, so reloading the model with " |
| 186 | + "the TF runtime in the same environment will not work. " |
| 187 | + "To use JAX-native serialization for high-performance export " |
| 188 | + "and serving, please install `tensorflow-gpu` and ensure " |
| 189 | + "CUDA version compatibility between your JAX and TF " |
| 190 | + "installations." |
| 191 | + ) |
| 192 | + return False |
| 193 | + else: |
| 194 | + return True |
0 commit comments