Skip to content

Commit f182c83

Browse files
committed
Update check_equivalence to support conversion if necessary
1 parent fd94fd1 commit f182c83

File tree

2 files changed

+87
-8
lines changed

2 files changed

+87
-8
lines changed

examples/check_equivalence.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ int main(int argc, char** argv) {
101101
bsp_matrix_t matrix1 = bsp_read_matrix(info1.fname, info1.dataset);
102102
bsp_matrix_t matrix2 = bsp_read_matrix(info2.fname, info2.dataset);
103103

104+
// If matrices are not the same format, try to convert.
105+
if (matrix1.format != matrix2.format) {
106+
if (matrix1.format != BSP_COOR) {
107+
bsp_matrix_t intermediate = bsp_convert_matrix(matrix1, BSP_COOR);
108+
bsp_destroy_matrix_t(matrix1);
109+
matrix1 = intermediate;
110+
}
111+
112+
if (matrix2.format != BSP_COOR) {
113+
bsp_matrix_t intermediate = bsp_convert_matrix(matrix2, BSP_COOR);
114+
bsp_destroy_matrix_t(matrix2);
115+
matrix2 = intermediate;
116+
}
117+
}
118+
104119
if (matrix1.format != matrix2.format) {
105120
fprintf(stderr, "Formats do not match. (%s != %s)\n",
106121
bsp_get_matrix_format_string(matrix1.format),

include/binsparse/convert_matrix.h

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,54 @@ bsp_matrix_t bsp_convert_matrix(bsp_matrix_t matrix,
1212

1313
if (format == BSP_COOR) {
1414
// *Convert to COO* from another format.
15-
assert(false);
15+
if (matrix.format == BSP_CSR) {
16+
// Convert CSR -> COOR
17+
bsp_matrix_t result = bsp_construct_default_matrix_t();
18+
19+
result.format = BSP_COOR;
20+
21+
// Inherit NNZ, nrows, ncols, ISO-ness, and structure directly from
22+
// original matrix.
23+
result.nnz = matrix.nnz;
24+
result.nrows = matrix.nrows;
25+
result.ncols = matrix.ncols;
26+
result.is_iso = matrix.is_iso;
27+
result.structure = matrix.structure;
28+
29+
size_t max_dim =
30+
(matrix.nrows > matrix.ncols) ? matrix.nrows : matrix.ncols;
31+
32+
bsp_type_t index_type = bsp_pick_integer_type(max_dim);
33+
34+
result.values = bsp_copy_construct_array_t(matrix.values);
35+
36+
// There is a corner case with tall and skinny matrices where we need a
37+
// higher width for rowind. In order to keep rowind/colind the same type,
38+
// we might upcast.
39+
40+
if (index_type == matrix.indices_0.type) {
41+
result.indices_1 = bsp_copy_construct_array_t(matrix.indices_0);
42+
} else {
43+
result.indices_1 = bsp_construct_array_t(matrix.nnz, index_type);
44+
for (size_t i = 0; i < matrix.nnz; i++) {
45+
bsp_array_awrite(result.indices_1, i, matrix.indices_0, i);
46+
}
47+
}
48+
49+
result.indices_0 = bsp_construct_array_t(matrix.nnz, index_type);
50+
51+
for (size_t i = 0; i < matrix.nrows; i++) {
52+
size_t row_begin, row_end;
53+
bsp_array_read(matrix.pointers_to_1, i, row_begin);
54+
bsp_array_read(matrix.pointers_to_1, i + 1, row_end);
55+
for (size_t j_ptr = row_begin; j_ptr < row_end; j_ptr++) {
56+
bsp_array_write(result.indices_0, j_ptr, i);
57+
}
58+
}
59+
return result;
60+
} else {
61+
assert(false);
62+
}
1663
} else {
1764
// Convert to any another format.
1865

@@ -29,6 +76,14 @@ bsp_matrix_t bsp_convert_matrix(bsp_matrix_t matrix,
2976

3077
bsp_matrix_t result = bsp_construct_default_matrix_t();
3178

79+
result.format = BSP_CSR;
80+
81+
result.nrows = matrix.nrows;
82+
result.ncols = matrix.ncols;
83+
result.nnz = matrix.nnz;
84+
result.is_iso = matrix.is_iso;
85+
result.structure = matrix.structure;
86+
3287
// TODO: consider whether to produce files with varying integer types
3388
// for row indices, column indices, and offsets.
3489

@@ -41,23 +96,32 @@ bsp_matrix_t bsp_convert_matrix(bsp_matrix_t matrix,
4196
bsp_type_t value_type = matrix.values.type;
4297
bsp_type_t index_type = bsp_pick_integer_type(max_value);
4398

44-
result.values = bsp_construct_array_t(matrix.nnz, value_type);
45-
result.indices_0 = bsp_construct_array_t(matrix.nnz, index_type);
99+
// Since COOR is sorted by rows and then by columns, values and column
100+
// indices can be copied exactly. Values' type will not change, but
101+
// column indices might, thus the extra branch.
102+
103+
result.values = bsp_copy_construct_array_t(matrix.values);
104+
105+
if (index_type == matrix.indices_1.type) {
106+
result.indices_0 = bsp_copy_construct_array_t(matrix.indices_1);
107+
} else {
108+
result.indices_0 = bsp_construct_array_t(matrix.nnz, index_type);
109+
110+
for (size_t i = 0; i < matrix.nnz; i++) {
111+
bsp_array_awrite(result.indices_0, i, matrix.indices_1, i);
112+
}
113+
}
114+
46115
result.pointers_to_1 =
47116
bsp_construct_array_t(matrix.nrows + 1, index_type);
48117

49-
bsp_array_t values = result.values;
50-
bsp_array_t colind = result.indices_0;
51118
bsp_array_t rowptr = result.pointers_to_1;
52119

53120
bsp_array_write(rowptr, 0, 0);
54121

55122
size_t r = 0;
56123
size_t c = 0;
57124
for (size_t c = 0; c < matrix.nnz; c++) {
58-
bsp_array_awrite(values, c, matrix.values, c);
59-
bsp_array_awrite(colind, c, matrix.indices_1, c);
60-
61125
size_t j;
62126
bsp_array_read(matrix.indices_0, c, j);
63127

0 commit comments

Comments
 (0)