-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Hello again!
Thank you for your quick fixes for #16 and #17 !
This is more of a discussion rather than an issue, so feel free to close it at some point.
As I have to convert big sparse matrices, I am facing long compilation time using convert(foo, compile=True) as it does not ignore the zeros and assign them to the output, meaning there is a ton of assignment operations even though there are few non-zeros coefficients.
For instance the following code takes 10 seconds to run for N = 100 or 76 seconds for N = 200 (and I will likely have to convert bigger matrices). I guess the compiler struggles with the high number of operations and does not realize it can ignore most of them?
N = 200
x = ca.SX.sym("x", N)
y = ca.SX.sym("y", (N, N)) * 0.0
for i in range(N):
y[i, i] = 2 * x[i]
foo = ca.Function("foo", [x], [y])
start = time.time()
convert(foo, compile=True)
print("Convert time:", time.time() - start)
Note that this is just an example, of course for this diagonal matrix I could just directly code it in jax.
As a quick roundabout, I slightly modified the codegen function to ignore the operations associated with a zero value:
for layer in sorted_nodes:
indices = []
# MODIF HERE
zero_nodes = []
assignment = "["
for node in layer:
if len(graph[node]) == 0 and node not in output_map:
continue
if node in output_map:
oo = output_map[node]
if outputs.get(oo[0], None) is None:
outputs[oo[0]] = {"rows": [], "cols": [], "values": []}
# MODIF HERE
if int(values[node].split("]")[0][-1]) in zero_nodes:
continue
outputs[oo[0]]["rows"].append(oo[1])
outputs[oo[0]]["cols"].append(oo[2])
outputs[oo[0]]["values"].append(values[node])
else:
if len(assignment) > 1:
assignment += ", "
# MODIF HERE
if "jnp" in values[node] and not any((char.isdigit() and char != '0') for char in values[node]):
zero_nodes.append(node)
assignment += values[node]
indices += [node]
if len(indices) == 0:
continue
assignment += "]"
code += f" work = work.at[jnp.array({indices})].set({assignment})\n"
The tests seem to be all passing, and it brought down the convert time from 76 to less than 2 seconds for N = 200.
I'm sure there is a better way to do that, but I just wanted to share the code in case other people are facing the same struggles with spare matrices.
For completeness, the translation (for N = 4 just to illustrate) goes from
def evaluate_foo(*args):
inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]
outputs = [jnp.zeros(out) for out in [(4, 4)]]
work = jnp.zeros((26, 1))
work = work.at[jnp.array([0, 1, 4, 9, 16, 23])].set([jnp.array([2.0000000000000000]), inputs[0][0], jnp.array([0.0000000000000000]), inputs[0][1], inputs[0][2], inputs[0][3]])
work = work.at[jnp.array([2, 10, 17, 24])].set([work[0] * work[1], work[0] * work[9], work[0] * work[16], work[0] * work[23]])
outputs[0] = outputs[0].at[([1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 1, 2, 3])].set([work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[2][0], work[10][0], work[17][0], work[24][0]])
return outputs
to
def evaluate_foo(*args):
inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]
outputs = [jnp.zeros(out) for out in [(4, 4)]]
work = jnp.zeros((26, 1))
work = work.at[jnp.array([0, 1, 4, 9, 16, 23])].set([jnp.array([2.0000000000000000]), inputs[0][0], jnp.array([0.0000000000000000]), inputs[0][1], inputs[0][2], inputs[0][3]])
work = work.at[jnp.array([2, 10, 17, 24])].set([work[0] * work[1], work[0] * work[9], work[0] * work[16], work[0] * work[23]])
outputs[0] = outputs[0].at[([0, 1, 2, 3], [0, 1, 2, 3])].set([work[2][0], work[10][0], work[17][0], work[24][0]])
return outputs
We can see it just discards useless assignments of work[4][0] to outputs[0] without affecting the rest.
Best,