@@ -18,6 +18,7 @@ import forEach from 'lodash/forEach';
1818import { useModelProviders } from '../Settings/hooks' ;
1919import { ModelProviderType } from '../Settings/AddModelProviderButton' ;
2020import { CustomModel } from '../Settings/SettingsPage' ;
21+ import filter from 'lodash/filter' ;
2122
2223
2324const StepContainer = styled ( Flex ) `
@@ -52,14 +53,21 @@ export const WORKFLOW_OPTIONS = [
5253export const MODEL_TYPE_OPTIONS : ModelProvidersDropdownOpts = [
5354 { label : MODEL_PROVIDER_LABELS [ ModelProviders . BEDROCK ] , value : ModelProviders . BEDROCK } ,
5455 { label : MODEL_PROVIDER_LABELS [ ModelProviders . CAII ] , value : ModelProviders . CAII } ,
56+ { label : MODEL_PROVIDER_LABELS [ ModelProviders . OPENAI ] , value : ModelProviders . OPENAI } ,
57+ { label : MODEL_PROVIDER_LABELS [ ModelProviders . GEMINI ] , value : ModelProviders . GEMINI } ,
5558] ;
5659
5760const Configure : FunctionComponent = ( ) => {
5861 const form = Form . useFormInstance ( ) ;
5962 const formData = Form . useWatch ( ( values ) => values , form ) ;
6063 const location = useLocation ( ) ;
6164 const { template_name, generate_file_name } = useParams ( ) ;
65+ const [ models , setModels ] = useState < string [ ] > ( [ ] )
6266 const [ wizardModeType , setWizardModeType ] = useState ( getWizardModeType ( location ) ) ;
67+ const { data } = useFetchModels ( ) ;
68+ const customModelPrividersReq = useModelProviders ( ) ;
69+ const customModels = get ( customModelPrividersReq , 'data.endpoints' , [ ] ) ;
70+ console . log ( 'customModels' , customModels ) ;
6371
6472 useEffect ( ( ) => {
6573 if ( wizardModeType === WizardModeType . DATA_AUGMENTATION ) {
@@ -82,11 +90,16 @@ const Configure: FunctionComponent = () => {
8290 }
8391 } , [ template_name ] ) ;
8492
93+ useEffect ( ( ) => {
94+ // set model providers
95+ // set model ids
96+ console . log ( 'useEffect' , formData ) ;
97+
98+ } , [ customModels , formData ] ) ;
99+
85100
86101 // let formData = Form.useWatch((values) => values, form);
87102 const { setIsStepValid } = useWizardCtx ( ) ;
88- const { data } = useFetchModels ( ) ;
89- const customModelPrividersReq = useModelProviders ( ) ;
90103 const [ selectedFiles , setSelectedFiles ] = useState (
91104 ! isEmpty ( form . getFieldValue ( 'doc_paths' ) ) ? form . getFieldValue ( 'doc_paths' ) : [ ] ) ;
92105
@@ -110,7 +123,6 @@ const Configure: FunctionComponent = () => {
110123
111124
112125 useEffect ( ( ) => {
113- console . log ( 'useEffect 1' ) ;
114126 if ( formData && formData ?. inference_type === undefined && isEmpty ( generate_file_name ) ) {
115127 form . setFieldValue ( 'inference_type' , ModelProviders . CAII ) ;
116128 setTimeout ( ( ) => {
@@ -161,42 +173,18 @@ const Configure: FunctionComponent = () => {
161173 }
162174 }
163175
164- // const getModelsGroupOptions = (models: any) => {
165- // if (isEmpty(models)) {
166- // return [];
167- // }
168- // const options = [];
169- // const modelTypes = Object.keys(models);
170- // console.log('modelTypes', modelTypes);
171- // forEach(modelTypes, (modelType: string) => {
172- // const models = get(modelTypes, modelType);
173- // console.log('models', models);
174- // if (!isEmpty(models)) {
175- // const children = models.map((model: string) => ({
176- // label: <span>${model}</span>,
177- // value: model
178- // }))
179- // const groupOption = {
180- // label: <span>${modelType}</span>,
181- // title: modelType,
182- // options: children,
183- // };
184- // options.push(groupOption);
185- // }
186- // });
187- // }
188-
189- console . log ( 'data?.models' , data ?. models ) ;
190- const customModels = get ( customModelPrividersReq , 'data.endpoints' , [ ] ) ;
191- const customModelIds : string [ ] = [ ] ;
192- forEach ( customModels , ( model : CustomModel ) => {
193- if ( model . provider_type === ModelProviderType . GEMINIE ||
194- model . provider_type === ModelProviderType . OPENAI
195- ) {
196- customModelIds . push ( model . model_id )
176+ const onModelProviderChange = ( value : string ) => {
177+ form . setFieldValue ( 'model_id' , undefined )
178+ console . log ( 'value' , value ) ;
179+ if ( ModelProviderType . OPENAI === value ) {
180+ const _models = filter ( customModels , ( model : CustomModel ) => model . provider_type === ModelProviderType . OPENAI ) ;
181+ setModels ( _models ) ;
182+ } else if ( ModelProviderType . GEMINIE === value ) {
183+ const _models = filter ( customModels , ( model : CustomModel ) => model . provider_type === ModelProviderType . GEMINIE ) ;
184+ setModels ( _models ) ;
197185 }
198- } ) ;
199- console . log ( 'customModelIds' , customModelIds ) ;
186+ }
187+
200188
201189 return (
202190 < StepContainer justify = 'center' >
@@ -221,7 +209,7 @@ const Configure: FunctionComponent = () => {
221209 >
222210 < Select
223211
224- onChange = { ( ) => form . setFieldValue ( 'model_id' , undefined ) }
212+ onChange = { ( value : string ) => onModelProviderChange ( value ) }
225213 placeholder = { 'Select a model provider' }
226214 >
227215 { MODEL_TYPE_OPTIONS . map ( ( { label, value } , i ) =>
@@ -247,31 +235,16 @@ const Configure: FunctionComponent = () => {
247235 placeholder = { 'Select a Model' }
248236 notFoundContent = { 'You must select a Model Provider before selecting a Model' }
249237 >
250- { ! isEmpty ( data ?. models ) && ! isEmpty ( data ?. models ?. [ ModelProviders . BEDROCK ] ) && (
251- < Select . OptGroup label = "AWS-Bedrock" >
252- { data ?. models ?. [ ModelProviders . BEDROCK ] ?. map ( ( model , i ) => (
253- < Select . Option key = { `${ model } -${ i } ` } value = { model } >
254- { model }
255- </ Select . Option >
256- ) ) }
257- </ Select . OptGroup >
258- ) }
259-
260- { /* Add custom model providers here as needed */ }
261- { ! isEmpty ( customModelIds ) && (
262- < Select . OptGroup label = "Custom" >
263- { customModelIds . map ( ( model_id , i ) => {
264- console . log ( 'model' , model_id , i )
265- return (
266- < Select . Option key = { `${ model_id } -${ i } ` } value = { model_id } >
267- { model_id }
268- </ Select . Option >
269- ) ;
270-
271-
272- } ) }
273- </ Select . OptGroup >
274- ) }
238+ { formData ?. inference_type === ModelProviders . BEDROCK && data ?. models ?. [ ModelProviders . BEDROCK ] ?. map ( ( model , i ) => (
239+ < Select . Option key = { `${ model } -${ i } ` } value = { model } >
240+ { model }
241+ </ Select . Option >
242+ ) ) }
243+ { ( formData ?. inference_type === ModelProviders . OPENAI || formData ?. inference_type === ModelProviders . GEMINI ) && models ?. map ( ( model , i ) => (
244+ < Select . Option key = { `${ model } -${ i } ` } value = { model } >
245+ { model }
246+ </ Select . Option >
247+ ) ) }
275248 </ Select >
276249 ) }
277250 </ Form . Item >
0 commit comments