Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/compass-telemetry/src/telemetry-events.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2632,6 +2632,7 @@ type ScreenEvent = ConnectionScopedEvent<{
| 'my_queries'
| 'performance'
| 'schema'
| 'vector_visualizer'
| 'validation'
| 'confirm_new_pipeline_modal'
| 'create_collection_modal'
Expand Down
4 changes: 3 additions & 1 deletion packages/compass-vector-embedding-visualizer/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"@leafygreen-ui/tooltip": "^13.0.12",
"@types/plotly.js": "^3.0.0",
"ml-pca": "^4.1.1",
"mongodb": "^6.16.0",
"plotly.js": "^3.0.1",
"react": "^17.0.2",
"react-dom": "^17.0.2"
Expand All @@ -65,6 +66,7 @@
"@types/chai": "^4.2.21",
"@types/chai-dom": "^0.0.10",
"@types/mocha": "^9.0.0",
"@types/mongodb": "^4.0.6",
"@types/react": "^17.0.5",
"@types/react-dom": "^17.0.10",
"@types/sinon-chai": "^3.2.5",
Expand All @@ -74,7 +76,7 @@
"mocha": "^10.2.0",
"nyc": "^15.1.0",
"sinon": "^17.0.1",
"typescript": "^5.0.4",
"typescript": "^5.8.3",
"xvfb-maybe": "^0.2.1"
},
"is_compass_plugin": true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,97 +1,143 @@
import React, { useEffect, useState } from 'react';
import { connect } from 'react-redux';
import Plotly from 'plotly.js';
import * as PCA from 'ml-pca';
import type { Binary } from 'mongodb';
import type { Document } from 'bson';

type HoverInfo = {
x: number;
y: number;
text: string;
} | null;
import type { VectorEmbeddingVisualizerState } from '../stores/reducer';
import { loadDocuments } from '../stores/visualization';
import { ErrorSummary } from '@mongodb-js/compass-components';
import { collectionModelLocator } from '@mongodb-js/compass-app-stores/provider';

export const VectorVisualizer: React.FC = () => {
type HoverInfo = { x: number; y: number; text: string } | null;

export interface VectorVisualizerProps {
onFetchDocs: () => void;
docs: Document[];
loadingDocumentsState: 'initial' | 'loading' | 'loaded' | 'error';
loadingDocumentsError: Error | null;
}

function normalizeTo2D(vectors: Binary[]): { x: number; y: number }[] {
const raw = vectors.map((v) => Array.from(v.toFloat32Array()));
const pca = new PCA.PCA(raw);
const reduced = pca.predict(raw, { nComponents: 2 }).to2DArray();
return reduced.map(([x, y]) => ({ x, y }));
}

const VectorVisualizer: React.FC<VectorVisualizerProps> = ({
onFetchDocs,
docs,
loadingDocumentsState,
loadingDocumentsError,
}) => {
const [hoverInfo, setHoverInfo] = useState<HoverInfo>(null);

useEffect(() => {
if (loadingDocumentsState === 'initial') {
// Fetch the documents when the component mounts when they aren't already loaded.
onFetchDocs();
}
}, [loadingDocumentsState, onFetchDocs]);

useEffect(() => {
const container = document.getElementById('vector-plot');
if (!container) return;

let isMounted = true;
const abortController = new AbortController();

const plot = async () => {
await Plotly.newPlot(
container,
[
{
x: [1, 2, 3, 4, 5],
y: [10, 15, 13, 17, 12],
mode: 'markers',
type: 'scatter',
name: 'baskd',
text: ['doc1', 'doc2', 'doc3', 'doc4', 'doc5'],
hoverinfo: 'none',
marker: {
size: 15,
color: 'teal',
line: { width: 1, color: '#fff' },
try {
const vectors = docs.map((doc) => doc.review_vec).filter(Boolean);

if (!vectors.length) return;

const points = normalizeTo2D(vectors.slice(0, 50));

await Plotly.newPlot(
container,
[
{
x: points.map((p) => p.x),
y: points.map((p) => p.y),
mode: 'markers',
type: 'scatter',
text: docs.map((doc) => doc.review || '[no text]'),
hoverinfo: 'none',
marker: {
size: 12,
color: 'teal',
line: { width: 1, color: '#fff' },
},
},
],
{
hovermode: 'closest',
margin: { l: 40, r: 10, t: 30, b: 30 },
plot_bgcolor: '#f9f9f9',
paper_bgcolor: '#f9f9f9',
},
],
{
margin: { l: 40, r: 10, t: 40, b: 40 },
hovermode: 'closest',
hoverdistance: 30,
dragmode: 'zoom',
plot_bgcolor: '#f7f7f7',
paper_bgcolor: '#f7f7f7',
xaxis: { gridcolor: '#e0e0e0' },
yaxis: { gridcolor: '#e0e0e0' },
},
{ responsive: true }
);

const handleHover = (data: any) => {
const point = data.points?.[0];
if (!point) return;

const containerRect = container.getBoundingClientRect();
const relX = data.event.clientX - containerRect.left;
const relY = data.event.clientY - containerRect.top;

if (isMounted) {
setHoverInfo({ x: relX, y: relY, text: point.text });
}
};

const handleUnhover = () => {
if (isMounted) {
setHoverInfo(null);
}
};

container.addEventListener('plotly_hover', handleHover);
container.addEventListener('plotly_unhover', handleUnhover);

// Cleanup
return () => {
isMounted = false;
container.removeEventListener('plotly_hover', handleHover);
container.removeEventListener('plotly_unhover', handleUnhover);
};
{ responsive: true }
);

const handleHover = (event: Event) => {
const e = event as CustomEvent<{
points: { text: string }[];
event: MouseEvent;
}>;

const point = e.detail?.points?.[0];
const mouse = e.detail?.event;
if (!point || !mouse) return;

const rect = container.getBoundingClientRect();
setHoverInfo({
x: mouse.clientX - rect.left,
y: mouse.clientY - rect.top,
text: point.text,
});
};

const handleUnhover = () => setHoverInfo(null);

container.addEventListener(
'plotly_hover',
handleHover as EventListener
);
container.addEventListener(
'plotly_unhover',
handleUnhover as EventListener
);

return () => {
container.removeEventListener(
'plotly_hover',
handleHover as EventListener
);
container.removeEventListener(
'plotly_unhover',
handleUnhover as EventListener
);
};
} catch (err) {
console.error('VectorVisualizer error:', err);
}
};

let cleanup: (() => void) | undefined;
void plot().then((c) => {
if (typeof c === 'function') cleanup = c;
});
void plot();

return () => {
isMounted = false;
if (cleanup) cleanup();
abortController.abort();
};
}, []);
}, [docs]);

return (
<div style={{ position: 'relative', width: '100%', height: '100%' }}>
<div id="vector-plot" style={{ width: '100%', height: '100%' }} />
{loadingDocumentsError && (
<ErrorSummary errors={loadingDocumentsError.message} />
)}
{hoverInfo && (
<div
style={{
Expand All @@ -103,8 +149,8 @@ export const VectorVisualizer: React.FC = () => {
padding: '4px 8px',
borderRadius: 4,
pointerEvents: 'none',
whiteSpace: 'nowrap',
zIndex: 1000,
whiteSpace: 'nowrap',
}}
>
{hoverInfo.text}
Expand All @@ -113,3 +159,14 @@ export const VectorVisualizer: React.FC = () => {
</div>
);
};

export default connect(
(state: VectorEmbeddingVisualizerState) => ({
docs: state.visualization.docs,
loadingDocumentsState: state.visualization.loadingDocumentsState,
loadingDocumentsError: state.visualization.loadingDocumentsError,
}),
{
onFetchDocs: loadDocuments,
}
)(VectorVisualizer);
Loading
Loading