Static Arguments as a Means of Increasing Performance. #14460
Unanswered
Jordan-Dennis
asked this question in
Q&A
Replies: 2 comments 1 reply
-
Hi - thanks for the question. I think there's not enough information here to give an answer (I'm honestly not even sure what you're asking!) and the diuscussion you linked to is very long and it's difficult for me to understand which comments and code snippets are relevant to your question. Perhaps you could edit your question to more directly add a minimal example of the situation you're curious about? |
Beta Was this translation helpful? Give feedback.
1 reply
-
I found the following functions useful trying to answer this question myself, perhaps they will also be useful for you. import re
import difflib
def get_mhlo(
func: callable, # Compiled
*args: object,
**kwargs: object
) -> str:
mhlo: str = func.lower(*args, **kwargs).compile().as_text()
return re.sub("metadata={.*}", "", mhlo)
def print_diff(mhlo: str, comp: str) -> list:
mhlo_lines: list = mhlo.splitlines()
comp_lines: list = comp.splitlines()
diff: iter = difflib.unified_diff(mhlo_lines, comp_lines)
for line in diff:
print(line) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi The
jax
Community,If the answer is yes then I am also curios about:
I am interested in the subset of functions with static shapes, i.e.
jax.jit(func)
succeeds. This question is based on the discussion in this issue.Below I have included a minimal(ish) working example based on our discussions:
Regards
Jordan
Beta Was this translation helpful? Give feedback.
All reactions