@@ -2,17 +2,23 @@ import endsWith from 'lodash/endsWith';
22import isEmpty from 'lodash/isEmpty' ;
33import isFunction from 'lodash/isFunction' ;
44import { FunctionComponent , useEffect , useState } from 'react' ;
5- import { Flex , Form , FormInstance , Input , Select , Typography } from 'antd' ;
5+ import { Flex , Form , Input , Select , Typography } from 'antd' ;
66import styled from 'styled-components' ;
77import { File , WorkflowType } from './types' ;
88import { useFetchModels } from '../../api/api' ;
99import { MODEL_PROVIDER_LABELS } from './constants' ;
1010import { ModelProviders , ModelProvidersDropdownOpts } from './types' ;
11- import { getWizardModel , getWizardModeType , useWizardCtx } from './utils' ;
11+ import { getWizardModeType , useWizardCtx } from './utils' ;
1212import FileSelectorButton from './FileSelectorButton' ;
1313import UseCaseSelector from './UseCaseSelector' ;
1414import { useLocation , useParams } from 'react-router-dom' ;
1515import { WizardModeType } from '../../types' ;
16+ import get from 'lodash/get' ;
17+ import forEach from 'lodash/forEach' ;
18+ import { useModelProviders } from '../Settings/hooks' ;
19+ import { ModelProviderType } from '../Settings/AddModelProviderButton' ;
20+ import { CustomModel } from '../Settings/SettingsPage' ;
21+ import filter from 'lodash/filter' ;
1622
1723
1824const StepContainer = styled ( Flex ) `
@@ -47,14 +53,21 @@ export const WORKFLOW_OPTIONS = [
4753export const MODEL_TYPE_OPTIONS : ModelProvidersDropdownOpts = [
4854 { label : MODEL_PROVIDER_LABELS [ ModelProviders . BEDROCK ] , value : ModelProviders . BEDROCK } ,
4955 { label : MODEL_PROVIDER_LABELS [ ModelProviders . CAII ] , value : ModelProviders . CAII } ,
56+ { label : MODEL_PROVIDER_LABELS [ ModelProviders . OPENAI ] , value : ModelProviders . OPENAI } ,
57+ { label : MODEL_PROVIDER_LABELS [ ModelProviders . GEMINI ] , value : ModelProviders . GEMINI } ,
5058] ;
5159
5260const Configure : FunctionComponent = ( ) => {
5361 const form = Form . useFormInstance ( ) ;
5462 const formData = Form . useWatch ( ( values ) => values , form ) ;
5563 const location = useLocation ( ) ;
5664 const { template_name, generate_file_name } = useParams ( ) ;
65+ const [ models , setModels ] = useState < string [ ] > ( [ ] )
5766 const [ wizardModeType , setWizardModeType ] = useState ( getWizardModeType ( location ) ) ;
67+ const { data } = useFetchModels ( ) ;
68+ const customModelPrividersReq = useModelProviders ( ) ;
69+ const customModels = get ( customModelPrividersReq , 'data.endpoints' , [ ] ) ;
70+ console . log ( 'customModels' , customModels ) ;
5871
5972 useEffect ( ( ) => {
6073 if ( wizardModeType === WizardModeType . DATA_AUGMENTATION ) {
@@ -77,10 +90,19 @@ const Configure: FunctionComponent = () => {
7790 }
7891 } , [ template_name ] ) ;
7992
93+ useEffect ( ( ) => {
94+ // set model providers
95+ // set model ids
96+ console . log ( 'useEffect' , formData ) ;
97+ if ( formData && formData ?. inference_type === ModelProviderType . OPENAI && isEmpty ( generate_file_name ) ) {
98+ form . setFieldValue ( 'inference_type' , ModelProviders . OPENAI ) ;
99+ }
100+
101+ } , [ customModels , formData ] ) ;
102+
80103
81104 // let formData = Form.useWatch((values) => values, form);
82105 const { setIsStepValid } = useWizardCtx ( ) ;
83- const { data } = useFetchModels ( ) ;
84106 const [ selectedFiles , setSelectedFiles ] = useState (
85107 ! isEmpty ( form . getFieldValue ( 'doc_paths' ) ) ? form . getFieldValue ( 'doc_paths' ) : [ ] ) ;
86108
@@ -104,7 +126,6 @@ const Configure: FunctionComponent = () => {
104126
105127
106128 useEffect ( ( ) => {
107- console . log ( 'useEffect 1' ) ;
108129 if ( formData && formData ?. inference_type === undefined && isEmpty ( generate_file_name ) ) {
109130 form . setFieldValue ( 'inference_type' , ModelProviders . CAII ) ;
110131 setTimeout ( ( ) => {
@@ -155,6 +176,20 @@ const Configure: FunctionComponent = () => {
155176 }
156177 }
157178
179+ const onModelProviderChange = ( value : string ) => {
180+ form . setFieldValue ( 'model_id' , undefined )
181+ console . log ( 'value' , value ) ;
182+ if ( ModelProviderType . OPENAI === value ) {
183+ const _models = filter ( customModels , ( model : CustomModel ) => model . provider_type === ModelProviderType . OPENAI ) ;
184+ setModels ( _models . map ( ( _model : CustomModel ) => _model . model_id ) ) ;
185+ } else if ( ModelProviderType . GEMINIE === value ) {
186+ const _models = filter ( customModels , ( model : CustomModel ) => model . provider_type === ModelProviderType . GEMINIE ) ;
187+ setModels ( _models . map ( ( _model : CustomModel ) => _model . model_id ) ) ;
188+ }
189+ }
190+ console . log ( 'models' , models ) ;
191+
192+
158193 return (
159194 < StepContainer justify = 'center' >
160195 < FormContainer vertical >
@@ -178,7 +213,7 @@ const Configure: FunctionComponent = () => {
178213 >
179214 < Select
180215
181- onChange = { ( ) => form . setFieldValue ( 'model_id' , undefined ) }
216+ onChange = { ( value : string ) => onModelProviderChange ( value ) }
182217 placeholder = { 'Select a model provider' }
183218 >
184219 { MODEL_TYPE_OPTIONS . map ( ( { label, value } , i ) =>
@@ -200,15 +235,22 @@ const Configure: FunctionComponent = () => {
200235 { formData ?. inference_type === ModelProviders . CAII ? (
201236 < Input placeholder = { 'Enter Cloudera AI Inference Model ID' } />
202237 ) : (
203- < Select placeholder = { 'Select a Model' } notFoundContent = { 'You must select a Model Provider before selecting a Model' } >
204- { ! isEmpty ( data ?. models ) && data ?. models [ ModelProviders . BEDROCK ] ?. map ( ( model , i ) =>
238+ < Select
239+ placeholder = { 'Select a Model' }
240+ notFoundContent = { 'You must select a Model Provider before selecting a Model' }
241+ >
242+ { formData ?. inference_type === ModelProviders . BEDROCK && data ?. models ?. [ ModelProviders . BEDROCK ] ?. map ( ( model , i ) => (
205243 < Select . Option key = { `${ model } -${ i } ` } value = { model } >
206244 { model }
207245 </ Select . Option >
208- ) }
246+ ) ) }
247+ { ( formData ?. inference_type === ModelProviders . OPENAI || formData ?. inference_type === ModelProviders . GEMINI ) && models ?. map ( ( model , i ) => (
248+ < Select . Option key = { `${ model } -${ i } ` } value = { model } >
249+ { model }
250+ </ Select . Option >
251+ ) ) }
209252 </ Select >
210253 ) }
211-
212254 </ Form . Item >
213255 { formData ?. inference_type === ModelProviders . CAII && (
214256 < >
0 commit comments