Skip to content

Commit a0a7dc9

Browse files
committed
Merge branch 'm-kovalsky/vpaxfix'
2 parents 1740a58 + c0d3d74 commit a0a7dc9

File tree

2 files changed

+113
-59
lines changed

2 files changed

+113
-59
lines changed

src/sempy_labs/_helper_functions.py

Lines changed: 94 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,15 @@ def create_abfss_path(
7474
return path
7575

7676

77+
def create_abfss_path_from_path(
78+
lakehouse_id: UUID, workspace_id: UUID, file_path: str
79+
) -> str:
80+
81+
fp = _get_default_file_path()
82+
83+
return f"abfss://{workspace_id}@{fp}/{lakehouse_id}/{file_path}"
84+
85+
7786
def _get_default_file_path() -> str:
7887

7988
default_file_storage = _get_fabric_context_setting(name="fs.defaultFS")
@@ -1547,31 +1556,10 @@ def _get_column_aggregate(
15471556
path = create_abfss_path(lakehouse_id, workspace_id, table_name, schema_name)
15481557
df = _read_delta_table(path)
15491558

1550-
if isinstance(column_name, str):
1551-
result = _get_aggregate(
1552-
df=df,
1553-
column_name=column_name,
1554-
function=function,
1555-
default_value=default_value,
1556-
)
1557-
elif isinstance(column_name, list):
1558-
result = {}
1559-
for col in column_name:
1560-
result[col] = _get_aggregate(
1561-
df=df,
1562-
column_name=col,
1563-
function=function,
1564-
default_value=default_value,
1565-
)
1566-
else:
1567-
raise TypeError("column_name must be a string or a list of strings.")
1568-
1569-
return result
1570-
1571-
1572-
def _get_aggregate(df, column_name, function, default_value: int = 0) -> int:
1559+
function = function.lower()
15731560

1574-
function = function.upper()
1561+
if isinstance(column_name, str):
1562+
column_name = [column_name]
15751563

15761564
if _pure_python_notebook():
15771565
import polars as pl
@@ -1581,36 +1569,76 @@ def _get_aggregate(df, column_name, function, default_value: int = 0) -> int:
15811569

15821570
df = pl.from_pandas(df)
15831571

1584-
# Perform aggregation
1585-
if "DISTINCT" in function:
1586-
if isinstance(df[column_name].dtype, pl.Decimal):
1587-
result = df[column_name].cast(pl.Float64).n_unique()
1572+
def get_expr(col):
1573+
col_dtype = df.schema[col]
1574+
1575+
if "approx" in function:
1576+
return pl.col(col).unique().count().alias(col)
1577+
elif "distinct" in function:
1578+
if col_dtype == pl.Decimal:
1579+
return pl.col(col).cast(pl.Float64).n_unique().alias(col)
1580+
else:
1581+
return pl.col(col).n_unique().alias(col)
1582+
elif function == "sum":
1583+
return pl.col(col).sum().alias(col)
1584+
elif function == "min":
1585+
return pl.col(col).min().alias(col)
1586+
elif function == "max":
1587+
return pl.col(col).max().alias(col)
1588+
elif function == "count":
1589+
return pl.col(col).count().alias(col)
1590+
elif function in {"avg", "mean"}:
1591+
return pl.col(col).mean().alias(col)
15881592
else:
1589-
result = df[column_name].n_unique()
1590-
elif "APPROX" in function:
1591-
result = df[column_name].unique().shape[0]
1592-
else:
1593-
try:
1594-
result = getattr(df[column_name], function.lower())()
1595-
except AttributeError:
15961593
raise ValueError(f"Unsupported function: {function}")
15971594

1598-
return result if result is not None else default_value
1595+
exprs = [get_expr(col) for col in column_name]
1596+
aggs = df.select(exprs).to_dict(as_series=False)
1597+
1598+
if len(column_name) == 1:
1599+
result = aggs[column_name[0]][0] or default_value
1600+
else:
1601+
result = {col: aggs[col][0] for col in column_name}
15991602
else:
1600-
from pyspark.sql.functions import approx_count_distinct
1601-
from pyspark.sql import functions as F
1603+
from pyspark.sql.functions import (
1604+
count,
1605+
sum,
1606+
min,
1607+
max,
1608+
avg,
1609+
approx_count_distinct,
1610+
countDistinct,
1611+
)
16021612

1603-
if isinstance(df, pd.DataFrame):
1604-
df = _create_spark_dataframe(df)
1613+
result = None
1614+
if "approx" in function:
1615+
spark_func = approx_count_distinct
1616+
elif "distinct" in function:
1617+
spark_func = countDistinct
1618+
elif function == "count":
1619+
spark_func = count
1620+
elif function == "sum":
1621+
spark_func = sum
1622+
elif function == "min":
1623+
spark_func = min
1624+
elif function == "max":
1625+
spark_func = max
1626+
elif function == "avg":
1627+
spark_func = avg
1628+
else:
1629+
raise ValueError(f"Unsupported function: {function}")
1630+
1631+
agg_exprs = []
1632+
for col in column_name:
1633+
agg_exprs.append(spark_func(col).alias(col))
16051634

1606-
if "DISTINCT" in function:
1607-
result = df.select(F.count_distinct(F.col(column_name)))
1608-
elif "APPROX" in function:
1609-
result = df.select(approx_count_distinct(column_name))
1635+
aggs = df.agg(*agg_exprs).collect()[0]
1636+
if len(column_name) == 1:
1637+
result = aggs[0] or default_value
16101638
else:
1611-
result = df.selectExpr(f"{function}({column_name})")
1639+
result = {col: aggs[col] for col in column_name}
16121640

1613-
return result.collect()[0][0] or default_value
1641+
return result
16141642

16151643

16161644
def _create_spark_dataframe(df: pd.DataFrame):
@@ -2222,3 +2250,23 @@ def _xml_to_dict(element):
22222250
element.text.strip() if element.text and element.text.strip() else None
22232251
)
22242252
return data
2253+
2254+
2255+
def file_exists(file_path: str) -> bool:
2256+
"""
2257+
Check if a file exists in the given path.
2258+
2259+
Parameters
2260+
----------
2261+
file_path : str
2262+
The path to the file.
2263+
2264+
Returns
2265+
-------
2266+
bool
2267+
True if the file exists, False otherwise.
2268+
"""
2269+
2270+
import notebookutils
2271+
2272+
return len(notebookutils.fs.ls(file_path)) > 0

src/sempy_labs/_vpax.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
_mount,
1414
_get_column_aggregate,
1515
resolve_item_type,
16+
file_exists,
17+
create_abfss_path_from_path,
1618
)
1719
import sempy_labs._icons as icons
18-
from sempy_labs.lakehouse._blobs import list_blobs
19-
from sempy_labs.tom import connect_semantic_model
2020
import zipfile
2121
import requests
2222

@@ -200,19 +200,20 @@ def create_vpax(
200200
local_path = _mount(lakehouse=lakehouse_id, workspace=lakehouse_workspace_id)
201201
if file_path is None:
202202
file_path = dataset_name
203-
path = f"{local_path}/Files/{file_path}.vpax"
203+
204+
if file_path.endswith(".vpax"):
205+
file_path = file_path[:-5]
206+
save_location = f"Files/{file_path}.vpax"
207+
path = f"{local_path}/{save_location}"
204208

205209
# Check if the .vpax file already exists in the lakehouse
206210
if not overwrite:
207-
df = list_blobs(
208-
lakehouse=lakehouse_id,
209-
workspace=lakehouse_workspace_id,
210-
container="Files",
211+
new_path = create_abfss_path_from_path(
212+
lakehouse_id, lakehouse_workspace_id, save_location
211213
)
212-
df_filt = df[df["Blob Name"] == f"{lakehouse_id}/Files/{file_path}.vpax"]
213-
if not df_filt.empty:
214+
if file_exists(new_path):
214215
print(
215-
f"{icons.warning} The Files/{file_path}.vpax file already exists in the '{lakehouse_name}' lakehouse. Set overwrite=True to overwrite the file."
216+
f"{icons.warning} The {save_location} file already exists in the '{lakehouse_name}' lakehouse. Set overwrite=True to overwrite the file."
216217
)
217218
return
218219

@@ -240,6 +241,8 @@ def create_vpax(
240241
tom_database = TomExtractor.GetDatabase(connection_string)
241242

242243
# Calculate Direct Lake stats for columns which are IsResident=False
244+
from sempy_labs.tom import connect_semantic_model
245+
243246
with connect_semantic_model(dataset=dataset, workspace=workspace) as tom:
244247
is_direct_lake = tom.is_direct_lake()
245248
if read_stats_from_data and is_direct_lake and direct_lake_stats_mode == "Full":
@@ -256,7 +259,6 @@ def create_vpax(
256259

257260
# For SQL endpoints (do once)
258261
dfI = fabric.list_items(workspace=workspace)
259-
260262
# Get list of tables in Direct Lake mode which have columns that are not resident
261263
tbls = [
262264
t
@@ -331,7 +333,7 @@ def create_vpax(
331333
table_name=entity_name,
332334
schema_name=schema_name,
333335
column_name=list(col_dict.values()),
334-
function="distinctcount",
336+
function="distinct",
335337
)
336338
column_cardinalities = {
337339
column_name: col_agg[source_column]
@@ -343,15 +345,19 @@ def create_vpax(
343345
tbl = next(
344346
table
345347
for table in dax_model.Tables
346-
if str(t.TableName) == table_name
348+
if str(table.TableName) == table_name
347349
)
350+
# print(
351+
# f"{icons.in_progress} Calculating column cardinalities for the '{table_name}' table..."
352+
# )
348353
cols = [
349354
col
350355
for col in tbl.Columns
351356
if str(col.ColumnType) != "RowNumber"
352357
and str(col.ColumnName) in column_cardinalities
353358
]
354359
for col in cols:
360+
# print(str(col.ColumnName), col.ColumnCardinality)
355361
col.ColumnCardinality = column_cardinalities.get(
356362
str(col.ColumnName)
357363
)

0 commit comments

Comments
 (0)