-
hey, concerning supporting string in JAX, i notice that I am trying to custom backward of conv and try to wrap @custom_vjp
def custom_conv_general_dilated(lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: PrecisionLike = None,
preferred_element_type: Optional[DType] = None):
return jax.lax.conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers=dimension_numbers)
def custom_conv_general_fwd(inputs, W, window_strides, padding, lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: jax.lax.ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: PrecisionLike = None,
preferred_element_type: Optional[DType] = None):
print("Custom Conv general forward", "="*20)
# custom modification to inputs and W
outputs = custom_conv_general_dilated(inputs, W_lo, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers=dimension_numbers)
return outputs, (inputs, W, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers) The error msg is --> 211 return custom_conv_general_dilated(inputs, W, strides, padding, one, one,
212 dimension_numbers=dimension_numbers) + b
213 return init_fun, apply_fun
[... skipping hidden 4 frame]
~/anaconda3/envs/jax/lib/python3.8/site-packages/jax/core.py in concrete_aval(x)
960 if hasattr(x, '__jax_array__'):
961 return concrete_aval(x.__jax_array__())
--> 962 raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
963 "type")
964
TypeError: Value 'VALID' with type <class 'str'> is not a valid JAX type Below is the def conv_general_dilated(
lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: Union[str, Sequence[Tuple[int, int]]],
lhs_dilation: Optional[Sequence[int]] = None,
rhs_dilation: Optional[Sequence[int]] = None,
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
<https://www.tensorflow.org/xla/operation_semantics#conv_convolution>`_
operator.
Args:
lhs: a rank `n+2` dimensional input array.
rhs: a rank `n+2` dimensional array of kernel weights.
window_strides: a sequence of `n` integers, representing the inter-window
strides.
padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
`n` `(low, high)` integer pairs that give the padding to apply before and
after each spatial dimension.
lhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
is also known as transposed convolution.
rhs_dilation: `None`, or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
is also known as atrous convolution.
dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or
a 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a
string of length `n+2`.
feature_group_count: integer, default 1. See XLA HLO docs.
batch_group_count: integer, default 1. See XLA HLO docs.
precision: Optional. Either ``None``, which means the default precision for
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or
'fastest', see the ``jax.default_matmul_precision`` context manager), or a
tuple of two ``lax.Precision`` enums or strings indicating precision of
``lhs`` and ``rhs``.
preferred_element_type: Optional. Either ``None``, which means the default
accumulation type for the input types, or a datatype, indicating to
accumulate results to and return a result with that datatype.
Returns:
An array containing the convolution result.
In the string case of ``dimension_numbers``, each character identifies by
position:
- the batch dimensions in ``lhs``, ``rhs``, and the output with the character
'N',
- the feature dimensions in `lhs` and the output with the character 'C',
- the input and output feature dimensions in rhs with the characters 'I'
and 'O' respectively, and
- spatial dimension correspondences between lhs, rhs, and the output using
any distinct characters.
For example, to indicate dimension numbers consistent with the ``conv``
function with two spatial dimensions, one could use ``('NCHW', 'OIHW',
'NCHW')``. As another example, to indicate dimension numbers consistent with
the TensorFlow Conv2D operation, one could use ``('NHWC', 'HWIO', 'NHWC')``.
When using the latter form of convolution dimension specification, window
strides are associated with spatial dimension character labels according to
the order in which the labels appear in the ``rhs_spec`` string, so that
``window_strides[0]`` is matched with the dimension corresponding to the first
character appearing in rhs_spec that is not ``'I'`` or ``'O'``.
If ``dimension_numbers`` is ``None``, the default is ``('NCHW', 'OIHW',
'NCHW')`` (for a 2D convolution).
"""
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
if lhs_dilation is None:
lhs_dilation = (1,) * (lhs.ndim - 2)
elif isinstance(padding, str) and not len(lhs_dilation) == lhs_dilation.count(1):
raise ValueError(
"String padding is not implemented for transposed convolution "
"using this op. Please either exactly specify the required padding or "
"use conv_transpose.")
if rhs_dilation is None:
rhs_dilation = (1,) * (rhs.ndim - 2)
if isinstance(padding, str):
lhs_perm, rhs_perm, _ = dnums
rhs_shape = np.take(rhs.shape, rhs_perm)[2:] # type: ignore[index]
effective_rhs_shape = [(k-1) * r + 1 for k, r in zip(rhs_shape, rhs_dilation)]
padding = padtype_to_pads(
np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index]
window_strides, padding)
preferred_element_type = (None if preferred_element_type is None else
np.dtype(preferred_element_type))
return conv_general_dilated_p.bind(
lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding),
lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation),
dimension_numbers=dnums,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
lhs_shape=lhs.shape, rhs_shape=rhs.shape,
precision=canonicalize_precision(precision),
preferred_element_type=preferred_element_type) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
If you want a custom JVP/VJP rule with some arguments that are non-differentiable, you can use the |
Beta Was this translation helpful? Give feedback.
If you want a custom JVP/VJP rule with some arguments that are non-differentiable, you can use the
nondiff_argnums
argument to mark which ones those are. Note that this requires all arguments to be passed by position, so for functions like this with many keyword arguments you might choose to define the custom JVP/VJP on a private function, and wrap it with another function that accepts keywords.