Skip to content

Commit c2aabfd

Browse files
committed
Improved embedding for larger datasets
1 parent 841bba0 commit c2aabfd

File tree

7 files changed

+269
-13
lines changed

7 files changed

+269
-13
lines changed

smoosense-gui/src/components/emb/Umap2D.tsx

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { useState, useEffect, useCallback, useMemo } from 'react'
44
import _ from 'lodash'
55
import { useAppSelector } from '@/lib/hooks'
66
import { useSingleColumnRenderType } from '@/lib/hooks/useRenderType'
7+
import { BUBBLE_SIZE_COLOR_VALUE } from '@/lib/features/bubblePlot/BubblePlotMoreControls'
78
import Umap2DControls from './Umap2DControls'
89
import Umap2DScatterPlot from './Umap2DScatterPlot'
910
import TextPlaceHolder from '@/components/common/TextPlaceHolder'
@@ -14,6 +15,9 @@ export interface UmapResult {
1415
columnValues: Record<string, unknown[]>
1516
count: number
1617
runtime: number
18+
sampled: boolean
19+
totalRows: number
20+
maxRows: number
1721
params: {
1822
nNeighbors: number
1923
minDist: number
@@ -37,7 +41,9 @@ export default function Umap2D({ onResultChange, onSelectionChange }: Umap2DProp
3741
const visualColumn = useAppSelector((state) => state.ui.columnForGalleryVisual)
3842
const captionColumn = useAppSelector((state) => state.ui.columnForGalleryCaption)
3943
const breakdownColumn = useAppSelector((state) => state.ui.bubblePlotBreakdownColumn)
40-
const colorColumn = useAppSelector((state) => state.ui.bubblePlotColorColumn)
44+
const rawColorColumn = useAppSelector((state) => state.ui.bubblePlotColorColumn)
45+
// Filter out __bubble_size__ which is only valid for bubble plots
46+
const colorColumn = rawColorColumn === BUBBLE_SIZE_COLOR_VALUE ? '' : rawColorColumn
4147
const nNeighbors = useAppSelector((state) => state.ui.umapNNeighbors)
4248
const minDist = useAppSelector((state) => state.ui.umapMinDist)
4349
const visualRenderType = useSingleColumnRenderType(visualColumn || '')
@@ -85,6 +91,9 @@ export default function Umap2D({ onResultChange, onSelectionChange }: Umap2DProp
8591
columnValues: data.columnValues || {},
8692
count: data.count,
8793
runtime: data.runtime,
94+
sampled: data.sampled || false,
95+
totalRows: data.totalRows || data.count,
96+
maxRows: data.maxRows || 10000,
8897
params: data.params,
8998
}
9099
setResult(newResult)
@@ -175,8 +184,15 @@ export default function Umap2D({ onResultChange, onSelectionChange }: Umap2DProp
175184
visualRenderType={visualRenderType}
176185
onSelectionChange={handleSelectionChange}
177186
/>
178-
<div className="flex-shrink-0 px-4 py-2 text-xs text-muted-foreground border-t">
179-
{result.count} points | n_neighbors={result.params?.nNeighbors ?? nNeighbors} | min_dist={result.params?.minDist ?? minDist} | {result.runtime.toFixed(2)}s
187+
<div className="flex-shrink-0 px-4 py-2 text-xs text-muted-foreground border-t flex justify-between">
188+
<span>
189+
{result.count} points | n_neighbors={result.params?.nNeighbors ?? nNeighbors} | min_dist={result.params?.minDist ?? minDist}{result.runtime != null && ` | ${result.runtime.toFixed(2)}s`}
190+
</span>
191+
{result.sampled && (
192+
<span className="text-amber-600 dark:text-amber-400">
193+
Sampled {result.count.toLocaleString()} of {result.totalRows.toLocaleString()} rows
194+
</span>
195+
)}
180196
</div>
181197
</div>
182198
) : (

smoosense-gui/src/components/emb/Umap2DScatterPlot.tsx

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ const Umap2DScatterPlot = React.memo(function Umap2DScatterPlot({
6868
setSelectedIndices([])
6969
}, [x, y])
7070

71+
const hasColorValues = colorValues && colorValues.length > 0
72+
7173
const plotData = useMemo((): Partial<PlotData>[] => {
7274
if (!x || !y || x.length === 0) {
7375
return []
@@ -89,11 +91,11 @@ const Umap2DScatterPlot = React.memo(function Umap2DScatterPlot({
8991
const traces: Partial<PlotData>[] = []
9092
const groupNames = Array.from(groups.keys()).sort()
9193

92-
groupNames.forEach((groupName) => {
94+
groupNames.forEach((groupName, groupIndex) => {
9395
const indices = groups.get(groupName)!
9496
const groupX = indices.map(i => x[i])
9597
const groupY = indices.map(i => y[i])
96-
const groupColorValues = colorValues ? indices.map(i => colorValues[i]) : undefined
98+
const groupColorValues = hasColorValues ? indices.map(i => colorValues[i]) : undefined
9799
const customdata = indices.map(i => ({ index: i }))
98100

99101
// Determine marker color
@@ -112,7 +114,12 @@ const Umap2DScatterPlot = React.memo(function Umap2DScatterPlot({
112114
size: markerSize,
113115
opacity: opacity,
114116
color: markerColor,
115-
colorscale: colorValues ? colorScale : undefined,
117+
colorscale: hasColorValues ? colorScale : undefined,
118+
showscale: hasColorValues && groupIndex === 0, // Only show colorbar on first trace
119+
colorbar: hasColorValues ? {
120+
thickness: 15,
121+
len: 0.5
122+
} : undefined,
116123
line: {
117124
width: 0.5,
118125
color: colors.foreground
@@ -130,7 +137,7 @@ const Umap2DScatterPlot = React.memo(function Umap2DScatterPlot({
130137
// Single trace (no breakdown)
131138
// Determine marker color based on colorValues or selection
132139
let markerColor: string[] | number[]
133-
if (colorValues && colorValues.length > 0) {
140+
if (hasColorValues) {
134141
markerColor = colorValues
135142
} else {
136143
markerColor = x.map((_, i) =>
@@ -150,7 +157,12 @@ const Umap2DScatterPlot = React.memo(function Umap2DScatterPlot({
150157
size: markerSize,
151158
opacity: opacity,
152159
color: markerColor,
153-
colorscale: colorValues ? colorScale : undefined,
160+
colorscale: hasColorValues ? colorScale : undefined,
161+
showscale: hasColorValues,
162+
colorbar: hasColorValues ? {
163+
thickness: 15,
164+
len: 0.5
165+
} : undefined,
154166
line: {
155167
width: 0.5,
156168
color: colors.foreground
@@ -160,7 +172,7 @@ const Umap2DScatterPlot = React.memo(function Umap2DScatterPlot({
160172
customdata: customdata as any,
161173
hoverinfo: 'none', // Disable default hover
162174
}]
163-
}, [x, y, selectedIndices, colors.primary, colors.foreground, markerSize, opacity, colorScale, breakdownValues, colorValues])
175+
}, [x, y, selectedIndices, colors.primary, colors.foreground, markerSize, opacity, colorScale, breakdownValues, colorValues, hasColorValues])
164176

165177
const hasBreakdown = breakdownValues && breakdownValues.length > 0
166178

@@ -172,6 +184,7 @@ const Umap2DScatterPlot = React.memo(function Umap2DScatterPlot({
172184
...baseLayout,
173185
dragmode: 'lasso',
174186
hovermode: 'closest',
187+
showlegend: hasBreakdown,
175188
margin: {
176189
l: 0,
177190
r: 0,

smoosense-gui/src/lib/features/ui/uiSlice.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ const initialState: UiState = {
8989
bubblePlotMinMarkerSize: 7,
9090
bubblePlotOpacity: 0.7,
9191
bubblePlotMarkerSizeContrastRatio: 4.2,
92-
bubblePlotColorColumn: '__bubble_size__',
92+
bubblePlotColorColumn: '',
9393
bubblePlotColorScale: 'Jet',
9494
bubblePlotLogScaleX: false,
9595
bubblePlotLogScaleY: false,

smoosense-gui/src/lib/utils/cellRenderers/EmbeddingCellRenderer.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ function SimilarRowsGallery({ embedding, columnName }: SimilarRowsGalleryProps)
9292
<div className="flex flex-col h-full">
9393
<GalleryControls showRandom={false} />
9494
<div className="px-2 py-1 text-xs text-muted-foreground">
95-
Cosine distance (0 = identical, 1 = orthogonal, 2 = opposite; may have small error due to indexing)
95+
Cosine distance (0 = identical, 1 = orthogonal, 2 = opposite; may have small error due to vector indexing)
9696
</div>
9797
<div className="flex-1 overflow-auto p-2">
9898
{_.isEmpty(rows) ? (
@@ -114,7 +114,7 @@ function SimilarRowsGallery({ embedding, columnName }: SimilarRowsGalleryProps)
114114
return (
115115
<div key={index} className="relative">
116116
<div className="absolute top-1 left-1 z-10 bg-black/70 text-white text-xs px-1.5 py-0.5 rounded-full">
117-
{distance.toFixed(3)}
117+
distance: {distance.toFixed(3)}
118118
</div>
119119
<GalleryItem
120120
row={row}

smoosense-py/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ build-backend = "setuptools.build_meta"
100100
[project.scripts]
101101
sense = "smoosense.cli:main"
102102
sense-ingest-images = "smoosense.images.ingest:main"
103+
parquet-to-lance = "smoosense.lance.parquet_to_lance:main"
103104

104105
# Ruff configuration
105106
[tool.ruff]

smoosense-py/smoosense/handlers/umap.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
logger = logging.getLogger(__name__)
1414
umap_bp = Blueprint("umap", __name__)
1515

16+
# Maximum number of rows to compute UMAP on (random sample if exceeded)
17+
UMAP_MAX_ROWS = 1_000
1618

1719
@umap_bp.post("/umap")
1820
@requires_auth_api
@@ -86,19 +88,34 @@ def compute_umap() -> Response:
8688
if len(embeddings) < 2:
8789
raise ValueError("Not enough embeddings to compute UMAP (need at least 2)")
8890

91+
# Random sample if exceeds max rows
92+
sampled = False
93+
total_rows = len(embeddings)
94+
if total_rows > UMAP_MAX_ROWS:
95+
logger.info(f"Random sampling UMAP input from {total_rows} to {UMAP_MAX_ROWS} rows")
96+
rng = np.random.default_rng(seed=42)
97+
indices = rng.choice(total_rows, size=UMAP_MAX_ROWS, replace=False)
98+
indices.sort() # Keep relative order for consistency
99+
embeddings = [embeddings[i] for i in indices]
100+
for col in extra_values:
101+
extra_values[col] = [extra_values[col][i] for i in indices]
102+
sampled = True
103+
89104
# Convert to numpy array
90105
embeddings_array = np.array(embeddings, dtype=np.float32)
91106

92107
# Adjust n_neighbors if larger than dataset
93108
actual_n_neighbors = min(n_neighbors, len(embeddings) - 1)
94109

95-
# Compute UMAP
110+
# Compute UMAP with performance optimizations
96111
reducer = umap.UMAP(
97112
n_neighbors=actual_n_neighbors,
98113
min_dist=min_dist,
99114
n_components=2,
100115
metric="cosine",
101116
random_state=42,
117+
low_memory=False, # Trade memory for speed
118+
n_jobs=-1, # Use all CPU cores
102119
)
103120
projection = reducer.fit_transform(embeddings_array)
104121

@@ -113,6 +130,9 @@ def compute_umap() -> Response:
113130
"y": y_coords,
114131
"columnValues": {col: serialize(vals) for col, vals in extra_values.items()},
115132
"count": len(x_coords),
133+
"sampled": sampled,
134+
"totalRows": total_rows,
135+
"maxRows": UMAP_MAX_ROWS,
116136
"runtime": default_timer() - time_start,
117137
"params": {
118138
"nNeighbors": actual_n_neighbors,

0 commit comments

Comments
 (0)