Skip to content

Conversion time for sparse matrix #19

@paLeziart

Description

@paLeziart

Hello again!

Thank you for your quick fixes for #16 and #17 !

This is more of a discussion rather than an issue, so feel free to close it at some point.

As I have to convert big sparse matrices, I am facing long compilation time using convert(foo, compile=True) as it does not ignore the zeros and assign them to the output, meaning there is a ton of assignment operations even though there are few non-zeros coefficients.

For instance the following code takes 10 seconds to run for N = 100 or 76 seconds for N = 200 (and I will likely have to convert bigger matrices). I guess the compiler struggles with the high number of operations and does not realize it can ignore most of them?

N = 200
x = ca.SX.sym("x", N)
y = ca.SX.sym("y", (N, N)) * 0.0
for i in range(N):
    y[i, i] = 2 * x[i]
foo = ca.Function("foo", [x], [y])

start = time.time()
convert(foo, compile=True)
print("Convert time:", time.time() - start)

Note that this is just an example, of course for this diagonal matrix I could just directly code it in jax.

As a quick roundabout, I slightly modified the codegen function to ignore the operations associated with a zero value:

for layer in sorted_nodes:
        indices = []
        # MODIF HERE
        zero_nodes = []
        assignment = "["
        for node in layer:
            if len(graph[node]) == 0 and node not in output_map:
                continue
            if node in output_map:
                oo = output_map[node]
                if outputs.get(oo[0], None) is None:
                    outputs[oo[0]] = {"rows": [], "cols": [], "values": []}
                
                # MODIF HERE
                if int(values[node].split("]")[0][-1]) in zero_nodes:
                    continue

                outputs[oo[0]]["rows"].append(oo[1])
                outputs[oo[0]]["cols"].append(oo[2])
                outputs[oo[0]]["values"].append(values[node])
            else:
                if len(assignment) > 1:
                    assignment += ", "

                # MODIF HERE
                if "jnp" in values[node] and not any((char.isdigit() and char != '0') for char in values[node]):
                    zero_nodes.append(node)

                assignment += values[node]
                indices += [node]
        if len(indices) == 0:
            continue
        assignment += "]"
        code += f"    work = work.at[jnp.array({indices})].set({assignment})\n"

The tests seem to be all passing, and it brought down the convert time from 76 to less than 2 seconds for N = 200.

I'm sure there is a better way to do that, but I just wanted to share the code in case other people are facing the same struggles with spare matrices.

For completeness, the translation (for N = 4 just to illustrate) goes from

def evaluate_foo(*args):
    inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]
    outputs = [jnp.zeros(out) for out in [(4, 4)]]
    work = jnp.zeros((26, 1))
    work = work.at[jnp.array([0, 1, 4, 9, 16, 23])].set([jnp.array([2.0000000000000000]), inputs[0][0], jnp.array([0.0000000000000000]), inputs[0][1], inputs[0][2], inputs[0][3]])
    work = work.at[jnp.array([2, 10, 17, 24])].set([work[0] * work[1], work[0] * work[9], work[0] * work[16], work[0] * work[23]])
    outputs[0] = outputs[0].at[([1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, 3], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 1, 2, 3])].set([work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[4][0], work[2][0], work[10][0], work[17][0], work[24][0]])
    return outputs

to

def evaluate_foo(*args):
    inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]
    outputs = [jnp.zeros(out) for out in [(4, 4)]]
    work = jnp.zeros((26, 1))
    work = work.at[jnp.array([0, 1, 4, 9, 16, 23])].set([jnp.array([2.0000000000000000]), inputs[0][0], jnp.array([0.0000000000000000]), inputs[0][1], inputs[0][2], inputs[0][3]])
    work = work.at[jnp.array([2, 10, 17, 24])].set([work[0] * work[1], work[0] * work[9], work[0] * work[16], work[0] * work[23]])
    outputs[0] = outputs[0].at[([0, 1, 2, 3], [0, 1, 2, 3])].set([work[2][0], work[10][0], work[17][0], work[24][0]])
    return outputs

We can see it just discards useless assignments of work[4][0] to outputs[0] without affecting the rest.

Best,

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions