Skip to content

Commit d51775a

Browse files
author
Keivan Vosoughi
committed
Change Configure Model ID
1 parent 7b3e133 commit d51775a

File tree

6 files changed

+44
-75
lines changed

6 files changed

+44
-75
lines changed

app/client/src/pages/DataGenerator/Configure.tsx

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import forEach from 'lodash/forEach';
1818
import { useModelProviders } from '../Settings/hooks';
1919
import { ModelProviderType } from '../Settings/AddModelProviderButton';
2020
import { CustomModel } from '../Settings/SettingsPage';
21+
import filter from 'lodash/filter';
2122

2223

2324
const StepContainer = styled(Flex)`
@@ -52,14 +53,21 @@ export const WORKFLOW_OPTIONS = [
5253
export const MODEL_TYPE_OPTIONS: ModelProvidersDropdownOpts = [
5354
{ label: MODEL_PROVIDER_LABELS[ModelProviders.BEDROCK], value: ModelProviders.BEDROCK},
5455
{ label: MODEL_PROVIDER_LABELS[ModelProviders.CAII], value: ModelProviders.CAII },
56+
{ label: MODEL_PROVIDER_LABELS[ModelProviders.OPENAI], value: ModelProviders.OPENAI },
57+
{ label: MODEL_PROVIDER_LABELS[ModelProviders.GEMINI], value: ModelProviders.GEMINI },
5558
];
5659

5760
const Configure: FunctionComponent = () => {
5861
const form = Form.useFormInstance();
5962
const formData = Form.useWatch((values) => values, form);
6063
const location = useLocation();
6164
const { template_name, generate_file_name } = useParams();
65+
const [models, setModels] = useState<string[]>([])
6266
const [wizardModeType, setWizardModeType] = useState(getWizardModeType(location));
67+
const { data } = useFetchModels();
68+
const customModelPrividersReq = useModelProviders();
69+
const customModels = get(customModelPrividersReq, 'data.endpoints', []);
70+
console.log('customModels', customModels);
6371

6472
useEffect(() => {
6573
if (wizardModeType === WizardModeType.DATA_AUGMENTATION) {
@@ -82,11 +90,16 @@ const Configure: FunctionComponent = () => {
8290
}
8391
}, [template_name]);
8492

93+
useEffect(() => {
94+
// set model providers
95+
// set model ids
96+
console.log('useEffect', formData);
97+
98+
}, [customModels, formData]);
99+
85100

86101
// let formData = Form.useWatch((values) => values, form);
87102
const { setIsStepValid } = useWizardCtx();
88-
const { data } = useFetchModels();
89-
const customModelPrividersReq = useModelProviders();
90103
const [selectedFiles, setSelectedFiles] = useState(
91104
!isEmpty(form.getFieldValue('doc_paths')) ? form.getFieldValue('doc_paths') : []);
92105

@@ -110,7 +123,6 @@ const Configure: FunctionComponent = () => {
110123

111124

112125
useEffect(() => {
113-
console.log('useEffect 1');
114126
if (formData && formData?.inference_type === undefined && isEmpty(generate_file_name)) {
115127
form.setFieldValue('inference_type', ModelProviders.CAII);
116128
setTimeout(() => {
@@ -161,42 +173,18 @@ const Configure: FunctionComponent = () => {
161173
}
162174
}
163175

164-
// const getModelsGroupOptions = (models: any) => {
165-
// if (isEmpty(models)) {
166-
// return [];
167-
// }
168-
// const options = [];
169-
// const modelTypes = Object.keys(models);
170-
// console.log('modelTypes', modelTypes);
171-
// forEach(modelTypes, (modelType: string) => {
172-
// const models = get(modelTypes, modelType);
173-
// console.log('models', models);
174-
// if (!isEmpty(models)) {
175-
// const children = models.map((model: string) => ({
176-
// label: <span>${model}</span>,
177-
// value: model
178-
// }))
179-
// const groupOption = {
180-
// label: <span>${modelType}</span>,
181-
// title: modelType,
182-
// options: children,
183-
// };
184-
// options.push(groupOption);
185-
// }
186-
// });
187-
// }
188-
189-
console.log('data?.models', data?.models);
190-
const customModels = get(customModelPrividersReq, 'data.endpoints', []);
191-
const customModelIds: string[] = [];
192-
forEach(customModels, (model: CustomModel) => {
193-
if (model.provider_type === ModelProviderType.GEMINIE ||
194-
model.provider_type === ModelProviderType.OPENAI
195-
) {
196-
customModelIds.push(model.model_id)
176+
const onModelProviderChange = (value: string) => {
177+
form.setFieldValue('model_id', undefined)
178+
console.log('value', value);
179+
if (ModelProviderType.OPENAI === value) {
180+
const _models = filter(customModels, (model: CustomModel) => model.provider_type === ModelProviderType.OPENAI);
181+
setModels(_models);
182+
} else if (ModelProviderType.GEMINIE === value) {
183+
const _models = filter(customModels, (model: CustomModel) => model.provider_type === ModelProviderType.GEMINIE);
184+
setModels(_models);
197185
}
198-
});
199-
console.log('customModelIds', customModelIds);
186+
}
187+
200188

201189
return (
202190
<StepContainer justify='center'>
@@ -221,7 +209,7 @@ const Configure: FunctionComponent = () => {
221209
>
222210
<Select
223211

224-
onChange={() => form.setFieldValue('model_id', undefined)}
212+
onChange={(value: string) => onModelProviderChange(value)}
225213
placeholder={'Select a model provider'}
226214
>
227215
{MODEL_TYPE_OPTIONS.map(({ label, value }, i) =>
@@ -247,31 +235,16 @@ const Configure: FunctionComponent = () => {
247235
placeholder={'Select a Model'}
248236
notFoundContent={'You must select a Model Provider before selecting a Model'}
249237
>
250-
{!isEmpty(data?.models) && !isEmpty(data?.models?.[ModelProviders.BEDROCK]) && (
251-
<Select.OptGroup label="AWS-Bedrock">
252-
{data?.models?.[ModelProviders.BEDROCK]?.map((model, i) => (
253-
<Select.Option key={`${model}-${i}`} value={model}>
254-
{model}
255-
</Select.Option>
256-
))}
257-
</Select.OptGroup>
258-
)}
259-
260-
{/* Add custom model providers here as needed */}
261-
{!isEmpty(customModelIds) && (
262-
<Select.OptGroup label="Custom">
263-
{customModelIds.map((model_id, i) => {
264-
console.log('model', model_id, i)
265-
return (
266-
<Select.Option key={`${model_id}-${i}`} value={model_id}>
267-
{model_id}
268-
</Select.Option>
269-
);
270-
271-
272-
})}
273-
</Select.OptGroup>
274-
)}
238+
{formData?.inference_type === ModelProviders.BEDROCK && data?.models?.[ModelProviders.BEDROCK]?.map((model, i) => (
239+
<Select.Option key={`${model}-${i}`} value={model}>
240+
{model}
241+
</Select.Option>
242+
))}
243+
{(formData?.inference_type === ModelProviders.OPENAI || formData?.inference_type === ModelProviders.GEMINI) && models?.map((model, i) => (
244+
<Select.Option key={`${model}-${i}`} value={model}>
245+
{model}
246+
</Select.Option>
247+
))}
275248
</Select>
276249
)}
277250
</Form.Item>

app/client/src/pages/DataGenerator/constants.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ export const MODEL_PROVIDER_LABELS = {
55
[ModelProviders.CAII]: 'Cloudera AI Inference Service',
66
[ModelProviders.GOOGLE_GEMINI]: 'Google Gemini',
77
[ModelProviders.AZURE_OPENAI]: 'Azure OpenAI',
8+
[ModelProviders.GEMINI]: 'Gemini',
9+
[ModelProviders.OPENAI]: 'OpenAI'
810
};
911

1012
export const MIN_SEED_INSTRUCTIONS = 1

app/client/src/pages/DataGenerator/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ export enum ModelProviders {
1919
CAII = 'CAII',
2020
AZURE_OPENAI = 'AZURE_OPENAI',
2121
GOOGLE_GEMINI = 'GOOGLE_GEMINI',
22+
OPENAI = 'openai',
23+
GEMINI = 'gemini',
2224
}
2325

2426
export type ModelProvidersDropdownOpts = { label: string, value: ModelProviders }[];

app/client/src/pages/Settings/AddModelProviderButton.tsx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ const AddModelProviderButton: React.FC<Props> = ({ refetch }) => {
9696
try {
9797
await form.validateFields();
9898
const values = form.getFieldsValue();
99-
console.log('values', values);
99+
100100
mutation.mutate({
101101
endpoint_config: {
102102
display_name: values.display_name,
@@ -118,9 +118,7 @@ const AddModelProviderButton: React.FC<Props> = ({ refetch }) => {
118118
};
119119

120120
const onChange = (e: any) => {
121-
console.log('onChange', e);
122121
const value = get(e, 'target.value');
123-
console.log('value:', value);
124122
if (value === 'openai' && !isEqual(OPENAI_MODELS_OPTIONS, models)) {
125123
setModels(OPENAI_MODELS_OPTIONS);
126124
} else if (value === 'gemini' && !isEqual(GEMINI_MODELS_OPTIONS, models)) {

app/client/src/pages/Settings/EditModelProvider.tsx

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,13 @@ interface Props {
6767
const EditModelProvider: React.FC<Props> = ({ model, refetch, onClose }) => {
6868
const [form] = Form.useForm();
6969
const modelProviderReq = useGetModelProvider(model.endpoint_id);
70-
console.log('modelProviderReq', modelProviderReq);
7170
const [models, setModels] = useState(OPENAI_MODELS_OPTIONS);
7271
const mutation = useMutation({
7372
mutationFn: addModelProvider
7473
});
7574

7675
useEffect(() => {
7776
if (!isEmpty(modelProviderReq.data)) {
78-
console.log('-------->', modelProviderReq.data);
7977
const endpoint = get(modelProviderReq, 'data.endpoint');
8078
form.setFieldsValue({
8179
...endpoint
@@ -109,7 +107,7 @@ const EditModelProvider: React.FC<Props> = ({ model, refetch, onClose }) => {
109107
try {
110108
await form.validateFields();
111109
const values = form.getFieldsValue();
112-
console.log('values', values);
110+
113111
mutation.mutate({
114112
endpoint_config: {
115113
display_name: values.display_name,
@@ -131,9 +129,7 @@ const EditModelProvider: React.FC<Props> = ({ model, refetch, onClose }) => {
131129
};
132130

133131
const onChange = (e: any) => {
134-
console.log('onChange', e);
135132
const value = get(e, 'target.value');
136-
console.log('value:', value);
137133
if (value === 'openai' && !isEqual(OPENAI_MODELS_OPTIONS, models)) {
138134
setModels(OPENAI_MODELS_OPTIONS);
139135
} else if (value === 'gemini' && !isEqual(GEMINI_MODELS_OPTIONS, models)) {

app/client/src/pages/Settings/SettingsPage.tsx

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ const StyledButton = styled(Button)`
7979
const SettingsPage: React.FC = () => {
8080
const [showModal, setShowModal] = useState(false);
8181
const [model, setModel] = useState<CustomModel | null>(null);
82-
const filteredModelsReq = useModelProviders();
83-
console.log('filteredModelsReq', filteredModelsReq);
82+
const filteredModelsReq = useModelProviders();
8483
const customModels = get(filteredModelsReq, 'data.endpoints', []);
8584

8685
const mutation = useMutation({
@@ -164,7 +163,6 @@ const SettingsPage: React.FC = () => {
164163
title: 'Actions',
165164
width: 100,
166165
render: (model: CustomModel) => {
167-
console.log('model', model);
168166
return (
169167
<Flex>
170168
<Tooltip title="Edit">

0 commit comments

Comments
 (0)