Following the benchmark results which states:
- For CSR backward, torch spmm outperforms tsgu spmm in both time and peak memory; tsgu spmm incurs extra memory due to internal CSR→COO conversion for gradient computation.
- tsgu spmm supports CSR gradients contrary to current public PyTorch documentation statements.
And, the confirmation of the PyTorch documentation not accurately reflecting this, as raised in pytorch/pytorch#172550
The function sparse_mm should be updated to use the PyTorch function for unbatched CSR matrices.