Skip to content

Commit ceefe3b

Browse files
Merge pull request #79 from code-kern-ai/embedder-rework
Embedder rework
2 parents 89c5d40 + 35189dd commit ceefe3b

File tree

5 files changed

+9
-27
lines changed

5 files changed

+9
-27
lines changed

src/components/projects/projectId/settings/embeddings/AddNewEmbeddingModal.tsx

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ export default function AddNewEmbeddingModal() {
109109
function changePlatformOrGranularity() {
110110
prepareSuggestions();
111111
const savePlatform = platform.platform;
112-
if (savePlatform == PlatformType.COHERE || savePlatform == PlatformType.OPEN_AI || savePlatform == PlatformType.AZURE) {
112+
if (savePlatform == PlatformType.OPEN_AI || savePlatform == PlatformType.AZURE) {
113113
setGranularity(GRANULARITY_TYPES_ARRAY.find((g) => g.value == EmbeddingType.ON_ATTRIBUTE));
114114
if (savePlatform == PlatformType.AZURE) {
115115
const azureUrls = localStorage.getItem('azureUrls');
@@ -137,7 +137,7 @@ export default function AddNewEmbeddingModal() {
137137

138138
function checkIfPlatformHasToken() {
139139
if (!platform) return;
140-
if (platform.name == platformNamesDict[PlatformType.COHERE] || platform.name == platformNamesDict[PlatformType.OPEN_AI] || platform.name == platformNamesDict[PlatformType.AZURE]) {
140+
if (platform.name == platformNamesDict[PlatformType.OPEN_AI] || platform.name == platformNamesDict[PlatformType.AZURE]) {
141141
setGranularityArray(GRANULARITY_TYPES_ARRAY.filter((g) => g.value != EmbeddingType.ON_TOKEN));
142142
} else {
143143
setGranularityArray(GRANULARITY_TYPES_ARRAY);
@@ -179,13 +179,11 @@ export default function AddNewEmbeddingModal() {
179179
filterAttributes: filteredAttributes
180180
}
181181

182-
if (platform.name == platformNamesDict[PlatformType.HUGGING_FACE] || platform.name == platformNamesDict[PlatformType.PYTHON]) {
182+
if (platform.name == platformNamesDict[PlatformType.HUGGING_FACE]) {
183183
config.model = model;
184184
} else if (platform.name == platformNamesDict[PlatformType.OPEN_AI]) {
185185
config.model = model;
186186
config.apiToken = apiToken;
187-
} else if (platform.name == platformNamesDict[PlatformType.COHERE]) {
188-
config.apiToken = apiToken;
189187
} else if (platform.name == platformNamesDict[PlatformType.AZURE]) {
190188
config.model = engine; //note that is handled internally as model so we use the model field for the request
191189
config.apiToken = apiToken;
@@ -251,15 +249,6 @@ export default function AddNewEmbeddingModal() {
251249
<input placeholder="Enter your API token" onChange={(e) => setApiToken(e.target.value)} value={apiToken}
252250
className="h-9 w-full text-sm border-gray-300 rounded-md placeholder-italic border text-gray-900 pl-4 placeholder:text-gray-400 focus:outline-none focus:ring-2 focus:ring-gray-300 focus:ring-offset-2 focus:ring-offset-gray-100" />
253251
</>}
254-
{platform && platform.name == platformNamesDict[PlatformType.COHERE] && <>
255-
<Tooltip content={TOOLTIPS_DICT.PROJECT_SETTINGS.EMBEDDINGS.API_TOKEN} placement="right" color="invert">
256-
<span className="card-title mb-0 label-text flex"><span className="cursor-help underline filtersUnderline">API token</span></span>
257-
</Tooltip>
258-
<input placeholder="Enter your API token" onChange={(e) => setApiToken(e.target.value)} value={apiToken}
259-
className="h-9 w-full text-sm border-gray-300 rounded-md placeholder-italic border text-gray-900 pl-4 placeholder:text-gray-400 focus:outline-none focus:ring-2 focus:ring-gray-300 focus:ring-offset-2 focus:ring-offset-gray-100" />
260-
</>}
261-
{platform && platform.name == platformNamesDict[PlatformType.PYTHON] && <SuggestionsModel options={embeddingHandles[targetAttribute]} selectedOption={(option: string) => setModel(option)} />}
262-
263252
{platform && platform.name == platformNamesDict[PlatformType.AZURE] && <>
264253
<Tooltip content={TOOLTIPS_DICT.PROJECT_SETTINGS.EMBEDDINGS.API_TOKEN} placement="right" color="invert">
265254
<span className="card-title mb-0 label-text flex"><span className="cursor-help underline filtersUnderline">API token</span></span>
@@ -302,11 +291,10 @@ export default function AddNewEmbeddingModal() {
302291
</>}
303292
</>}
304293
</div>
305-
{platform && (platform.name == platformNamesDict[PlatformType.COHERE] || platform.name == platformNamesDict[PlatformType.OPEN_AI] || platform.name == platformNamesDict[PlatformType.AZURE]) && <div className="text-center mt-3">
294+
{platform && (platform.name == platformNamesDict[PlatformType.OPEN_AI] || platform.name == platformNamesDict[PlatformType.AZURE]) && <div className="text-center mt-3">
306295
<div className="border border-gray-300 text-xs text-gray-500 p-2.5 rounded-lg text-justify">
307296
<label ref={gdprText} className="text-gray-700">
308297
{selectedPlatform.splitTerms[0]}
309-
{platform.name == platformNamesDict[PlatformType.COHERE] && <a href={selectedPlatform.link} target="_blank" className="underline">cohere terms of service.</a>}
310298
{platform.name == platformNamesDict[PlatformType.OPEN_AI] && <a href={selectedPlatform.link} target="_blank" className="underline">openai terms of service.</a>}
311299
{platform.name == platformNamesDict[PlatformType.AZURE] && <a href={selectedPlatform.link} target="_blank" className="underline">azure terms of service.</a>}
312300
<div>{selectedPlatform.splitTerms[1]}</div>

src/components/shared/upload/helper-components/UploadWrapper.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ export default function UploadWrapper(props: UploadWrapperProps) {
1919
const uploadFileType = useSelector(selectUploadData).uploadFileType;
2020
const importOptions = useSelector(selectUploadData).importOptions;
2121
const embeddings = useSelector(selectEmbeddings);
22-
const recalculationCosts = embeddings.some((e: Embedding) => e.platform == PlatformType.COHERE || e.platform == PlatformType.OPEN_AI || e.platform == PlatformType.AZURE);
22+
const recalculationCosts = embeddings.some((e: Embedding) => e.platform == PlatformType.OPEN_AI || e.platform == PlatformType.AZURE);
2323

2424
const [selectedFile, setSelectedFile] = useState<File | null>(null);
2525
const [fileEndsWithZip, setFileEndsWithZip] = useState<boolean>(false);

src/types/components/projects/projectId/settings/embeddings.ts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ export type EmbeddingPlatform = {
4646
export enum PlatformType {
4747
HUGGING_FACE = "huggingface",
4848
OPEN_AI = "openai",
49-
COHERE = "cohere",
50-
PYTHON = "python",
5149
AZURE = "azure"
5250
}
5351

src/util/components/projects/projectId/settings/embeddings-helper.ts

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ export function postProcessingEmbeddingPlatforms(platforms: EmbeddingPlatform[],
4646
export const platformNamesDict = {
4747
[PlatformType.HUGGING_FACE]: "Hugging Face",
4848
[PlatformType.OPEN_AI]: "OpenAI",
49-
[PlatformType.COHERE]: "Cohere",
50-
[PlatformType.PYTHON]: "Python",
5149
[PlatformType.AZURE]: "Azure"
5250
}
5351

@@ -73,9 +71,9 @@ function buildExpectedEmbeddingName(data: any): string {
7371
let toReturn = data.targetAttribute.name;
7472
toReturn += "-" + (data.granularity.value == EmbeddingType.ON_ATTRIBUTE ? 'classification' : 'extraction');
7573
const platform = data.platform;
76-
if (platform == PlatformType.HUGGING_FACE || platform == PlatformType.PYTHON) {
74+
if (platform == PlatformType.HUGGING_FACE) {
7775
toReturn += "-" + platform + "-" + data.model;
78-
} else if (platform == PlatformType.OPEN_AI || platform == PlatformType.COHERE || platform == PlatformType.AZURE) {
76+
} else if (platform == PlatformType.OPEN_AI || platform == PlatformType.AZURE) {
7977
toReturn += buildEmbeddingNameWithApiToken(data);
8078
}
8179
return toReturn;
@@ -106,12 +104,10 @@ export function checkIfCreateEmbeddingIsDisabled(props: EmbeddingCreationEnabled
106104
const engine = props.engine;
107105
const version = props.version;
108106
const url = props.url;
109-
if (platform.name == platformNamesDict[PlatformType.HUGGING_FACE] || platform.name == platformNamesDict[PlatformType.PYTHON]) {
107+
if (platform.name == platformNamesDict[PlatformType.HUGGING_FACE]) {
110108
checkFormFields = model == null || model == "";
111109
} else if (platform.name == platformNamesDict[PlatformType.OPEN_AI]) {
112110
checkFormFields = model == null || apiToken == null || apiToken == "" || !termsAccepted;
113-
} else if (platform.name == platformNamesDict[PlatformType.COHERE]) {
114-
checkFormFields = apiToken == null || apiToken == "" || !termsAccepted;
115111
} else if (platform.name == platformNamesDict[PlatformType.AZURE]) {
116112
checkFormFields = apiToken == null || apiToken == "" || url == null || url == "" || version == null || version == "" || !termsAccepted || !engine;
117113
}

src/util/components/projects/projectId/settings/project-export-helper.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ export function postProcessingFormGroups(projectSize: any, embeddings: Embedding
2222
projectSize.forEach((element: any) => {
2323
let hasGdpr = false;
2424
if (element.table == ProjectExportGroup.EMBEDDING_TENSORS) {
25-
hasGdpr = embeddings.some((e: any) => e.name.split("-")[2] == PlatformType.COHERE || e.name.split("-")[2] == PlatformType.OPEN_AI || e.name.split("-")[2] == PlatformType.AZURE);
25+
hasGdpr = embeddings.some((e: any) => e.name.split("-")[2] == PlatformType.OPEN_AI || e.name.split("-")[2] == PlatformType.AZURE);
2626
}
2727
let group = {
2828
export: element.default,

0 commit comments

Comments
 (0)