Prevent tf.function retracing and minor TF2 fixes.#353
Prevent tf.function retracing and minor TF2 fixes.#353virajbshah merged 2 commits intogoogle:mainfrom
tf.function retracing and minor TF2 fixes.#353Conversation
* The main `tf.function` (`_compute_and_apply_gradients`) was forcing re-traces on every call for some model configurations, which came with a significant performance penalty. * Passing `reduce_retracing=True` seems to be enough to prevent this: TF is able to construct a `ConcreteFunction` with the right input signature after seeing a few invocations. * Mark model functions using `match-case` as excluded from AutoGraph conversion. AutoGraph fails to convert them anyways, and this suppresses related warnings. * Make `TrainingEpochStats` print correctly as their actual values, rather than the enclosing Tensors.
Should we factor to avoid |
|
I tested it against an AutoGraph-ed |
|
Interesting that it doesn't significantly impact performance. Not doing the conversion would mean it has to reexecute the python rather than just chaining operators though, right? I would think that could have a big impact if we try and enable XLA JIT compilation. |
|
I'm not entirely sure, but it looks like it only has to re-execute Python while tracing/re-tracing. While tracing, only the Relevant snippet from the graph ( --- snip ---
# Assertion ensuring `loss_type == options.LossType.HUBER` holds.
['assert_equal_11/Assert/AssertGuard'] -> assert_equal_11/Assert/AssertGuard/Identity
['truediv_1', '^assert_equal_11/Assert/AssertGuard/Identity'] -> control_dependency_11
# Prepare arguments for `_huber(tf.abs(normalized_data))`.
['control_dependency_11'] -> Abs_1
# Huber loss computation.
[] -> Const_18
['Abs_1', 'Const_18'] -> Minimum
['Abs_1'] -> Shape_9
['Shape_9'] -> Cast_6
['Minimum'] -> Shape_10
['Shape_10'] -> Cast_7
['Abs_1', 'Minimum'] -> Sub_1
['Minimum'] -> Square
[] -> Const_19
['Const_19', 'Square'] -> Mul
['Const_18', 'Sub_1'] -> mul_1/Mul
['Mul'] -> Shape_11
['Shape_11'] -> Cast_8
['mul_1/Mul'] -> Shape_12
['Shape_12'] -> Cast_9
['Mul', 'mul_1/Mul'] -> Add_6
--- snip ---AutoGraph shouldn't be able to convert an equivalent Nevertheless, I'm not sure if this affects the XLA JIT compiler and we might as well as use the |
tf.function(_compute_and_apply_gradients) was forcing re-traces on every call for some model configurations with a significant performance penalty. Passingreduce_retracing=Trueseems to be enough to prevent this: TF is able to construct aConcreteFunctionwith the right input signature after seeing a few invocations.match-caseas excluded from AutoGraph conversion. AutoGraph fails to convert them anyways, and this suppresses related warnings.TrainingEpochStatsprint correctly as actual values, rather than the enclosing Tensors.