Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pygwalker/communications/gradio_comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import gc
import warnings

from fastapi import FastAPI
from starlette.routing import Route
Expand Down Expand Up @@ -36,6 +37,14 @@ class GradioCommunication(BaseCommunication):
"""
def __init__(self, gid: str) -> None:
super().__init__(gid)
if gid in gradio_comm_map:
warnings.warn(
f"GID collision detected: {gid}. "
f"Two different datasets produced the same identifier. "
f"Pass an explicit gid= parameter to avoid this.",
UserWarning,
stacklevel=2,
)
gradio_comm_map[gid] = self


Expand Down
9 changes: 9 additions & 0 deletions pygwalker/communications/reflex_comm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import warnings
from fastapi import FastAPI, HTTPException
from starlette.routing import Route
from starlette.responses import JSONResponse, Response
Expand Down Expand Up @@ -46,6 +47,14 @@ class ReflexCommunication(BaseCommunication):

def __init__(self, gid: str) -> None:
super().__init__(gid)
if gid in reflex_comm_map:
warnings.warn(
f"GID collision detected: {gid}. "
f"Two different datasets produced the same identifier. "
f"Pass an explicit gid= parameter to avoid this.",
UserWarning,
stacklevel=2,
)
reflex_comm_map[gid] = self


Expand Down
9 changes: 9 additions & 0 deletions pygwalker/communications/streamlit_comm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import json
import warnings

from tornado.web import Application
from streamlit import config
Expand Down Expand Up @@ -66,4 +67,12 @@ class StreamlitCommunication(BaseCommunication):
"""
def __init__(self, gid: str) -> None:
super().__init__(gid)
if gid in streamlit_comm_map:
warnings.warn(
f"GID collision detected: {gid}. "
f"Two different datasets produced the same identifier. "
f"Pass an explicit gid= parameter to avoid this.",
UserWarning,
stacklevel=2,
)
streamlit_comm_map[gid] = self
14 changes: 10 additions & 4 deletions pygwalker/services/data_parsers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import hashlib
import numpy as np
import pandas as pd
from typing import Dict, Optional, Union, Any, List, Tuple
from typing_extensions import Literal
Expand Down Expand Up @@ -90,7 +91,8 @@ def _get_pl_dataset_hash(dataset: DataFrame) -> str:
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_polars"
if row_count > 4000:
dataset = pl.concat([dataset[:2000], dataset[-2000:]])
indices = np.linspace(0, row_count - 1, 4000, dtype=int)
dataset = dataset[indices.tolist()]
hash_bytes = dataset.hash_rows().to_numpy().tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()

Expand All @@ -100,7 +102,8 @@ def _get_pd_dataset_hash(dataset: DataFrame) -> str:
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_pandas"
if row_count > 4000:
dataset = pd.concat([dataset[:2000], dataset[-2000:]])
indices = np.linspace(0, row_count - 1, 4000, dtype=int)
dataset = dataset.iloc[indices]
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()

Expand All @@ -111,7 +114,8 @@ def _get_modin_dataset_hash(dataset: DataFrame) -> str:
row_count = dataset.shape[0]
other_info = str(dataset.shape) + "_modin"
if row_count > 4000:
dataset = mpd.concat([dataset[:2000], dataset[-2000:]])
indices = np.linspace(0, row_count - 1, 4000, dtype=int)
dataset = dataset.iloc[indices]
dataset = dataset._to_pandas()
hash_bytes = pd.util.hash_pandas_object(dataset).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()
Expand All @@ -123,7 +127,9 @@ def _get_spark_dataset_hash(dataset: DataFrame) -> str:
row_count = shape[0]
other_info = str(shape) + "_pyspark"
if row_count > 4000:
dataset = dataset.limit(4000)
# Sample uniformly in Spark before toPandas() to avoid pulling full dataset to driver (OOM)
fraction = min(4000 / row_count * 1.5, 1.0)
dataset = dataset.sample(fraction=fraction, seed=42).limit(4000)
dataset_pd = dataset.toPandas()
hash_bytes = pd.util.hash_pandas_object(dataset_pd).values.tobytes() + other_info.encode()
return hashlib.md5(hash_bytes).hexdigest()
Expand Down
37 changes: 36 additions & 1 deletion tests/test_data_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import polars as pl
import pytest

from pygwalker.services.data_parsers import get_parser
from pygwalker.services.data_parsers import get_parser, get_dataset_hash
from pygwalker.data_parsers.database_parser import Connector, text
from pygwalker.data_parsers.database_parser import _check_view_sql
from pygwalker.errors import ViewSqlSameColumnError
Expand Down Expand Up @@ -108,3 +108,38 @@ def test_connector():
assert connector.dialect_name == "duckdb"
assert connector.view_sql == view_sql
assert connector.url == database_url


def test_dataset_hash_no_interior_collision_pandas():
"""Regression test: DataFrames differing only in interior rows must produce different hashes."""
base = pd.DataFrame({"value": range(10000)})

df_a = base.copy()
df_a.loc[2000:7999, "value"] = 0

df_b = base.copy()
df_b.loc[2000:7999, "value"] = 99999

assert df_a["value"].sum() != df_b["value"].sum()
assert get_dataset_hash(df_a) != get_dataset_hash(df_b)


def test_dataset_hash_no_interior_collision_polars():
"""Regression test: Polars DataFrames differing only in interior rows must produce different hashes."""
base = pl.DataFrame({"value": list(range(10000))})

df_a = base.with_columns(
pl.when(pl.arange(0, pl.count()).is_between(2000, 7999))
.then(pl.lit(0))
.otherwise(pl.col("value"))
.alias("value")
)
df_b = base.with_columns(
pl.when(pl.arange(0, pl.count()).is_between(2000, 7999))
.then(pl.lit(99999))
.otherwise(pl.col("value"))
.alias("value")
)

assert df_a["value"].sum() != df_b["value"].sum()
assert get_dataset_hash(df_a) != get_dataset_hash(df_b)