Skip to content

Commit d25ebb4

Browse files
Fix test for deprecated spmm_coo
1 parent bbb7063 commit d25ebb4

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

bitsandbytes/functional.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2647,7 +2647,7 @@ def double_quant(
26472647
"""
26482648

26492649
coo_tensor = None
2650-
quant_row, quant_col, row_stats, col_stats, _ = int8_double_quant(
2650+
quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant(
26512651
A,
26522652
col_stats,
26532653
row_stats,
@@ -2657,16 +2657,15 @@ def double_quant(
26572657
)
26582658

26592659
if threshold > 0.0:
2660-
# Build COO tensor for any outliers.
2661-
outlier_mask = A.abs() >= threshold
2662-
outlier_locations = outlier_mask.nonzero()
2663-
outliers = A[outlier_mask]
2660+
# Build a COO tensor including all of the outlier columns.
2661+
outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
2662+
outliers = A[:, outlier_cols]
26642663
coo_tensor = COOSparseTensor(
26652664
A.shape[0],
26662665
A.shape[1],
26672666
outliers.numel(),
2668-
outlier_locations[:, 0].int(),
2669-
outlier_locations[:, 1].int(),
2667+
outlier_rows.repeat_interleave(outliers.size(1)),
2668+
outlier_cols.repeat(outliers.size(0)).int(),
26702669
outliers,
26712670
)
26722671

0 commit comments

Comments
 (0)