Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pygwalker/communications/gradio_comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import gc
import warnings

from fastapi import FastAPI
from starlette.routing import Route
Expand Down Expand Up @@ -36,6 +37,14 @@ class GradioCommunication(BaseCommunication):
"""
def __init__(self, gid: str) -> None:
super().__init__(gid)
if gid in gradio_comm_map:
warnings.warn(
f"GID collision detected: {gid}. "
f"Two different datasets produced the same identifier. "
f"Pass an explicit gid= parameter to avoid this.",
UserWarning,
stacklevel=2,
)
gradio_comm_map[gid] = self


Expand Down
9 changes: 9 additions & 0 deletions pygwalker/communications/reflex_comm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import warnings
from fastapi import FastAPI, HTTPException
from starlette.routing import Route
from starlette.responses import JSONResponse, Response
Expand Down Expand Up @@ -46,6 +47,14 @@ class ReflexCommunication(BaseCommunication):

def __init__(self, gid: str) -> None:
super().__init__(gid)
if gid in reflex_comm_map:
warnings.warn(
f"GID collision detected: {gid}. "
f"Two different datasets produced the same identifier. "
f"Pass an explicit gid= parameter to avoid this.",
UserWarning,
stacklevel=2,
)
reflex_comm_map[gid] = self


Expand Down
9 changes: 9 additions & 0 deletions pygwalker/communications/streamlit_comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import json
import warnings

from tornado.web import Application
from streamlit import config
Expand Down Expand Up @@ -66,4 +67,12 @@ class StreamlitCommunication(BaseCommunication):
"""
def __init__(self, gid: str) -> None:
super().__init__(gid)
if gid in streamlit_comm_map:
warnings.warn(
f"GID collision detected: {gid}. "
f"Two different datasets produced the same identifier. "
f"Pass an explicit gid= parameter to avoid this.",
UserWarning,
stacklevel=2,
)
streamlit_comm_map[gid] = self
17 changes: 12 additions & 5 deletions pygwalker/services/data_parsers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import hashlib
import numpy as np
import pandas as pd
from typing import Dict, Optional, Union, Any, List, Tuple
from typing_extensions import Literal
Expand Down Expand Up @@ -90,7 +91,8 @@ def _get_pl_dataset_hash(dataset: DataFrame) -> str:
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_polars"
if row_count > 4000:
dataset = pl.concat([dataset[:2000], dataset[-2000:]])
indices = np.linspace(0, row_count - 1, 4000, dtype=int)
dataset = dataset[indices.tolist()]
hash_bytes = dataset.hash_rows().to_numpy().tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()

Expand All @@ -100,7 +102,8 @@ def _get_pd_dataset_hash(dataset: DataFrame) -> str:
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_pandas"
if row_count > 4000:
dataset = pd.concat([dataset[:2000], dataset[-2000:]])
indices = np.linspace(0, row_count - 1, 4000, dtype=int)
dataset = dataset.iloc[indices]
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()

Expand All @@ -111,7 +114,8 @@ def _get_modin_dataset_hash(dataset: DataFrame) -> str:
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_modin"
if row_count > 4000:
dataset = mpd.concat([dataset[:2000], dataset[-2000:]])
indices = np.linspace(0, row_count - 1, 4000, dtype=int)
dataset = dataset.iloc[indices]
dataset = dataset._to_pandas()
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()
Expand All @@ -123,8 +127,11 @@ def _get_spark_dataset_hash(dataset: DataFrame) -> str:
row_count = shape[0]
other_info = str(shape) + "_pyspark"
if row_count > 4000:
dataset = dataset.limit(4000)
dataset_pd = dataset.toPandas()
dataset_pd = dataset.toPandas()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep Spark hash sampling bounded before toPandas

For Spark datasets larger than 4,000 rows, this now calls dataset.toPandas() on the full DataFrame and only then downsamples, which can pull millions of rows to the driver and cause severe latency or OOM during get_dataset_hash. Any workflow that relies on automatic GID generation for large Spark tables can fail before rendering; the sampling needs to happen in Spark (or otherwise be bounded) prior to converting to pandas.

Useful? React with 👍 / 👎.

indices = np.linspace(0, row_count - 1, 4000, dtype=int)
dataset_pd = dataset_pd.iloc[indices]
else:
dataset_pd = dataset.toPandas()
hash_bytes = pd.util.hash_pandas_object(dataset_pd).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()

Expand Down
37 changes: 36 additions & 1 deletion tests/test_data_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import polars as pl
import pytest

from pygwalker.services.data_parsers import get_parser
from pygwalker.services.data_parsers import get_parser, get_dataset_hash
from pygwalker.data_parsers.database_parser import Connector, text
from pygwalker.data_parsers.database_parser import _check_view_sql
from pygwalker.errors import ViewSqlSameColumnError
Expand Down Expand Up @@ -108,3 +108,38 @@ def test_connector():
assert connector.dialect_name == "duckdb"
assert connector.view_sql == view_sql
assert connector.url == database_url


def test_dataset_hash_no_interior_collision_pandas():
"""Regression test: DataFrames differing only in interior rows must produce different hashes."""
base = pd.DataFrame({"value": range(10000)})

df_a = base.copy()
df_a.loc[2000:7999, "value"] = 0

df_b = base.copy()
df_b.loc[2000:7999, "value"] = 99999

assert df_a["value"].sum() != df_b["value"].sum()
assert get_dataset_hash(df_a) != get_dataset_hash(df_b)


def test_dataset_hash_no_interior_collision_polars():
"""Regression test: Polars DataFrames differing only in interior rows must produce different hashes."""
base = pl.DataFrame({"value": list(range(10000))})

df_a = base.with_columns(
pl.when(pl.arange(0, pl.count()).is_between(2000, 7999))
.then(pl.lit(0))
.otherwise(pl.col("value"))
.alias("value")
)
df_b = base.with_columns(
pl.when(pl.arange(0, pl.count()).is_between(2000, 7999))
.then(pl.lit(99999))
.otherwise(pl.col("value"))
.alias("value")
)

assert df_a["value"].sum() != df_b["value"].sum()
assert get_dataset_hash(df_a) != get_dataset_hash(df_b)