Replies: 2 comments
-
https://docs.nvidia.com/cuda/cuda-math-api/modules.html#modules There doesn't exist operations that has different datatype.. I think assumption 1 is probably true.. But TPU could not different! Is there any information about it? |
Beta Was this translation helpful? Give feedback.
-
In general, binary operations between arrays of two types will result in one or both being cast to a common type following the documented type promotion rules. If you ever want to see exactly how a particular function is computed, you can do so by examining the jaxpr or compiled HLO using ahead-of-time compilation. For example: import jax
import jax.numpy as jnp
def f(a, b):
return a + b
a = jnp.arange(10, dtype='float32')
b = jnp.arange(10, dtype='bfloat16') # jaxpr: JAX-level IR
print(jax.make_jaxpr(f)(a, b))
# un-optimized HLO
print(jax.jit(f).lower(a, b).as_text())
# optimized HLO
print(jax.jit(f).lower(a, b).compile().as_text())
In all cases, you see that the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I just want to know how the memory is used when we apply binary operation on two different array.
For example, if we add two arrays (w1 and w2) with fp16 and fp32, what happened in the GPU VRAM?
Assumption 1. the output is fp32, so w1 is upcast to fp32 and add to w2. In this case, additional GPU VRAM is required.
Assumption 2. there is a magic function signature that add: fp16 -> fp32 -> fp32 so the additional GPU VRAM doesn't need.
Which one is the true..?
Beta Was this translation helpful? Give feedback.
All reactions