You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm very confused as to how array donation is working in jit compiled functions. It seems that, even though buffers can be donated, they can't be returned when execution actually occurs. I've tried to make a script that shows what I mean:
fromfunctoolsimportpartialimportmmapimportpsutilimportnumpyasnpimportjaxprocess=psutil.Process()
defdump_mappings(s):
print("Large mappings:")
forxinprocess.memory_maps(grouped=False):
ifx.size>=s:
print(x)
# Jax jit compiled function# Donate input buffer to output@partial(jax.jit, donate_argnums=1)defdoit(arr, out):
returnoutprint(f"Virtual memory size (start): {process.memory_info().vms:_d}")
# Set up data buffersz=1024*16# (I am unreliably informed that mmap will be page boundary aligned.)buf_sz=np.dtype('complex64').itemsize*sz**2print(f"numpy_array byte size {buf_sz:_d}")
data=mmap.mmap(-1, buf_sz)
numpy_array=np.frombuffer(data, dtype=np.complex64).reshape((sz,)*2)
numpy_array[0,0] =np.nan+1j*np.nannumpy_array[0,1] =np.infnumpy_array[1,0] =complex(0.,np.inf)
orig_pointer=numpy_array.ctypes.dataprint(f"numpy_array location {hex(orig_pointer)} ")
print(f"Virtual memory size (numpy array): {process.memory_info().vms:_d}")
dump_mappings(buf_sz)
cpu_device=jax.devices('cpu')[0]
print(f"Commiting array to {cpu_device}")
big_array=jax.dlpack.from_dlpack(numpy_array, device=cpu_device, copy=False)
put_pointer=big_array.unsafe_buffer_pointer()
assertput_pointer==orig_pointerprint(f"big_array2 location {hex(put_pointer)}")
print(big_array.at[:2,:2].get())
print(f"Virtual memory size (commited array): {process.memory_info().vms:_d}")
dump_mappings(buf_sz)
big_array_old=big_arraybig_array=doit(10.-10j, big_array)
print(f"Is the input buffer deleted? {big_array_old.is_deleted()}")
new_pointer=big_array.unsafe_buffer_pointer()
print(f" out array location {hex(new_pointer)}")
print(big_array[:2,:2])
print(f"Virtual memory size (function run): {process.memory_info().vms:_d}")
dump_mappings(buf_sz)
assertnew_pointer==orig_pointer
The returned buffer from the trivial function is not the donated one from the input, but is copied into a new buffer. This is probably ok for small buffers, but how is one meant to deal with this situation for large buffers?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I'm very confused as to how array donation is working in jit compiled functions. It seems that, even though buffers can be donated, they can't be returned when execution actually occurs. I've tried to make a script that shows what I mean:
and a typical output I get is:
The returned buffer from the trivial function is not the donated one from the input, but is copied into a new buffer. This is probably ok for small buffers, but how is one meant to deal with this situation for large buffers?
Beta Was this translation helpful? Give feedback.
All reactions