Skip to content

Prevent tf.function retracing and minor TF2 fixes.#353

Merged
virajbshah merged 2 commits intogoogle:mainfrom
virajbshah:fix-model-nits
Jun 24, 2025
Merged

Prevent tf.function retracing and minor TF2 fixes.#353
virajbshah merged 2 commits intogoogle:mainfrom
virajbshah:fix-model-nits

Conversation

@virajbshah
Copy link
Contributor

  • The main tf.function (_compute_and_apply_gradients) was forcing re-traces on every call for some model configurations with a significant performance penalty. Passing reduce_retracing=True seems to be enough to prevent this: TF is able to construct a ConcreteFunction with the right input signature after seeing a few invocations.
  • Marks model functions using match-case as excluded from AutoGraph conversion. AutoGraph fails to convert them anyways, and this suppresses related warnings.
  • Makes TrainingEpochStats print correctly as actual values, rather than the enclosing Tensors.

 * The main `tf.function` (`_compute_and_apply_gradients`) was forcing
   re-traces on every call for some model configurations, which came
   with a significant performance penalty.
 * Passing `reduce_retracing=True` seems to be enough to prevent this:
   TF is able to construct a `ConcreteFunction` with the right input
   signature after seeing a few invocations.
 * Mark model functions using `match-case` as excluded from AutoGraph
   conversion. AutoGraph fails to convert them anyways, and this
   suppresses related warnings.
 * Make `TrainingEpochStats` print correctly as their actual values,
   rather than the enclosing Tensors.
@virajbshah virajbshah merged commit fcba9f7 into google:main Jun 24, 2025
7 checks passed
@boomanaiden154
Copy link
Collaborator

Marks model functions using match-case as excluded from AutoGraph conversion. AutoGraph fails to convert them anyways, and this suppresses related warnings.

Should we factor to avoid match/case here? This could significantly impact performance.

@virajbshah
Copy link
Contributor Author

I tested it against an AutoGraph-ed if-else ladder and using that didn't seem to significantly impact performance. I believe this is because converting the match-case control flow to TF Ops doesn't do anything meaningful here; the branch is not really data dependent. Reference.

@boomanaiden154
Copy link
Collaborator

Interesting that it doesn't significantly impact performance. Not doing the conversion would mean it has to reexecute the python rather than just chaining operators though, right? I would think that could have a big impact if we try and enable XLA JIT compilation.

@virajbshah
Copy link
Contributor Author

I'm not entirely sure, but it looks like it only has to re-execute Python while tracing/re-tracing. While tracing, only the match-case branch taken is be added to the graph (in contrast to tensor data-dependent ifs, which would be converted to tf.conds with both branches added to the graph). Then, the ops for the branch taken at trace-time are run whenever the graph is executed.

Relevant snippet from the graph (_apply_loss_function with Huber loss):

     --- snip ---
# Assertion ensuring `loss_type == options.LossType.HUBER` holds.
['assert_equal_11/Assert/AssertGuard'] -> assert_equal_11/Assert/AssertGuard/Identity
['truediv_1', '^assert_equal_11/Assert/AssertGuard/Identity'] -> control_dependency_11
# Prepare arguments for `_huber(tf.abs(normalized_data))`.
['control_dependency_11'] -> Abs_1
# Huber loss computation.
[] -> Const_18
['Abs_1', 'Const_18'] -> Minimum
['Abs_1'] -> Shape_9
['Shape_9'] -> Cast_6
['Minimum'] -> Shape_10
['Shape_10'] -> Cast_7
['Abs_1', 'Minimum'] -> Sub_1
['Minimum'] -> Square
[] -> Const_19
['Const_19', 'Square'] -> Mul
['Const_18', 'Sub_1'] -> mul_1/Mul
['Mul'] -> Shape_11
['Shape_11'] -> Cast_8
['mul_1/Mul'] -> Shape_12
['Shape_12'] -> Cast_9
['Mul', 'mul_1/Mul'] -> Add_6
     --- snip ---

AutoGraph shouldn't be able to convert an equivalent if-elif chain to tf.conds either, since the conditions (loss_type and normalization) aren't tensors.

Nevertheless, I'm not sure if this affects the XLA JIT compiler and we might as well as use the if-elif version just in case, since the refactor is pretty small.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants