11import React , { useEffect , useState } from 'react' ;
2+ import { connect } from 'react-redux' ;
23import Plotly from 'plotly.js' ;
4+ import * as PCA from 'ml-pca' ;
5+ import { Binary } from 'mongodb' ;
6+ import type { Document } from 'bson' ;
37
4- type HoverInfo = {
5- x : number ;
6- y : number ;
7- text : string ;
8- } | null ;
8+ import type { VectorEmbeddingVisualizerState } from '../stores/reducer' ;
9+ import { loadDocuments , runVectorAggregation } from '../stores/visualization' ;
10+ import { ErrorSummary , SpinLoader } from '@mongodb-js/compass-components' ;
911
10- export const VectorVisualizer : React . FC = ( ) => {
12+ type HoverInfo = { x : number ; y : number ; text : string } | null ;
13+
14+ export interface VectorVisualizerProps {
15+ onFetchDocs : ( ) => void ;
16+ onFetchAgg : ( ) => void ;
17+ docs : Document [ ] ;
18+ aggResults : { candidates : Document [ ] ; limited : Document [ ] } ;
19+ loadingDocumentsState : 'initial' | 'loading' | 'loaded' | 'error' ;
20+ loadingDocumentsError : Error | null ;
21+ }
22+
23+ function normalizeTo2D ( vectors : Binary [ ] ) : { x : number ; y : number } [ ] {
24+ const raw = vectors . map ( ( v ) => Array . from ( v . toFloat32Array ( ) ) ) ;
25+ const pca = new PCA . PCA ( raw ) ;
26+ const reduced = pca . predict ( raw , { nComponents : 2 } ) . to2DArray ( ) ;
27+ return reduced . map ( ( [ x , y ] ) => ( { x, y } ) ) ;
28+ }
29+
30+ const VectorVisualizer : React . FC < VectorVisualizerProps > = ( {
31+ onFetchDocs,
32+ onFetchAgg,
33+ docs,
34+ aggResults,
35+ loadingDocumentsState,
36+ loadingDocumentsError,
37+ } ) => {
1138 const [ hoverInfo , setHoverInfo ] = useState < HoverInfo > ( null ) ;
39+ const [ query , setQuery ] = useState < string > ( '' ) ;
40+ const [ shouldPlot , setShouldPlot ] = useState < boolean > ( false ) ;
41+ const [ loading , setLoading ] = useState < boolean > ( false ) ;
42+
43+ useEffect ( ( ) => {
44+ if ( loadingDocumentsState === 'initial' ) {
45+ onFetchDocs ( ) ;
46+ }
47+ } , [ loadingDocumentsState , onFetchDocs ] ) ;
1248
1349 useEffect ( ( ) => {
50+ if ( query ) {
51+ onFetchAgg ( ) ;
52+ setLoading ( true ) ;
53+ const timeout = setTimeout ( ( ) => {
54+ setShouldPlot ( true ) ;
55+ setLoading ( false ) ;
56+ } , 600 ) ;
57+ return ( ) => clearTimeout ( timeout ) ;
58+ }
59+ } , [ query , onFetchAgg ] ) ;
60+
61+ useEffect ( ( ) => {
62+ if ( ! shouldPlot ) return ;
63+
1464 const container = document . getElementById ( 'vector-plot' ) ;
1565 if ( ! container ) return ;
1666
17- let isMounted = true ;
67+ const abortController = new AbortController ( ) ;
1868
1969 const plot = async ( ) => {
20- await Plotly . newPlot (
21- container ,
22- [
23- {
24- x : [ 1 , 2 , 3 , 4 , 5 ] ,
25- y : [ 10 , 15 , 13 , 17 , 12 ] ,
26- mode : 'markers' ,
27- type : 'scatter' ,
28- name : 'baskd' ,
29- text : [ 'doc1' , 'doc2' , 'doc3' , 'doc4' , 'doc5' ] ,
30- hoverinfo : 'none' ,
31- marker : {
32- size : 15 ,
33- color : 'teal' ,
34- line : { width : 1 , color : '#fff' } ,
70+ try {
71+ if ( docs . length === 0 ) return ;
72+
73+ const points = normalizeTo2D (
74+ docs
75+ . map ( ( doc ) => doc . review_vec )
76+ . filter ( Boolean )
77+ . slice ( 0 , 500 )
78+ ) ;
79+
80+ const candidateIds = new Set (
81+ aggResults . candidates . map ( ( doc ) => doc . _id . toString ( ) )
82+ ) ;
83+ const limitedIds = new Set (
84+ aggResults . limited . map ( ( doc ) => doc . _id . toString ( ) )
85+ ) ;
86+
87+ await Plotly . newPlot (
88+ container ,
89+ [
90+ {
91+ x : points . map ( ( p ) => p . x ) ,
92+ y : points . map ( ( p ) => p . y ) ,
93+ mode : 'markers' ,
94+ type : 'scatter' ,
95+ text : docs . map ( ( doc ) => {
96+ const review = doc . review || '[no text]' ;
97+ return review . length > 50
98+ ? review . match ( / .{ 1 , 50 } / g) ?. join ( '<br>' ) || review
99+ : review ;
100+ } ) ,
101+ hoverinfo : 'text' ,
102+ marker : {
103+ size : 12 ,
104+ color : docs . map ( ( doc ) => {
105+ const hasLimitedId = limitedIds . has ( doc . _id . toString ( ) ) ;
106+ const hasCandidateId = candidateIds . has ( doc . _id . toString ( ) ) ;
107+ if ( hasLimitedId ) return 'red' ;
108+ if ( hasCandidateId ) return 'orange' ;
109+ return 'teal' ;
110+ } ) ,
111+ line : { width : 1 , color : '#fff' } ,
112+ } ,
35113 } ,
114+ ] ,
115+ {
116+ hovermode : 'closest' ,
117+ margin : { l : 40 , r : 10 , t : 30 , b : 30 } ,
118+ plot_bgcolor : '#f9f9f9' ,
119+ paper_bgcolor : '#f9f9f9' ,
36120 } ,
37- ] ,
38- {
39- margin : { l : 40 , r : 10 , t : 40 , b : 40 } ,
40- hovermode : 'closest' ,
41- hoverdistance : 30 ,
42- dragmode : 'zoom' ,
43- plot_bgcolor : '#f7f7f7' ,
44- paper_bgcolor : '#f7f7f7' ,
45- xaxis : { gridcolor : '#e0e0e0' } ,
46- yaxis : { gridcolor : '#e0e0e0' } ,
47- } ,
48- { responsive : true }
49- ) ;
50-
51- const handleHover = ( data : any ) => {
52- const point = data . points ?. [ 0 ] ;
53- if ( ! point ) return ;
54-
55- const containerRect = container . getBoundingClientRect ( ) ;
56- const relX = data . event . clientX - containerRect . left ;
57- const relY = data . event . clientY - containerRect . top ;
58-
59- if ( isMounted ) {
60- setHoverInfo ( { x : relX , y : relY , text : point . text } ) ;
61- }
62- } ;
63-
64- const handleUnhover = ( ) => {
65- if ( isMounted ) {
66- setHoverInfo ( null ) ;
67- }
68- } ;
69-
70- container . addEventListener ( 'plotly_hover' , handleHover ) ;
71- container . addEventListener ( 'plotly_unhover' , handleUnhover ) ;
72-
73- // Cleanup
74- return ( ) => {
75- isMounted = false ;
76- container . removeEventListener ( 'plotly_hover' , handleHover ) ;
77- container . removeEventListener ( 'plotly_unhover' , handleUnhover ) ;
78- } ;
121+ {
122+ responsive : true ,
123+ displayModeBar : false ,
124+ }
125+ ) ;
126+ } catch ( err ) {
127+ console . error ( 'VectorVisualizer error:' , err ) ;
128+ }
79129 } ;
80130
81- let cleanup : ( ( ) => void ) | undefined ;
82- void plot ( ) . then ( ( c ) => {
83- if ( typeof c === 'function' ) cleanup = c ;
84- } ) ;
131+ void plot ( ) ;
85132
86133 return ( ) => {
87- isMounted = false ;
88- if ( cleanup ) cleanup ( ) ;
134+ abortController . abort ( ) ;
89135 } ;
90- } , [ ] ) ;
136+ } , [ docs , aggResults , shouldPlot ] ) ;
137+
138+ const onInput = ( e : React . KeyboardEvent < HTMLInputElement > ) => {
139+ if ( e . key === 'Enter' ) {
140+ const inputQuery = e . currentTarget . value . trim ( ) ;
141+ if ( inputQuery ) {
142+ setQuery ( inputQuery ) ;
143+ setShouldPlot ( false ) ;
144+ }
145+ }
146+ } ;
91147
92148 return (
93149 < div style = { { position : 'relative' , width : '100%' , height : '100%' } } >
94- < div id = "vector-plot" style = { { width : '100%' , height : '100%' } } />
150+ < div
151+ style = { {
152+ marginBottom : '10px' ,
153+ display : 'flex' ,
154+ justifyContent : 'center' ,
155+ zIndex : 10 ,
156+ position : 'absolute' ,
157+ top : '10px' ,
158+ width : '100%' ,
159+ } }
160+ >
161+ < input
162+ id = "vector-input"
163+ type = "text"
164+ placeholder = "Input your vector query"
165+ style = { {
166+ width : '80%' ,
167+ padding : '8px 12px' ,
168+ fontSize : '14px' ,
169+ border : '1px solid #ccc' ,
170+ borderRadius : '4px' ,
171+ boxShadow : '0 1px 3px rgba(0, 0, 0, 0.1)' ,
172+ backgroundColor : 'white' ,
173+ } }
174+ onKeyDown = { onInput }
175+ />
176+ </ div >
177+
178+ { loading && (
179+ < div
180+ style = { {
181+ position : 'absolute' ,
182+ top : '50%' ,
183+ left : '50%' ,
184+ transform : 'translate(-50%, -50%)' ,
185+ zIndex : 1000 ,
186+ } }
187+ >
188+ < SpinLoader />
189+ </ div >
190+ ) }
191+
192+ < div
193+ id = "vector-plot"
194+ style = { { width : '100%' , height : '100%' , cursor : 'default' } }
195+ />
196+
197+ { loadingDocumentsError && (
198+ < ErrorSummary errors = { loadingDocumentsError . message } />
199+ ) }
200+
95201 { hoverInfo && (
96202 < div
97203 style = { {
@@ -103,8 +209,8 @@ export const VectorVisualizer: React.FC = () => {
103209 padding : '4px 8px' ,
104210 borderRadius : 4 ,
105211 pointerEvents : 'none' ,
106- whiteSpace : 'nowrap' ,
107212 zIndex : 1000 ,
213+ whiteSpace : 'nowrap' ,
108214 } }
109215 >
110216 { hoverInfo . text }
@@ -113,3 +219,16 @@ export const VectorVisualizer: React.FC = () => {
113219 </ div >
114220 ) ;
115221} ;
222+
223+ export default connect (
224+ ( state : VectorEmbeddingVisualizerState ) => ( {
225+ docs : state . visualization . docs ,
226+ aggResults : state . visualization . aggResults ,
227+ loadingDocumentsState : state . visualization . loadingDocumentsState ,
228+ loadingDocumentsError : state . visualization . loadingDocumentsError ,
229+ } ) ,
230+ {
231+ onFetchDocs : loadDocuments ,
232+ onFetchAgg : runVectorAggregation ,
233+ }
234+ ) ( VectorVisualizer ) ;
0 commit comments