You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been playnig around with Jax/jit for implementing certain stream-processing patterns (typical in async timeseries) by using the lax while_loop to run an event loop, and trying to extract the results with the experimental id_tap features in host_callback. It was working really well, and I thought I groked how it all worked, and then suddenly it was 1000x slower after i changed something, and it then took me ages to find a repro.
So apologies this repro is not smaller, but it was a real pain to find which scenario had triggered the huge slow down.
My questions:
how can I make sense of the x1000 slowdown here, just because i use a cond and elsewhere use id_print ? To be that slow, I feel it must be copying a massive array around constantly, or reallocating mem constantly, but the jit-compiler should have realised that everything can be updated in place.
is there a good way to 'see' what its spending all its time on?
Grateful for any advice about this - I can file it as an 'issue' if appropriate, but am new around here so starting with a discussion.
defa_funky_function( x ):
returnjnp.logical_not(x%9)
defone_step( context, static ):
# read one message, and increment the iteratormsg=jax.tree_map( lambdax : x[ context['in_iter'].begin ], context['in_array'] )
context['in_iter'] =InIter( context['in_iter'].begin+1, context['in_iter'].end )
keep_this_msg=a_funky_function( msg['value'] )
ifnotstatic.CONDITION_1:
context['out'] =push( context['out'], msg, static )
else:
context['out'] =cond( keep_this_msg,
lambdaq, val: push(q, val, static),
lambdaq, val: q,
context['out'], msg)
returncontextdefflush_return_none( buf, static ):
ifnotstatic.CONDITION_2:
returnNone# hcb.id_tap also evidences ithcb.id_print( buf['value'][0] )
hcb.id_print( buf['time'][0] )
defpush( out, val, static ):
''' Push val into the out-buffer defined by out '''it, buf=outnew_it=OutIter( (it.end+1) %it.N, it.N)
new_buf=jax.tree_multimap( lambdaq,m: q.at[it.end].set(m), buf, val)
# if we exhausted and looped back to start after N, 'tap out' the ring buffer to host cond( new_it.end==0,
lambdabuf: flush_return_none( buf, static ),
lambdabuf: None,
new_buf
)
returnnew_it, new_bufdefnext_exists( context ):
in_iter=context['in_iter']
in_array=context['in_array']
next_event_time=cond( in_iter.begin!=in_iter.end,
lambdai_it, i_ar : jax.tree_map( lambdax : x[i_it.begin], i_ar )['time'],
lambdai_it, i_ar : MAX_EVENT,
in_iter, in_array)
returnnext_event_time<MAX_EVENTdefmain_loop_body( context, static ):
in_iter=context['in_iter']
in_array=context['in_array']
next_event=jax.tree_map( lambdax : x[in_iter.begin], in_array )['time']
do_next_step=next_event<MAX_EVENTcontext=cond( do_next_step,
partial(one_step, static=static),
lambdax:x,
context)
returncontextdefmain( context, static ):
context=while_loop( lambdacontext: next_exists(context),
partial(main_loop_body, static=static),
context)
returncontext
The operation is basically to iterate over an input array, and copy it to the output array, but conditional on the value of function, which is simplest version of repro i could make.
( The out-buffer is big enough in this case, that evaluation never even hits the branch flush_return_none - but the presence of the branch in the code is enough to make a big slowdown )
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I've been playnig around with Jax/jit for implementing certain stream-processing patterns (typical in async timeseries) by using the lax while_loop to run an event loop, and trying to extract the results with the experimental
id_tap
features in host_callback. It was working really well, and I thought I groked how it all worked, and then suddenly it was 1000x slower after i changed something, and it then took me ages to find a repro.So apologies this repro is not smaller, but it was a real pain to find which scenario had triggered the huge slow down.
My questions:
cond
and elsewhere useid_print
? To be that slow, I feel it must be copying a massive array around constantly, or reallocating mem constantly, but the jit-compiler should have realised that everything can be updated in place.Grateful for any advice about this - I can file it as an 'issue' if appropriate, but am new around here so starting with a discussion.
The operation is basically to iterate over an input array, and copy it to the output array, but conditional on the value of function, which is simplest version of repro i could make.
( The out-buffer is big enough in this case, that evaluation never even hits the branch
flush_return_none
- but the presence of the branch in the code is enough to make a big slowdown )Now run it with different static-settings
Prints the following - showing significant slowdown if both conditions are true.
(I'm running on cpu only)
Beta Was this translation helpful? Give feedback.
All reactions