Skip to content

Commit 790f6d5

Browse files
authored
Preserve gcxs compression (#601)
1 parent e1990b2 commit 790f6d5

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

sparse/_umath.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def __init__(self, func, *args, **kwargs):
412412

413413
processed_args = []
414414
out_type = GCXS
415+
out_kwargs = {}
415416

416417
sparse_args = [arg for arg in args if isinstance(arg, SparseArray)]
417418

@@ -421,6 +422,8 @@ def __init__(self, func, *args, **kwargs):
421422
out_type = DOK
422423
elif all(isinstance(arg, GCXS) for arg in sparse_args):
423424
out_type = GCXS
425+
if len({arg.compressed_axes for arg in sparse_args}) == 1:
426+
out_kwargs["compressed_axes"] = sparse_args[0].compressed_axes
424427
else:
425428
out_type = COO
426429

@@ -441,6 +444,7 @@ def __init__(self, func, *args, **kwargs):
441444
return
442445

443446
self.out_type = out_type
447+
self.out_kwargs = out_kwargs
444448
self.args = tuple(processed_args)
445449
self.func = func
446450
self.dtype = kwargs.pop("dtype", None)
@@ -497,7 +501,7 @@ def get_result(self):
497501
shape=self.shape,
498502
has_duplicates=False,
499503
fill_value=self.fill_value,
500-
).asformat(self.out_type)
504+
).asformat(self.out_type, **self.out_kwargs)
501505

502506
def _get_fill_value(self):
503507
"""

0 commit comments

Comments
 (0)