Skip to content
Discussion options

You must be logged in to vote

TL;DR: it seems that the problem was caused by the fact that model weights were captured by the decoding function instead of being passed in as an argument, so they were a huge constant from jit's perspective and, I guess, constant folding optimization wasn't particularly happy about them.

How I found this out:

  • I've compared jaxprs of decoding functions for models with different dimensions and established that jaxprs are the same (tensor dimensions aside), so the problem wasn't caused by different computation graphs.
  • I've then manually split compilation into different stages using AOT functionality and measured the time of each state. As suspected, it was HLO compilation stage that slowe…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@hr0nix
Comment options

@hr0nix
Comment options

Answer selected by hr0nix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants