Skip to content

Commit c3e8bdc

Browse files
authored
Fix estimate_pandas_size on pd.MultiIndex (#2707)
1 parent 792aa9c commit c3e8bdc

File tree

3 files changed

+74
-18
lines changed

3 files changed

+74
-18
lines changed

mars/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from typing import NamedTuple, Optional
1818

19-
version_info = (0, 9, 0, "a2")
19+
version_info = (0, 9, 0, "b1")
2020
_num_index = max(idx if isinstance(v, int) else 0 for idx, v in enumerate(version_info))
2121
__version__ = ".".join(map(str, version_info[: _num_index + 1])) + "".join(
2222
version_info[_num_index + 1 :]

mars/tests/test_utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,11 +508,52 @@ def test_estimate_pandas_size():
508508

509509
s3 = pd.Series(np.random.choice(["abcd", "def", "gh"], size=(1000,)))
510510
assert utils.estimate_pandas_size(s3) != sys.getsizeof(s3)
511+
assert (
512+
pytest.approx(utils.estimate_pandas_size(s3) / sys.getsizeof(s3), abs=0.5) == 1
513+
)
511514

512515
idx1 = pd.MultiIndex.from_arrays(
513516
[np.arange(0, 1000), np.random.choice(["abcd", "def", "gh"], size=(1000,))]
514517
)
515-
assert utils.estimate_pandas_size(idx1) != sys.getsizeof(idx1)
518+
assert utils.estimate_pandas_size(idx1) == sys.getsizeof(idx1)
519+
520+
string_idx = pd.Index(np.random.choice(["a", "bb", "cc"], size=(1000,)))
521+
assert utils.estimate_pandas_size(string_idx) != sys.getsizeof(string_idx)
522+
assert (
523+
pytest.approx(
524+
utils.estimate_pandas_size(string_idx) / sys.getsizeof(string_idx), abs=0.5
525+
)
526+
== 1
527+
)
528+
529+
# dataframe with multi index
530+
idx2 = pd.MultiIndex.from_arrays(
531+
[np.arange(0, 1000), np.random.choice(["abcd", "def", "gh"], size=(1000,))]
532+
)
533+
df4 = pd.DataFrame(
534+
{
535+
"A": np.random.choice(["abcd", "def", "gh"], size=(1000,)),
536+
"B": np.random.rand(1000),
537+
"C": np.random.rand(1000),
538+
},
539+
index=idx2,
540+
)
541+
assert utils.estimate_pandas_size(df4) != sys.getsizeof(df4)
542+
assert (
543+
pytest.approx(utils.estimate_pandas_size(df4) / sys.getsizeof(df4), abs=0.5)
544+
== 1
545+
)
546+
547+
# series with multi index
548+
idx3 = pd.MultiIndex.from_arrays(
549+
[
550+
np.random.choice(["a1", "a2", "a3"], size=(1000,)),
551+
np.random.choice(["abcd", "def", "gh"], size=(1000,)),
552+
]
553+
)
554+
s4 = pd.Series(np.arange(1000), index=idx3)
555+
556+
assert utils.estimate_pandas_size(s4) == sys.getsizeof(s4)
516557

517558

518559
@require_ray

mars/utils.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,13 @@ def calc_data_size(dt: Any, shape: Tuple[int] = None) -> int:
424424

425425

426426
def estimate_pandas_size(
427-
df_obj, max_samples: int = 10, min_sample_rows: int = 100
427+
pd_obj, max_samples: int = 10, min_sample_rows: int = 100
428428
) -> int:
429-
if len(df_obj) <= min_sample_rows or isinstance(df_obj, pd.RangeIndex):
430-
return sys.getsizeof(df_obj)
429+
if len(pd_obj) <= min_sample_rows or isinstance(pd_obj, pd.RangeIndex):
430+
return sys.getsizeof(pd_obj)
431+
if isinstance(pd_obj, pd.MultiIndex):
432+
# MultiIndex's sample size can't be used to estimate
433+
return sys.getsizeof(pd_obj)
431434

432435
from .dataframe.arrays import ArrowDtype
433436

@@ -438,14 +441,16 @@ def _is_fast_dtype(dtype):
438441
return isinstance(dtype, ArrowDtype)
439442

440443
dtypes = []
441-
if isinstance(df_obj, pd.DataFrame):
442-
dtypes.extend(df_obj.dtypes)
443-
index_obj = df_obj.index
444-
elif isinstance(df_obj, pd.Series):
445-
dtypes.append(df_obj.dtype)
446-
index_obj = df_obj.index
444+
is_series = False
445+
if isinstance(pd_obj, pd.DataFrame):
446+
dtypes.extend(pd_obj.dtypes)
447+
index_obj = pd_obj.index
448+
elif isinstance(pd_obj, pd.Series):
449+
dtypes.append(pd_obj.dtype)
450+
index_obj = pd_obj.index
451+
is_series = True
447452
else:
448-
index_obj = df_obj
453+
index_obj = pd_obj
449454

450455
# handling possible MultiIndex
451456
if hasattr(index_obj, "dtypes"):
@@ -454,12 +459,22 @@ def _is_fast_dtype(dtype):
454459
dtypes.append(index_obj.dtype)
455460

456461
if all(_is_fast_dtype(dtype) for dtype in dtypes):
457-
return sys.getsizeof(df_obj)
458-
459-
indices = np.sort(np.random.choice(len(df_obj), size=max_samples, replace=False))
460-
iloc = df_obj if isinstance(df_obj, pd.Index) else df_obj.iloc
461-
sample_size = sys.getsizeof(iloc[indices])
462-
return sample_size * len(df_obj) // max_samples
462+
return sys.getsizeof(pd_obj)
463+
464+
indices = np.sort(np.random.choice(len(pd_obj), size=max_samples, replace=False))
465+
iloc = pd_obj if isinstance(pd_obj, pd.Index) else pd_obj.iloc
466+
if isinstance(index_obj, pd.MultiIndex):
467+
# MultiIndex's sample size is much greater than expected, thus we calculate
468+
# the size separately.
469+
index_size = sys.getsizeof(pd_obj.index)
470+
if is_series:
471+
sample_frame_size = iloc[indices].memory_usage(deep=True, index=False)
472+
else:
473+
sample_frame_size = iloc[indices].memory_usage(deep=True, index=False).sum()
474+
return index_size + sample_frame_size * len(pd_obj) // max_samples
475+
else:
476+
sample_size = sys.getsizeof(iloc[indices])
477+
return sample_size * len(pd_obj) // max_samples
463478

464479

465480
def build_fetch_chunk(

0 commit comments

Comments
 (0)