Replies: 3 comments 1 reply
-
Don't know about jaxpr, but you can annotate HLO with import jax
@jax.jit
def f(x: float, y: float):
return x * y
@jax.jit
@jax.named_scope("my_named_scope")
def f_with_named_scope(x: float, y: float):
return x * y
print("Without named scope:")
print(f.lower(0.3, 0.4).compile().as_text())
print("With named scope:")
print(f_with_named_scope.lower(0.3, 0.4).compile().as_text())
See the difference for the context manager name, HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[], f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.4 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0), metadata={op_name="x"}
%Arg_1.2 = f32[] parameter(1), metadata={op_name="y"}
- ROOT %multiply.3 = f32[] multiply(%Arg_0.1, %Arg_1.2), metadata={op_name="jit(f)/jit(main)/mul" source_file="/home/test.py" source_line=6}
+ ROOT %multiply.3 = f32[] multiply(%Arg_0.1, %Arg_1.2), metadata={op_name="jit(f_with_named_scope)/jit(main)/my_named_scope/mul" source_file="/home/test.py" source_line=12}
} |
Beta Was this translation helpful? Give feedback.
-
This is not possible in general; for example: def f(x):
for i in range(5):
x += 1
return x
print(jax.make_jaxpr(f)(1))
If we were trying to replicate the Python variable names, every jaxpr variable in this expression would have to be named Now, you could imagine various workaround for this (e.g. call them The current implementation has the benefit of simplicity, even if it may be harder to map back to the original code. |
Beta Was this translation helpful? Give feedback.
-
Just in case you're not aware: each jaxpr equation already has a full Python traceback attached. We don't display it by default because it's somewhat verbose. But the variable names are not something we currently attempt to guess (and it wouldn't be possible to do reliably in general).
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
For an input file
blah.py
like:The output I get when running
python blah.py
is:Is there any way to keep the original variable names, e.g.
x
andy
, in the JAXPR? Something like a map of invars to source code vars?Whenever possible, for example for parameter names, as I understand that local variables may be optimised out.
Beta Was this translation helpful? Give feedback.
All reactions